fix: stabilize multichain hub and follower sync flow
Some checks failed
CLI Tests / test-cli (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled

This commit is contained in:
aitbc
2026-04-13 14:31:23 +02:00
parent d72945f20c
commit bc96e47b8f
9 changed files with 469 additions and 234 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -9,7 +10,7 @@ from fastapi.responses import JSONResponse, PlainTextResponse
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from .config import settings from .config import settings
from .database import init_db from .database import init_db, session_scope
from .gossip import create_backend, gossip_broker from .gossip import create_backend, gossip_broker
from .logger import get_logger from .logger import get_logger
from .mempool import init_mempool from .mempool import init_mempool
@@ -99,29 +100,36 @@ async def lifespan(app: FastAPI):
broadcast_url=settings.gossip_broadcast_url, broadcast_url=settings.gossip_broadcast_url,
) )
await gossip_broker.set_backend(backend) await gossip_broker.set_backend(backend)
proposers = []
# Initialize PoA proposer for mining integration # Initialize PoA proposer for mining integration
if settings.enable_block_production and settings.proposer_id: if settings.enable_block_production and settings.proposer_id:
try: try:
from .consensus import PoAProposer, ProposerConfig from .consensus import PoAProposer, ProposerConfig
proposer_config = ProposerConfig( supported_chains = [c.strip() for c in settings.supported_chains.split(",") if c.strip()]
chain_id=settings.chain_id, if not supported_chains and settings.chain_id:
proposer_id=settings.proposer_id, supported_chains = [settings.chain_id]
interval_seconds=settings.block_time_seconds,
max_block_size_bytes=settings.max_block_size_bytes,
max_txs_per_block=settings.max_txs_per_block,
)
proposer = PoAProposer(config=proposer_config, session_factory=session_scope)
# Set the proposer for mining integration for chain_id in supported_chains:
set_poa_proposer(proposer) proposer_config = ProposerConfig(
chain_id=chain_id,
proposer_id=settings.proposer_id,
interval_seconds=settings.block_time_seconds,
max_block_size_bytes=settings.max_block_size_bytes,
max_txs_per_block=settings.max_txs_per_block,
)
proposer = PoAProposer(config=proposer_config, session_factory=session_scope)
# Start the proposer if block production is enabled # Set the proposer for mining integration
asyncio.create_task(proposer.start()) set_poa_proposer(proposer)
# Start the proposer if block production is enabled
asyncio.create_task(proposer.start())
proposers.append(proposer)
_app_logger.info("PoA proposer initialized for mining integration", extra={ _app_logger.info("PoA proposer initialized for mining integration", extra={
"proposer_id": settings.proposer_id, "proposer_id": settings.proposer_id,
"chain_id": settings.chain_id "supported_chains": supported_chains
}) })
except Exception as e: except Exception as e:
_app_logger.warning(f"Failed to initialize PoA proposer for mining: {e}") _app_logger.warning(f"Failed to initialize PoA proposer for mining: {e}")
@@ -130,6 +138,11 @@ async def lifespan(app: FastAPI):
try: try:
yield yield
finally: finally:
for proposer in proposers:
try:
await proposer.stop()
except Exception as exc:
_app_logger.warning(f"Failed to stop PoA proposer during shutdown: {exc}")
await gossip_broker.shutdown() await gossip_broker.shutdown()
_app_logger.info("Blockchain node stopped") _app_logger.info("Blockchain node stopped")

View File

@@ -72,11 +72,28 @@ class ChainSyncService:
logger.info("Stopping chain sync service") logger.info("Stopping chain sync service")
self._stop_event.set() self._stop_event.set()
async def _get_import_head_height(self, session) -> int:
"""Get the current height on the local import target."""
try:
async with session.get(
f"http://{self.import_host}:{self.import_port}/rpc/head",
params={"chain_id": settings.chain_id},
) as resp:
if resp.status == 200:
head_data = await resp.json()
return int(head_data.get('height', 0))
if resp.status == 404:
return -1
logger.warning(f"Failed to get import head height: RPC returned status {resp.status}")
except Exception as e:
logger.warning(f"Failed to get import head height: {e}")
return -1
async def _broadcast_blocks(self): async def _broadcast_blocks(self):
"""Broadcast local blocks to other nodes""" """Broadcast local blocks to other nodes"""
import aiohttp import aiohttp
last_broadcast_height = 22505 last_broadcast_height = -1
retry_count = 0 retry_count = 0
max_retries = 5 max_retries = 5
base_delay = settings.blockchain_monitoring_interval_seconds # Use config setting instead of hardcoded value base_delay = settings.blockchain_monitoring_interval_seconds # Use config setting instead of hardcoded value
@@ -85,6 +102,10 @@ class ChainSyncService:
try: try:
# Get current head from local RPC # Get current head from local RPC
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
if last_broadcast_height < 0:
last_broadcast_height = await self._get_import_head_height(session)
logger.info(f"Initialized sync baseline at height {last_broadcast_height} for node {self.node_id}")
async with session.get(f"http://{self.source_host}:{self.source_port}/rpc/head") as resp: async with session.get(f"http://{self.source_host}:{self.source_port}/rpc/head") as resp:
if resp.status == 200: if resp.status == 200:
head_data = await resp.json() head_data = await resp.json()

View File

@@ -239,6 +239,12 @@ class PoAProposer:
self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}") self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}")
continue continue
if pending_txs and not processed_txs and getattr(settings, "propose_only_if_mempool_not_empty", True):
self._logger.warning(
f"[PROPOSE] Skipping block proposal: all drained transactions were invalid (count={len(pending_txs)}, chain={self._config.chain_id})"
)
return False
# Compute block hash with transaction data # Compute block hash with transaction data
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs) block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs)
@@ -388,7 +394,12 @@ class PoAProposer:
def _fetch_chain_head(self) -> Optional[Block]: def _fetch_chain_head(self) -> Optional[Block]:
with self._session_factory() as session: with self._session_factory() as session:
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first() return session.exec(
select(Block)
.where(Block.chain_id == self._config.chain_id)
.order_by(Block.height.desc())
.limit(1)
).first()
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str: def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str:
# Include transaction hashes in block hash computation # Include transaction hashes in block hash computation

View File

@@ -35,11 +35,17 @@ class InMemoryMempool:
def __init__(self, max_size: int = 10_000, min_fee: int = 0, chain_id: str = None) -> None: def __init__(self, max_size: int = 10_000, min_fee: int = 0, chain_id: str = None) -> None:
from .config import settings from .config import settings
self._lock = Lock() self._lock = Lock()
self._transactions: Dict[str, PendingTransaction] = {} self._transactions: Dict[str, Dict[str, PendingTransaction]] = {}
self._max_size = max_size self._max_size = max_size
self._min_fee = min_fee self._min_fee = min_fee
self.chain_id = chain_id or settings.chain_id self.chain_id = chain_id or settings.chain_id
def _get_chain_transactions(self, chain_id: str) -> Dict[str, PendingTransaction]:
return self._transactions.setdefault(chain_id, {})
def _total_size(self) -> int:
return sum(len(chain_txs) for chain_txs in self._transactions.values())
def add(self, tx: Dict[str, Any], chain_id: str = None) -> str: def add(self, tx: Dict[str, Any], chain_id: str = None) -> str:
from .config import settings from .config import settings
if chain_id is None: if chain_id is None:
@@ -55,12 +61,13 @@ class InMemoryMempool:
fee=fee, size_bytes=size_bytes fee=fee, size_bytes=size_bytes
) )
with self._lock: with self._lock:
if tx_hash in self._transactions: chain_transactions = self._get_chain_transactions(chain_id)
if tx_hash in chain_transactions:
return tx_hash # duplicate return tx_hash # duplicate
if len(self._transactions) >= self._max_size: if len(chain_transactions) >= self._max_size:
self._evict_lowest_fee() self._evict_lowest_fee(chain_id)
self._transactions[tx_hash] = entry chain_transactions[tx_hash] = entry
metrics_registry.set_gauge("mempool_size", float(len(self._transactions))) metrics_registry.set_gauge("mempool_size", float(self._total_size()))
metrics_registry.increment(f"mempool_tx_added_total_{chain_id}") metrics_registry.increment(f"mempool_tx_added_total_{chain_id}")
return tx_hash return tx_hash
@@ -69,7 +76,7 @@ class InMemoryMempool:
if chain_id is None: if chain_id is None:
chain_id = settings.chain_id chain_id = settings.chain_id
with self._lock: with self._lock:
return list(self._transactions.values()) return list(self._get_chain_transactions(chain_id).values())
def drain(self, max_count: int, max_bytes: int, chain_id: str = None) -> List[PendingTransaction]: def drain(self, max_count: int, max_bytes: int, chain_id: str = None) -> List[PendingTransaction]:
from .config import settings from .config import settings
@@ -77,8 +84,9 @@ class InMemoryMempool:
chain_id = settings.chain_id chain_id = settings.chain_id
"""Drain transactions for block inclusion, prioritized by fee (highest first).""" """Drain transactions for block inclusion, prioritized by fee (highest first)."""
with self._lock: with self._lock:
chain_transactions = self._get_chain_transactions(chain_id)
sorted_txs = sorted( sorted_txs = sorted(
self._transactions.values(), chain_transactions.values(),
key=lambda t: (-t.fee, t.received_at) key=lambda t: (-t.fee, t.received_at)
) )
result: List[PendingTransaction] = [] result: List[PendingTransaction] = []
@@ -92,9 +100,9 @@ class InMemoryMempool:
total_bytes += tx.size_bytes total_bytes += tx.size_bytes
for tx in result: for tx in result:
del self._transactions[tx.tx_hash] del chain_transactions[tx.tx_hash]
metrics_registry.set_gauge("mempool_size", float(len(self._transactions))) metrics_registry.set_gauge("mempool_size", float(self._total_size()))
metrics_registry.increment(f"mempool_tx_drained_total_{chain_id}", float(len(result))) metrics_registry.increment(f"mempool_tx_drained_total_{chain_id}", float(len(result)))
return result return result
@@ -103,9 +111,9 @@ class InMemoryMempool:
if chain_id is None: if chain_id is None:
chain_id = settings.chain_id chain_id = settings.chain_id
with self._lock: with self._lock:
removed = self._transactions.pop(tx_hash, None) is not None removed = self._get_chain_transactions(chain_id).pop(tx_hash, None) is not None
if removed: if removed:
metrics_registry.set_gauge("mempool_size", float(len(self._transactions))) metrics_registry.set_gauge("mempool_size", float(self._total_size()))
return removed return removed
def size(self, chain_id: str = None) -> int: def size(self, chain_id: str = None) -> int:
@@ -113,7 +121,7 @@ class InMemoryMempool:
if chain_id is None: if chain_id is None:
chain_id = settings.chain_id chain_id = settings.chain_id
with self._lock: with self._lock:
return len(self._transactions) return len(self._get_chain_transactions(chain_id))
def get_pending_transactions(self, chain_id: str = None, limit: int = 100) -> List[Dict[str, Any]]: def get_pending_transactions(self, chain_id: str = None, limit: int = 100) -> List[Dict[str, Any]]:
"""Get pending transactions for RPC endpoint""" """Get pending transactions for RPC endpoint"""
@@ -124,20 +132,21 @@ class InMemoryMempool:
with self._lock: with self._lock:
# Get transactions sorted by fee (highest first) and time # Get transactions sorted by fee (highest first) and time
sorted_txs = sorted( sorted_txs = sorted(
self._transactions.values(), self._get_chain_transactions(chain_id).values(),
key=lambda t: (-t.fee, t.received_at) key=lambda t: (-t.fee, t.received_at)
) )
# Return only the content, limited by the limit parameter # Return only the content, limited by the limit parameter
return [tx.content for tx in sorted_txs[:limit]] return [tx.content for tx in sorted_txs[:limit]]
def _evict_lowest_fee(self) -> None: def _evict_lowest_fee(self, chain_id: str) -> None:
"""Evict the lowest-fee transaction to make room.""" """Evict the lowest-fee transaction to make room."""
if not self._transactions: chain_transactions = self._get_chain_transactions(chain_id)
if not chain_transactions:
return return
lowest = min(self._transactions.values(), key=lambda t: (t.fee, -t.received_at)) lowest = min(chain_transactions.values(), key=lambda t: (t.fee, -t.received_at))
del self._transactions[lowest.tx_hash] del chain_transactions[lowest.tx_hash]
metrics_registry.increment(f"mempool_evictions_total_{self.chain_id}") metrics_registry.increment(f"mempool_evictions_total_{chain_id}")
class DatabaseMempool: class DatabaseMempool:

View File

@@ -8,9 +8,11 @@ import logging
import time import time
import json import json
import os import os
import socket
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set
from dataclasses import dataclass, field, asdict from dataclasses import dataclass, field, asdict
from enum import Enum from enum import Enum
from ..config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,6 +61,7 @@ class HubManager:
self.local_port = local_port self.local_port = local_port
self.island_id = island_id self.island_id = island_id
self.island_name = island_name self.island_name = island_name
self.island_chain_id = settings.island_chain_id or settings.chain_id or f"ait-{island_id[:8]}"
self.redis_url = redis_url or "redis://localhost:6379" self.redis_url = redis_url or "redis://localhost:6379"
# Hub registration status # Hub registration status
@@ -155,15 +158,20 @@ class HubManager:
credentials = {} credentials = {}
# Get genesis block hash from genesis.json # Get genesis block hash from genesis.json
genesis_path = '/var/lib/aitbc/data/ait-mainnet/genesis.json' genesis_candidates = [
if os.path.exists(genesis_path): str(settings.db_path.parent / 'genesis.json'),
with open(genesis_path, 'r') as f: f"/var/lib/aitbc/data/{settings.chain_id}/genesis.json",
genesis_data = json.load(f) '/var/lib/aitbc/data/ait-mainnet/genesis.json',
# Get genesis block hash ]
if 'blocks' in genesis_data and len(genesis_data['blocks']) > 0: for genesis_path in genesis_candidates:
genesis_block = genesis_data['blocks'][0] if os.path.exists(genesis_path):
credentials['genesis_block_hash'] = genesis_block.get('hash', '') with open(genesis_path, 'r') as f:
credentials['genesis_block'] = genesis_data genesis_data = json.load(f)
if 'blocks' in genesis_data and len(genesis_data['blocks']) > 0:
genesis_block = genesis_data['blocks'][0]
credentials['genesis_block_hash'] = genesis_block.get('hash', '')
credentials['genesis_block'] = genesis_data
break
# Get genesis address from keystore # Get genesis address from keystore
keystore_path = '/var/lib/aitbc/keystore/validator_keys.json' keystore_path = '/var/lib/aitbc/keystore/validator_keys.json'
@@ -177,12 +185,15 @@ class HubManager:
break break
# Add chain info # Add chain info
credentials['chain_id'] = self.island_chain_id or f"ait-{self.island_id[:8]}" credentials['chain_id'] = self.island_chain_id
credentials['island_id'] = self.island_id credentials['island_id'] = self.island_id
credentials['island_name'] = self.island_name credentials['island_name'] = self.island_name
# Add RPC endpoint (local) # Add RPC endpoint (local)
credentials['rpc_endpoint'] = f"http://{self.local_address}:8006" rpc_host = self.local_address
if rpc_host in {"0.0.0.0", "127.0.0.1", "localhost", ""}:
rpc_host = settings.hub_discovery_url or socket.gethostname()
credentials['rpc_endpoint'] = f"http://{rpc_host}:8006"
credentials['p2p_port'] = self.local_port credentials['p2p_port'] = self.local_port
return credentials return credentials
@@ -190,33 +201,6 @@ class HubManager:
logger.error(f"Failed to get blockchain credentials: {e}") logger.error(f"Failed to get blockchain credentials: {e}")
return {} return {}
def __init__(self, local_node_id: str, local_address: str, local_port: int,
island_id: str, island_name: str, redis_url: str):
self.local_node_id = local_node_id
self.local_address = local_address
self.local_port = local_port
self.island_id = island_id
self.island_name = island_name
self.island_chain_id = f"ait-{island_id[:8]}"
self.known_hubs: Dict[str, HubInfo] = {}
self.peer_registry: Dict[str, PeerInfo] = {}
self.peer_reputation: Dict[str, float] = {}
self.peer_last_seen: Dict[str, float] = {}
# GPU marketplace tracking
self.gpu_offers: Dict[str, dict] = {}
self.gpu_bids: Dict[str, dict] = {}
self.gpu_providers: Dict[str, dict] = {} # node_id -> gpu info
# Exchange tracking
self.exchange_orders: Dict[str, dict] = {} # order_id -> order info
self.exchange_order_books: Dict[str, Dict] = {} # pair -> {bids: [], asks: []}
# Redis client for persistence
self.redis_url = redis_url
self._redis_client = None
async def handle_join_request(self, join_request: dict) -> Optional[dict]: async def handle_join_request(self, join_request: dict) -> Optional[dict]:
""" """
Handle island join request from a new node Handle island join request from a new node

View File

@@ -7,6 +7,7 @@ Handles decentralized peer-to-peer mesh communication between blockchain nodes
import asyncio import asyncio
import json import json
import logging import logging
from .config import settings
from .mempool import get_mempool, compute_tx_hash from .mempool import get_mempool, compute_tx_hash
from .network.nat_traversal import NATTraversalService from .network.nat_traversal import NATTraversalService
from .network.island_manager import IslandManager from .network.island_manager import IslandManager
@@ -89,7 +90,7 @@ class P2PNetworkService:
self.port, self.port,
self.island_id, self.island_id,
self.island_name, self.island_name,
self.config.redis_url settings.redis_url
) )
await self.hub_manager.register_as_hub(self.public_endpoint[0] if self.public_endpoint else None, await self.hub_manager.register_as_hub(self.public_endpoint[0] if self.public_endpoint else None,
self.public_endpoint[1] if self.public_endpoint else None) self.public_endpoint[1] if self.public_endpoint else None)
@@ -158,6 +159,29 @@ class P2PNetworkService:
await self._server.wait_closed() await self._server.wait_closed()
async def _send_message(self, writer: asyncio.StreamWriter, message: Dict[str, Any]):
"""Serialize and send a newline-delimited JSON message"""
payload = json.dumps(message).encode() + b"\n"
writer.write(payload)
await writer.drain()
async def _ping_peers_loop(self):
"""Periodically ping active peers to keep connections healthy"""
while not self._stop_event.is_set():
try:
writers = list(self.active_connections.items())
for peer_id, writer in writers:
try:
await self._send_message(writer, {'type': 'ping', 'node_id': self.node_id})
except Exception as e:
logger.debug(f"Failed to ping {peer_id}: {e}")
except Exception as e:
logger.error(f"Error in ping loop: {e}")
await asyncio.sleep(10)
async def _mempool_sync_loop(self): async def _mempool_sync_loop(self):
"""Periodically check local mempool and broadcast new transactions to peers""" """Periodically check local mempool and broadcast new transactions to peers"""
self.seen_txs = set() self.seen_txs = set()
@@ -170,23 +194,26 @@ class P2PNetworkService:
if hasattr(mempool, '_transactions'): # InMemoryMempool if hasattr(mempool, '_transactions'): # InMemoryMempool
with mempool._lock: with mempool._lock:
for tx_hash, pending_tx in mempool._transactions.items(): for chain_id, chain_transactions in mempool._transactions.items():
if tx_hash not in self.seen_txs: for tx_hash, pending_tx in chain_transactions.items():
self.seen_txs.add(tx_hash) seen_key = (chain_id, tx_hash)
txs_to_broadcast.append(pending_tx.content) if seen_key not in self.seen_txs:
self.seen_txs.add(seen_key)
txs_to_broadcast.append(pending_tx.content)
elif hasattr(mempool, '_conn'): # DatabaseMempool elif hasattr(mempool, '_conn'): # DatabaseMempool
with mempool._lock: with mempool._lock:
cursor = mempool._conn.execute( cursor = mempool._conn.execute(
"SELECT tx_hash, content FROM mempool WHERE chain_id = ?", "SELECT chain_id, tx_hash, content FROM mempool"
('ait-mainnet',)
) )
for row in cursor.fetchall(): for row in cursor.fetchall():
tx_hash = row[0] chain_id = row[0]
if tx_hash not in self.seen_txs: tx_hash = row[1]
self.seen_txs.add(tx_hash) seen_key = (chain_id, tx_hash)
if seen_key not in self.seen_txs:
self.seen_txs.add(seen_key)
import json import json
txs_to_broadcast.append(json.loads(row[1])) txs_to_broadcast.append(json.loads(row[2]))
logger.debug(f"Mempool sync loop iteration. txs_to_broadcast: {len(txs_to_broadcast)}") logger.debug(f"Mempool sync loop iteration. txs_to_broadcast: {len(txs_to_broadcast)}")
for tx in txs_to_broadcast: for tx in txs_to_broadcast:
@@ -321,8 +348,8 @@ class P2PNetworkService:
if self.island_manager and peer_island_id: if self.island_manager and peer_island_id:
self.island_manager.add_island_peer(peer_island_id, peer_node_id) self.island_manager.add_island_peer(peer_island_id, peer_node_id)
# Add peer to hub manager if available and peer is a hub # Add peer to hub manager if available
if self.hub_manager and peer_is_hub: if self.hub_manager:
from .network.hub_manager import PeerInfo from .network.hub_manager import PeerInfo
self.hub_manager.register_peer(PeerInfo( self.hub_manager.register_peer(PeerInfo(
node_id=peer_node_id, node_id=peer_node_id,
@@ -359,7 +386,8 @@ class P2PNetworkService:
logger.error(f"Error handling inbound connection from {addr}: {e}") logger.error(f"Error handling inbound connection from {addr}: {e}")
writer.close() writer.close()
async def _listen_to_stream(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, endpoint: Tuple[str, int], outbound: bool, peer_id: str = None): async def _listen_to_stream(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, endpoint: Tuple[str, int],
outbound: bool, peer_id: str = None):
"""Read loop for an established TCP stream (both inbound and outbound)""" """Read loop for an established TCP stream (both inbound and outbound)"""
addr = endpoint addr = endpoint
try: try:
@@ -394,8 +422,8 @@ class P2PNetworkService:
if self.island_manager and peer_island_id: if self.island_manager and peer_island_id:
self.island_manager.add_island_peer(peer_island_id, peer_id) self.island_manager.add_island_peer(peer_island_id, peer_id)
# Add peer to hub manager if available and peer is a hub # Add peer to hub manager if available
if self.hub_manager and peer_is_hub: if self.hub_manager:
from .network.hub_manager import PeerInfo from .network.hub_manager import PeerInfo
self.hub_manager.register_peer(PeerInfo( self.hub_manager.register_peer(PeerInfo(
node_id=peer_id, node_id=peer_id,
@@ -423,7 +451,7 @@ class P2PNetworkService:
logger.debug(f"Received pong from {peer_id}") logger.debug(f"Received pong from {peer_id}")
elif msg_type == 'handshake': elif msg_type == 'handshake':
pass # Ignore subsequent handshakes pass # Ignore subsequent handshakes
elif msg_type == 'join_request': elif msg_type == 'join_request':
# Handle island join request (only if we're a hub) # Handle island join request (only if we're a hub)
if self.hub_manager: if self.hub_manager:
@@ -463,28 +491,29 @@ class P2PNetworkService:
if tx_data: if tx_data:
try: try:
tx_hash = compute_tx_hash(tx_data) tx_hash = compute_tx_hash(tx_data)
chain_id = tx_data.get('chain_id', settings.chain_id)
if not hasattr(self, 'seen_txs'): if not hasattr(self, 'seen_txs'):
self.seen_txs = set() self.seen_txs = set()
if tx_hash not in self.seen_txs: seen_key = (chain_id, tx_hash)
if seen_key not in self.seen_txs:
logger.info(f"Received new P2P transaction: {tx_hash}") logger.info(f"Received new P2P transaction: {tx_hash}")
self.seen_txs.add(tx_hash) self.seen_txs.add(seen_key)
mempool = get_mempool() mempool = get_mempool()
# Add to local mempool # Add to local mempool
mempool.add(tx_data) mempool.add(tx_data, chain_id=chain_id)
# Forward to other peers (Gossip) # Forward to other peers (Gossip)
forward_msg = {'type': 'new_transaction', 'tx': tx_data} forward_msg = {'type': 'new_transaction', 'tx': tx_data}
writers = list(self.active_connections.values()) writers = list(self.active_connections.values())
for w in writers: for w in writers:
if w != writer: # Don't send back to sender if w != writer: # Don't send back to sender
await self._send_message(w, forward_msg) await self._send_message(w, forward_msg)
except ValueError as e: except ValueError as e:
logger.debug(f"P2P tx rejected by mempool: {e}") logger.debug(f"P2P tx rejected by mempool: {e}")
except Exception as e: except Exception as e:
logger.error(f"P2P tx handling error: {e}") logger.error(f"P2P tx handling error: {e}")
else: else:
logger.info(f"Received {msg_type} from {peer_id}: {message}") logger.info(f"Received {msg_type} from {peer_id}: {message}")
# In a real node, we would forward blocks/txs to the internal event bus here # In a real node, we would forward blocks/txs to the internal event bus here
@@ -542,7 +571,8 @@ class P2PNetworkService:
logger.error(f"Error getting GPU specs: {e}") logger.error(f"Error getting GPU specs: {e}")
return {} return {}
async def send_join_request(self, hub_address: str, hub_port: int, island_id: str, island_name: str, node_id: str, public_key_pem: str) -> Optional[dict]: async def send_join_request(self, hub_address: str, hub_port: int, island_id: str, island_name: str, node_id: str,
public_key_pem: str) -> Optional[dict]:
""" """
Send join request to a hub and wait for response Send join request to a hub and wait for response
@@ -562,6 +592,34 @@ class P2PNetworkService:
reader, writer = await asyncio.open_connection(hub_address, hub_port) reader, writer = await asyncio.open_connection(hub_address, hub_port)
logger.info(f"Connected to hub {hub_address}:{hub_port}") logger.info(f"Connected to hub {hub_address}:{hub_port}")
handshake = {
'type': 'handshake',
'node_id': node_id,
'listen_port': self.port,
'island_id': island_id,
'island_name': island_name,
'is_hub': self.is_hub,
'island_chain_id': self.island_chain_id,
'public_address': self.public_endpoint[0] if self.public_endpoint else None,
'public_port': self.public_endpoint[1] if self.public_endpoint else None,
}
await self._send_message(writer, handshake)
logger.info("Sent handshake to hub")
data = await asyncio.wait_for(reader.readline(), timeout=10.0)
if not data:
logger.warning("No handshake response from hub")
writer.close()
await writer.wait_closed()
return None
response = json.loads(data.decode().strip())
if response.get('type') != 'handshake':
logger.warning(f"Unexpected handshake response type: {response.get('type')}")
writer.close()
await writer.wait_closed()
return None
# Send join request # Send join request
join_request = { join_request = {
'type': 'join_request', 'type': 'join_request',
@@ -604,9 +662,21 @@ class P2PNetworkService:
async def run_p2p_service(host: str, port: int, node_id: str, peers: str): async def run_p2p_service(host: str, port: int, node_id: str, peers: str):
"""Run P2P service""" """Run P2P service"""
service = P2PNetworkService(host, port, node_id, peers) stun_servers = [server.strip() for server in settings.stun_servers.split(',') if server.strip()]
service = P2PNetworkService(
host,
port,
node_id,
peers,
stun_servers=stun_servers or None,
island_id=settings.island_id,
island_name=settings.island_name,
is_hub=settings.is_hub,
island_chain_id=settings.island_chain_id or settings.chain_id,
)
await service.start() await service.start()
def main(): def main():
import argparse import argparse

View File

@@ -17,6 +17,7 @@ from ..mempool import get_mempool
from ..metrics import metrics_registry from ..metrics import metrics_registry
from ..models import Account, Block, Receipt, Transaction from ..models import Account, Block, Receipt, Transaction
from ..logger import get_logger from ..logger import get_logger
from ..sync import ChainSync
from ..contracts.agent_messaging_contract import messaging_contract from ..contracts.agent_messaging_contract import messaging_contract
_logger = get_logger(__name__) _logger = get_logger(__name__)
@@ -28,16 +29,18 @@ _last_import_time = 0
_import_lock = asyncio.Lock() _import_lock = asyncio.Lock()
# Global variable to store the PoA proposer # Global variable to store the PoA proposer
_poa_proposer = None _poa_proposers: Dict[str, Any] = {}
def set_poa_proposer(proposer): def set_poa_proposer(proposer, chain_id: str = None):
"""Set the global PoA proposer instance""" """Set the global PoA proposer instance"""
global _poa_proposer if chain_id is None:
_poa_proposer = proposer chain_id = getattr(getattr(proposer, "_config", None), "chain_id", None) or get_chain_id(None)
_poa_proposers[chain_id] = proposer
def get_poa_proposer(): def get_poa_proposer(chain_id: str = None):
"""Get the global PoA proposer instance""" """Get the global PoA proposer instance"""
return _poa_proposer chain_id = get_chain_id(chain_id)
return _poa_proposers.get(chain_id)
def get_chain_id(chain_id: str = None) -> str: def get_chain_id(chain_id: str = None) -> str:
"""Get chain_id from parameter or use default from settings""" """Get chain_id from parameter or use default from settings"""
@@ -46,6 +49,117 @@ def get_chain_id(chain_id: str = None) -> str:
return settings.chain_id return settings.chain_id
return chain_id return chain_id
def get_supported_chains() -> List[str]:
from ..config import settings
chains = [chain.strip() for chain in settings.supported_chains.split(",") if chain.strip()]
if not chains and settings.chain_id:
return [settings.chain_id]
return chains
def _normalize_transaction_data(tx_data: Dict[str, Any], chain_id: str) -> Dict[str, Any]:
sender = tx_data.get("from")
recipient = tx_data.get("to")
if not isinstance(sender, str) or not sender.strip():
raise ValueError("transaction.from is required")
if not isinstance(recipient, str) or not recipient.strip():
raise ValueError("transaction.to is required")
try:
amount = int(tx_data["amount"])
except KeyError as exc:
raise ValueError("transaction.amount is required") from exc
except (TypeError, ValueError) as exc:
raise ValueError("transaction.amount must be an integer") from exc
try:
fee = int(tx_data.get("fee", 10))
except (TypeError, ValueError) as exc:
raise ValueError("transaction.fee must be an integer") from exc
try:
nonce = int(tx_data.get("nonce", 0))
except (TypeError, ValueError) as exc:
raise ValueError("transaction.nonce must be an integer") from exc
if amount < 0:
raise ValueError("transaction.amount must be non-negative")
if fee < 0:
raise ValueError("transaction.fee must be non-negative")
if nonce < 0:
raise ValueError("transaction.nonce must be non-negative")
payload = tx_data.get("payload", "0x")
if payload is None:
payload = "0x"
return {
"chain_id": chain_id,
"from": sender.strip(),
"to": recipient.strip(),
"amount": amount,
"fee": fee,
"nonce": nonce,
"payload": payload,
"signature": tx_data.get("signature") or tx_data.get("sig"),
}
def _validate_transaction_admission(tx_data: Dict[str, Any], mempool: Any) -> None:
from ..mempool import compute_tx_hash
chain_id = tx_data["chain_id"]
supported_chains = get_supported_chains()
if not chain_id:
raise ValueError("transaction.chain_id is required")
if supported_chains and chain_id not in supported_chains:
raise ValueError(f"unsupported chain_id '{chain_id}'. Supported chains: {supported_chains}")
tx_hash = compute_tx_hash(tx_data)
with session_scope() as session:
sender_account = session.get(Account, (chain_id, tx_data["from"]))
if sender_account is None:
raise ValueError(f"sender account not found on chain '{chain_id}'")
total_cost = tx_data["amount"] + tx_data["fee"]
if sender_account.balance < total_cost:
raise ValueError(
f"insufficient balance for sender '{tx_data['from']}' on chain '{chain_id}': has {sender_account.balance}, needs {total_cost}"
)
if tx_data["nonce"] != sender_account.nonce:
raise ValueError(
f"invalid nonce for sender '{tx_data['from']}' on chain '{chain_id}': expected {sender_account.nonce}, got {tx_data['nonce']}"
)
existing_tx = session.exec(
select(Transaction)
.where(Transaction.chain_id == chain_id)
.where(Transaction.tx_hash == tx_hash)
).first()
if existing_tx is not None:
raise ValueError(f"transaction '{tx_hash}' is already confirmed on chain '{chain_id}'")
existing_nonce = session.exec(
select(Transaction)
.where(Transaction.chain_id == chain_id)
.where(Transaction.sender == tx_data["from"])
.where(Transaction.nonce == tx_data["nonce"])
).first()
if existing_nonce is not None:
raise ValueError(
f"sender '{tx_data['from']}' already used nonce {tx_data['nonce']} on chain '{chain_id}'"
)
pending_txs = mempool.list_transactions(chain_id=chain_id)
if any(pending_tx.tx_hash == tx_hash for pending_tx in pending_txs):
raise ValueError(f"transaction '{tx_hash}' is already pending on chain '{chain_id}'")
if any(
pending_tx.content.get("from") == tx_data["from"] and pending_tx.content.get("nonce") == tx_data["nonce"]
for pending_tx in pending_txs
):
raise ValueError(
f"sender '{tx_data['from']}' already has pending nonce {tx_data['nonce']} on chain '{chain_id}'"
)
def _serialize_receipt(receipt: Receipt) -> Dict[str, Any]: def _serialize_receipt(receipt: Receipt) -> Dict[str, Any]:
return { return {
@@ -118,17 +232,25 @@ async def get_head(chain_id: str = None) -> Dict[str, Any]:
@router.get("/blocks/{height}", summary="Get block by height") @router.get("/blocks/{height}", summary="Get block by height")
async def get_block(height: int) -> Dict[str, Any]: async def get_block(height: int, chain_id: str = None) -> Dict[str, Any]:
"""Get block by height"""
chain_id = get_chain_id(chain_id)
metrics_registry.increment("rpc_get_block_total") metrics_registry.increment("rpc_get_block_total")
start = time.perf_counter() start = time.perf_counter()
with session_scope() as session: with session_scope() as session:
block = session.exec(select(Block).where(Block.height == height)).first() block = session.exec(
select(Block).where(Block.chain_id == chain_id).where(Block.height == height)
).first()
if block is None: if block is None:
metrics_registry.increment("rpc_get_block_not_found_total") metrics_registry.increment("rpc_get_block_not_found_total")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="block not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="block not found")
metrics_registry.increment("rpc_get_block_success_total") metrics_registry.increment("rpc_get_block_success_total")
txs = session.exec(select(Transaction).where(Transaction.block_height == height)).all() txs = session.exec(
select(Transaction)
.where(Transaction.chain_id == chain_id)
.where(Transaction.block_height == height)
).all()
tx_list = [] tx_list = []
for tx in txs: for tx in txs:
t = dict(tx.payload) if tx.payload else {} t = dict(tx.payload) if tx.payload else {}
@@ -137,6 +259,7 @@ async def get_block(height: int) -> Dict[str, Any]:
metrics_registry.observe("rpc_get_block_duration_seconds", time.perf_counter() - start) metrics_registry.observe("rpc_get_block_duration_seconds", time.perf_counter() - start)
return { return {
"chain_id": block.chain_id,
"height": block.height, "height": block.height,
"hash": block.hash, "hash": block.hash,
"parent_hash": block.parent_hash, "parent_hash": block.parent_hash,
@@ -152,25 +275,15 @@ async def get_block(height: int) -> Dict[str, Any]:
async def submit_transaction(tx_data: dict) -> Dict[str, Any]: async def submit_transaction(tx_data: dict) -> Dict[str, Any]:
"""Submit a new transaction to the mempool""" """Submit a new transaction to the mempool"""
from ..mempool import get_mempool from ..mempool import get_mempool
from ..models import Transaction
try: try:
mempool = get_mempool() mempool = get_mempool()
chain_id = tx_data.get("chain_id") or get_chain_id(None)
# Create transaction data as dictionary tx_data_dict = _normalize_transaction_data(tx_data, chain_id)
tx_data_dict = { _validate_transaction_admission(tx_data_dict, mempool)
"chain_id": tx_data.get("chain_id", "ait-mainnet"),
"from": tx_data["from"],
"to": tx_data["to"],
"amount": tx_data["amount"],
"fee": tx_data.get("fee", 10),
"nonce": tx_data.get("nonce", 0),
"payload": tx_data.get("payload", "0x"),
"signature": tx_data.get("signature")
}
# Add to mempool tx_hash = mempool.add(tx_data_dict, chain_id=chain_id)
tx_hash = mempool.add(tx_data_dict)
return { return {
"success": True, "success": True,
@@ -282,7 +395,7 @@ async def query_transactions(
@router.get("/blocks-range", summary="Get blocks in height range") @router.get("/blocks-range", summary="Get blocks in height range")
async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = True) -> Dict[str, Any]: async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = True, chain_id: str = None) -> Dict[str, Any]:
"""Get blocks in a height range """Get blocks in a height range
Args: Args:
@@ -291,12 +404,12 @@ async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = Tru
include_tx: Whether to include transaction data (default: True) include_tx: Whether to include transaction data (default: True)
""" """
with session_scope() as session: with session_scope() as session:
from ..config import settings as cfg
from ..models import Transaction from ..models import Transaction
chain_id = get_chain_id(chain_id)
blocks = session.exec( blocks = session.exec(
select(Block).where( select(Block).where(
Block.chain_id == cfg.chain_id, Block.chain_id == chain_id,
Block.height >= start, Block.height >= start,
Block.height <= end, Block.height <= end,
).order_by(Block.height.asc()) ).order_by(Block.height.asc())
@@ -317,7 +430,9 @@ async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = Tru
if include_tx: if include_tx:
# Fetch transactions for this block # Fetch transactions for this block
txs = session.exec( txs = session.exec(
select(Transaction).where(Transaction.block_height == b.height) select(Transaction)
.where(Transaction.chain_id == chain_id)
.where(Transaction.block_height == b.height)
).all() ).all()
block_data["transactions"] = [tx.model_dump() for tx in txs] block_data["transactions"] = [tx.model_dump() for tx in txs]
@@ -424,59 +539,53 @@ async def import_block(block_data: dict) -> Dict[str, Any]:
_last_import_time = time.time() _last_import_time = time.time()
with session_scope() as session: chain_id = block_data.get("chain_id") or block_data.get("chainId") or get_chain_id(None)
# Convert timestamp string to datetime if needed
timestamp = block_data.get("timestamp") timestamp = block_data.get("timestamp")
if isinstance(timestamp, str): if isinstance(timestamp, str):
try: try:
timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
except ValueError: except ValueError:
# Fallback to current time if parsing fails
timestamp = datetime.utcnow()
elif timestamp is None:
timestamp = datetime.utcnow() timestamp = datetime.utcnow()
elif timestamp is None:
timestamp = datetime.utcnow()
# Extract height from either 'number' or 'height' field height = block_data.get("number") or block_data.get("height")
height = block_data.get("number") or block_data.get("height") if height is None:
if height is None: raise ValueError("Block height is required")
raise ValueError("Block height is required")
# Check if block already exists to prevent duplicates transactions = block_data.get("transactions", [])
existing = session.execute( normalized_block = {
select(Block).where(Block.height == int(height)) "chain_id": chain_id,
).scalar_one_or_none() "height": int(height),
if existing: "hash": block_data.get("hash"),
return { "parent_hash": block_data.get("parent_hash") or block_data.get("parentHash", ""),
"success": True, "proposer": block_data.get("proposer") or block_data.get("miner", ""),
"block_number": existing.height, "timestamp": timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp,
"block_hash": existing.hash, "tx_count": block_data.get("tx_count", len(transactions)),
"message": "Block already exists" "state_root": block_data.get("state_root") or block_data.get("stateRoot"),
} }
# Create block from data from ..config import settings as cfg
block = Block( sync = ChainSync(
chain_id=block_data.get("chainId", "ait-mainnet"), session_factory=session_scope,
height=int(height), chain_id=chain_id,
hash=block_data.get("hash"), validate_signatures=cfg.sync_validate_signatures,
parent_hash=block_data.get("parentHash", ""), )
proposer=block_data.get("miner", ""), result = sync.import_block(normalized_block, transactions=transactions)
timestamp=timestamp,
tx_count=len(block_data.get("transactions", [])),
state_root=block_data.get("stateRoot"),
block_metadata=json.dumps(block_data)
)
session.add(block) if result.accepted:
session.commit() _logger.info(f"Successfully imported block {result.height}")
_logger.info(f"Successfully imported block {block.height}")
metrics_registry.increment("blocks_imported_total") metrics_registry.increment("blocks_imported_total")
return { return {
"success": True, "success": result.accepted,
"block_number": block.height, "accepted": result.accepted,
"block_hash": block.hash "block_number": result.height,
} "block_hash": result.block_hash,
"chain_id": chain_id,
"reason": result.reason,
}
except Exception as e: except Exception as e:
_logger.error(f"Failed to import block: {e}") _logger.error(f"Failed to import block: {e}")

View File

@@ -214,7 +214,7 @@ class ChainSync:
with self._session_factory() as session: with self._session_factory() as session:
# Check for duplicate # Check for duplicate
existing = session.exec( existing = session.exec(
select(Block).where(Block.hash == block_hash) select(Block).where(Block.chain_id == self._chain_id).where(Block.hash == block_hash)
).first() ).first()
if existing: if existing:
metrics_registry.increment("sync_blocks_duplicate_total") metrics_registry.increment("sync_blocks_duplicate_total")

View File

@@ -13,10 +13,17 @@ from pathlib import Path
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
from ..utils.output import output, success, error, warning, info try:
from ..core.config import MultiChainConfig, load_multichain_config, get_default_node_config, add_node_config, remove_node_config from ..utils.output import output, success, error, warning, info
from ..core.node_client import NodeClient from ..core.config import MultiChainConfig, load_multichain_config, get_default_node_config, add_node_config, remove_node_config
from ..utils import output, error, success from ..core.node_client import NodeClient
except ImportError:
from utils import output, error, success, warning
from core.config import MultiChainConfig, load_multichain_config, get_default_node_config, add_node_config, remove_node_config
from core.node_client import NodeClient
def info(message):
print(message)
import uuid import uuid
@click.group() @click.group()
@@ -501,6 +508,9 @@ def join(ctx, island_id, island_name, chain_id, hub, is_hub):
# Get system hostname # Get system hostname
hostname = socket.gethostname() hostname = socket.gethostname()
sys.path.insert(0, '/opt/aitbc/apps/blockchain-node/src')
from aitbc_chain.config import settings as chain_settings
# Get public key from keystore # Get public key from keystore
keystore_path = '/var/lib/aitbc/keystore/validator_keys.json' keystore_path = '/var/lib/aitbc/keystore/validator_keys.json'
public_key_pem = None public_key_pem = None
@@ -522,22 +532,30 @@ def join(ctx, island_id, island_name, chain_id, hub, is_hub):
# Generate node_id using hostname-based method # Generate node_id using hostname-based method
local_address = socket.gethostbyname(hostname) local_address = socket.gethostbyname(hostname)
local_port = 8001 # Default hub port local_port = chain_settings.p2p_bind_port
content = f"{hostname}:{local_address}:{local_port}:{public_key_pem}" content = f"{hostname}:{local_address}:{local_port}:{public_key_pem}"
node_id = hashlib.sha256(content.encode()).hexdigest() node_id = hashlib.sha256(content.encode()).hexdigest()
# Resolve hub domain to IP # Resolve hub domain to IP
hub_ip = socket.gethostbyname(hub) hub_ip = socket.gethostbyname(hub)
hub_port = 8001 # Default hub port hub_port = chain_settings.p2p_bind_port
info(f"Connecting to hub {hub} ({hub_ip}:{hub_port})...") click.echo(f"Connecting to hub {hub} ({hub_ip}:{hub_port})...")
# Create P2P network service instance for sending join request # Create P2P network service instance for sending join request
sys.path.insert(0, '/opt/aitbc/apps/blockchain-node/src')
from aitbc_chain.p2p_network import P2PNetworkService from aitbc_chain.p2p_network import P2PNetworkService
# Create a minimal P2P service just for sending the join request # Create a minimal P2P service just for sending the join request
p2p_service = P2PNetworkService(local_address, local_port, node_id, []) p2p_service = P2PNetworkService(
local_address,
local_port,
node_id,
"",
island_id=island_id,
island_name=island_name,
is_hub=is_hub,
island_chain_id=chain_id or chain_settings.island_chain_id or chain_settings.chain_id,
)
# Send join request # Send join request
async def send_join(): async def send_join():
@@ -586,9 +604,9 @@ def join(ctx, island_id, island_name, chain_id, hub, is_hub):
# If registering as hub # If registering as hub
if is_hub: if is_hub:
info("Registering as hub...") click.echo("Registering as hub...")
# Hub registration would happen here via the hub register command # Hub registration would happen here via the hub register command
info("Run 'aitbc node hub register' to complete hub registration") click.echo("Run 'aitbc node hub register' to complete hub registration")
else: else:
error("Failed to join island - no response from hub") error("Failed to join island - no response from hub")
raise click.Abort() raise click.Abort()