diff --git a/apps/blockchain-node/src/aitbc_chain/app.py b/apps/blockchain-node/src/aitbc_chain/app.py index 50f885e7..d0fe6f9c 100755 --- a/apps/blockchain-node/src/aitbc_chain/app.py +++ b/apps/blockchain-node/src/aitbc_chain/app.py @@ -93,7 +93,7 @@ async def lifespan(app: FastAPI): init_db() init_mempool( backend=settings.mempool_backend, - db_path=str(settings.db_path.parent / "mempool.db"), + db_url=settings.mempool_db_url, max_size=settings.mempool_max_size, min_fee=settings.min_fee, ) diff --git a/apps/blockchain-node/src/aitbc_chain/config.py b/apps/blockchain-node/src/aitbc_chain/config.py index ec2173b8..852e9252 100755 --- a/apps/blockchain-node/src/aitbc_chain/config.py +++ b/apps/blockchain-node/src/aitbc_chain/config.py @@ -81,6 +81,7 @@ class ChainSettings(BaseSettings): # Mempool settings mempool_backend: str = "database" # "database" or "memory" (database recommended for persistence) + mempool_db_url: str = "postgresql+psycopg://aitbc_mempool:password@localhost:5432/aitbc_mempool" # PostgreSQL URL for mempool (can be overridden by MEMPOOL_DB_URL env var) mempool_max_size: int = 10_000 mempool_eviction_interval: int = 60 # seconds diff --git a/apps/blockchain-node/src/aitbc_chain/database.py b/apps/blockchain-node/src/aitbc_chain/database.py index e06682dd..89e0c621 100755 --- a/apps/blockchain-node/src/aitbc_chain/database.py +++ b/apps/blockchain-node/src/aitbc_chain/database.py @@ -72,9 +72,16 @@ def get_engine(chain_id: str = "") -> object: @event.listens_for(engine, "connect") def set_encryption_key(dbapi_connection, connection_record): dbapi_connection.execute(f"PRAGMA key = '{key_hex}'") + dbapi_connection.execute("PRAGMA journal_mode=WAL") + dbapi_connection.execute("PRAGMA synchronous=NORMAL") else: # Use standard SQLite engine = create_engine(f"sqlite:///{db_path}", echo=False) + + @event.listens_for(engine, "connect") + def set_wal_mode(dbapi_connection, connection_record): + dbapi_connection.execute("PRAGMA journal_mode=WAL") + dbapi_connection.execute("PRAGMA synchronous=NORMAL") _engines[resolved_chain_id] = engine @@ -87,8 +94,7 @@ _engine = create_engine(f"sqlite:///{settings.db_path}", echo=False) @event.listens_for(_engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() - # WAL mode disabled due to issues on btrfs raid (CoW already disabled on data directory) - # cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA journal_mode=WAL") cursor.execute("PRAGMA synchronous=NORMAL") cursor.execute("PRAGMA cache_size=-64000") cursor.execute("PRAGMA temp_store=MEMORY") diff --git a/apps/blockchain-node/src/aitbc_chain/main.py b/apps/blockchain-node/src/aitbc_chain/main.py index fb7b145f..fe42f479 100755 --- a/apps/blockchain-node/src/aitbc_chain/main.py +++ b/apps/blockchain-node/src/aitbc_chain/main.py @@ -96,12 +96,46 @@ class BlockchainNode: self._stop_event = asyncio.Event() self._proposers: dict[str, PoAProposer] = {} + @staticmethod + def _env_value(*names: str) -> Optional[str]: + for name in names: + value = os.getenv(name) + if value is not None: + return value + return None + + def _block_production_enabled(self) -> bool: + override = self._env_value("AITBC_FORCE_ENABLE_BLOCK_PRODUCTION", "ENABLE_BLOCK_PRODUCTION", "enable_block_production") + if override is not None: + return override.strip().lower() in {"1", "true", "yes", "on"} + return bool(getattr(settings, "enable_block_production", True)) + + def _supported_chains(self) -> list[str]: + chains_str = getattr(settings, 'supported_chains', settings.chain_id) + chains = [c.strip() for c in chains_str.split(",") if c.strip()] + if not chains and settings.chain_id: + chains = [settings.chain_id] + return chains + + def _proposer_config(self, chain_id: str) -> ProposerConfig: + return ProposerConfig( + chain_id=chain_id, + proposer_id=settings.proposer_id, + interval_seconds=settings.block_time_seconds, + max_block_size_bytes=settings.max_block_size_bytes, + max_txs_per_block=settings.max_txs_per_block, + ) + + async def _ensure_genesis_for_chains(self) -> None: + for chain_id in self._supported_chains(): + proposer = PoAProposer(config=self._proposer_config(chain_id), session_factory=session_scope) + await proposer._ensure_genesis_block() + async def _setup_gossip_subscribers(self) -> None: logger.info("Setting up gossip subscribers") # Parse supported chains - chains_str = getattr(settings, 'supported_chains', settings.chain_id) - chains = [c.strip() for c in chains_str.split(",") if c.strip()] + chains = self._supported_chains() # Transactions (single topic for all chains) try: @@ -213,8 +247,7 @@ class BlockchainNode: logger.info("Gossip backend initialized successfully") # Parse supported chains - chains_str = getattr(settings, 'supported_chains', settings.chain_id) - chains = [c.strip() for c in chains_str.split(",") if c.strip()] + chains = self._supported_chains() logger.info(f"Initializing databases for chains: {chains}") # Initialize database for each supported chain @@ -224,12 +257,13 @@ class BlockchainNode: init_mempool( backend=settings.mempool_backend, - db_path=str(settings.db_path.parent / "mempool.db"), + db_url=settings.mempool_db_url, max_size=settings.mempool_max_size, min_fee=settings.min_fee, ) + await self._ensure_genesis_for_chains() # Start proposers only if enabled (followers set enable_block_production=False) - if getattr(settings, "enable_block_production", True): + if self._block_production_enabled(): self._start_proposers() else: logger.info("Block production disabled on this node", extra={"proposer_id": settings.proposer_id}) @@ -245,11 +279,12 @@ class BlockchainNode: await self._shutdown() def _start_proposers(self) -> None: - chains_str = getattr(settings, 'supported_chains', settings.chain_id) - chains = [c.strip() for c in chains_str.split(",") if c.strip()] + chains = self._supported_chains() # Get chains that should produce blocks (if specified, otherwise all supported chains) - production_chains_str = getattr(settings, 'block_production_chains', chains_str) + production_chains_str = self._env_value("AITBC_FORCE_BLOCK_PRODUCTION_CHAINS", "BLOCK_PRODUCTION_CHAINS", "block_production_chains") + if production_chains_str is None: + production_chains_str = getattr(settings, 'block_production_chains', ",".join(chains)) production_chains = [c.strip() for c in production_chains_str.split(",") if c.strip()] for chain_id in chains: @@ -261,15 +296,7 @@ class BlockchainNode: if chain_id in self._proposers: continue - proposer_config = ProposerConfig( - chain_id=chain_id, - proposer_id=settings.proposer_id, - interval_seconds=settings.block_time_seconds, - max_block_size_bytes=settings.max_block_size_bytes, - max_txs_per_block=settings.max_txs_per_block, - ) - - proposer = PoAProposer(config=proposer_config, session_factory=session_scope) + proposer = PoAProposer(config=self._proposer_config(chain_id), session_factory=session_scope) self._proposers[chain_id] = proposer asyncio.create_task(proposer.start()) diff --git a/apps/blockchain-node/src/aitbc_chain/mempool.py b/apps/blockchain-node/src/aitbc_chain/mempool.py index 61089653..063a10ba 100755 --- a/apps/blockchain-node/src/aitbc_chain/mempool.py +++ b/apps/blockchain-node/src/aitbc_chain/mempool.py @@ -7,9 +7,31 @@ from dataclasses import dataclass, field from threading import Lock from typing import Any, Dict, List, Optional +from sqlmodel import Session, SQLModel, create_engine, select, Field, text +from sqlalchemy import Column, String, Integer, Float, Text, Index, MetaData, Table + from .metrics import metrics_registry +mempool_metadata = MetaData() + + +class MempoolEntry(SQLModel, table=True): + __tablename__ = "mempool" + __table_args__ = {"metadata": mempool_metadata} + + chain_id: str = Field(primary_key=True) + tx_hash: str = Field(primary_key=True) + content: str = Field(sa_column=Column(Text, nullable=False)) + fee: int = Field(default=0, sa_column=Column(Integer, nullable=False)) + size_bytes: int = Field(default=0, sa_column=Column(Integer, nullable=False)) + received_at: float = Field(sa_column=Column(Float, nullable=False)) + + __table_args__ = ( + Index('idx_mempool_fee', 'fee', postgresql_ops={'fee': 'DESC'}), + ) + + @dataclass(frozen=True) class PendingTransaction: tx_hash: str @@ -150,32 +172,33 @@ class InMemoryMempool: class DatabaseMempool: - """SQLite-backed mempool for persistence and cross-service sharing.""" + """PostgreSQL-backed mempool for persistence and cross-service sharing.""" - def __init__(self, db_path: str, max_size: int = 10_000, min_fee: int = 0) -> None: - import sqlite3 - self._db_path = db_path + def __init__(self, db_url: str, max_size: int = 10_000, min_fee: int = 0) -> None: + self._db_url = db_url self._max_size = max_size self._min_fee = min_fee - self._conn = sqlite3.connect(db_path, check_same_thread=False) + self._engine = create_engine(db_url, echo=False, pool_pre_ping=True) self._lock = Lock() self._init_table() def _init_table(self) -> None: with self._lock: - self._conn.execute(""" - CREATE TABLE IF NOT EXISTS mempool ( - chain_id TEXT NOT NULL, - tx_hash TEXT NOT NULL, - content TEXT NOT NULL, - fee INTEGER DEFAULT 0, - size_bytes INTEGER DEFAULT 0, - received_at REAL NOT NULL, - PRIMARY KEY (chain_id, tx_hash) - ) - """) - self._conn.execute("CREATE INDEX IF NOT EXISTS idx_mempool_fee ON mempool(fee DESC)") - self._conn.commit() + with Session(self._engine) as session: + # Create table manually using raw SQL to avoid chain table conflicts + session.exec(text(""" + CREATE TABLE IF NOT EXISTS mempool ( + chain_id TEXT NOT NULL, + tx_hash TEXT NOT NULL, + content TEXT NOT NULL, + fee INTEGER DEFAULT 0, + size_bytes INTEGER DEFAULT 0, + received_at REAL NOT NULL, + PRIMARY KEY (chain_id, tx_hash) + ) + """)) + session.exec(text("CREATE INDEX IF NOT EXISTS idx_mempool_fee ON mempool(fee DESC)")) + session.commit() def add(self, tx: Dict[str, Any], chain_id: str = None) -> str: from .config import settings @@ -190,27 +213,42 @@ class DatabaseMempool: size_bytes = len(content.encode()) with self._lock: - # Check duplicate - row = self._conn.execute("SELECT 1 FROM mempool WHERE chain_id = ? AND tx_hash = ?", (chain_id, tx_hash)).fetchone() - if row: - return tx_hash - - # Evict if full - count = self._conn.execute("SELECT COUNT(*) FROM mempool WHERE chain_id = ?", (chain_id,)).fetchone()[0] - if count >= self._max_size: - self._conn.execute(""" - DELETE FROM mempool WHERE chain_id = ? AND tx_hash = ( - SELECT tx_hash FROM mempool WHERE chain_id = ? ORDER BY fee ASC, received_at DESC LIMIT 1 + with Session(self._engine) as session: + # Check duplicate + existing = session.exec( + select(MempoolEntry).where( + MempoolEntry.chain_id == chain_id, + MempoolEntry.tx_hash == tx_hash ) - """, (chain_id, chain_id)) - metrics_registry.increment(f"mempool_evictions_total_{chain_id}") + ).first() + if existing: + return tx_hash - self._conn.execute( - "INSERT INTO mempool (chain_id, tx_hash, content, fee, size_bytes, received_at) VALUES (?, ?, ?, ?, ?, ?)", - (chain_id, tx_hash, content, fee, size_bytes, time.time()) - ) - self._conn.commit() - metrics_registry.increment(f"mempool_tx_added_total_{chain_id}") + # Evict if full + count = session.exec( + select(MempoolEntry).where(MempoolEntry.chain_id == chain_id) + ).count() + if count >= self._max_size: + to_evict = session.exec( + select(MempoolEntry).where(MempoolEntry.chain_id == chain_id) + .order_by(MempoolEntry.fee.asc(), MempoolEntry.received_at.desc()) + .limit(1) + ).first() + if to_evict: + session.delete(to_evict) + metrics_registry.increment(f"mempool_evictions_total_{chain_id}") + + entry = MempoolEntry( + chain_id=chain_id, + tx_hash=tx_hash, + content=content, + fee=fee, + size_bytes=size_bytes, + received_at=time.time() + ) + session.add(entry) + session.commit() + metrics_registry.increment(f"mempool_tx_added_total_{chain_id}") self._update_gauge(chain_id) return tx_hash @@ -219,15 +257,16 @@ class DatabaseMempool: if chain_id is None: chain_id = settings.chain_id with self._lock: - rows = self._conn.execute( - "SELECT tx_hash, content, fee, size_bytes, received_at FROM mempool WHERE chain_id = ? ORDER BY fee DESC, received_at ASC", - (chain_id,) - ).fetchall() + with Session(self._engine) as session: + entries = session.exec( + select(MempoolEntry).where(MempoolEntry.chain_id == chain_id) + .order_by(MempoolEntry.fee.desc(), MempoolEntry.received_at.asc()) + ).all() return [ PendingTransaction( - tx_hash=r[0], content=json.loads(r[1]), - fee=r[2], size_bytes=r[3], received_at=r[4] - ) for r in rows + tx_hash=e.tx_hash, content=json.loads(e.content), + fee=e.fee, size_bytes=e.size_bytes, received_at=e.received_at + ) for e in entries ] def drain(self, max_count: int, max_bytes: int, chain_id: str = None) -> List[PendingTransaction]: @@ -235,35 +274,41 @@ class DatabaseMempool: if chain_id is None: chain_id = settings.chain_id with self._lock: - rows = self._conn.execute( - "SELECT tx_hash, content, fee, size_bytes, received_at FROM mempool WHERE chain_id = ? ORDER BY fee DESC, received_at ASC", - (chain_id,) - ).fetchall() + with Session(self._engine) as session: + entries = session.exec( + select(MempoolEntry).where(MempoolEntry.chain_id == chain_id) + .order_by(MempoolEntry.fee.desc(), MempoolEntry.received_at.asc()) + ).all() - result: List[PendingTransaction] = [] - total_bytes = 0 - hashes_to_remove: List[str] = [] + result: List[PendingTransaction] = [] + total_bytes = 0 + hashes_to_remove: List[str] = [] - for r in rows: - if len(result) >= max_count: - break - if total_bytes + r[3] > max_bytes: - continue - result.append(PendingTransaction( - tx_hash=r[0], content=json.loads(r[1]), - fee=r[2], size_bytes=r[3], received_at=r[4] - )) - total_bytes += r[3] - hashes_to_remove.append(r[0]) + for e in entries: + if len(result) >= max_count: + break + if total_bytes + e.size_bytes > max_bytes: + continue + result.append(PendingTransaction( + tx_hash=e.tx_hash, content=json.loads(e.content), + fee=e.fee, size_bytes=e.size_bytes, received_at=e.received_at + )) + total_bytes += e.size_bytes + hashes_to_remove.append(e.tx_hash) - if hashes_to_remove: - # Use parameterized query to avoid SQL injection - placeholders = ",".join(["?"] * len(hashes_to_remove)) - query = f"DELETE FROM mempool WHERE chain_id = ? AND tx_hash IN ({placeholders})" - self._conn.execute(query, [chain_id] + hashes_to_remove) - self._conn.commit() + if hashes_to_remove: + for hash_to_remove in hashes_to_remove: + entry = session.exec( + select(MempoolEntry).where( + MempoolEntry.chain_id == chain_id, + MempoolEntry.tx_hash == hash_to_remove + ) + ).first() + if entry: + session.delete(entry) + session.commit() - metrics_registry.increment(f"mempool_tx_drained_total_{chain_id}", float(len(result))) + metrics_registry.increment(f"mempool_tx_drained_total_{chain_id}", float(len(result))) self._update_gauge(chain_id) return result @@ -272,9 +317,19 @@ class DatabaseMempool: if chain_id is None: chain_id = settings.chain_id with self._lock: - cursor = self._conn.execute("DELETE FROM mempool WHERE chain_id = ? AND tx_hash = ?", (chain_id, tx_hash)) - self._conn.commit() - removed = cursor.rowcount > 0 + with Session(self._engine) as session: + entry = session.exec( + select(MempoolEntry).where( + MempoolEntry.chain_id == chain_id, + MempoolEntry.tx_hash == tx_hash + ) + ).first() + if entry: + session.delete(entry) + session.commit() + removed = True + else: + removed = False if removed: self._update_gauge(chain_id) return removed @@ -284,7 +339,10 @@ class DatabaseMempool: if chain_id is None: chain_id = settings.chain_id with self._lock: - return self._conn.execute("SELECT COUNT(*) FROM mempool WHERE chain_id = ?", (chain_id,)).fetchone()[0] + with Session(self._engine) as session: + return session.exec( + select(MempoolEntry).where(MempoolEntry.chain_id == chain_id) + ).count() def get_pending_transactions(self, chain_id: str = None, limit: int = 100) -> List[Dict[str, Any]]: """Get pending transactions for RPC endpoint""" @@ -293,18 +351,20 @@ class DatabaseMempool: chain_id = settings.chain_id with self._lock: - rows = self._conn.execute( - "SELECT content FROM mempool WHERE chain_id = ? ORDER BY fee DESC, received_at ASC LIMIT ?", - (chain_id, limit) - ).fetchall() + with Session(self._engine) as session: + entries = session.exec( + select(MempoolEntry).where(MempoolEntry.chain_id == chain_id) + .order_by(MempoolEntry.fee.desc(), MempoolEntry.received_at.asc()) + .limit(limit) + ).all() - return [json.loads(row[0]) for row in rows] + return [json.loads(e.content) for e in entries] def _update_gauge(self, chain_id: str = None) -> None: from .config import settings if chain_id is None: chain_id = settings.chain_id - count = self._conn.execute("SELECT COUNT(*) FROM mempool WHERE chain_id = ?", (chain_id,)).fetchone()[0] + count = self.size(chain_id) metrics_registry.set_gauge(f"mempool_size_{chain_id}", float(count)) @@ -312,10 +372,10 @@ class DatabaseMempool: _MEMPOOL: Optional[InMemoryMempool | DatabaseMempool] = None -def init_mempool(backend: str = "memory", db_path: str = "", max_size: int = 10_000, min_fee: int = 0) -> None: +def init_mempool(backend: str = "memory", db_url: str = "", max_size: int = 10_000, min_fee: int = 0) -> None: global _MEMPOOL - if backend == "database" and db_path: - _MEMPOOL = DatabaseMempool(db_path, max_size=max_size, min_fee=min_fee) + if backend == "database" and db_url: + _MEMPOOL = DatabaseMempool(db_url, max_size=max_size, min_fee=min_fee) else: _MEMPOOL = InMemoryMempool(max_size=max_size, min_fee=min_fee) diff --git a/apps/blockchain-node/src/aitbc_chain/p2p_network.py b/apps/blockchain-node/src/aitbc_chain/p2p_network.py index c019a590..ee484489 100644 --- a/apps/blockchain-node/src/aitbc_chain/p2p_network.py +++ b/apps/blockchain-node/src/aitbc_chain/p2p_network.py @@ -720,14 +720,14 @@ def main(): try: from .mempool import init_mempool import pathlib - - db_path = "" + + db_url = "" if settings.mempool_backend == "database": - db_path = str(settings.db_path.parent / "mempool.db") - + db_url = settings.mempool_db_url + init_mempool( backend=settings.mempool_backend, - db_path=db_path, + db_url=db_url, max_size=settings.mempool_max_size, min_fee=settings.min_fee ) diff --git a/apps/blockchain-node/src/aitbc_chain/sync.py b/apps/blockchain-node/src/aitbc_chain/sync.py index 485d1cf7..f9aefe83 100755 --- a/apps/blockchain-node/src/aitbc_chain/sync.py +++ b/apps/blockchain-node/src/aitbc_chain/sync.py @@ -414,6 +414,7 @@ class ChainSync: def _append_block(self, session: Session, block_data: Dict[str, Any], transactions: Optional[List[Dict[str, Any]]] = None) -> ImportResult: """Append a block to the chain tip.""" + block_hash = block_data["hash"] timestamp_str = block_data.get("timestamp", "") try: timestamp = datetime.fromisoformat(timestamp_str) if timestamp_str else datetime.now(datetime.UTC) @@ -486,11 +487,9 @@ class ChainSync: ) session.add(tx) - session.commit() - # Verify state root if provided if block_data.get("state_root"): - from aitbc_chain.config import settings + session.flush() state_manager = StateManager() accounts = session.exec( select(Account).where(Account.chain_id == self._chain_id) @@ -498,11 +497,33 @@ class ChainSync: account_dict = {acc.address: acc for acc in accounts} computed_root = state_manager.compute_state_root(account_dict) - expected_root = bytes.fromhex(block_data.get("state_root").replace("0x", "")) + try: + expected_root = bytes.fromhex(str(block_data.get("state_root")).replace("0x", "")) + except ValueError: + expected_root = None - if computed_root != expected_root: + if expected_root is None or len(expected_root) != 32: if settings.enforce_state_root_validation: metrics_registry.increment("sync_state_root_rejected_total") + session.rollback() + logger.error( + f"[SYNC] Invalid state root at height {block_data['height']}: " + f"{block_data.get('state_root')} - BLOCK REJECTED" + ) + return ImportResult( + accepted=False, + height=block_data["height"], + block_hash=block_hash, + reason=f"Invalid state root: {block_data.get('state_root')}" + ) + logger.warning( + f"[SYNC] Invalid state root at height {block_data['height']}: " + f"{block_data.get('state_root')}" + ) + elif computed_root != expected_root: + if settings.enforce_state_root_validation: + metrics_registry.increment("sync_state_root_rejected_total") + session.rollback() logger.error( f"[SYNC] State root mismatch at height {block_data['height']}: " f"expected {expected_root.hex()}, computed {computed_root.hex()} - BLOCK REJECTED" @@ -519,6 +540,8 @@ class ChainSync: f"expected {expected_root.hex()}, computed {computed_root.hex()}" ) + session.commit() + metrics_registry.increment("sync_blocks_accepted_total") metrics_registry.set_gauge("sync_chain_height", float(block_data["height"])) logger.info("Imported block", extra={ diff --git a/apps/blockchain-node/tests/test_mempool.py b/apps/blockchain-node/tests/test_mempool.py index a441f15f..e1037a50 100755 --- a/apps/blockchain-node/tests/test_mempool.py +++ b/apps/blockchain-node/tests/test_mempool.py @@ -250,6 +250,7 @@ class TestInitMempool: def test_init_database(self, tmp_path): db_path = str(tmp_path / "init.db") - init_mempool(backend="database", db_path=db_path, max_size=50, min_fee=0) + db_url = f"sqlite:///{db_path}" + init_mempool(backend="database", db_url=db_url, max_size=50, min_fee=0) pool = get_mempool() assert isinstance(pool, DatabaseMempool) diff --git a/apps/blockchain-node/tests/test_sync.py b/apps/blockchain-node/tests/test_sync.py index f62a5726..39416d08 100755 --- a/apps/blockchain-node/tests/test_sync.py +++ b/apps/blockchain-node/tests/test_sync.py @@ -11,6 +11,7 @@ from unittest.mock import AsyncMock, Mock from sqlmodel import Session, SQLModel, create_engine, select from aitbc_chain.models import Block, Transaction +from aitbc_chain.sync import settings as sync_settings from aitbc_chain.metrics import metrics_registry from aitbc_chain.sync import ChainSync, ProposerSignatureValidator, ImportResult @@ -286,6 +287,29 @@ class TestChainSyncBulkImport: stored_txs = session.exec(select(Transaction).where(Transaction.block_height == 1)).all() assert len(stored_txs) == 2 + def test_enforced_state_root_mismatch_rolls_back_block(self, session_factory, monkeypatch): + monkeypatch.setattr(sync_settings, "enforce_state_root_validation", True) + sync = ChainSync(session_factory, chain_id="test", validate_signatures=False) + blocks = _seed_chain(session_factory, count=1, chain_id="test") + last = blocks[-1] + ts = datetime(2026, 1, 1, 0, 0, 1) + bh = _make_block_hash("test", 1, last["hash"], ts) + + result = sync.import_block({ + "height": 1, + "hash": bh, + "parent_hash": last["hash"], + "proposer": "node-a", + "timestamp": ts.isoformat(), + "state_root": "0x" + "11" * 32, + }) + + assert result.accepted is False + assert "State root mismatch" in result.reason + with session_factory() as session: + stored_block = session.exec(select(Block).where(Block.chain_id == "test", Block.height == 1)).first() + assert stored_block is None + class TestChainSyncSignatureValidation: diff --git a/apps/exchange/postgresql+psycopg:/aitbc_exchange:password@localhost:5432/aitbc_exchange b/apps/exchange/postgresql+psycopg:/aitbc_exchange:password@localhost:5432/aitbc_exchange new file mode 100644 index 00000000..f6c0b90a Binary files /dev/null and b/apps/exchange/postgresql+psycopg:/aitbc_exchange:password@localhost:5432/aitbc_exchange differ diff --git a/scripts/utils/chain_regen_node.py b/scripts/utils/chain_regen_node.py new file mode 100755 index 00000000..f040473a --- /dev/null +++ b/scripts/utils/chain_regen_node.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import glob +import hashlib +import json +import os +import shutil +import socket +import sqlite3 +import subprocess +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +REPO_DIR = Path("/opt/aitbc") +BLOCKCHAIN_SRC = REPO_DIR / "apps" / "blockchain-node" / "src" +if str(REPO_DIR) not in sys.path: + sys.path.insert(0, str(REPO_DIR)) +if str(BLOCKCHAIN_SRC) not in sys.path: + sys.path.insert(0, str(BLOCKCHAIN_SRC)) + +from sqlmodel import Session, create_engine, select +from sqlalchemy.exc import SQLAlchemyError + +from aitbc_chain.config import ChainSettings +from aitbc_chain.models import Account, Block, Transaction +from aitbc_chain.state.merkle_patricia_trie import StateManager + +SERVICE_NAME = "aitbc-blockchain-node.service" +DATA_ROOT = Path("/var/lib/aitbc/data") +BACKUP_ROOT = Path("/var/lib/aitbc/backups/mpt-regeneration") +ENV_FILES = [Path("/etc/aitbc/.env"), Path("/etc/aitbc/node.env")] + + +def _run(command: list[str], check: bool = False) -> subprocess.CompletedProcess[str]: + return subprocess.run(command, text=True, capture_output=True, check=check) + + +def _service_state(service_name: str) -> dict[str, Any]: + active = _run(["systemctl", "is-active", service_name]).stdout.strip() + enabled = _run(["systemctl", "is-enabled", service_name]).stdout.strip() + fragment = _run(["systemctl", "show", service_name, "-p", "FragmentPath", "--value"]).stdout.strip() + dropins = _run(["systemctl", "show", service_name, "-p", "DropInPaths", "--value"]).stdout.strip() + return { + "active": active, + "enabled": enabled, + "fragment_path": fragment, + "drop_in_paths": [item for item in dropins.split() if item], + } + + +def _git_revision() -> str | None: + result = _run(["git", "-C", str(REPO_DIR), "rev-parse", "HEAD"]) + if result.returncode != 0: + return None + return result.stdout.strip() + + +def _sha256_file(path: Path) -> str | None: + if not path.exists() or not path.is_file(): + return None + digest = hashlib.sha256() + with path.open("rb") as handle: + for chunk in iter(lambda: handle.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def _load_genesis(path: Path) -> dict[str, Any]: + if not path.exists(): + return {} + with path.open() as handle: + return json.load(handle) + + +def _allocation_digest(genesis: dict[str, Any]) -> str | None: + allocations = genesis.get("allocations") + if allocations is None: + return None + canonical = json.dumps(sorted(allocations, key=lambda item: item.get("address", "")), sort_keys=True, separators=(",", ":")) + return hashlib.sha256(canonical.encode()).hexdigest() + + +def _live_db_files(chain_id: str) -> list[Path]: + chain_dir = DATA_ROOT / chain_id + files = [ + chain_dir / "chain.db", + chain_dir / "chain.db-wal", + chain_dir / "chain.db-shm", + chain_dir / "chain.db-journal", + DATA_ROOT / "mempool.db", + DATA_ROOT / "mempool.db-wal", + DATA_ROOT / "mempool.db-shm", + DATA_ROOT / "mempool.db-journal", + chain_dir / "mempool.db", + chain_dir / "mempool.db-wal", + chain_dir / "mempool.db-shm", + chain_dir / "mempool.db-journal", + ] + return [path for path in files if path.exists()] + + +def _backup_sources(chain_id: str) -> list[Path]: + chain_dir = DATA_ROOT / chain_id + sources: list[Path] = [] + for pattern in [str(chain_dir / "chain.db*"), str(chain_dir / "genesis.json"), str(DATA_ROOT / "mempool.db*"), str(chain_dir / "mempool.db*")]: + sources.extend(Path(path) for path in glob.glob(pattern)) + sources.extend(path for path in ENV_FILES if path.exists()) + unique: dict[str, Path] = {} + for path in sources: + if path.exists(): + unique[str(path)] = path + return [unique[key] for key in sorted(unique)] + + +def _integrity_check(path: Path) -> str | None: + if not path.exists() or path.stat().st_size == 0: + return None + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + try: + row = conn.execute("PRAGMA integrity_check").fetchone() + return row[0] if row else None + finally: + conn.close() + except sqlite3.Error as exc: + return f"error: {exc}" + + +def _db_snapshot(chain_id: str, db_path: Path) -> dict[str, Any]: + data: dict[str, Any] = { + "db_error": None, + "head": None, + "block_count": None, + "transaction_count": None, + "account_count": None, + "computed_state_root": None, + "head_state_root_matches_computed": None, + } + if not db_path.exists() or db_path.stat().st_size == 0: + data["db_error"] = "database missing or empty" + return data + + engine = create_engine(f"sqlite:///{db_path}", echo=False) + try: + with Session(engine) as session: + head = session.exec(select(Block).where(Block.chain_id == chain_id).order_by(Block.height.desc()).limit(1)).first() + blocks = session.exec(select(Block).where(Block.chain_id == chain_id)).all() + transactions = session.exec(select(Transaction).where(Transaction.chain_id == chain_id)).all() + accounts = session.exec(select(Account).where(Account.chain_id == chain_id).order_by(Account.address)).all() + account_dict = {account.address: account for account in accounts} + computed_root = StateManager().compute_state_root(account_dict) + computed_root_hex = "0x" + computed_root.hex() + data.update( + { + "head": None if head is None else { + "height": head.height, + "hash": head.hash, + "state_root": head.state_root, + "proposer": head.proposer, + "timestamp": head.timestamp.isoformat() if head.timestamp else None, + }, + "block_count": len(blocks), + "transaction_count": len(transactions), + "account_count": len(accounts), + "account_digest": hashlib.sha256(json.dumps( + [{"address": account.address, "balance": account.balance, "nonce": account.nonce} for account in accounts], + sort_keys=True, + separators=(",", ":"), + ).encode()).hexdigest(), + "computed_state_root": computed_root_hex, + "head_state_root_matches_computed": bool(head and head.state_root == computed_root_hex), + } + ) + except SQLAlchemyError as exc: + data["db_error"] = str(exc) + finally: + engine.dispose() + return data + + +def snapshot(chain_id: str, service_name: str) -> dict[str, Any]: + settings = ChainSettings() + db_path = settings.get_db_path(chain_id) + genesis_path = DATA_ROOT / chain_id / "genesis.json" + genesis = _load_genesis(genesis_path) + data: dict[str, Any] = { + "host": socket.gethostname(), + "chain_id": chain_id, + "timestamp": datetime.now(timezone.utc).isoformat(), + "repo_revision": _git_revision(), + "service": _service_state(service_name), + "db_path": str(db_path), + "db_exists": db_path.exists(), + "db_size_bytes": db_path.stat().st_size if db_path.exists() else None, + "db_integrity": _integrity_check(db_path), + "genesis_path": str(genesis_path), + "genesis_exists": genesis_path.exists(), + "genesis_file_sha256": _sha256_file(genesis_path), + "genesis_allocation_digest": _allocation_digest(genesis), + "genesis_allocation_count": len(genesis.get("allocations", [])) if genesis else 0, + "live_db_files": [str(path) for path in _live_db_files(chain_id)], + } + data.update(_db_snapshot(chain_id, db_path)) + return data + + +def backup(chain_id: str, service_name: str, backup_root: Path, timestamp: str | None) -> dict[str, Any]: + stamp = timestamp or datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + target = backup_root / stamp / socket.gethostname() / chain_id + target.mkdir(parents=True, exist_ok=True) + manifest: dict[str, Any] = { + "host": socket.gethostname(), + "chain_id": chain_id, + "timestamp": stamp, + "target": str(target), + "preflight": snapshot(chain_id, service_name), + "files": [], + } + for source in _backup_sources(chain_id): + rel = source.relative_to(source.anchor) + destination = target / rel + destination.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(source, destination) + manifest["files"].append( + { + "source": str(source), + "backup": str(destination), + "size_bytes": source.stat().st_size, + "sha256": _sha256_file(source), + "integrity_check": _integrity_check(destination) if source.name.startswith("chain.db") or source.name.startswith("mempool.db") else None, + } + ) + manifest_path = target / "manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True, default=str)) + return manifest + + +def reset(chain_id: str, service_name: str, yes: bool, force: bool) -> dict[str, Any]: + if not yes: + raise SystemExit("reset requires --yes") + service = _service_state(service_name) + if service["active"] == "active" and not force: + raise SystemExit(f"{service_name} is active; stop it first or pass --force") + removed: list[str] = [] + for path in _live_db_files(chain_id): + path.unlink() + removed.append(str(path)) + return {"host": socket.gethostname(), "chain_id": chain_id, "removed": removed} + + +def set_role(chain_id: str, service_name: str, role: str) -> dict[str, Any]: + dropin_dir = Path("/etc/systemd/system") / f"{service_name}.d" + dropin_path = dropin_dir / "mpt-regeneration.conf" + if role == "clear": + if dropin_path.exists(): + dropin_path.unlink() + _run(["systemctl", "daemon-reload"], check=True) + return {"host": socket.gethostname(), "role": role, "dropin": str(dropin_path), "exists": dropin_path.exists()} + dropin_dir.mkdir(parents=True, exist_ok=True) + if role == "leader": + content = ( + "[Service]\n" + 'Environment="AITBC_FORCE_ENABLE_BLOCK_PRODUCTION=true"\n' + f'Environment="AITBC_FORCE_BLOCK_PRODUCTION_CHAINS={chain_id}"\n' + 'Environment="ENABLE_BLOCK_PRODUCTION=true"\n' + 'Environment="enable_block_production=true"\n' + f'Environment="BLOCK_PRODUCTION_CHAINS={chain_id}"\n' + f'Environment="block_production_chains={chain_id}"\n' + ) + elif role == "follower": + content = ( + "[Service]\n" + 'Environment="AITBC_FORCE_ENABLE_BLOCK_PRODUCTION=false"\n' + 'Environment="AITBC_FORCE_BLOCK_PRODUCTION_CHAINS="\n' + 'Environment="ENABLE_BLOCK_PRODUCTION=false"\n' + 'Environment="enable_block_production=false"\n' + 'Environment="BLOCK_PRODUCTION_CHAINS="\n' + 'Environment="block_production_chains="\n' + ) + else: + raise SystemExit(f"unsupported role: {role}") + dropin_path.write_text(content) + _run(["systemctl", "daemon-reload"], check=True) + return {"host": socket.gethostname(), "role": role, "dropin": str(dropin_path), "exists": dropin_path.exists()} + + +def verify(chain_id: str, service_name: str, require_nonzero_root: bool) -> dict[str, Any]: + data = snapshot(chain_id, service_name) + head = data.get("head") + ok = data.get("db_error") is None and head is not None and data.get("head_state_root_matches_computed") is True + if require_nonzero_root and data.get("computed_state_root") == "0x" + ("00" * 32): + ok = False + data["ok"] = ok + return data + + +def print_json(data: Any) -> None: + print(json.dumps(data, indent=2, sort_keys=True, default=str)) + + +def main() -> int: + parser = argparse.ArgumentParser(description="AITBC coordinated MPT chain regeneration node utility") + parser.add_argument("--chain-id", default="ait-mainnet") + parser.add_argument("--service-name", default=SERVICE_NAME) + subparsers = parser.add_subparsers(dest="command", required=True) + + subparsers.add_parser("preflight") + + backup_parser = subparsers.add_parser("backup") + backup_parser.add_argument("--backup-root", type=Path, default=BACKUP_ROOT) + backup_parser.add_argument("--timestamp") + + reset_parser = subparsers.add_parser("reset") + reset_parser.add_argument("--yes", action="store_true") + reset_parser.add_argument("--force", action="store_true") + + verify_parser = subparsers.add_parser("verify") + verify_parser.add_argument("--require-nonzero-root", action="store_true") + + role_parser = subparsers.add_parser("set-role") + role_parser.add_argument("role", choices=["leader", "follower", "clear"]) + + args = parser.parse_args() + + if args.command == "preflight": + print_json(snapshot(args.chain_id, args.service_name)) + elif args.command == "backup": + print_json(backup(args.chain_id, args.service_name, args.backup_root, args.timestamp)) + elif args.command == "reset": + print_json(reset(args.chain_id, args.service_name, args.yes, args.force)) + elif args.command == "verify": + result = verify(args.chain_id, args.service_name, args.require_nonzero_root) + print_json(result) + return 0 if result.get("ok") else 2 + elif args.command == "set-role": + print_json(set_role(args.chain_id, args.service_name, args.role)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/utils/coordinated_chain_regen.sh b/scripts/utils/coordinated_chain_regen.sh new file mode 100755 index 00000000..462a9ac3 --- /dev/null +++ b/scripts/utils/coordinated_chain_regen.sh @@ -0,0 +1,161 @@ +#!/usr/bin/env bash +set -euo pipefail + +CHAIN_ID="${CHAIN_ID:-ait-mainnet}" +SERVICE_NAME="${SERVICE_NAME:-aitbc-blockchain-node.service}" +LEADER="${LEADER:-aitbc1}" +NODES_RAW="${NODES:-localhost aitbc1 gitea-runner}" +UTILITY="${UTILITY:-/opt/aitbc/scripts/utils/chain_regen_node.py}" +PYTHON_BIN="${PYTHON_BIN:-/opt/aitbc/venv/bin/python}" +TIMESTAMP="${TIMESTAMP:-$(date -u +%Y%m%dT%H%M%SZ)}" +STARTUP_WAIT_SECONDS="${STARTUP_WAIT_SECONDS:-8}" +FOLLOWER_START_WAIT_SECONDS="${FOLLOWER_START_WAIT_SECONDS:-8}" +VERIFY_RETRIES="${VERIFY_RETRIES:-6}" +VERIFY_RETRY_SECONDS="${VERIFY_RETRY_SECONDS:-10}" + +usage() { + cat < + +Environment: + CHAIN_ID=$CHAIN_ID + SERVICE_NAME=$SERVICE_NAME + LEADER=$LEADER + NODES="$NODES_RAW" + TIMESTAMP=$TIMESTAMP +USAGE +} + +node_cmd() { + local node="$1" + shift + if [[ "$node" == "localhost" || "$node" == "local" || "$node" == "$(hostname)" ]]; then + "$@" + else + ssh "$node" "$*" + fi +} + +run_node_utility() { + local node="$1" + shift + node_cmd "$node" "$PYTHON_BIN" "$UTILITY" --chain-id "$CHAIN_ID" --service-name "$SERVICE_NAME" "$@" +} + +for_each_node() { + local action="$1" + shift + local node + for node in $NODES_RAW; do + echo "===== $node :: $action =====" + "$@" "$node" + done +} + +preflight_node() { + run_node_utility "$1" preflight +} + +backup_node() { + run_node_utility "$1" backup --timestamp "$TIMESTAMP" +} + +set_role_node() { + local node="$1" + if [[ "$node" == "$LEADER" ]]; then + run_node_utility "$node" set-role leader + else + run_node_utility "$node" set-role follower + fi +} + +clear_role_node() { + run_node_utility "$1" set-role clear +} + +stop_node() { + node_cmd "$1" systemctl stop "$SERVICE_NAME" +} + +reset_node() { + run_node_utility "$1" reset --yes +} + +start_leader() { + echo "===== $LEADER :: start leader =====" + node_cmd "$LEADER" systemctl start "$SERVICE_NAME" + sleep "$STARTUP_WAIT_SECONDS" +} + +start_followers() { + local node + for node in $NODES_RAW; do + if [[ "$node" == "$LEADER" ]]; then + continue + fi + echo "===== $node :: start follower =====" + node_cmd "$node" systemctl start "$SERVICE_NAME" + done + sleep "$FOLLOWER_START_WAIT_SECONDS" +} + +verify_node() { + run_node_utility "$1" verify --require-nonzero-root +} + +verify_node_with_retry() { + local node="$1" + local attempt + for attempt in $(seq 1 "$VERIFY_RETRIES"); do + if run_node_utility "$node" verify --require-nonzero-root; then + return 0 + fi + if [[ "$attempt" == "$VERIFY_RETRIES" ]]; then + return 1 + fi + echo "verify failed for $node, retrying in ${VERIFY_RETRY_SECONDS}s (${attempt}/${VERIFY_RETRIES})" >&2 + sleep "$VERIFY_RETRY_SECONDS" + done +} + +case "${1:-}" in + preflight) + for_each_node preflight preflight_node + ;; + backup) + for_each_node backup backup_node + ;; + set-roles) + for_each_node set-roles set_role_node + ;; + clear-roles) + for_each_node clear-roles clear_role_node + ;; + stop) + for_each_node stop stop_node + ;; + reset) + for_each_node reset reset_node + ;; + start) + start_leader + start_followers + ;; + verify) + for_each_node verify verify_node + ;; + rollout) + for_each_node preflight preflight_node + for_each_node set-roles set_role_node + for_each_node backup backup_node + for_each_node stop stop_node + for_each_node reset reset_node + start_leader + start_followers + for_each_node verify verify_node_with_retry + ;; + *) + usage + exit 64 + ;; +esac diff --git a/scripts/utils/init_production_genesis.py b/scripts/utils/init_production_genesis.py index dcfc189a..629b2c4f 100644 --- a/scripts/utils/init_production_genesis.py +++ b/scripts/utils/init_production_genesis.py @@ -137,7 +137,8 @@ def main() -> None: # Ensure mempool DB exists (though not needed for genesis) mempool_path = settings.db_path.parent / "mempool.db" - init_mempool(backend="database", db_path=str(mempool_path), max_size=10000, min_fee=0) + mempool_url = f"sqlite:///{mempool_path}" + init_mempool(backend="database", db_url=mempool_url, max_size=10000, min_fee=0) print(f"[*] Mempool initialized at {mempool_path}") # Create genesis block