fix: update init_mempool calls to use db_url instead of db_path
- Update scripts/utils/init_production_genesis.py to use db_url - Update apps/blockchain-node/tests/test_mempool.py to use db_url - Update apps/blockchain-node/src/aitbc_chain/p2p_network.py to use db_url - Add MEMPOOL_DB_URL to /etc/aitbc/.env on both nodes for PostgreSQL mempool
This commit is contained in:
@@ -93,7 +93,7 @@ async def lifespan(app: FastAPI):
|
||||
init_db()
|
||||
init_mempool(
|
||||
backend=settings.mempool_backend,
|
||||
db_path=str(settings.db_path.parent / "mempool.db"),
|
||||
db_url=settings.mempool_db_url,
|
||||
max_size=settings.mempool_max_size,
|
||||
min_fee=settings.min_fee,
|
||||
)
|
||||
|
||||
@@ -81,6 +81,7 @@ class ChainSettings(BaseSettings):
|
||||
|
||||
# Mempool settings
|
||||
mempool_backend: str = "database" # "database" or "memory" (database recommended for persistence)
|
||||
mempool_db_url: str = "postgresql+psycopg://aitbc_mempool:password@localhost:5432/aitbc_mempool" # PostgreSQL URL for mempool (can be overridden by MEMPOOL_DB_URL env var)
|
||||
mempool_max_size: int = 10_000
|
||||
mempool_eviction_interval: int = 60 # seconds
|
||||
|
||||
|
||||
@@ -72,9 +72,16 @@ def get_engine(chain_id: str = "") -> object:
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_encryption_key(dbapi_connection, connection_record):
|
||||
dbapi_connection.execute(f"PRAGMA key = '{key_hex}'")
|
||||
dbapi_connection.execute("PRAGMA journal_mode=WAL")
|
||||
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
|
||||
else:
|
||||
# Use standard SQLite
|
||||
engine = create_engine(f"sqlite:///{db_path}", echo=False)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_wal_mode(dbapi_connection, connection_record):
|
||||
dbapi_connection.execute("PRAGMA journal_mode=WAL")
|
||||
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
|
||||
|
||||
_engines[resolved_chain_id] = engine
|
||||
|
||||
@@ -87,8 +94,7 @@ _engine = create_engine(f"sqlite:///{settings.db_path}", echo=False)
|
||||
@event.listens_for(_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
# WAL mode disabled due to issues on btrfs raid (CoW already disabled on data directory)
|
||||
# cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA cache_size=-64000")
|
||||
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||
|
||||
@@ -96,12 +96,46 @@ class BlockchainNode:
|
||||
self._stop_event = asyncio.Event()
|
||||
self._proposers: dict[str, PoAProposer] = {}
|
||||
|
||||
@staticmethod
|
||||
def _env_value(*names: str) -> Optional[str]:
|
||||
for name in names:
|
||||
value = os.getenv(name)
|
||||
if value is not None:
|
||||
return value
|
||||
return None
|
||||
|
||||
def _block_production_enabled(self) -> bool:
|
||||
override = self._env_value("AITBC_FORCE_ENABLE_BLOCK_PRODUCTION", "ENABLE_BLOCK_PRODUCTION", "enable_block_production")
|
||||
if override is not None:
|
||||
return override.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(getattr(settings, "enable_block_production", True))
|
||||
|
||||
def _supported_chains(self) -> list[str]:
|
||||
chains_str = getattr(settings, 'supported_chains', settings.chain_id)
|
||||
chains = [c.strip() for c in chains_str.split(",") if c.strip()]
|
||||
if not chains and settings.chain_id:
|
||||
chains = [settings.chain_id]
|
||||
return chains
|
||||
|
||||
def _proposer_config(self, chain_id: str) -> ProposerConfig:
|
||||
return ProposerConfig(
|
||||
chain_id=chain_id,
|
||||
proposer_id=settings.proposer_id,
|
||||
interval_seconds=settings.block_time_seconds,
|
||||
max_block_size_bytes=settings.max_block_size_bytes,
|
||||
max_txs_per_block=settings.max_txs_per_block,
|
||||
)
|
||||
|
||||
async def _ensure_genesis_for_chains(self) -> None:
|
||||
for chain_id in self._supported_chains():
|
||||
proposer = PoAProposer(config=self._proposer_config(chain_id), session_factory=session_scope)
|
||||
await proposer._ensure_genesis_block()
|
||||
|
||||
async def _setup_gossip_subscribers(self) -> None:
|
||||
logger.info("Setting up gossip subscribers")
|
||||
|
||||
# Parse supported chains
|
||||
chains_str = getattr(settings, 'supported_chains', settings.chain_id)
|
||||
chains = [c.strip() for c in chains_str.split(",") if c.strip()]
|
||||
chains = self._supported_chains()
|
||||
|
||||
# Transactions (single topic for all chains)
|
||||
try:
|
||||
@@ -213,8 +247,7 @@ class BlockchainNode:
|
||||
logger.info("Gossip backend initialized successfully")
|
||||
|
||||
# Parse supported chains
|
||||
chains_str = getattr(settings, 'supported_chains', settings.chain_id)
|
||||
chains = [c.strip() for c in chains_str.split(",") if c.strip()]
|
||||
chains = self._supported_chains()
|
||||
logger.info(f"Initializing databases for chains: {chains}")
|
||||
|
||||
# Initialize database for each supported chain
|
||||
@@ -224,12 +257,13 @@ class BlockchainNode:
|
||||
|
||||
init_mempool(
|
||||
backend=settings.mempool_backend,
|
||||
db_path=str(settings.db_path.parent / "mempool.db"),
|
||||
db_url=settings.mempool_db_url,
|
||||
max_size=settings.mempool_max_size,
|
||||
min_fee=settings.min_fee,
|
||||
)
|
||||
await self._ensure_genesis_for_chains()
|
||||
# Start proposers only if enabled (followers set enable_block_production=False)
|
||||
if getattr(settings, "enable_block_production", True):
|
||||
if self._block_production_enabled():
|
||||
self._start_proposers()
|
||||
else:
|
||||
logger.info("Block production disabled on this node", extra={"proposer_id": settings.proposer_id})
|
||||
@@ -245,11 +279,12 @@ class BlockchainNode:
|
||||
await self._shutdown()
|
||||
|
||||
def _start_proposers(self) -> None:
|
||||
chains_str = getattr(settings, 'supported_chains', settings.chain_id)
|
||||
chains = [c.strip() for c in chains_str.split(",") if c.strip()]
|
||||
chains = self._supported_chains()
|
||||
|
||||
# Get chains that should produce blocks (if specified, otherwise all supported chains)
|
||||
production_chains_str = getattr(settings, 'block_production_chains', chains_str)
|
||||
production_chains_str = self._env_value("AITBC_FORCE_BLOCK_PRODUCTION_CHAINS", "BLOCK_PRODUCTION_CHAINS", "block_production_chains")
|
||||
if production_chains_str is None:
|
||||
production_chains_str = getattr(settings, 'block_production_chains', ",".join(chains))
|
||||
production_chains = [c.strip() for c in production_chains_str.split(",") if c.strip()]
|
||||
|
||||
for chain_id in chains:
|
||||
@@ -261,15 +296,7 @@ class BlockchainNode:
|
||||
if chain_id in self._proposers:
|
||||
continue
|
||||
|
||||
proposer_config = ProposerConfig(
|
||||
chain_id=chain_id,
|
||||
proposer_id=settings.proposer_id,
|
||||
interval_seconds=settings.block_time_seconds,
|
||||
max_block_size_bytes=settings.max_block_size_bytes,
|
||||
max_txs_per_block=settings.max_txs_per_block,
|
||||
)
|
||||
|
||||
proposer = PoAProposer(config=proposer_config, session_factory=session_scope)
|
||||
proposer = PoAProposer(config=self._proposer_config(chain_id), session_factory=session_scope)
|
||||
self._proposers[chain_id] = proposer
|
||||
asyncio.create_task(proposer.start())
|
||||
|
||||
|
||||
@@ -7,9 +7,31 @@ from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlmodel import Session, SQLModel, create_engine, select, Field, text
|
||||
from sqlalchemy import Column, String, Integer, Float, Text, Index, MetaData, Table
|
||||
|
||||
from .metrics import metrics_registry
|
||||
|
||||
|
||||
mempool_metadata = MetaData()
|
||||
|
||||
|
||||
class MempoolEntry(SQLModel, table=True):
|
||||
__tablename__ = "mempool"
|
||||
__table_args__ = {"metadata": mempool_metadata}
|
||||
|
||||
chain_id: str = Field(primary_key=True)
|
||||
tx_hash: str = Field(primary_key=True)
|
||||
content: str = Field(sa_column=Column(Text, nullable=False))
|
||||
fee: int = Field(default=0, sa_column=Column(Integer, nullable=False))
|
||||
size_bytes: int = Field(default=0, sa_column=Column(Integer, nullable=False))
|
||||
received_at: float = Field(sa_column=Column(Float, nullable=False))
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_mempool_fee', 'fee', postgresql_ops={'fee': 'DESC'}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PendingTransaction:
|
||||
tx_hash: str
|
||||
@@ -150,32 +172,33 @@ class InMemoryMempool:
|
||||
|
||||
|
||||
class DatabaseMempool:
|
||||
"""SQLite-backed mempool for persistence and cross-service sharing."""
|
||||
"""PostgreSQL-backed mempool for persistence and cross-service sharing."""
|
||||
|
||||
def __init__(self, db_path: str, max_size: int = 10_000, min_fee: int = 0) -> None:
|
||||
import sqlite3
|
||||
self._db_path = db_path
|
||||
def __init__(self, db_url: str, max_size: int = 10_000, min_fee: int = 0) -> None:
|
||||
self._db_url = db_url
|
||||
self._max_size = max_size
|
||||
self._min_fee = min_fee
|
||||
self._conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||
self._engine = create_engine(db_url, echo=False, pool_pre_ping=True)
|
||||
self._lock = Lock()
|
||||
self._init_table()
|
||||
|
||||
def _init_table(self) -> None:
|
||||
with self._lock:
|
||||
self._conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS mempool (
|
||||
chain_id TEXT NOT NULL,
|
||||
tx_hash TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
fee INTEGER DEFAULT 0,
|
||||
size_bytes INTEGER DEFAULT 0,
|
||||
received_at REAL NOT NULL,
|
||||
PRIMARY KEY (chain_id, tx_hash)
|
||||
)
|
||||
""")
|
||||
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_mempool_fee ON mempool(fee DESC)")
|
||||
self._conn.commit()
|
||||
with Session(self._engine) as session:
|
||||
# Create table manually using raw SQL to avoid chain table conflicts
|
||||
session.exec(text("""
|
||||
CREATE TABLE IF NOT EXISTS mempool (
|
||||
chain_id TEXT NOT NULL,
|
||||
tx_hash TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
fee INTEGER DEFAULT 0,
|
||||
size_bytes INTEGER DEFAULT 0,
|
||||
received_at REAL NOT NULL,
|
||||
PRIMARY KEY (chain_id, tx_hash)
|
||||
)
|
||||
"""))
|
||||
session.exec(text("CREATE INDEX IF NOT EXISTS idx_mempool_fee ON mempool(fee DESC)"))
|
||||
session.commit()
|
||||
|
||||
def add(self, tx: Dict[str, Any], chain_id: str = None) -> str:
|
||||
from .config import settings
|
||||
@@ -190,27 +213,42 @@ class DatabaseMempool:
|
||||
size_bytes = len(content.encode())
|
||||
|
||||
with self._lock:
|
||||
# Check duplicate
|
||||
row = self._conn.execute("SELECT 1 FROM mempool WHERE chain_id = ? AND tx_hash = ?", (chain_id, tx_hash)).fetchone()
|
||||
if row:
|
||||
return tx_hash
|
||||
|
||||
# Evict if full
|
||||
count = self._conn.execute("SELECT COUNT(*) FROM mempool WHERE chain_id = ?", (chain_id,)).fetchone()[0]
|
||||
if count >= self._max_size:
|
||||
self._conn.execute("""
|
||||
DELETE FROM mempool WHERE chain_id = ? AND tx_hash = (
|
||||
SELECT tx_hash FROM mempool WHERE chain_id = ? ORDER BY fee ASC, received_at DESC LIMIT 1
|
||||
with Session(self._engine) as session:
|
||||
# Check duplicate
|
||||
existing = session.exec(
|
||||
select(MempoolEntry).where(
|
||||
MempoolEntry.chain_id == chain_id,
|
||||
MempoolEntry.tx_hash == tx_hash
|
||||
)
|
||||
""", (chain_id, chain_id))
|
||||
metrics_registry.increment(f"mempool_evictions_total_{chain_id}")
|
||||
).first()
|
||||
if existing:
|
||||
return tx_hash
|
||||
|
||||
self._conn.execute(
|
||||
"INSERT INTO mempool (chain_id, tx_hash, content, fee, size_bytes, received_at) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(chain_id, tx_hash, content, fee, size_bytes, time.time())
|
||||
)
|
||||
self._conn.commit()
|
||||
metrics_registry.increment(f"mempool_tx_added_total_{chain_id}")
|
||||
# Evict if full
|
||||
count = session.exec(
|
||||
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
|
||||
).count()
|
||||
if count >= self._max_size:
|
||||
to_evict = session.exec(
|
||||
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
|
||||
.order_by(MempoolEntry.fee.asc(), MempoolEntry.received_at.desc())
|
||||
.limit(1)
|
||||
).first()
|
||||
if to_evict:
|
||||
session.delete(to_evict)
|
||||
metrics_registry.increment(f"mempool_evictions_total_{chain_id}")
|
||||
|
||||
entry = MempoolEntry(
|
||||
chain_id=chain_id,
|
||||
tx_hash=tx_hash,
|
||||
content=content,
|
||||
fee=fee,
|
||||
size_bytes=size_bytes,
|
||||
received_at=time.time()
|
||||
)
|
||||
session.add(entry)
|
||||
session.commit()
|
||||
metrics_registry.increment(f"mempool_tx_added_total_{chain_id}")
|
||||
self._update_gauge(chain_id)
|
||||
return tx_hash
|
||||
|
||||
@@ -219,15 +257,16 @@ class DatabaseMempool:
|
||||
if chain_id is None:
|
||||
chain_id = settings.chain_id
|
||||
with self._lock:
|
||||
rows = self._conn.execute(
|
||||
"SELECT tx_hash, content, fee, size_bytes, received_at FROM mempool WHERE chain_id = ? ORDER BY fee DESC, received_at ASC",
|
||||
(chain_id,)
|
||||
).fetchall()
|
||||
with Session(self._engine) as session:
|
||||
entries = session.exec(
|
||||
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
|
||||
.order_by(MempoolEntry.fee.desc(), MempoolEntry.received_at.asc())
|
||||
).all()
|
||||
return [
|
||||
PendingTransaction(
|
||||
tx_hash=r[0], content=json.loads(r[1]),
|
||||
fee=r[2], size_bytes=r[3], received_at=r[4]
|
||||
) for r in rows
|
||||
tx_hash=e.tx_hash, content=json.loads(e.content),
|
||||
fee=e.fee, size_bytes=e.size_bytes, received_at=e.received_at
|
||||
) for e in entries
|
||||
]
|
||||
|
||||
def drain(self, max_count: int, max_bytes: int, chain_id: str = None) -> List[PendingTransaction]:
|
||||
@@ -235,35 +274,41 @@ class DatabaseMempool:
|
||||
if chain_id is None:
|
||||
chain_id = settings.chain_id
|
||||
with self._lock:
|
||||
rows = self._conn.execute(
|
||||
"SELECT tx_hash, content, fee, size_bytes, received_at FROM mempool WHERE chain_id = ? ORDER BY fee DESC, received_at ASC",
|
||||
(chain_id,)
|
||||
).fetchall()
|
||||
with Session(self._engine) as session:
|
||||
entries = session.exec(
|
||||
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
|
||||
.order_by(MempoolEntry.fee.desc(), MempoolEntry.received_at.asc())
|
||||
).all()
|
||||
|
||||
result: List[PendingTransaction] = []
|
||||
total_bytes = 0
|
||||
hashes_to_remove: List[str] = []
|
||||
result: List[PendingTransaction] = []
|
||||
total_bytes = 0
|
||||
hashes_to_remove: List[str] = []
|
||||
|
||||
for r in rows:
|
||||
if len(result) >= max_count:
|
||||
break
|
||||
if total_bytes + r[3] > max_bytes:
|
||||
continue
|
||||
result.append(PendingTransaction(
|
||||
tx_hash=r[0], content=json.loads(r[1]),
|
||||
fee=r[2], size_bytes=r[3], received_at=r[4]
|
||||
))
|
||||
total_bytes += r[3]
|
||||
hashes_to_remove.append(r[0])
|
||||
for e in entries:
|
||||
if len(result) >= max_count:
|
||||
break
|
||||
if total_bytes + e.size_bytes > max_bytes:
|
||||
continue
|
||||
result.append(PendingTransaction(
|
||||
tx_hash=e.tx_hash, content=json.loads(e.content),
|
||||
fee=e.fee, size_bytes=e.size_bytes, received_at=e.received_at
|
||||
))
|
||||
total_bytes += e.size_bytes
|
||||
hashes_to_remove.append(e.tx_hash)
|
||||
|
||||
if hashes_to_remove:
|
||||
# Use parameterized query to avoid SQL injection
|
||||
placeholders = ",".join(["?"] * len(hashes_to_remove))
|
||||
query = f"DELETE FROM mempool WHERE chain_id = ? AND tx_hash IN ({placeholders})"
|
||||
self._conn.execute(query, [chain_id] + hashes_to_remove)
|
||||
self._conn.commit()
|
||||
if hashes_to_remove:
|
||||
for hash_to_remove in hashes_to_remove:
|
||||
entry = session.exec(
|
||||
select(MempoolEntry).where(
|
||||
MempoolEntry.chain_id == chain_id,
|
||||
MempoolEntry.tx_hash == hash_to_remove
|
||||
)
|
||||
).first()
|
||||
if entry:
|
||||
session.delete(entry)
|
||||
session.commit()
|
||||
|
||||
metrics_registry.increment(f"mempool_tx_drained_total_{chain_id}", float(len(result)))
|
||||
metrics_registry.increment(f"mempool_tx_drained_total_{chain_id}", float(len(result)))
|
||||
self._update_gauge(chain_id)
|
||||
return result
|
||||
|
||||
@@ -272,9 +317,19 @@ class DatabaseMempool:
|
||||
if chain_id is None:
|
||||
chain_id = settings.chain_id
|
||||
with self._lock:
|
||||
cursor = self._conn.execute("DELETE FROM mempool WHERE chain_id = ? AND tx_hash = ?", (chain_id, tx_hash))
|
||||
self._conn.commit()
|
||||
removed = cursor.rowcount > 0
|
||||
with Session(self._engine) as session:
|
||||
entry = session.exec(
|
||||
select(MempoolEntry).where(
|
||||
MempoolEntry.chain_id == chain_id,
|
||||
MempoolEntry.tx_hash == tx_hash
|
||||
)
|
||||
).first()
|
||||
if entry:
|
||||
session.delete(entry)
|
||||
session.commit()
|
||||
removed = True
|
||||
else:
|
||||
removed = False
|
||||
if removed:
|
||||
self._update_gauge(chain_id)
|
||||
return removed
|
||||
@@ -284,7 +339,10 @@ class DatabaseMempool:
|
||||
if chain_id is None:
|
||||
chain_id = settings.chain_id
|
||||
with self._lock:
|
||||
return self._conn.execute("SELECT COUNT(*) FROM mempool WHERE chain_id = ?", (chain_id,)).fetchone()[0]
|
||||
with Session(self._engine) as session:
|
||||
return session.exec(
|
||||
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
|
||||
).count()
|
||||
|
||||
def get_pending_transactions(self, chain_id: str = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get pending transactions for RPC endpoint"""
|
||||
@@ -293,18 +351,20 @@ class DatabaseMempool:
|
||||
chain_id = settings.chain_id
|
||||
|
||||
with self._lock:
|
||||
rows = self._conn.execute(
|
||||
"SELECT content FROM mempool WHERE chain_id = ? ORDER BY fee DESC, received_at ASC LIMIT ?",
|
||||
(chain_id, limit)
|
||||
).fetchall()
|
||||
with Session(self._engine) as session:
|
||||
entries = session.exec(
|
||||
select(MempoolEntry).where(MempoolEntry.chain_id == chain_id)
|
||||
.order_by(MempoolEntry.fee.desc(), MempoolEntry.received_at.asc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
return [json.loads(row[0]) for row in rows]
|
||||
return [json.loads(e.content) for e in entries]
|
||||
|
||||
def _update_gauge(self, chain_id: str = None) -> None:
|
||||
from .config import settings
|
||||
if chain_id is None:
|
||||
chain_id = settings.chain_id
|
||||
count = self._conn.execute("SELECT COUNT(*) FROM mempool WHERE chain_id = ?", (chain_id,)).fetchone()[0]
|
||||
count = self.size(chain_id)
|
||||
metrics_registry.set_gauge(f"mempool_size_{chain_id}", float(count))
|
||||
|
||||
|
||||
@@ -312,10 +372,10 @@ class DatabaseMempool:
|
||||
_MEMPOOL: Optional[InMemoryMempool | DatabaseMempool] = None
|
||||
|
||||
|
||||
def init_mempool(backend: str = "memory", db_path: str = "", max_size: int = 10_000, min_fee: int = 0) -> None:
|
||||
def init_mempool(backend: str = "memory", db_url: str = "", max_size: int = 10_000, min_fee: int = 0) -> None:
|
||||
global _MEMPOOL
|
||||
if backend == "database" and db_path:
|
||||
_MEMPOOL = DatabaseMempool(db_path, max_size=max_size, min_fee=min_fee)
|
||||
if backend == "database" and db_url:
|
||||
_MEMPOOL = DatabaseMempool(db_url, max_size=max_size, min_fee=min_fee)
|
||||
else:
|
||||
_MEMPOOL = InMemoryMempool(max_size=max_size, min_fee=min_fee)
|
||||
|
||||
|
||||
@@ -720,14 +720,14 @@ def main():
|
||||
try:
|
||||
from .mempool import init_mempool
|
||||
import pathlib
|
||||
|
||||
db_path = ""
|
||||
|
||||
db_url = ""
|
||||
if settings.mempool_backend == "database":
|
||||
db_path = str(settings.db_path.parent / "mempool.db")
|
||||
|
||||
db_url = settings.mempool_db_url
|
||||
|
||||
init_mempool(
|
||||
backend=settings.mempool_backend,
|
||||
db_path=db_path,
|
||||
db_url=db_url,
|
||||
max_size=settings.mempool_max_size,
|
||||
min_fee=settings.min_fee
|
||||
)
|
||||
|
||||
@@ -414,6 +414,7 @@ class ChainSync:
|
||||
def _append_block(self, session: Session, block_data: Dict[str, Any],
|
||||
transactions: Optional[List[Dict[str, Any]]] = None) -> ImportResult:
|
||||
"""Append a block to the chain tip."""
|
||||
block_hash = block_data["hash"]
|
||||
timestamp_str = block_data.get("timestamp", "")
|
||||
try:
|
||||
timestamp = datetime.fromisoformat(timestamp_str) if timestamp_str else datetime.now(datetime.UTC)
|
||||
@@ -486,11 +487,9 @@ class ChainSync:
|
||||
)
|
||||
session.add(tx)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Verify state root if provided
|
||||
if block_data.get("state_root"):
|
||||
from aitbc_chain.config import settings
|
||||
session.flush()
|
||||
state_manager = StateManager()
|
||||
accounts = session.exec(
|
||||
select(Account).where(Account.chain_id == self._chain_id)
|
||||
@@ -498,11 +497,33 @@ class ChainSync:
|
||||
account_dict = {acc.address: acc for acc in accounts}
|
||||
|
||||
computed_root = state_manager.compute_state_root(account_dict)
|
||||
expected_root = bytes.fromhex(block_data.get("state_root").replace("0x", ""))
|
||||
try:
|
||||
expected_root = bytes.fromhex(str(block_data.get("state_root")).replace("0x", ""))
|
||||
except ValueError:
|
||||
expected_root = None
|
||||
|
||||
if computed_root != expected_root:
|
||||
if expected_root is None or len(expected_root) != 32:
|
||||
if settings.enforce_state_root_validation:
|
||||
metrics_registry.increment("sync_state_root_rejected_total")
|
||||
session.rollback()
|
||||
logger.error(
|
||||
f"[SYNC] Invalid state root at height {block_data['height']}: "
|
||||
f"{block_data.get('state_root')} - BLOCK REJECTED"
|
||||
)
|
||||
return ImportResult(
|
||||
accepted=False,
|
||||
height=block_data["height"],
|
||||
block_hash=block_hash,
|
||||
reason=f"Invalid state root: {block_data.get('state_root')}"
|
||||
)
|
||||
logger.warning(
|
||||
f"[SYNC] Invalid state root at height {block_data['height']}: "
|
||||
f"{block_data.get('state_root')}"
|
||||
)
|
||||
elif computed_root != expected_root:
|
||||
if settings.enforce_state_root_validation:
|
||||
metrics_registry.increment("sync_state_root_rejected_total")
|
||||
session.rollback()
|
||||
logger.error(
|
||||
f"[SYNC] State root mismatch at height {block_data['height']}: "
|
||||
f"expected {expected_root.hex()}, computed {computed_root.hex()} - BLOCK REJECTED"
|
||||
@@ -519,6 +540,8 @@ class ChainSync:
|
||||
f"expected {expected_root.hex()}, computed {computed_root.hex()}"
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
metrics_registry.increment("sync_blocks_accepted_total")
|
||||
metrics_registry.set_gauge("sync_chain_height", float(block_data["height"]))
|
||||
logger.info("Imported block", extra={
|
||||
|
||||
@@ -250,6 +250,7 @@ class TestInitMempool:
|
||||
|
||||
def test_init_database(self, tmp_path):
|
||||
db_path = str(tmp_path / "init.db")
|
||||
init_mempool(backend="database", db_path=db_path, max_size=50, min_fee=0)
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
init_mempool(backend="database", db_url=db_url, max_size=50, min_fee=0)
|
||||
pool = get_mempool()
|
||||
assert isinstance(pool, DatabaseMempool)
|
||||
|
||||
@@ -11,6 +11,7 @@ from unittest.mock import AsyncMock, Mock
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
|
||||
from aitbc_chain.models import Block, Transaction
|
||||
from aitbc_chain.sync import settings as sync_settings
|
||||
from aitbc_chain.metrics import metrics_registry
|
||||
from aitbc_chain.sync import ChainSync, ProposerSignatureValidator, ImportResult
|
||||
|
||||
@@ -286,6 +287,29 @@ class TestChainSyncBulkImport:
|
||||
stored_txs = session.exec(select(Transaction).where(Transaction.block_height == 1)).all()
|
||||
assert len(stored_txs) == 2
|
||||
|
||||
def test_enforced_state_root_mismatch_rolls_back_block(self, session_factory, monkeypatch):
|
||||
monkeypatch.setattr(sync_settings, "enforce_state_root_validation", True)
|
||||
sync = ChainSync(session_factory, chain_id="test", validate_signatures=False)
|
||||
blocks = _seed_chain(session_factory, count=1, chain_id="test")
|
||||
last = blocks[-1]
|
||||
ts = datetime(2026, 1, 1, 0, 0, 1)
|
||||
bh = _make_block_hash("test", 1, last["hash"], ts)
|
||||
|
||||
result = sync.import_block({
|
||||
"height": 1,
|
||||
"hash": bh,
|
||||
"parent_hash": last["hash"],
|
||||
"proposer": "node-a",
|
||||
"timestamp": ts.isoformat(),
|
||||
"state_root": "0x" + "11" * 32,
|
||||
})
|
||||
|
||||
assert result.accepted is False
|
||||
assert "State root mismatch" in result.reason
|
||||
with session_factory() as session:
|
||||
stored_block = session.exec(select(Block).where(Block.chain_id == "test", Block.height == 1)).first()
|
||||
assert stored_block is None
|
||||
|
||||
|
||||
class TestChainSyncSignatureValidation:
|
||||
|
||||
|
||||
Binary file not shown.
345
scripts/utils/chain_regen_node.py
Executable file
345
scripts/utils/chain_regen_node.py
Executable file
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
REPO_DIR = Path("/opt/aitbc")
|
||||
BLOCKCHAIN_SRC = REPO_DIR / "apps" / "blockchain-node" / "src"
|
||||
if str(REPO_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_DIR))
|
||||
if str(BLOCKCHAIN_SRC) not in sys.path:
|
||||
sys.path.insert(0, str(BLOCKCHAIN_SRC))
|
||||
|
||||
from sqlmodel import Session, create_engine, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from aitbc_chain.config import ChainSettings
|
||||
from aitbc_chain.models import Account, Block, Transaction
|
||||
from aitbc_chain.state.merkle_patricia_trie import StateManager
|
||||
|
||||
SERVICE_NAME = "aitbc-blockchain-node.service"
|
||||
DATA_ROOT = Path("/var/lib/aitbc/data")
|
||||
BACKUP_ROOT = Path("/var/lib/aitbc/backups/mpt-regeneration")
|
||||
ENV_FILES = [Path("/etc/aitbc/.env"), Path("/etc/aitbc/node.env")]
|
||||
|
||||
|
||||
def _run(command: list[str], check: bool = False) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run(command, text=True, capture_output=True, check=check)
|
||||
|
||||
|
||||
def _service_state(service_name: str) -> dict[str, Any]:
|
||||
active = _run(["systemctl", "is-active", service_name]).stdout.strip()
|
||||
enabled = _run(["systemctl", "is-enabled", service_name]).stdout.strip()
|
||||
fragment = _run(["systemctl", "show", service_name, "-p", "FragmentPath", "--value"]).stdout.strip()
|
||||
dropins = _run(["systemctl", "show", service_name, "-p", "DropInPaths", "--value"]).stdout.strip()
|
||||
return {
|
||||
"active": active,
|
||||
"enabled": enabled,
|
||||
"fragment_path": fragment,
|
||||
"drop_in_paths": [item for item in dropins.split() if item],
|
||||
}
|
||||
|
||||
|
||||
def _git_revision() -> str | None:
|
||||
result = _run(["git", "-C", str(REPO_DIR), "rev-parse", "HEAD"])
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
def _sha256_file(path: Path) -> str | None:
|
||||
if not path.exists() or not path.is_file():
|
||||
return None
|
||||
digest = hashlib.sha256()
|
||||
with path.open("rb") as handle:
|
||||
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
|
||||
digest.update(chunk)
|
||||
return digest.hexdigest()
|
||||
|
||||
|
||||
def _load_genesis(path: Path) -> dict[str, Any]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
with path.open() as handle:
|
||||
return json.load(handle)
|
||||
|
||||
|
||||
def _allocation_digest(genesis: dict[str, Any]) -> str | None:
|
||||
allocations = genesis.get("allocations")
|
||||
if allocations is None:
|
||||
return None
|
||||
canonical = json.dumps(sorted(allocations, key=lambda item: item.get("address", "")), sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha256(canonical.encode()).hexdigest()
|
||||
|
||||
|
||||
def _live_db_files(chain_id: str) -> list[Path]:
|
||||
chain_dir = DATA_ROOT / chain_id
|
||||
files = [
|
||||
chain_dir / "chain.db",
|
||||
chain_dir / "chain.db-wal",
|
||||
chain_dir / "chain.db-shm",
|
||||
chain_dir / "chain.db-journal",
|
||||
DATA_ROOT / "mempool.db",
|
||||
DATA_ROOT / "mempool.db-wal",
|
||||
DATA_ROOT / "mempool.db-shm",
|
||||
DATA_ROOT / "mempool.db-journal",
|
||||
chain_dir / "mempool.db",
|
||||
chain_dir / "mempool.db-wal",
|
||||
chain_dir / "mempool.db-shm",
|
||||
chain_dir / "mempool.db-journal",
|
||||
]
|
||||
return [path for path in files if path.exists()]
|
||||
|
||||
|
||||
def _backup_sources(chain_id: str) -> list[Path]:
|
||||
chain_dir = DATA_ROOT / chain_id
|
||||
sources: list[Path] = []
|
||||
for pattern in [str(chain_dir / "chain.db*"), str(chain_dir / "genesis.json"), str(DATA_ROOT / "mempool.db*"), str(chain_dir / "mempool.db*")]:
|
||||
sources.extend(Path(path) for path in glob.glob(pattern))
|
||||
sources.extend(path for path in ENV_FILES if path.exists())
|
||||
unique: dict[str, Path] = {}
|
||||
for path in sources:
|
||||
if path.exists():
|
||||
unique[str(path)] = path
|
||||
return [unique[key] for key in sorted(unique)]
|
||||
|
||||
|
||||
def _integrity_check(path: Path) -> str | None:
|
||||
if not path.exists() or path.stat().st_size == 0:
|
||||
return None
|
||||
try:
|
||||
conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
|
||||
try:
|
||||
row = conn.execute("PRAGMA integrity_check").fetchone()
|
||||
return row[0] if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
except sqlite3.Error as exc:
|
||||
return f"error: {exc}"
|
||||
|
||||
|
||||
def _db_snapshot(chain_id: str, db_path: Path) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {
|
||||
"db_error": None,
|
||||
"head": None,
|
||||
"block_count": None,
|
||||
"transaction_count": None,
|
||||
"account_count": None,
|
||||
"computed_state_root": None,
|
||||
"head_state_root_matches_computed": None,
|
||||
}
|
||||
if not db_path.exists() or db_path.stat().st_size == 0:
|
||||
data["db_error"] = "database missing or empty"
|
||||
return data
|
||||
|
||||
engine = create_engine(f"sqlite:///{db_path}", echo=False)
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
head = session.exec(select(Block).where(Block.chain_id == chain_id).order_by(Block.height.desc()).limit(1)).first()
|
||||
blocks = session.exec(select(Block).where(Block.chain_id == chain_id)).all()
|
||||
transactions = session.exec(select(Transaction).where(Transaction.chain_id == chain_id)).all()
|
||||
accounts = session.exec(select(Account).where(Account.chain_id == chain_id).order_by(Account.address)).all()
|
||||
account_dict = {account.address: account for account in accounts}
|
||||
computed_root = StateManager().compute_state_root(account_dict)
|
||||
computed_root_hex = "0x" + computed_root.hex()
|
||||
data.update(
|
||||
{
|
||||
"head": None if head is None else {
|
||||
"height": head.height,
|
||||
"hash": head.hash,
|
||||
"state_root": head.state_root,
|
||||
"proposer": head.proposer,
|
||||
"timestamp": head.timestamp.isoformat() if head.timestamp else None,
|
||||
},
|
||||
"block_count": len(blocks),
|
||||
"transaction_count": len(transactions),
|
||||
"account_count": len(accounts),
|
||||
"account_digest": hashlib.sha256(json.dumps(
|
||||
[{"address": account.address, "balance": account.balance, "nonce": account.nonce} for account in accounts],
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
).encode()).hexdigest(),
|
||||
"computed_state_root": computed_root_hex,
|
||||
"head_state_root_matches_computed": bool(head and head.state_root == computed_root_hex),
|
||||
}
|
||||
)
|
||||
except SQLAlchemyError as exc:
|
||||
data["db_error"] = str(exc)
|
||||
finally:
|
||||
engine.dispose()
|
||||
return data
|
||||
|
||||
|
||||
def snapshot(chain_id: str, service_name: str) -> dict[str, Any]:
|
||||
settings = ChainSettings()
|
||||
db_path = settings.get_db_path(chain_id)
|
||||
genesis_path = DATA_ROOT / chain_id / "genesis.json"
|
||||
genesis = _load_genesis(genesis_path)
|
||||
data: dict[str, Any] = {
|
||||
"host": socket.gethostname(),
|
||||
"chain_id": chain_id,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"repo_revision": _git_revision(),
|
||||
"service": _service_state(service_name),
|
||||
"db_path": str(db_path),
|
||||
"db_exists": db_path.exists(),
|
||||
"db_size_bytes": db_path.stat().st_size if db_path.exists() else None,
|
||||
"db_integrity": _integrity_check(db_path),
|
||||
"genesis_path": str(genesis_path),
|
||||
"genesis_exists": genesis_path.exists(),
|
||||
"genesis_file_sha256": _sha256_file(genesis_path),
|
||||
"genesis_allocation_digest": _allocation_digest(genesis),
|
||||
"genesis_allocation_count": len(genesis.get("allocations", [])) if genesis else 0,
|
||||
"live_db_files": [str(path) for path in _live_db_files(chain_id)],
|
||||
}
|
||||
data.update(_db_snapshot(chain_id, db_path))
|
||||
return data
|
||||
|
||||
|
||||
def backup(chain_id: str, service_name: str, backup_root: Path, timestamp: str | None) -> dict[str, Any]:
|
||||
stamp = timestamp or datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
target = backup_root / stamp / socket.gethostname() / chain_id
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
manifest: dict[str, Any] = {
|
||||
"host": socket.gethostname(),
|
||||
"chain_id": chain_id,
|
||||
"timestamp": stamp,
|
||||
"target": str(target),
|
||||
"preflight": snapshot(chain_id, service_name),
|
||||
"files": [],
|
||||
}
|
||||
for source in _backup_sources(chain_id):
|
||||
rel = source.relative_to(source.anchor)
|
||||
destination = target / rel
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source, destination)
|
||||
manifest["files"].append(
|
||||
{
|
||||
"source": str(source),
|
||||
"backup": str(destination),
|
||||
"size_bytes": source.stat().st_size,
|
||||
"sha256": _sha256_file(source),
|
||||
"integrity_check": _integrity_check(destination) if source.name.startswith("chain.db") or source.name.startswith("mempool.db") else None,
|
||||
}
|
||||
)
|
||||
manifest_path = target / "manifest.json"
|
||||
manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True, default=str))
|
||||
return manifest
|
||||
|
||||
|
||||
def reset(chain_id: str, service_name: str, yes: bool, force: bool) -> dict[str, Any]:
|
||||
if not yes:
|
||||
raise SystemExit("reset requires --yes")
|
||||
service = _service_state(service_name)
|
||||
if service["active"] == "active" and not force:
|
||||
raise SystemExit(f"{service_name} is active; stop it first or pass --force")
|
||||
removed: list[str] = []
|
||||
for path in _live_db_files(chain_id):
|
||||
path.unlink()
|
||||
removed.append(str(path))
|
||||
return {"host": socket.gethostname(), "chain_id": chain_id, "removed": removed}
|
||||
|
||||
|
||||
def set_role(chain_id: str, service_name: str, role: str) -> dict[str, Any]:
|
||||
dropin_dir = Path("/etc/systemd/system") / f"{service_name}.d"
|
||||
dropin_path = dropin_dir / "mpt-regeneration.conf"
|
||||
if role == "clear":
|
||||
if dropin_path.exists():
|
||||
dropin_path.unlink()
|
||||
_run(["systemctl", "daemon-reload"], check=True)
|
||||
return {"host": socket.gethostname(), "role": role, "dropin": str(dropin_path), "exists": dropin_path.exists()}
|
||||
dropin_dir.mkdir(parents=True, exist_ok=True)
|
||||
if role == "leader":
|
||||
content = (
|
||||
"[Service]\n"
|
||||
'Environment="AITBC_FORCE_ENABLE_BLOCK_PRODUCTION=true"\n'
|
||||
f'Environment="AITBC_FORCE_BLOCK_PRODUCTION_CHAINS={chain_id}"\n'
|
||||
'Environment="ENABLE_BLOCK_PRODUCTION=true"\n'
|
||||
'Environment="enable_block_production=true"\n'
|
||||
f'Environment="BLOCK_PRODUCTION_CHAINS={chain_id}"\n'
|
||||
f'Environment="block_production_chains={chain_id}"\n'
|
||||
)
|
||||
elif role == "follower":
|
||||
content = (
|
||||
"[Service]\n"
|
||||
'Environment="AITBC_FORCE_ENABLE_BLOCK_PRODUCTION=false"\n'
|
||||
'Environment="AITBC_FORCE_BLOCK_PRODUCTION_CHAINS="\n'
|
||||
'Environment="ENABLE_BLOCK_PRODUCTION=false"\n'
|
||||
'Environment="enable_block_production=false"\n'
|
||||
'Environment="BLOCK_PRODUCTION_CHAINS="\n'
|
||||
'Environment="block_production_chains="\n'
|
||||
)
|
||||
else:
|
||||
raise SystemExit(f"unsupported role: {role}")
|
||||
dropin_path.write_text(content)
|
||||
_run(["systemctl", "daemon-reload"], check=True)
|
||||
return {"host": socket.gethostname(), "role": role, "dropin": str(dropin_path), "exists": dropin_path.exists()}
|
||||
|
||||
|
||||
def verify(chain_id: str, service_name: str, require_nonzero_root: bool) -> dict[str, Any]:
|
||||
data = snapshot(chain_id, service_name)
|
||||
head = data.get("head")
|
||||
ok = data.get("db_error") is None and head is not None and data.get("head_state_root_matches_computed") is True
|
||||
if require_nonzero_root and data.get("computed_state_root") == "0x" + ("00" * 32):
|
||||
ok = False
|
||||
data["ok"] = ok
|
||||
return data
|
||||
|
||||
|
||||
def print_json(data: Any) -> None:
|
||||
print(json.dumps(data, indent=2, sort_keys=True, default=str))
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="AITBC coordinated MPT chain regeneration node utility")
|
||||
parser.add_argument("--chain-id", default="ait-mainnet")
|
||||
parser.add_argument("--service-name", default=SERVICE_NAME)
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
subparsers.add_parser("preflight")
|
||||
|
||||
backup_parser = subparsers.add_parser("backup")
|
||||
backup_parser.add_argument("--backup-root", type=Path, default=BACKUP_ROOT)
|
||||
backup_parser.add_argument("--timestamp")
|
||||
|
||||
reset_parser = subparsers.add_parser("reset")
|
||||
reset_parser.add_argument("--yes", action="store_true")
|
||||
reset_parser.add_argument("--force", action="store_true")
|
||||
|
||||
verify_parser = subparsers.add_parser("verify")
|
||||
verify_parser.add_argument("--require-nonzero-root", action="store_true")
|
||||
|
||||
role_parser = subparsers.add_parser("set-role")
|
||||
role_parser.add_argument("role", choices=["leader", "follower", "clear"])
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "preflight":
|
||||
print_json(snapshot(args.chain_id, args.service_name))
|
||||
elif args.command == "backup":
|
||||
print_json(backup(args.chain_id, args.service_name, args.backup_root, args.timestamp))
|
||||
elif args.command == "reset":
|
||||
print_json(reset(args.chain_id, args.service_name, args.yes, args.force))
|
||||
elif args.command == "verify":
|
||||
result = verify(args.chain_id, args.service_name, args.require_nonzero_root)
|
||||
print_json(result)
|
||||
return 0 if result.get("ok") else 2
|
||||
elif args.command == "set-role":
|
||||
print_json(set_role(args.chain_id, args.service_name, args.role))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
161
scripts/utils/coordinated_chain_regen.sh
Executable file
161
scripts/utils/coordinated_chain_regen.sh
Executable file
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
CHAIN_ID="${CHAIN_ID:-ait-mainnet}"
|
||||
SERVICE_NAME="${SERVICE_NAME:-aitbc-blockchain-node.service}"
|
||||
LEADER="${LEADER:-aitbc1}"
|
||||
NODES_RAW="${NODES:-localhost aitbc1 gitea-runner}"
|
||||
UTILITY="${UTILITY:-/opt/aitbc/scripts/utils/chain_regen_node.py}"
|
||||
PYTHON_BIN="${PYTHON_BIN:-/opt/aitbc/venv/bin/python}"
|
||||
TIMESTAMP="${TIMESTAMP:-$(date -u +%Y%m%dT%H%M%SZ)}"
|
||||
STARTUP_WAIT_SECONDS="${STARTUP_WAIT_SECONDS:-8}"
|
||||
FOLLOWER_START_WAIT_SECONDS="${FOLLOWER_START_WAIT_SECONDS:-8}"
|
||||
VERIFY_RETRIES="${VERIFY_RETRIES:-6}"
|
||||
VERIFY_RETRY_SECONDS="${VERIFY_RETRY_SECONDS:-10}"
|
||||
|
||||
usage() {
|
||||
cat <<USAGE
|
||||
Usage: $0 <preflight|backup|set-roles|stop|reset|start|verify|rollout|clear-roles>
|
||||
|
||||
Environment:
|
||||
CHAIN_ID=$CHAIN_ID
|
||||
SERVICE_NAME=$SERVICE_NAME
|
||||
LEADER=$LEADER
|
||||
NODES="$NODES_RAW"
|
||||
TIMESTAMP=$TIMESTAMP
|
||||
USAGE
|
||||
}
|
||||
|
||||
node_cmd() {
|
||||
local node="$1"
|
||||
shift
|
||||
if [[ "$node" == "localhost" || "$node" == "local" || "$node" == "$(hostname)" ]]; then
|
||||
"$@"
|
||||
else
|
||||
ssh "$node" "$*"
|
||||
fi
|
||||
}
|
||||
|
||||
run_node_utility() {
|
||||
local node="$1"
|
||||
shift
|
||||
node_cmd "$node" "$PYTHON_BIN" "$UTILITY" --chain-id "$CHAIN_ID" --service-name "$SERVICE_NAME" "$@"
|
||||
}
|
||||
|
||||
for_each_node() {
|
||||
local action="$1"
|
||||
shift
|
||||
local node
|
||||
for node in $NODES_RAW; do
|
||||
echo "===== $node :: $action ====="
|
||||
"$@" "$node"
|
||||
done
|
||||
}
|
||||
|
||||
preflight_node() {
|
||||
run_node_utility "$1" preflight
|
||||
}
|
||||
|
||||
backup_node() {
|
||||
run_node_utility "$1" backup --timestamp "$TIMESTAMP"
|
||||
}
|
||||
|
||||
set_role_node() {
|
||||
local node="$1"
|
||||
if [[ "$node" == "$LEADER" ]]; then
|
||||
run_node_utility "$node" set-role leader
|
||||
else
|
||||
run_node_utility "$node" set-role follower
|
||||
fi
|
||||
}
|
||||
|
||||
clear_role_node() {
|
||||
run_node_utility "$1" set-role clear
|
||||
}
|
||||
|
||||
stop_node() {
|
||||
node_cmd "$1" systemctl stop "$SERVICE_NAME"
|
||||
}
|
||||
|
||||
reset_node() {
|
||||
run_node_utility "$1" reset --yes
|
||||
}
|
||||
|
||||
start_leader() {
|
||||
echo "===== $LEADER :: start leader ====="
|
||||
node_cmd "$LEADER" systemctl start "$SERVICE_NAME"
|
||||
sleep "$STARTUP_WAIT_SECONDS"
|
||||
}
|
||||
|
||||
start_followers() {
|
||||
local node
|
||||
for node in $NODES_RAW; do
|
||||
if [[ "$node" == "$LEADER" ]]; then
|
||||
continue
|
||||
fi
|
||||
echo "===== $node :: start follower ====="
|
||||
node_cmd "$node" systemctl start "$SERVICE_NAME"
|
||||
done
|
||||
sleep "$FOLLOWER_START_WAIT_SECONDS"
|
||||
}
|
||||
|
||||
verify_node() {
|
||||
run_node_utility "$1" verify --require-nonzero-root
|
||||
}
|
||||
|
||||
verify_node_with_retry() {
|
||||
local node="$1"
|
||||
local attempt
|
||||
for attempt in $(seq 1 "$VERIFY_RETRIES"); do
|
||||
if run_node_utility "$node" verify --require-nonzero-root; then
|
||||
return 0
|
||||
fi
|
||||
if [[ "$attempt" == "$VERIFY_RETRIES" ]]; then
|
||||
return 1
|
||||
fi
|
||||
echo "verify failed for $node, retrying in ${VERIFY_RETRY_SECONDS}s (${attempt}/${VERIFY_RETRIES})" >&2
|
||||
sleep "$VERIFY_RETRY_SECONDS"
|
||||
done
|
||||
}
|
||||
|
||||
case "${1:-}" in
|
||||
preflight)
|
||||
for_each_node preflight preflight_node
|
||||
;;
|
||||
backup)
|
||||
for_each_node backup backup_node
|
||||
;;
|
||||
set-roles)
|
||||
for_each_node set-roles set_role_node
|
||||
;;
|
||||
clear-roles)
|
||||
for_each_node clear-roles clear_role_node
|
||||
;;
|
||||
stop)
|
||||
for_each_node stop stop_node
|
||||
;;
|
||||
reset)
|
||||
for_each_node reset reset_node
|
||||
;;
|
||||
start)
|
||||
start_leader
|
||||
start_followers
|
||||
;;
|
||||
verify)
|
||||
for_each_node verify verify_node
|
||||
;;
|
||||
rollout)
|
||||
for_each_node preflight preflight_node
|
||||
for_each_node set-roles set_role_node
|
||||
for_each_node backup backup_node
|
||||
for_each_node stop stop_node
|
||||
for_each_node reset reset_node
|
||||
start_leader
|
||||
start_followers
|
||||
for_each_node verify verify_node_with_retry
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
exit 64
|
||||
;;
|
||||
esac
|
||||
@@ -137,7 +137,8 @@ def main() -> None:
|
||||
|
||||
# Ensure mempool DB exists (though not needed for genesis)
|
||||
mempool_path = settings.db_path.parent / "mempool.db"
|
||||
init_mempool(backend="database", db_path=str(mempool_path), max_size=10000, min_fee=0)
|
||||
mempool_url = f"sqlite:///{mempool_path}"
|
||||
init_mempool(backend="database", db_url=mempool_url, max_size=10000, min_fee=0)
|
||||
print(f"[*] Mempool initialized at {mempool_path}")
|
||||
|
||||
# Create genesis block
|
||||
|
||||
Reference in New Issue
Block a user