fix: update init_mempool calls to use db_url instead of db_path

- Update scripts/utils/init_production_genesis.py to use db_url
- Update apps/blockchain-node/tests/test_mempool.py to use db_url
- Update apps/blockchain-node/src/aitbc_chain/p2p_network.py to use db_url
- Add MEMPOOL_DB_URL to /etc/aitbc/.env on both nodes for PostgreSQL mempool
This commit is contained in:
aitbc
2026-05-03 21:14:26 +02:00
parent f07a5de9b1
commit 4f870e9d8d
13 changed files with 764 additions and 115 deletions

View File

@@ -93,7 +93,7 @@ async def lifespan(app: FastAPI):
init_db()
init_mempool(
backend=settings.mempool_backend,
db_path=str(settings.db_path.parent / "mempool.db"),
db_url=settings.mempool_db_url,
max_size=settings.mempool_max_size,
min_fee=settings.min_fee,
)

View File

@@ -81,6 +81,7 @@ class ChainSettings(BaseSettings):
# Mempool settings
mempool_backend: str = "database" # "database" or "memory" (database recommended for persistence)
mempool_db_url: str = "postgresql+psycopg://aitbc_mempool:password@localhost:5432/aitbc_mempool" # PostgreSQL URL for mempool (can be overridden by MEMPOOL_DB_URL env var)
mempool_max_size: int = 10_000
mempool_eviction_interval: int = 60 # seconds

View File

@@ -72,9 +72,16 @@ def get_engine(chain_id: str = "") -> object:
@event.listens_for(engine, "connect")
def set_encryption_key(dbapi_connection, connection_record):
dbapi_connection.execute(f"PRAGMA key = '{key_hex}'")
dbapi_connection.execute("PRAGMA journal_mode=WAL")
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
else:
# Use standard SQLite
engine = create_engine(f"sqlite:///{db_path}", echo=False)
@event.listens_for(engine, "connect")
def set_wal_mode(dbapi_connection, connection_record):
dbapi_connection.execute("PRAGMA journal_mode=WAL")
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
_engines[resolved_chain_id] = engine
@@ -87,8 +94,7 @@ _engine = create_engine(f"sqlite:///{settings.db_path}", echo=False)
@event.listens_for(_engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
# WAL mode disabled due to issues on btrfs raid (CoW already disabled on data directory)
# cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA cache_size=-64000")
cursor.execute("PRAGMA temp_store=MEMORY")

View File

@@ -96,12 +96,46 @@ class BlockchainNode:
self._stop_event = asyncio.Event()
self._proposers: dict[str, PoAProposer] = {}
@staticmethod
def _env_value(*names: str) -> Optional[str]:
for name in names:
value = os.getenv(name)
if value is not None:
return value
return None
def _block_production_enabled(self) -> bool:
override = self._env_value("AITBC_FORCE_ENABLE_BLOCK_PRODUCTION", "ENABLE_BLOCK_PRODUCTION", "enable_block_production")
if override is not None:
return override.strip().lower() in {"1", "true", "yes", "on"}
return bool(getattr(settings, "enable_block_production", True))
def _supported_chains(self) -> list[str]:
chains_str = getattr(settings, 'supported_chains', settings.chain_id)
chains = [c.strip() for c in chains_str.split(",") if c.strip()]
if not chains and settings.chain_id:
chains = [settings.chain_id]
return chains
def _proposer_config(self, chain_id: str) -> ProposerConfig:
return ProposerConfig(
chain_id=chain_id,
proposer_id=settings.proposer_id,
interval_seconds=settings.block_time_seconds,
max_block_size_bytes=settings.max_block_size_bytes,
max_txs_per_block=settings.max_txs_per_block,
)
async def _ensure_genesis_for_chains(self) -> None:
for chain_id in self._supported_chains():
proposer = PoAProposer(config=self._proposer_config(chain_id), session_factory=session_scope)
await proposer._ensure_genesis_block()
async def _setup_gossip_subscribers(self) -> None:
logger.info("Setting up gossip subscribers")
# Parse supported chains
chains_str = getattr(settings, 'supported_chains', settings.chain_id)
chains = [c.strip() for c in chains_str.split(",") if c.strip()]
chains = self._supported_chains()
# Transactions (single topic for all chains)
try:
@@ -213,8 +247,7 @@ class BlockchainNode:
logger.info("Gossip backend initialized successfully")
# Parse supported chains
chains_str = getattr(settings, 'supported_chains', settings.chain_id)
chains = [c.strip() for c in chains_str.split(",") if c.strip()]
chains = self._supported_chains()
logger.info(f"Initializing databases for chains: {chains}")
# Initialize database for each supported chain
@@ -224,12 +257,13 @@ class BlockchainNode:
init_mempool(
backend=settings.mempool_backend,
db_path=str(settings.db_path.parent / "mempool.db"),
db_url=settings.mempool_db_url,
max_size=settings.mempool_max_size,
min_fee=settings.min_fee,
)
await self._ensure_genesis_for_chains()
# Start proposers only if enabled (followers set enable_block_production=False)
if getattr(settings, "enable_block_production", True):
if self._block_production_enabled():
self._start_proposers()
else:
logger.info("Block production disabled on this node", extra={"proposer_id": settings.proposer_id})
@@ -245,11 +279,12 @@ class BlockchainNode:
await self._shutdown()
def _start_proposers(self) -> None:
chains_str = getattr(settings, 'supported_chains', settings.chain_id)
chains = [c.strip() for c in chains_str.split(",") if c.strip()]
chains = self._supported_chains()
# Get chains that should produce blocks (if specified, otherwise all supported chains)
production_chains_str = getattr(settings, 'block_production_chains', chains_str)
production_chains_str = self._env_value("AITBC_FORCE_BLOCK_PRODUCTION_CHAINS", "BLOCK_PRODUCTION_CHAINS", "block_production_chains")
if production_chains_str is None:
production_chains_str = getattr(settings, 'block_production_chains', ",".join(chains))
production_chains = [c.strip() for c in production_chains_str.split(",") if c.strip()]
for chain_id in chains:
@@ -261,15 +296,7 @@ class BlockchainNode:
if chain_id in self._proposers:
continue
proposer_config = ProposerConfig(
chain_id=chain_id,
proposer_id=settings.proposer_id,
interval_seconds=settings.block_time_seconds,
max_block_size_bytes=settings.max_block_size_bytes,
max_txs_per_block=settings.max_txs_per_block,
)
proposer = PoAProposer(config=proposer_config, session_factory=session_scope)
proposer = PoAProposer(config=self._proposer_config(chain_id), session_factory=session_scope)
self._proposers[chain_id] = proposer
asyncio.create_task(proposer.start())

View File

@@ -7,9 +7,31 @@ from dataclasses import dataclass, field
from threading import Lock
from typing import Any, Dict, List, Optional
from sqlmodel import Session, SQLModel, create_engine, select, Field, text
from sqlalchemy import Column, String, Integer, Float, Text, Index, MetaData, Table
from .metrics import metrics_registry
mempool_metadata = MetaData()
class MempoolEntry(SQLModel, table=True):
__tablename__ = "mempool"
__table_args__ = {"metadata": mempool_metadata}
chain_id: str = Field(primary_key=True)
tx_hash: str = Field(primary_key=True)
content: str = Field(sa_column=Column(Text, nullable=False))
fee: int = Field(default=0, sa_column=Column(Integer, nullable=False))
size_bytes: int = Field(default=0, sa_column=Column(Integer, nullable=False))
received_at: float = Field(sa_column=Column(Float, nullable=False))
__table_args__ = (
Index('idx_mempool_fee', 'fee', postgresql_ops={'fee': 'DESC'}),
)
@dataclass(frozen=True)
class PendingTransaction:
tx_hash: str
@@ -150,32 +172,33 @@ class InMemoryMempool:
class DatabaseMempool:
"""SQLite-backed mempool for persistence and cross-service sharing."""
"""PostgreSQL-backed mempool for persistence and cross-service sharing."""
def __init__(self, db_path: str, max_size: int = 10_000, min_fee: int = 0) -> None:
import sqlite3
self._db_path = db_path
def __init__(self, db_url: str, max_size: int = 10_000, min_fee: int = 0) -> None:
self._db_url = db_url
self._max_size = max_size
self._min_fee = min_fee
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._engine = create_engine(db_url, echo=False, pool_pre_ping=True)
self._lock = Lock()
self._init_table()
def _init_table(self) -> None:
with self._lock:
self._conn.execute("""
CREATE TABLE IF NOT EXISTS mempool (
chain_id TEXT NOT NULL,
tx_hash TEXT NOT NULL,
content TEXT NOT NULL,
fee INTEGER DEFAULT 0,
size_bytes INTEGER DEFAULT 0,
received_at REAL NOT NULL,
PRIMARY KEY (chain_id, tx_hash)
)
""")
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_mempool_fee ON mempool(fee DESC)")
self._conn.commit()
with Session(self._engine) as session:
# Create table manually using raw SQL to avoid chain table conflicts
session.exec(text("""
CREATE TABLE IF NOT EXISTS mempool (
chain_id TEXT NOT NULL,
tx_hash TEXT NOT NULL,
content TEXT NOT NULL,
fee INTEGER DEFAULT 0,
size_bytes INTEGER DEFAULT 0,
received_at REAL NOT NULL,
PRIMARY KEY (chain_id, tx_hash)
)
"""))
session.exec(text("CREATE INDEX IF NOT EXISTS idx_mempool_fee ON mempool(fee DESC)"))
session.commit()
def add(self, tx: Dict[str, Any], chain_id: str = None) -> str:
from .config import settings
@@ -190,27 +213,42 @@ class DatabaseMempool:
size_bytes = len(content.encode())
with self._lock:
# Check duplicate
row = self._conn.execute("SELECT 1 FROM mempool WHERE chain_id = ? AND tx_hash = ?", (chain_id, tx_hash)).fetchone()
if row:
return tx_hash
# Evict if full
count = self._conn.execute("SELECT COUNT(*) FROM mempool WHERE chain_id = ?", (chain_id,)).fetchone()[0]
if count >= self._max_size:
self._conn.execute("""
DELETE FROM mempool WHERE chain_id = ? AND tx_hash = (
SELECT tx_hash FROM mempool WHERE chain_id = ? ORDER BY fee ASC, received_at DESC LIMIT 1
with Session(self._engine) as session:
# Check duplicate
existing = session.exec(
select(MempoolEntry).where(
MempoolEntry.chain_id == chain_id,
MempoolEntry.tx_hash == tx_hash
)
""", (chain_id, chain_id))
metrics_registry.increment(f"mempool_evictions_total_{chain_id}")
).first()
if existing:
return tx_hash
self._conn.execute(
"INSERT INTO mempool (chain_id, tx_hash, content, fee, size_bytes, received_at) VALUES (?, ?, ?, ?, ?, ?)",
(chain_id, tx_hash, content, fee, size_bytes, time.time())
)
self._conn.commit()
metrics_registry.increment(f"mempool_tx_added_total_{chain_id}")
# Evict if full
count = session.exec(
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
).count()
if count >= self._max_size:
to_evict = session.exec(
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
.order_by(MempoolEntry.fee.asc(), MempoolEntry.received_at.desc())
.limit(1)
).first()
if to_evict:
session.delete(to_evict)
metrics_registry.increment(f"mempool_evictions_total_{chain_id}")
entry = MempoolEntry(
chain_id=chain_id,
tx_hash=tx_hash,
content=content,
fee=fee,
size_bytes=size_bytes,
received_at=time.time()
)
session.add(entry)
session.commit()
metrics_registry.increment(f"mempool_tx_added_total_{chain_id}")
self._update_gauge(chain_id)
return tx_hash
@@ -219,15 +257,16 @@ class DatabaseMempool:
if chain_id is None:
chain_id = settings.chain_id
with self._lock:
rows = self._conn.execute(
"SELECT tx_hash, content, fee, size_bytes, received_at FROM mempool WHERE chain_id = ? ORDER BY fee DESC, received_at ASC",
(chain_id,)
).fetchall()
with Session(self._engine) as session:
entries = session.exec(
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
.order_by(MempoolEntry.fee.desc(), MempoolEntry.received_at.asc())
).all()
return [
PendingTransaction(
tx_hash=r[0], content=json.loads(r[1]),
fee=r[2], size_bytes=r[3], received_at=r[4]
) for r in rows
tx_hash=e.tx_hash, content=json.loads(e.content),
fee=e.fee, size_bytes=e.size_bytes, received_at=e.received_at
) for e in entries
]
def drain(self, max_count: int, max_bytes: int, chain_id: str = None) -> List[PendingTransaction]:
@@ -235,35 +274,41 @@ class DatabaseMempool:
if chain_id is None:
chain_id = settings.chain_id
with self._lock:
rows = self._conn.execute(
"SELECT tx_hash, content, fee, size_bytes, received_at FROM mempool WHERE chain_id = ? ORDER BY fee DESC, received_at ASC",
(chain_id,)
).fetchall()
with Session(self._engine) as session:
entries = session.exec(
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
.order_by(MempoolEntry.fee.desc(), MempoolEntry.received_at.asc())
).all()
result: List[PendingTransaction] = []
total_bytes = 0
hashes_to_remove: List[str] = []
result: List[PendingTransaction] = []
total_bytes = 0
hashes_to_remove: List[str] = []
for r in rows:
if len(result) >= max_count:
break
if total_bytes + r[3] > max_bytes:
continue
result.append(PendingTransaction(
tx_hash=r[0], content=json.loads(r[1]),
fee=r[2], size_bytes=r[3], received_at=r[4]
))
total_bytes += r[3]
hashes_to_remove.append(r[0])
for e in entries:
if len(result) >= max_count:
break
if total_bytes + e.size_bytes > max_bytes:
continue
result.append(PendingTransaction(
tx_hash=e.tx_hash, content=json.loads(e.content),
fee=e.fee, size_bytes=e.size_bytes, received_at=e.received_at
))
total_bytes += e.size_bytes
hashes_to_remove.append(e.tx_hash)
if hashes_to_remove:
# Use parameterized query to avoid SQL injection
placeholders = ",".join(["?"] * len(hashes_to_remove))
query = f"DELETE FROM mempool WHERE chain_id = ? AND tx_hash IN ({placeholders})"
self._conn.execute(query, [chain_id] + hashes_to_remove)
self._conn.commit()
if hashes_to_remove:
for hash_to_remove in hashes_to_remove:
entry = session.exec(
select(MempoolEntry).where(
MempoolEntry.chain_id == chain_id,
MempoolEntry.tx_hash == hash_to_remove
)
).first()
if entry:
session.delete(entry)
session.commit()
metrics_registry.increment(f"mempool_tx_drained_total_{chain_id}", float(len(result)))
metrics_registry.increment(f"mempool_tx_drained_total_{chain_id}", float(len(result)))
self._update_gauge(chain_id)
return result
@@ -272,9 +317,19 @@ class DatabaseMempool:
if chain_id is None:
chain_id = settings.chain_id
with self._lock:
cursor = self._conn.execute("DELETE FROM mempool WHERE chain_id = ? AND tx_hash = ?", (chain_id, tx_hash))
self._conn.commit()
removed = cursor.rowcount > 0
with Session(self._engine) as session:
entry = session.exec(
select(MempoolEntry).where(
MempoolEntry.chain_id == chain_id,
MempoolEntry.tx_hash == tx_hash
)
).first()
if entry:
session.delete(entry)
session.commit()
removed = True
else:
removed = False
if removed:
self._update_gauge(chain_id)
return removed
@@ -284,7 +339,10 @@ class DatabaseMempool:
if chain_id is None:
chain_id = settings.chain_id
with self._lock:
return self._conn.execute("SELECT COUNT(*) FROM mempool WHERE chain_id = ?", (chain_id,)).fetchone()[0]
with Session(self._engine) as session:
return session.exec(
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
).count()
def get_pending_transactions(self, chain_id: str = None, limit: int = 100) -> List[Dict[str, Any]]:
"""Get pending transactions for RPC endpoint"""
@@ -293,18 +351,20 @@ class DatabaseMempool:
chain_id = settings.chain_id
with self._lock:
rows = self._conn.execute(
"SELECT content FROM mempool WHERE chain_id = ? ORDER BY fee DESC, received_at ASC LIMIT ?",
(chain_id, limit)
).fetchall()
with Session(self._engine) as session:
entries = session.exec(
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
.order_by(MempoolEntry.fee.desc(), MempoolEntry.received_at.asc())
.limit(limit)
).all()
return [json.loads(row[0]) for row in rows]
return [json.loads(e.content) for e in entries]
def _update_gauge(self, chain_id: str = None) -> None:
from .config import settings
if chain_id is None:
chain_id = settings.chain_id
count = self._conn.execute("SELECT COUNT(*) FROM mempool WHERE chain_id = ?", (chain_id,)).fetchone()[0]
count = self.size(chain_id)
metrics_registry.set_gauge(f"mempool_size_{chain_id}", float(count))
@@ -312,10 +372,10 @@ class DatabaseMempool:
_MEMPOOL: Optional[InMemoryMempool | DatabaseMempool] = None
def init_mempool(backend: str = "memory", db_path: str = "", max_size: int = 10_000, min_fee: int = 0) -> None:
def init_mempool(backend: str = "memory", db_url: str = "", max_size: int = 10_000, min_fee: int = 0) -> None:
global _MEMPOOL
if backend == "database" and db_path:
_MEMPOOL = DatabaseMempool(db_path, max_size=max_size, min_fee=min_fee)
if backend == "database" and db_url:
_MEMPOOL = DatabaseMempool(db_url, max_size=max_size, min_fee=min_fee)
else:
_MEMPOOL = InMemoryMempool(max_size=max_size, min_fee=min_fee)

View File

@@ -720,14 +720,14 @@ def main():
try:
from .mempool import init_mempool
import pathlib
db_path = ""
db_url = ""
if settings.mempool_backend == "database":
db_path = str(settings.db_path.parent / "mempool.db")
db_url = settings.mempool_db_url
init_mempool(
backend=settings.mempool_backend,
db_path=db_path,
db_url=db_url,
max_size=settings.mempool_max_size,
min_fee=settings.min_fee
)

View File

@@ -414,6 +414,7 @@ class ChainSync:
def _append_block(self, session: Session, block_data: Dict[str, Any],
transactions: Optional[List[Dict[str, Any]]] = None) -> ImportResult:
"""Append a block to the chain tip."""
block_hash = block_data["hash"]
timestamp_str = block_data.get("timestamp", "")
try:
timestamp = datetime.fromisoformat(timestamp_str) if timestamp_str else datetime.now(datetime.UTC)
@@ -486,11 +487,9 @@ class ChainSync:
)
session.add(tx)
session.commit()
# Verify state root if provided
if block_data.get("state_root"):
from aitbc_chain.config import settings
session.flush()
state_manager = StateManager()
accounts = session.exec(
select(Account).where(Account.chain_id == self._chain_id)
@@ -498,11 +497,33 @@ class ChainSync:
account_dict = {acc.address: acc for acc in accounts}
computed_root = state_manager.compute_state_root(account_dict)
expected_root = bytes.fromhex(block_data.get("state_root").replace("0x", ""))
try:
expected_root = bytes.fromhex(str(block_data.get("state_root")).replace("0x", ""))
except ValueError:
expected_root = None
if computed_root != expected_root:
if expected_root is None or len(expected_root) != 32:
if settings.enforce_state_root_validation:
metrics_registry.increment("sync_state_root_rejected_total")
session.rollback()
logger.error(
f"[SYNC] Invalid state root at height {block_data['height']}: "
f"{block_data.get('state_root')} - BLOCK REJECTED"
)
return ImportResult(
accepted=False,
height=block_data["height"],
block_hash=block_hash,
reason=f"Invalid state root: {block_data.get('state_root')}"
)
logger.warning(
f"[SYNC] Invalid state root at height {block_data['height']}: "
f"{block_data.get('state_root')}"
)
elif computed_root != expected_root:
if settings.enforce_state_root_validation:
metrics_registry.increment("sync_state_root_rejected_total")
session.rollback()
logger.error(
f"[SYNC] State root mismatch at height {block_data['height']}: "
f"expected {expected_root.hex()}, computed {computed_root.hex()} - BLOCK REJECTED"
@@ -519,6 +540,8 @@ class ChainSync:
f"expected {expected_root.hex()}, computed {computed_root.hex()}"
)
session.commit()
metrics_registry.increment("sync_blocks_accepted_total")
metrics_registry.set_gauge("sync_chain_height", float(block_data["height"]))
logger.info("Imported block", extra={

View File

@@ -250,6 +250,7 @@ class TestInitMempool:
def test_init_database(self, tmp_path):
db_path = str(tmp_path / "init.db")
init_mempool(backend="database", db_path=db_path, max_size=50, min_fee=0)
db_url = f"sqlite:///{db_path}"
init_mempool(backend="database", db_url=db_url, max_size=50, min_fee=0)
pool = get_mempool()
assert isinstance(pool, DatabaseMempool)

View File

@@ -11,6 +11,7 @@ from unittest.mock import AsyncMock, Mock
from sqlmodel import Session, SQLModel, create_engine, select
from aitbc_chain.models import Block, Transaction
from aitbc_chain.sync import settings as sync_settings
from aitbc_chain.metrics import metrics_registry
from aitbc_chain.sync import ChainSync, ProposerSignatureValidator, ImportResult
@@ -286,6 +287,29 @@ class TestChainSyncBulkImport:
stored_txs = session.exec(select(Transaction).where(Transaction.block_height == 1)).all()
assert len(stored_txs) == 2
def test_enforced_state_root_mismatch_rolls_back_block(self, session_factory, monkeypatch):
monkeypatch.setattr(sync_settings, "enforce_state_root_validation", True)
sync = ChainSync(session_factory, chain_id="test", validate_signatures=False)
blocks = _seed_chain(session_factory, count=1, chain_id="test")
last = blocks[-1]
ts = datetime(2026, 1, 1, 0, 0, 1)
bh = _make_block_hash("test", 1, last["hash"], ts)
result = sync.import_block({
"height": 1,
"hash": bh,
"parent_hash": last["hash"],
"proposer": "node-a",
"timestamp": ts.isoformat(),
"state_root": "0x" + "11" * 32,
})
assert result.accepted is False
assert "State root mismatch" in result.reason
with session_factory() as session:
stored_block = session.exec(select(Block).where(Block.chain_id == "test", Block.height == 1)).first()
assert stored_block is None
class TestChainSyncSignatureValidation:

345
scripts/utils/chain_regen_node.py Executable file
View File

@@ -0,0 +1,345 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import glob
import hashlib
import json
import os
import shutil
import socket
import sqlite3
import subprocess
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
REPO_DIR = Path("/opt/aitbc")
BLOCKCHAIN_SRC = REPO_DIR / "apps" / "blockchain-node" / "src"
if str(REPO_DIR) not in sys.path:
sys.path.insert(0, str(REPO_DIR))
if str(BLOCKCHAIN_SRC) not in sys.path:
sys.path.insert(0, str(BLOCKCHAIN_SRC))
from sqlmodel import Session, create_engine, select
from sqlalchemy.exc import SQLAlchemyError
from aitbc_chain.config import ChainSettings
from aitbc_chain.models import Account, Block, Transaction
from aitbc_chain.state.merkle_patricia_trie import StateManager
SERVICE_NAME = "aitbc-blockchain-node.service"
DATA_ROOT = Path("/var/lib/aitbc/data")
BACKUP_ROOT = Path("/var/lib/aitbc/backups/mpt-regeneration")
ENV_FILES = [Path("/etc/aitbc/.env"), Path("/etc/aitbc/node.env")]
def _run(command: list[str], check: bool = False) -> subprocess.CompletedProcess[str]:
return subprocess.run(command, text=True, capture_output=True, check=check)
def _service_state(service_name: str) -> dict[str, Any]:
active = _run(["systemctl", "is-active", service_name]).stdout.strip()
enabled = _run(["systemctl", "is-enabled", service_name]).stdout.strip()
fragment = _run(["systemctl", "show", service_name, "-p", "FragmentPath", "--value"]).stdout.strip()
dropins = _run(["systemctl", "show", service_name, "-p", "DropInPaths", "--value"]).stdout.strip()
return {
"active": active,
"enabled": enabled,
"fragment_path": fragment,
"drop_in_paths": [item for item in dropins.split() if item],
}
def _git_revision() -> str | None:
result = _run(["git", "-C", str(REPO_DIR), "rev-parse", "HEAD"])
if result.returncode != 0:
return None
return result.stdout.strip()
def _sha256_file(path: Path) -> str | None:
if not path.exists() or not path.is_file():
return None
digest = hashlib.sha256()
with path.open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def _load_genesis(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
with path.open() as handle:
return json.load(handle)
def _allocation_digest(genesis: dict[str, Any]) -> str | None:
allocations = genesis.get("allocations")
if allocations is None:
return None
canonical = json.dumps(sorted(allocations, key=lambda item: item.get("address", "")), sort_keys=True, separators=(",", ":"))
return hashlib.sha256(canonical.encode()).hexdigest()
def _live_db_files(chain_id: str) -> list[Path]:
chain_dir = DATA_ROOT / chain_id
files = [
chain_dir / "chain.db",
chain_dir / "chain.db-wal",
chain_dir / "chain.db-shm",
chain_dir / "chain.db-journal",
DATA_ROOT / "mempool.db",
DATA_ROOT / "mempool.db-wal",
DATA_ROOT / "mempool.db-shm",
DATA_ROOT / "mempool.db-journal",
chain_dir / "mempool.db",
chain_dir / "mempool.db-wal",
chain_dir / "mempool.db-shm",
chain_dir / "mempool.db-journal",
]
return [path for path in files if path.exists()]
def _backup_sources(chain_id: str) -> list[Path]:
chain_dir = DATA_ROOT / chain_id
sources: list[Path] = []
for pattern in [str(chain_dir / "chain.db*"), str(chain_dir / "genesis.json"), str(DATA_ROOT / "mempool.db*"), str(chain_dir / "mempool.db*")]:
sources.extend(Path(path) for path in glob.glob(pattern))
sources.extend(path for path in ENV_FILES if path.exists())
unique: dict[str, Path] = {}
for path in sources:
if path.exists():
unique[str(path)] = path
return [unique[key] for key in sorted(unique)]
def _integrity_check(path: Path) -> str | None:
if not path.exists() or path.stat().st_size == 0:
return None
try:
conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
try:
row = conn.execute("PRAGMA integrity_check").fetchone()
return row[0] if row else None
finally:
conn.close()
except sqlite3.Error as exc:
return f"error: {exc}"
def _db_snapshot(chain_id: str, db_path: Path) -> dict[str, Any]:
data: dict[str, Any] = {
"db_error": None,
"head": None,
"block_count": None,
"transaction_count": None,
"account_count": None,
"computed_state_root": None,
"head_state_root_matches_computed": None,
}
if not db_path.exists() or db_path.stat().st_size == 0:
data["db_error"] = "database missing or empty"
return data
engine = create_engine(f"sqlite:///{db_path}", echo=False)
try:
with Session(engine) as session:
head = session.exec(select(Block).where(Block.chain_id == chain_id).order_by(Block.height.desc()).limit(1)).first()
blocks = session.exec(select(Block).where(Block.chain_id == chain_id)).all()
transactions = session.exec(select(Transaction).where(Transaction.chain_id == chain_id)).all()
accounts = session.exec(select(Account).where(Account.chain_id == chain_id).order_by(Account.address)).all()
account_dict = {account.address: account for account in accounts}
computed_root = StateManager().compute_state_root(account_dict)
computed_root_hex = "0x" + computed_root.hex()
data.update(
{
"head": None if head is None else {
"height": head.height,
"hash": head.hash,
"state_root": head.state_root,
"proposer": head.proposer,
"timestamp": head.timestamp.isoformat() if head.timestamp else None,
},
"block_count": len(blocks),
"transaction_count": len(transactions),
"account_count": len(accounts),
"account_digest": hashlib.sha256(json.dumps(
[{"address": account.address, "balance": account.balance, "nonce": account.nonce} for account in accounts],
sort_keys=True,
separators=(",", ":"),
).encode()).hexdigest(),
"computed_state_root": computed_root_hex,
"head_state_root_matches_computed": bool(head and head.state_root == computed_root_hex),
}
)
except SQLAlchemyError as exc:
data["db_error"] = str(exc)
finally:
engine.dispose()
return data
def snapshot(chain_id: str, service_name: str) -> dict[str, Any]:
settings = ChainSettings()
db_path = settings.get_db_path(chain_id)
genesis_path = DATA_ROOT / chain_id / "genesis.json"
genesis = _load_genesis(genesis_path)
data: dict[str, Any] = {
"host": socket.gethostname(),
"chain_id": chain_id,
"timestamp": datetime.now(timezone.utc).isoformat(),
"repo_revision": _git_revision(),
"service": _service_state(service_name),
"db_path": str(db_path),
"db_exists": db_path.exists(),
"db_size_bytes": db_path.stat().st_size if db_path.exists() else None,
"db_integrity": _integrity_check(db_path),
"genesis_path": str(genesis_path),
"genesis_exists": genesis_path.exists(),
"genesis_file_sha256": _sha256_file(genesis_path),
"genesis_allocation_digest": _allocation_digest(genesis),
"genesis_allocation_count": len(genesis.get("allocations", [])) if genesis else 0,
"live_db_files": [str(path) for path in _live_db_files(chain_id)],
}
data.update(_db_snapshot(chain_id, db_path))
return data
def backup(chain_id: str, service_name: str, backup_root: Path, timestamp: str | None) -> dict[str, Any]:
stamp = timestamp or datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
target = backup_root / stamp / socket.gethostname() / chain_id
target.mkdir(parents=True, exist_ok=True)
manifest: dict[str, Any] = {
"host": socket.gethostname(),
"chain_id": chain_id,
"timestamp": stamp,
"target": str(target),
"preflight": snapshot(chain_id, service_name),
"files": [],
}
for source in _backup_sources(chain_id):
rel = source.relative_to(source.anchor)
destination = target / rel
destination.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source, destination)
manifest["files"].append(
{
"source": str(source),
"backup": str(destination),
"size_bytes": source.stat().st_size,
"sha256": _sha256_file(source),
"integrity_check": _integrity_check(destination) if source.name.startswith("chain.db") or source.name.startswith("mempool.db") else None,
}
)
manifest_path = target / "manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True, default=str))
return manifest
def reset(chain_id: str, service_name: str, yes: bool, force: bool) -> dict[str, Any]:
if not yes:
raise SystemExit("reset requires --yes")
service = _service_state(service_name)
if service["active"] == "active" and not force:
raise SystemExit(f"{service_name} is active; stop it first or pass --force")
removed: list[str] = []
for path in _live_db_files(chain_id):
path.unlink()
removed.append(str(path))
return {"host": socket.gethostname(), "chain_id": chain_id, "removed": removed}
def set_role(chain_id: str, service_name: str, role: str) -> dict[str, Any]:
dropin_dir = Path("/etc/systemd/system") / f"{service_name}.d"
dropin_path = dropin_dir / "mpt-regeneration.conf"
if role == "clear":
if dropin_path.exists():
dropin_path.unlink()
_run(["systemctl", "daemon-reload"], check=True)
return {"host": socket.gethostname(), "role": role, "dropin": str(dropin_path), "exists": dropin_path.exists()}
dropin_dir.mkdir(parents=True, exist_ok=True)
if role == "leader":
content = (
"[Service]\n"
'Environment="AITBC_FORCE_ENABLE_BLOCK_PRODUCTION=true"\n'
f'Environment="AITBC_FORCE_BLOCK_PRODUCTION_CHAINS={chain_id}"\n'
'Environment="ENABLE_BLOCK_PRODUCTION=true"\n'
'Environment="enable_block_production=true"\n'
f'Environment="BLOCK_PRODUCTION_CHAINS={chain_id}"\n'
f'Environment="block_production_chains={chain_id}"\n'
)
elif role == "follower":
content = (
"[Service]\n"
'Environment="AITBC_FORCE_ENABLE_BLOCK_PRODUCTION=false"\n'
'Environment="AITBC_FORCE_BLOCK_PRODUCTION_CHAINS="\n'
'Environment="ENABLE_BLOCK_PRODUCTION=false"\n'
'Environment="enable_block_production=false"\n'
'Environment="BLOCK_PRODUCTION_CHAINS="\n'
'Environment="block_production_chains="\n'
)
else:
raise SystemExit(f"unsupported role: {role}")
dropin_path.write_text(content)
_run(["systemctl", "daemon-reload"], check=True)
return {"host": socket.gethostname(), "role": role, "dropin": str(dropin_path), "exists": dropin_path.exists()}
def verify(chain_id: str, service_name: str, require_nonzero_root: bool) -> dict[str, Any]:
data = snapshot(chain_id, service_name)
head = data.get("head")
ok = data.get("db_error") is None and head is not None and data.get("head_state_root_matches_computed") is True
if require_nonzero_root and data.get("computed_state_root") == "0x" + ("00" * 32):
ok = False
data["ok"] = ok
return data
def print_json(data: Any) -> None:
print(json.dumps(data, indent=2, sort_keys=True, default=str))
def main() -> int:
parser = argparse.ArgumentParser(description="AITBC coordinated MPT chain regeneration node utility")
parser.add_argument("--chain-id", default="ait-mainnet")
parser.add_argument("--service-name", default=SERVICE_NAME)
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("preflight")
backup_parser = subparsers.add_parser("backup")
backup_parser.add_argument("--backup-root", type=Path, default=BACKUP_ROOT)
backup_parser.add_argument("--timestamp")
reset_parser = subparsers.add_parser("reset")
reset_parser.add_argument("--yes", action="store_true")
reset_parser.add_argument("--force", action="store_true")
verify_parser = subparsers.add_parser("verify")
verify_parser.add_argument("--require-nonzero-root", action="store_true")
role_parser = subparsers.add_parser("set-role")
role_parser.add_argument("role", choices=["leader", "follower", "clear"])
args = parser.parse_args()
if args.command == "preflight":
print_json(snapshot(args.chain_id, args.service_name))
elif args.command == "backup":
print_json(backup(args.chain_id, args.service_name, args.backup_root, args.timestamp))
elif args.command == "reset":
print_json(reset(args.chain_id, args.service_name, args.yes, args.force))
elif args.command == "verify":
result = verify(args.chain_id, args.service_name, args.require_nonzero_root)
print_json(result)
return 0 if result.get("ok") else 2
elif args.command == "set-role":
print_json(set_role(args.chain_id, args.service_name, args.role))
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,161 @@
#!/usr/bin/env bash
set -euo pipefail
CHAIN_ID="${CHAIN_ID:-ait-mainnet}"
SERVICE_NAME="${SERVICE_NAME:-aitbc-blockchain-node.service}"
LEADER="${LEADER:-aitbc1}"
NODES_RAW="${NODES:-localhost aitbc1 gitea-runner}"
UTILITY="${UTILITY:-/opt/aitbc/scripts/utils/chain_regen_node.py}"
PYTHON_BIN="${PYTHON_BIN:-/opt/aitbc/venv/bin/python}"
TIMESTAMP="${TIMESTAMP:-$(date -u +%Y%m%dT%H%M%SZ)}"
STARTUP_WAIT_SECONDS="${STARTUP_WAIT_SECONDS:-8}"
FOLLOWER_START_WAIT_SECONDS="${FOLLOWER_START_WAIT_SECONDS:-8}"
VERIFY_RETRIES="${VERIFY_RETRIES:-6}"
VERIFY_RETRY_SECONDS="${VERIFY_RETRY_SECONDS:-10}"
usage() {
cat <<USAGE
Usage: $0 <preflight|backup|set-roles|stop|reset|start|verify|rollout|clear-roles>
Environment:
CHAIN_ID=$CHAIN_ID
SERVICE_NAME=$SERVICE_NAME
LEADER=$LEADER
NODES="$NODES_RAW"
TIMESTAMP=$TIMESTAMP
USAGE
}
node_cmd() {
local node="$1"
shift
if [[ "$node" == "localhost" || "$node" == "local" || "$node" == "$(hostname)" ]]; then
"$@"
else
ssh "$node" "$*"
fi
}
run_node_utility() {
local node="$1"
shift
node_cmd "$node" "$PYTHON_BIN" "$UTILITY" --chain-id "$CHAIN_ID" --service-name "$SERVICE_NAME" "$@"
}
for_each_node() {
local action="$1"
shift
local node
for node in $NODES_RAW; do
echo "===== $node :: $action ====="
"$@" "$node"
done
}
preflight_node() {
run_node_utility "$1" preflight
}
backup_node() {
run_node_utility "$1" backup --timestamp "$TIMESTAMP"
}
set_role_node() {
local node="$1"
if [[ "$node" == "$LEADER" ]]; then
run_node_utility "$node" set-role leader
else
run_node_utility "$node" set-role follower
fi
}
clear_role_node() {
run_node_utility "$1" set-role clear
}
stop_node() {
node_cmd "$1" systemctl stop "$SERVICE_NAME"
}
reset_node() {
run_node_utility "$1" reset --yes
}
start_leader() {
echo "===== $LEADER :: start leader ====="
node_cmd "$LEADER" systemctl start "$SERVICE_NAME"
sleep "$STARTUP_WAIT_SECONDS"
}
start_followers() {
local node
for node in $NODES_RAW; do
if [[ "$node" == "$LEADER" ]]; then
continue
fi
echo "===== $node :: start follower ====="
node_cmd "$node" systemctl start "$SERVICE_NAME"
done
sleep "$FOLLOWER_START_WAIT_SECONDS"
}
verify_node() {
run_node_utility "$1" verify --require-nonzero-root
}
verify_node_with_retry() {
local node="$1"
local attempt
for attempt in $(seq 1 "$VERIFY_RETRIES"); do
if run_node_utility "$node" verify --require-nonzero-root; then
return 0
fi
if [[ "$attempt" == "$VERIFY_RETRIES" ]]; then
return 1
fi
echo "verify failed for $node, retrying in ${VERIFY_RETRY_SECONDS}s (${attempt}/${VERIFY_RETRIES})" >&2
sleep "$VERIFY_RETRY_SECONDS"
done
}
case "${1:-}" in
preflight)
for_each_node preflight preflight_node
;;
backup)
for_each_node backup backup_node
;;
set-roles)
for_each_node set-roles set_role_node
;;
clear-roles)
for_each_node clear-roles clear_role_node
;;
stop)
for_each_node stop stop_node
;;
reset)
for_each_node reset reset_node
;;
start)
start_leader
start_followers
;;
verify)
for_each_node verify verify_node
;;
rollout)
for_each_node preflight preflight_node
for_each_node set-roles set_role_node
for_each_node backup backup_node
for_each_node stop stop_node
for_each_node reset reset_node
start_leader
start_followers
for_each_node verify verify_node_with_retry
;;
*)
usage
exit 64
;;
esac

View File

@@ -137,7 +137,8 @@ def main() -> None:
# Ensure mempool DB exists (though not needed for genesis)
mempool_path = settings.db_path.parent / "mempool.db"
init_mempool(backend="database", db_path=str(mempool_path), max_size=10000, min_fee=0)
mempool_url = f"sqlite:///{mempool_path}"
init_mempool(backend="database", db_url=mempool_url, max_size=10000, min_fee=0)
print(f"[*] Mempool initialized at {mempool_path}")
# Create genesis block