diff --git a/apps/blockchain-node/src/aitbc_chain/app.py b/apps/blockchain-node/src/aitbc_chain/app.py index f77fbac3..7ce44436 100755 --- a/apps/blockchain-node/src/aitbc_chain/app.py +++ b/apps/blockchain-node/src/aitbc_chain/app.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import time from collections import defaultdict from contextlib import asynccontextmanager @@ -9,7 +10,7 @@ from fastapi.responses import JSONResponse, PlainTextResponse from starlette.middleware.base import BaseHTTPMiddleware from .config import settings -from .database import init_db +from .database import init_db, session_scope from .gossip import create_backend, gossip_broker from .logger import get_logger from .mempool import init_mempool @@ -99,29 +100,36 @@ async def lifespan(app: FastAPI): broadcast_url=settings.gossip_broadcast_url, ) await gossip_broker.set_backend(backend) + proposers = [] # Initialize PoA proposer for mining integration if settings.enable_block_production and settings.proposer_id: try: from .consensus import PoAProposer, ProposerConfig - proposer_config = ProposerConfig( - chain_id=settings.chain_id, - proposer_id=settings.proposer_id, - interval_seconds=settings.block_time_seconds, - max_block_size_bytes=settings.max_block_size_bytes, - max_txs_per_block=settings.max_txs_per_block, - ) - proposer = PoAProposer(config=proposer_config, session_factory=session_scope) - - # Set the proposer for mining integration - set_poa_proposer(proposer) - - # Start the proposer if block production is enabled - asyncio.create_task(proposer.start()) - + supported_chains = [c.strip() for c in settings.supported_chains.split(",") if c.strip()] + if not supported_chains and settings.chain_id: + supported_chains = [settings.chain_id] + + for chain_id in supported_chains: + proposer_config = ProposerConfig( + chain_id=chain_id, + proposer_id=settings.proposer_id, + interval_seconds=settings.block_time_seconds, + max_block_size_bytes=settings.max_block_size_bytes, + max_txs_per_block=settings.max_txs_per_block, + ) + proposer = PoAProposer(config=proposer_config, session_factory=session_scope) + + # Set the proposer for mining integration + set_poa_proposer(proposer) + + # Start the proposer if block production is enabled + asyncio.create_task(proposer.start()) + proposers.append(proposer) + _app_logger.info("PoA proposer initialized for mining integration", extra={ "proposer_id": settings.proposer_id, - "chain_id": settings.chain_id + "supported_chains": supported_chains }) except Exception as e: _app_logger.warning(f"Failed to initialize PoA proposer for mining: {e}") @@ -130,6 +138,11 @@ async def lifespan(app: FastAPI): try: yield finally: + for proposer in proposers: + try: + await proposer.stop() + except Exception as exc: + _app_logger.warning(f"Failed to stop PoA proposer during shutdown: {exc}") await gossip_broker.shutdown() _app_logger.info("Blockchain node stopped") diff --git a/apps/blockchain-node/src/aitbc_chain/chain_sync.py b/apps/blockchain-node/src/aitbc_chain/chain_sync.py index 78ced5a1..6a01df60 100644 --- a/apps/blockchain-node/src/aitbc_chain/chain_sync.py +++ b/apps/blockchain-node/src/aitbc_chain/chain_sync.py @@ -71,12 +71,29 @@ class ChainSyncService: """Stop chain synchronization service""" logger.info("Stopping chain sync service") self._stop_event.set() + + async def _get_import_head_height(self, session) -> int: + """Get the current height on the local import target.""" + try: + async with session.get( + f"http://{self.import_host}:{self.import_port}/rpc/head", + params={"chain_id": settings.chain_id}, + ) as resp: + if resp.status == 200: + head_data = await resp.json() + return int(head_data.get('height', 0)) + if resp.status == 404: + return -1 + logger.warning(f"Failed to get import head height: RPC returned status {resp.status}") + except Exception as e: + logger.warning(f"Failed to get import head height: {e}") + return -1 async def _broadcast_blocks(self): """Broadcast local blocks to other nodes""" import aiohttp - last_broadcast_height = 22505 + last_broadcast_height = -1 retry_count = 0 max_retries = 5 base_delay = settings.blockchain_monitoring_interval_seconds # Use config setting instead of hardcoded value @@ -85,6 +102,10 @@ class ChainSyncService: try: # Get current head from local RPC async with aiohttp.ClientSession() as session: + if last_broadcast_height < 0: + last_broadcast_height = await self._get_import_head_height(session) + logger.info(f"Initialized sync baseline at height {last_broadcast_height} for node {self.node_id}") + async with session.get(f"http://{self.source_host}:{self.source_port}/rpc/head") as resp: if resp.status == 200: head_data = await resp.json() diff --git a/apps/blockchain-node/src/aitbc_chain/consensus/poa.py b/apps/blockchain-node/src/aitbc_chain/consensus/poa.py index fecefde0..d88d3d07 100755 --- a/apps/blockchain-node/src/aitbc_chain/consensus/poa.py +++ b/apps/blockchain-node/src/aitbc_chain/consensus/poa.py @@ -171,25 +171,25 @@ class PoAProposer: recipient = tx_data.get("to") value = tx_data.get("amount", 0) fee = tx_data.get("fee", 0) - + self._logger.info(f"[PROPOSE] Processing tx {tx.tx_hash}: from={sender}, to={recipient}, amount={value}, fee={fee}") - + if not sender or not recipient: self._logger.warning(f"[PROPOSE] Skipping tx {tx.tx_hash}: missing sender or recipient") continue - + # Get sender account sender_account = session.get(Account, (self._config.chain_id, sender)) if not sender_account: self._logger.warning(f"[PROPOSE] Skipping tx {tx.tx_hash}: sender account not found for {sender}") continue - + # Check sufficient balance total_cost = value + fee if sender_account.balance < total_cost: self._logger.warning(f"[PROPOSE] Skipping tx {tx.tx_hash}: insufficient balance (has {sender_account.balance}, needs {total_cost})") continue - + # Get or create recipient account recipient_account = session.get(Account, (self._config.chain_id, recipient)) if not recipient_account: @@ -199,12 +199,12 @@ class PoAProposer: session.flush() else: self._logger.info(f"[PROPOSE] Recipient account exists for {recipient}") - + # Update balances sender_account.balance -= total_cost sender_account.nonce += 1 recipient_account.balance += value - + # Check if transaction already exists in database existing_tx = session.exec( select(Transaction).where( @@ -212,11 +212,11 @@ class PoAProposer: Transaction.tx_hash == tx.tx_hash ) ).first() - + if existing_tx: self._logger.warning(f"[PROPOSE] Skipping tx {tx.tx_hash}: already exists in database at block {existing_tx.block_height}") continue - + # Create transaction record transaction = Transaction( chain_id=self._config.chain_id, @@ -234,11 +234,17 @@ class PoAProposer: session.add(transaction) processed_txs.append(tx) self._logger.info(f"[PROPOSE] Successfully processed tx {tx.tx_hash}: updated balances") - + except Exception as e: self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}") continue - + + if pending_txs and not processed_txs and getattr(settings, "propose_only_if_mempool_not_empty", True): + self._logger.warning( + f"[PROPOSE] Skipping block proposal: all drained transactions were invalid (count={len(pending_txs)}, chain={self._config.chain_id})" + ) + return False + # Compute block hash with transaction data block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs) @@ -388,7 +394,12 @@ class PoAProposer: def _fetch_chain_head(self) -> Optional[Block]: with self._session_factory() as session: - return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first() + return session.exec( + select(Block) + .where(Block.chain_id == self._config.chain_id) + .order_by(Block.height.desc()) + .limit(1) + ).first() def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str: # Include transaction hashes in block hash computation diff --git a/apps/blockchain-node/src/aitbc_chain/mempool.py b/apps/blockchain-node/src/aitbc_chain/mempool.py index 28e238a0..5dab0eff 100755 --- a/apps/blockchain-node/src/aitbc_chain/mempool.py +++ b/apps/blockchain-node/src/aitbc_chain/mempool.py @@ -35,11 +35,17 @@ class InMemoryMempool: def __init__(self, max_size: int = 10_000, min_fee: int = 0, chain_id: str = None) -> None: from .config import settings self._lock = Lock() - self._transactions: Dict[str, PendingTransaction] = {} + self._transactions: Dict[str, Dict[str, PendingTransaction]] = {} self._max_size = max_size self._min_fee = min_fee self.chain_id = chain_id or settings.chain_id + def _get_chain_transactions(self, chain_id: str) -> Dict[str, PendingTransaction]: + return self._transactions.setdefault(chain_id, {}) + + def _total_size(self) -> int: + return sum(len(chain_txs) for chain_txs in self._transactions.values()) + def add(self, tx: Dict[str, Any], chain_id: str = None) -> str: from .config import settings if chain_id is None: @@ -55,12 +61,13 @@ class InMemoryMempool: fee=fee, size_bytes=size_bytes ) with self._lock: - if tx_hash in self._transactions: + chain_transactions = self._get_chain_transactions(chain_id) + if tx_hash in chain_transactions: return tx_hash # duplicate - if len(self._transactions) >= self._max_size: - self._evict_lowest_fee() - self._transactions[tx_hash] = entry - metrics_registry.set_gauge("mempool_size", float(len(self._transactions))) + if len(chain_transactions) >= self._max_size: + self._evict_lowest_fee(chain_id) + chain_transactions[tx_hash] = entry + metrics_registry.set_gauge("mempool_size", float(self._total_size())) metrics_registry.increment(f"mempool_tx_added_total_{chain_id}") return tx_hash @@ -69,7 +76,7 @@ class InMemoryMempool: if chain_id is None: chain_id = settings.chain_id with self._lock: - return list(self._transactions.values()) + return list(self._get_chain_transactions(chain_id).values()) def drain(self, max_count: int, max_bytes: int, chain_id: str = None) -> List[PendingTransaction]: from .config import settings @@ -77,8 +84,9 @@ class InMemoryMempool: chain_id = settings.chain_id """Drain transactions for block inclusion, prioritized by fee (highest first).""" with self._lock: + chain_transactions = self._get_chain_transactions(chain_id) sorted_txs = sorted( - self._transactions.values(), + chain_transactions.values(), key=lambda t: (-t.fee, t.received_at) ) result: List[PendingTransaction] = [] @@ -92,9 +100,9 @@ class InMemoryMempool: total_bytes += tx.size_bytes for tx in result: - del self._transactions[tx.tx_hash] + del chain_transactions[tx.tx_hash] - metrics_registry.set_gauge("mempool_size", float(len(self._transactions))) + metrics_registry.set_gauge("mempool_size", float(self._total_size())) metrics_registry.increment(f"mempool_tx_drained_total_{chain_id}", float(len(result))) return result @@ -103,9 +111,9 @@ class InMemoryMempool: if chain_id is None: chain_id = settings.chain_id with self._lock: - removed = self._transactions.pop(tx_hash, None) is not None + removed = self._get_chain_transactions(chain_id).pop(tx_hash, None) is not None if removed: - metrics_registry.set_gauge("mempool_size", float(len(self._transactions))) + metrics_registry.set_gauge("mempool_size", float(self._total_size())) return removed def size(self, chain_id: str = None) -> int: @@ -113,7 +121,7 @@ class InMemoryMempool: if chain_id is None: chain_id = settings.chain_id with self._lock: - return len(self._transactions) + return len(self._get_chain_transactions(chain_id)) def get_pending_transactions(self, chain_id: str = None, limit: int = 100) -> List[Dict[str, Any]]: """Get pending transactions for RPC endpoint""" @@ -124,20 +132,21 @@ class InMemoryMempool: with self._lock: # Get transactions sorted by fee (highest first) and time sorted_txs = sorted( - self._transactions.values(), + self._get_chain_transactions(chain_id).values(), key=lambda t: (-t.fee, t.received_at) ) # Return only the content, limited by the limit parameter return [tx.content for tx in sorted_txs[:limit]] - def _evict_lowest_fee(self) -> None: + def _evict_lowest_fee(self, chain_id: str) -> None: """Evict the lowest-fee transaction to make room.""" - if not self._transactions: + chain_transactions = self._get_chain_transactions(chain_id) + if not chain_transactions: return - lowest = min(self._transactions.values(), key=lambda t: (t.fee, -t.received_at)) - del self._transactions[lowest.tx_hash] - metrics_registry.increment(f"mempool_evictions_total_{self.chain_id}") + lowest = min(chain_transactions.values(), key=lambda t: (t.fee, -t.received_at)) + del chain_transactions[lowest.tx_hash] + metrics_registry.increment(f"mempool_evictions_total_{chain_id}") class DatabaseMempool: diff --git a/apps/blockchain-node/src/aitbc_chain/network/hub_manager.py b/apps/blockchain-node/src/aitbc_chain/network/hub_manager.py index a945b137..aa5df3a4 100644 --- a/apps/blockchain-node/src/aitbc_chain/network/hub_manager.py +++ b/apps/blockchain-node/src/aitbc_chain/network/hub_manager.py @@ -8,9 +8,11 @@ import logging import time import json import os +import socket from typing import Dict, List, Optional, Set from dataclasses import dataclass, field, asdict from enum import Enum +from ..config import settings logger = logging.getLogger(__name__) @@ -59,6 +61,7 @@ class HubManager: self.local_port = local_port self.island_id = island_id self.island_name = island_name + self.island_chain_id = settings.island_chain_id or settings.chain_id or f"ait-{island_id[:8]}" self.redis_url = redis_url or "redis://localhost:6379" # Hub registration status @@ -155,15 +158,20 @@ class HubManager: credentials = {} # Get genesis block hash from genesis.json - genesis_path = '/var/lib/aitbc/data/ait-mainnet/genesis.json' - if os.path.exists(genesis_path): - with open(genesis_path, 'r') as f: - genesis_data = json.load(f) - # Get genesis block hash - if 'blocks' in genesis_data and len(genesis_data['blocks']) > 0: - genesis_block = genesis_data['blocks'][0] - credentials['genesis_block_hash'] = genesis_block.get('hash', '') - credentials['genesis_block'] = genesis_data + genesis_candidates = [ + str(settings.db_path.parent / 'genesis.json'), + f"/var/lib/aitbc/data/{settings.chain_id}/genesis.json", + '/var/lib/aitbc/data/ait-mainnet/genesis.json', + ] + for genesis_path in genesis_candidates: + if os.path.exists(genesis_path): + with open(genesis_path, 'r') as f: + genesis_data = json.load(f) + if 'blocks' in genesis_data and len(genesis_data['blocks']) > 0: + genesis_block = genesis_data['blocks'][0] + credentials['genesis_block_hash'] = genesis_block.get('hash', '') + credentials['genesis_block'] = genesis_data + break # Get genesis address from keystore keystore_path = '/var/lib/aitbc/keystore/validator_keys.json' @@ -177,12 +185,15 @@ class HubManager: break # Add chain info - credentials['chain_id'] = self.island_chain_id or f"ait-{self.island_id[:8]}" + credentials['chain_id'] = self.island_chain_id credentials['island_id'] = self.island_id credentials['island_name'] = self.island_name # Add RPC endpoint (local) - credentials['rpc_endpoint'] = f"http://{self.local_address}:8006" + rpc_host = self.local_address + if rpc_host in {"0.0.0.0", "127.0.0.1", "localhost", ""}: + rpc_host = settings.hub_discovery_url or socket.gethostname() + credentials['rpc_endpoint'] = f"http://{rpc_host}:8006" credentials['p2p_port'] = self.local_port return credentials @@ -190,33 +201,6 @@ class HubManager: logger.error(f"Failed to get blockchain credentials: {e}") return {} - def __init__(self, local_node_id: str, local_address: str, local_port: int, - island_id: str, island_name: str, redis_url: str): - self.local_node_id = local_node_id - self.local_address = local_address - self.local_port = local_port - self.island_id = island_id - self.island_name = island_name - self.island_chain_id = f"ait-{island_id[:8]}" - - self.known_hubs: Dict[str, HubInfo] = {} - self.peer_registry: Dict[str, PeerInfo] = {} - self.peer_reputation: Dict[str, float] = {} - self.peer_last_seen: Dict[str, float] = {} - - # GPU marketplace tracking - self.gpu_offers: Dict[str, dict] = {} - self.gpu_bids: Dict[str, dict] = {} - self.gpu_providers: Dict[str, dict] = {} # node_id -> gpu info - - # Exchange tracking - self.exchange_orders: Dict[str, dict] = {} # order_id -> order info - self.exchange_order_books: Dict[str, Dict] = {} # pair -> {bids: [], asks: []} - - # Redis client for persistence - self.redis_url = redis_url - self._redis_client = None - async def handle_join_request(self, join_request: dict) -> Optional[dict]: """ Handle island join request from a new node diff --git a/apps/blockchain-node/src/aitbc_chain/p2p_network.py b/apps/blockchain-node/src/aitbc_chain/p2p_network.py index d19b9d65..d147425f 100644 --- a/apps/blockchain-node/src/aitbc_chain/p2p_network.py +++ b/apps/blockchain-node/src/aitbc_chain/p2p_network.py @@ -7,6 +7,7 @@ Handles decentralized peer-to-peer mesh communication between blockchain nodes import asyncio import json import logging +from .config import settings from .mempool import get_mempool, compute_tx_hash from .network.nat_traversal import NATTraversalService from .network.island_manager import IslandManager @@ -89,7 +90,7 @@ class P2PNetworkService: self.port, self.island_id, self.island_name, - self.config.redis_url + settings.redis_url ) await self.hub_manager.register_as_hub(self.public_endpoint[0] if self.public_endpoint else None, self.public_endpoint[1] if self.public_endpoint else None) @@ -158,6 +159,29 @@ class P2PNetworkService: await self._server.wait_closed() + async def _send_message(self, writer: asyncio.StreamWriter, message: Dict[str, Any]): + """Serialize and send a newline-delimited JSON message""" + payload = json.dumps(message).encode() + b"\n" + writer.write(payload) + await writer.drain() + + + async def _ping_peers_loop(self): + """Periodically ping active peers to keep connections healthy""" + while not self._stop_event.is_set(): + try: + writers = list(self.active_connections.items()) + for peer_id, writer in writers: + try: + await self._send_message(writer, {'type': 'ping', 'node_id': self.node_id}) + except Exception as e: + logger.debug(f"Failed to ping {peer_id}: {e}") + except Exception as e: + logger.error(f"Error in ping loop: {e}") + + await asyncio.sleep(10) + + async def _mempool_sync_loop(self): """Periodically check local mempool and broadcast new transactions to peers""" self.seen_txs = set() @@ -170,23 +194,26 @@ class P2PNetworkService: if hasattr(mempool, '_transactions'): # InMemoryMempool with mempool._lock: - for tx_hash, pending_tx in mempool._transactions.items(): - if tx_hash not in self.seen_txs: - self.seen_txs.add(tx_hash) - txs_to_broadcast.append(pending_tx.content) + for chain_id, chain_transactions in mempool._transactions.items(): + for tx_hash, pending_tx in chain_transactions.items(): + seen_key = (chain_id, tx_hash) + if seen_key not in self.seen_txs: + self.seen_txs.add(seen_key) + txs_to_broadcast.append(pending_tx.content) elif hasattr(mempool, '_conn'): # DatabaseMempool with mempool._lock: cursor = mempool._conn.execute( - "SELECT tx_hash, content FROM mempool WHERE chain_id = ?", - ('ait-mainnet',) + "SELECT chain_id, tx_hash, content FROM mempool" ) for row in cursor.fetchall(): - tx_hash = row[0] - if tx_hash not in self.seen_txs: - self.seen_txs.add(tx_hash) + chain_id = row[0] + tx_hash = row[1] + seen_key = (chain_id, tx_hash) + if seen_key not in self.seen_txs: + self.seen_txs.add(seen_key) import json - txs_to_broadcast.append(json.loads(row[1])) + txs_to_broadcast.append(json.loads(row[2])) logger.debug(f"Mempool sync loop iteration. txs_to_broadcast: {len(txs_to_broadcast)}") for tx in txs_to_broadcast: @@ -297,32 +324,32 @@ class P2PNetworkService: # Store peer's island information logger.info(f"Peer {peer_node_id} from island {peer_island_id} (hub: {peer_is_hub})") - + # Store peer's public endpoint if provided if peer_public_address and peer_public_port: logger.info(f"Peer {peer_node_id} public endpoint: {peer_public_address}:{peer_public_port}") - + # Accept handshake and store connection logger.info(f"Handshake accepted from node {peer_node_id} at {addr}") - + # If we already have a connection to this node, drop the new one to prevent duplicates if peer_node_id in self.active_connections: logger.info(f"Already connected to node {peer_node_id}. Dropping duplicate inbound.") writer.close() return - + self.active_connections[peer_node_id] = writer - + # Map their listening endpoint so we don't try to dial them remote_ip = addr[0] self.connected_endpoints.add((remote_ip, peer_listen_port)) - + # Add peer to island manager if available if self.island_manager and peer_island_id: self.island_manager.add_island_peer(peer_island_id, peer_node_id) - - # Add peer to hub manager if available and peer is a hub - if self.hub_manager and peer_is_hub: + + # Add peer to hub manager if available + if self.hub_manager: from .network.hub_manager import PeerInfo self.hub_manager.register_peer(PeerInfo( node_id=peer_node_id, @@ -334,7 +361,7 @@ class P2PNetworkService: public_port=peer_public_port, last_seen=asyncio.get_event_loop().time() )) - + # Reply with our handshake including island information reply_handshake = { 'type': 'handshake', @@ -348,10 +375,10 @@ class P2PNetworkService: 'public_port': self.public_endpoint[1] if self.public_endpoint else None } await self._send_message(writer, reply_handshake) - + # Listen for messages await self._listen_to_stream(reader, writer, (remote_ip, peer_listen_port), outbound=False, peer_id=peer_node_id) - + except asyncio.TimeoutError: logger.warning(f"Timeout waiting for handshake from {addr}") writer.close() @@ -359,7 +386,8 @@ class P2PNetworkService: logger.error(f"Error handling inbound connection from {addr}: {e}") writer.close() - async def _listen_to_stream(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, endpoint: Tuple[str, int], outbound: bool, peer_id: str = None): + async def _listen_to_stream(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, endpoint: Tuple[str, int], + outbound: bool, peer_id: str = None): """Read loop for an established TCP stream (both inbound and outbound)""" addr = endpoint try: @@ -367,35 +395,35 @@ class P2PNetworkService: data = await reader.readline() if not data: break # Connection closed remotely - + try: message = json.loads(data.decode().strip()) - + msg_type = message.get('type') - + # If this is an outbound connection, the first message MUST be their handshake reply if outbound and peer_id is None: if msg_type == 'handshake': peer_id = message.get('node_id') peer_island_id = message.get('island_id', '') peer_is_hub = message.get('is_hub', False) - + if not peer_id or peer_id == self.node_id: logger.warning(f"Invalid handshake reply from {addr}. Closing.") break - + if peer_id in self.active_connections: logger.info(f"Already connected to node {peer_id}. Closing duplicate outbound.") break - + self.active_connections[peer_id] = writer - + # Add peer to island manager if available if self.island_manager and peer_island_id: self.island_manager.add_island_peer(peer_island_id, peer_id) - - # Add peer to hub manager if available and peer is a hub - if self.hub_manager and peer_is_hub: + + # Add peer to hub manager if available + if self.hub_manager: from .network.hub_manager import PeerInfo self.hub_manager.register_peer(PeerInfo( node_id=peer_id, @@ -407,23 +435,23 @@ class P2PNetworkService: public_port=message.get('public_port'), last_seen=asyncio.get_event_loop().time() )) - + logger.info(f"Outbound handshake complete. Connected to node {peer_id} (island: {peer_island_id})") continue else: logger.warning(f"Expected handshake reply from {addr}, got {msg_type}") break - + # Normal message handling if msg_type == 'ping': logger.debug(f"Received ping from {peer_id}") await self._send_message(writer, {'type': 'pong', 'node_id': self.node_id}) - + elif msg_type == 'pong': logger.debug(f"Received pong from {peer_id}") - + elif msg_type == 'handshake': - pass # Ignore subsequent handshakes + pass # Ignore subsequent handshakes elif msg_type == 'join_request': # Handle island join request (only if we're a hub) if self.hub_manager: @@ -463,35 +491,36 @@ class P2PNetworkService: if tx_data: try: tx_hash = compute_tx_hash(tx_data) + chain_id = tx_data.get('chain_id', settings.chain_id) if not hasattr(self, 'seen_txs'): self.seen_txs = set() - - if tx_hash not in self.seen_txs: + + seen_key = (chain_id, tx_hash) + if seen_key not in self.seen_txs: logger.info(f"Received new P2P transaction: {tx_hash}") - self.seen_txs.add(tx_hash) + self.seen_txs.add(seen_key) mempool = get_mempool() # Add to local mempool - mempool.add(tx_data) - + mempool.add(tx_data, chain_id=chain_id) + # Forward to other peers (Gossip) forward_msg = {'type': 'new_transaction', 'tx': tx_data} writers = list(self.active_connections.values()) for w in writers: - if w != writer: # Don't send back to sender + if w != writer: # Don't send back to sender await self._send_message(w, forward_msg) except ValueError as e: logger.debug(f"P2P tx rejected by mempool: {e}") except Exception as e: logger.error(f"P2P tx handling error: {e}") - else: logger.info(f"Received {msg_type} from {peer_id}: {message}") # In a real node, we would forward blocks/txs to the internal event bus here - + except json.JSONDecodeError: logger.warning(f"Invalid JSON received from {addr}") - + except asyncio.CancelledError: pass except Exception as e: @@ -542,7 +571,8 @@ class P2PNetworkService: logger.error(f"Error getting GPU specs: {e}") return {} - async def send_join_request(self, hub_address: str, hub_port: int, island_id: str, island_name: str, node_id: str, public_key_pem: str) -> Optional[dict]: + async def send_join_request(self, hub_address: str, hub_port: int, island_id: str, island_name: str, node_id: str, + public_key_pem: str) -> Optional[dict]: """ Send join request to a hub and wait for response @@ -562,6 +592,34 @@ class P2PNetworkService: reader, writer = await asyncio.open_connection(hub_address, hub_port) logger.info(f"Connected to hub {hub_address}:{hub_port}") + handshake = { + 'type': 'handshake', + 'node_id': node_id, + 'listen_port': self.port, + 'island_id': island_id, + 'island_name': island_name, + 'is_hub': self.is_hub, + 'island_chain_id': self.island_chain_id, + 'public_address': self.public_endpoint[0] if self.public_endpoint else None, + 'public_port': self.public_endpoint[1] if self.public_endpoint else None, + } + await self._send_message(writer, handshake) + logger.info("Sent handshake to hub") + + data = await asyncio.wait_for(reader.readline(), timeout=10.0) + if not data: + logger.warning("No handshake response from hub") + writer.close() + await writer.wait_closed() + return None + + response = json.loads(data.decode().strip()) + if response.get('type') != 'handshake': + logger.warning(f"Unexpected handshake response type: {response.get('type')}") + writer.close() + await writer.wait_closed() + return None + # Send join request join_request = { 'type': 'join_request', @@ -604,20 +662,32 @@ class P2PNetworkService: async def run_p2p_service(host: str, port: int, node_id: str, peers: str): """Run P2P service""" - service = P2PNetworkService(host, port, node_id, peers) + stun_servers = [server.strip() for server in settings.stun_servers.split(',') if server.strip()] + service = P2PNetworkService( + host, + port, + node_id, + peers, + stun_servers=stun_servers or None, + island_id=settings.island_id, + island_name=settings.island_name, + is_hub=settings.is_hub, + island_chain_id=settings.island_chain_id or settings.chain_id, + ) await service.start() + def main(): import argparse - + parser = argparse.ArgumentParser(description="AITBC Direct TCP P2P Mesh Network") parser.add_argument("--host", default="0.0.0.0", help="Bind host") parser.add_argument("--port", type=int, default=7070, help="Bind port") parser.add_argument("--node-id", required=True, help="Node identifier (required for handshake)") parser.add_argument("--peers", default="", help="Comma separated list of initial peers to dial (ip:port)") - + args = parser.parse_args() - + logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' diff --git a/apps/blockchain-node/src/aitbc_chain/rpc/router.py b/apps/blockchain-node/src/aitbc_chain/rpc/router.py index bdac81da..4cf54c7f 100755 --- a/apps/blockchain-node/src/aitbc_chain/rpc/router.py +++ b/apps/blockchain-node/src/aitbc_chain/rpc/router.py @@ -17,6 +17,7 @@ from ..mempool import get_mempool from ..metrics import metrics_registry from ..models import Account, Block, Receipt, Transaction from ..logger import get_logger +from ..sync import ChainSync from ..contracts.agent_messaging_contract import messaging_contract _logger = get_logger(__name__) @@ -28,16 +29,18 @@ _last_import_time = 0 _import_lock = asyncio.Lock() # Global variable to store the PoA proposer -_poa_proposer = None +_poa_proposers: Dict[str, Any] = {} -def set_poa_proposer(proposer): +def set_poa_proposer(proposer, chain_id: str = None): """Set the global PoA proposer instance""" - global _poa_proposer - _poa_proposer = proposer + if chain_id is None: + chain_id = getattr(getattr(proposer, "_config", None), "chain_id", None) or get_chain_id(None) + _poa_proposers[chain_id] = proposer -def get_poa_proposer(): +def get_poa_proposer(chain_id: str = None): """Get the global PoA proposer instance""" - return _poa_proposer + chain_id = get_chain_id(chain_id) + return _poa_proposers.get(chain_id) def get_chain_id(chain_id: str = None) -> str: """Get chain_id from parameter or use default from settings""" @@ -46,6 +49,117 @@ def get_chain_id(chain_id: str = None) -> str: return settings.chain_id return chain_id +def get_supported_chains() -> List[str]: + from ..config import settings + chains = [chain.strip() for chain in settings.supported_chains.split(",") if chain.strip()] + if not chains and settings.chain_id: + return [settings.chain_id] + return chains + +def _normalize_transaction_data(tx_data: Dict[str, Any], chain_id: str) -> Dict[str, Any]: + sender = tx_data.get("from") + recipient = tx_data.get("to") + if not isinstance(sender, str) or not sender.strip(): + raise ValueError("transaction.from is required") + if not isinstance(recipient, str) or not recipient.strip(): + raise ValueError("transaction.to is required") + + try: + amount = int(tx_data["amount"]) + except KeyError as exc: + raise ValueError("transaction.amount is required") from exc + except (TypeError, ValueError) as exc: + raise ValueError("transaction.amount must be an integer") from exc + + try: + fee = int(tx_data.get("fee", 10)) + except (TypeError, ValueError) as exc: + raise ValueError("transaction.fee must be an integer") from exc + + try: + nonce = int(tx_data.get("nonce", 0)) + except (TypeError, ValueError) as exc: + raise ValueError("transaction.nonce must be an integer") from exc + + if amount < 0: + raise ValueError("transaction.amount must be non-negative") + if fee < 0: + raise ValueError("transaction.fee must be non-negative") + if nonce < 0: + raise ValueError("transaction.nonce must be non-negative") + + payload = tx_data.get("payload", "0x") + if payload is None: + payload = "0x" + + return { + "chain_id": chain_id, + "from": sender.strip(), + "to": recipient.strip(), + "amount": amount, + "fee": fee, + "nonce": nonce, + "payload": payload, + "signature": tx_data.get("signature") or tx_data.get("sig"), + } + +def _validate_transaction_admission(tx_data: Dict[str, Any], mempool: Any) -> None: + from ..mempool import compute_tx_hash + + chain_id = tx_data["chain_id"] + supported_chains = get_supported_chains() + if not chain_id: + raise ValueError("transaction.chain_id is required") + if supported_chains and chain_id not in supported_chains: + raise ValueError(f"unsupported chain_id '{chain_id}'. Supported chains: {supported_chains}") + + tx_hash = compute_tx_hash(tx_data) + + with session_scope() as session: + sender_account = session.get(Account, (chain_id, tx_data["from"])) + if sender_account is None: + raise ValueError(f"sender account not found on chain '{chain_id}'") + + total_cost = tx_data["amount"] + tx_data["fee"] + if sender_account.balance < total_cost: + raise ValueError( + f"insufficient balance for sender '{tx_data['from']}' on chain '{chain_id}': has {sender_account.balance}, needs {total_cost}" + ) + + if tx_data["nonce"] != sender_account.nonce: + raise ValueError( + f"invalid nonce for sender '{tx_data['from']}' on chain '{chain_id}': expected {sender_account.nonce}, got {tx_data['nonce']}" + ) + + existing_tx = session.exec( + select(Transaction) + .where(Transaction.chain_id == chain_id) + .where(Transaction.tx_hash == tx_hash) + ).first() + if existing_tx is not None: + raise ValueError(f"transaction '{tx_hash}' is already confirmed on chain '{chain_id}'") + + existing_nonce = session.exec( + select(Transaction) + .where(Transaction.chain_id == chain_id) + .where(Transaction.sender == tx_data["from"]) + .where(Transaction.nonce == tx_data["nonce"]) + ).first() + if existing_nonce is not None: + raise ValueError( + f"sender '{tx_data['from']}' already used nonce {tx_data['nonce']} on chain '{chain_id}'" + ) + + pending_txs = mempool.list_transactions(chain_id=chain_id) + if any(pending_tx.tx_hash == tx_hash for pending_tx in pending_txs): + raise ValueError(f"transaction '{tx_hash}' is already pending on chain '{chain_id}'") + if any( + pending_tx.content.get("from") == tx_data["from"] and pending_tx.content.get("nonce") == tx_data["nonce"] + for pending_tx in pending_txs + ): + raise ValueError( + f"sender '{tx_data['from']}' already has pending nonce {tx_data['nonce']} on chain '{chain_id}'" + ) def _serialize_receipt(receipt: Receipt) -> Dict[str, Any]: return { @@ -118,17 +232,25 @@ async def get_head(chain_id: str = None) -> Dict[str, Any]: @router.get("/blocks/{height}", summary="Get block by height") -async def get_block(height: int) -> Dict[str, Any]: +async def get_block(height: int, chain_id: str = None) -> Dict[str, Any]: + """Get block by height""" + chain_id = get_chain_id(chain_id) metrics_registry.increment("rpc_get_block_total") start = time.perf_counter() with session_scope() as session: - block = session.exec(select(Block).where(Block.height == height)).first() + block = session.exec( + select(Block).where(Block.chain_id == chain_id).where(Block.height == height) + ).first() if block is None: metrics_registry.increment("rpc_get_block_not_found_total") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="block not found") metrics_registry.increment("rpc_get_block_success_total") - txs = session.exec(select(Transaction).where(Transaction.block_height == height)).all() + txs = session.exec( + select(Transaction) + .where(Transaction.chain_id == chain_id) + .where(Transaction.block_height == height) + ).all() tx_list = [] for tx in txs: t = dict(tx.payload) if tx.payload else {} @@ -137,6 +259,7 @@ async def get_block(height: int) -> Dict[str, Any]: metrics_registry.observe("rpc_get_block_duration_seconds", time.perf_counter() - start) return { + "chain_id": block.chain_id, "height": block.height, "hash": block.hash, "parent_hash": block.parent_hash, @@ -152,26 +275,16 @@ async def get_block(height: int) -> Dict[str, Any]: async def submit_transaction(tx_data: dict) -> Dict[str, Any]: """Submit a new transaction to the mempool""" from ..mempool import get_mempool - from ..models import Transaction - + try: mempool = get_mempool() - - # Create transaction data as dictionary - tx_data_dict = { - "chain_id": tx_data.get("chain_id", "ait-mainnet"), - "from": tx_data["from"], - "to": tx_data["to"], - "amount": tx_data["amount"], - "fee": tx_data.get("fee", 10), - "nonce": tx_data.get("nonce", 0), - "payload": tx_data.get("payload", "0x"), - "signature": tx_data.get("signature") - } - - # Add to mempool - tx_hash = mempool.add(tx_data_dict) - + chain_id = tx_data.get("chain_id") or get_chain_id(None) + + tx_data_dict = _normalize_transaction_data(tx_data, chain_id) + _validate_transaction_admission(tx_data_dict, mempool) + + tx_hash = mempool.add(tx_data_dict, chain_id=chain_id) + return { "success": True, "transaction_hash": tx_hash, @@ -282,7 +395,7 @@ async def query_transactions( @router.get("/blocks-range", summary="Get blocks in height range") -async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = True) -> Dict[str, Any]: +async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = True, chain_id: str = None) -> Dict[str, Any]: """Get blocks in a height range Args: @@ -291,12 +404,12 @@ async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = Tru include_tx: Whether to include transaction data (default: True) """ with session_scope() as session: - from ..config import settings as cfg from ..models import Transaction + chain_id = get_chain_id(chain_id) blocks = session.exec( select(Block).where( - Block.chain_id == cfg.chain_id, + Block.chain_id == chain_id, Block.height >= start, Block.height <= end, ).order_by(Block.height.asc()) @@ -317,7 +430,9 @@ async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = Tru if include_tx: # Fetch transactions for this block txs = session.exec( - select(Transaction).where(Transaction.block_height == b.height) + select(Transaction) + .where(Transaction.chain_id == chain_id) + .where(Transaction.block_height == b.height) ).all() block_data["transactions"] = [tx.model_dump() for tx in txs] @@ -424,59 +539,53 @@ async def import_block(block_data: dict) -> Dict[str, Any]: _last_import_time = time.time() - with session_scope() as session: - # Convert timestamp string to datetime if needed - timestamp = block_data.get("timestamp") - if isinstance(timestamp, str): - try: - timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) - except ValueError: - # Fallback to current time if parsing fails - timestamp = datetime.utcnow() - elif timestamp is None: + chain_id = block_data.get("chain_id") or block_data.get("chainId") or get_chain_id(None) + + timestamp = block_data.get("timestamp") + if isinstance(timestamp, str): + try: + timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) + except ValueError: timestamp = datetime.utcnow() - - # Extract height from either 'number' or 'height' field - height = block_data.get("number") or block_data.get("height") - if height is None: - raise ValueError("Block height is required") - - # Check if block already exists to prevent duplicates - existing = session.execute( - select(Block).where(Block.height == int(height)) - ).scalar_one_or_none() - if existing: - return { - "success": True, - "block_number": existing.height, - "block_hash": existing.hash, - "message": "Block already exists" - } - - # Create block from data - block = Block( - chain_id=block_data.get("chainId", "ait-mainnet"), - height=int(height), - hash=block_data.get("hash"), - parent_hash=block_data.get("parentHash", ""), - proposer=block_data.get("miner", ""), - timestamp=timestamp, - tx_count=len(block_data.get("transactions", [])), - state_root=block_data.get("stateRoot"), - block_metadata=json.dumps(block_data) - ) - - session.add(block) - session.commit() - - _logger.info(f"Successfully imported block {block.height}") + elif timestamp is None: + timestamp = datetime.utcnow() + + height = block_data.get("number") or block_data.get("height") + if height is None: + raise ValueError("Block height is required") + + transactions = block_data.get("transactions", []) + normalized_block = { + "chain_id": chain_id, + "height": int(height), + "hash": block_data.get("hash"), + "parent_hash": block_data.get("parent_hash") or block_data.get("parentHash", ""), + "proposer": block_data.get("proposer") or block_data.get("miner", ""), + "timestamp": timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp, + "tx_count": block_data.get("tx_count", len(transactions)), + "state_root": block_data.get("state_root") or block_data.get("stateRoot"), + } + + from ..config import settings as cfg + sync = ChainSync( + session_factory=session_scope, + chain_id=chain_id, + validate_signatures=cfg.sync_validate_signatures, + ) + result = sync.import_block(normalized_block, transactions=transactions) + + if result.accepted: + _logger.info(f"Successfully imported block {result.height}") metrics_registry.increment("blocks_imported_total") - - return { - "success": True, - "block_number": block.height, - "block_hash": block.hash - } + + return { + "success": result.accepted, + "accepted": result.accepted, + "block_number": result.height, + "block_hash": result.block_hash, + "chain_id": chain_id, + "reason": result.reason, + } except Exception as e: _logger.error(f"Failed to import block: {e}") diff --git a/apps/blockchain-node/src/aitbc_chain/sync.py b/apps/blockchain-node/src/aitbc_chain/sync.py index 4e840b90..832f9737 100755 --- a/apps/blockchain-node/src/aitbc_chain/sync.py +++ b/apps/blockchain-node/src/aitbc_chain/sync.py @@ -214,7 +214,7 @@ class ChainSync: with self._session_factory() as session: # Check for duplicate existing = session.exec( - select(Block).where(Block.hash == block_hash) + select(Block).where(Block.chain_id == self._chain_id).where(Block.hash == block_hash) ).first() if existing: metrics_registry.increment("sync_blocks_duplicate_total") diff --git a/cli/aitbc_cli/commands/node.py b/cli/aitbc_cli/commands/node.py index 730aef32..c8e79ab0 100755 --- a/cli/aitbc_cli/commands/node.py +++ b/cli/aitbc_cli/commands/node.py @@ -13,10 +13,17 @@ from pathlib import Path from typing import Optional from datetime import datetime -from ..utils.output import output, success, error, warning, info -from ..core.config import MultiChainConfig, load_multichain_config, get_default_node_config, add_node_config, remove_node_config -from ..core.node_client import NodeClient -from ..utils import output, error, success +try: + from ..utils.output import output, success, error, warning, info + from ..core.config import MultiChainConfig, load_multichain_config, get_default_node_config, add_node_config, remove_node_config + from ..core.node_client import NodeClient +except ImportError: + from utils import output, error, success, warning + from core.config import MultiChainConfig, load_multichain_config, get_default_node_config, add_node_config, remove_node_config + from core.node_client import NodeClient + + def info(message): + print(message) import uuid @click.group() @@ -501,6 +508,9 @@ def join(ctx, island_id, island_name, chain_id, hub, is_hub): # Get system hostname hostname = socket.gethostname() + sys.path.insert(0, '/opt/aitbc/apps/blockchain-node/src') + from aitbc_chain.config import settings as chain_settings + # Get public key from keystore keystore_path = '/var/lib/aitbc/keystore/validator_keys.json' public_key_pem = None @@ -522,22 +532,30 @@ def join(ctx, island_id, island_name, chain_id, hub, is_hub): # Generate node_id using hostname-based method local_address = socket.gethostbyname(hostname) - local_port = 8001 # Default hub port + local_port = chain_settings.p2p_bind_port content = f"{hostname}:{local_address}:{local_port}:{public_key_pem}" node_id = hashlib.sha256(content.encode()).hexdigest() # Resolve hub domain to IP hub_ip = socket.gethostbyname(hub) - hub_port = 8001 # Default hub port + hub_port = chain_settings.p2p_bind_port - info(f"Connecting to hub {hub} ({hub_ip}:{hub_port})...") + click.echo(f"Connecting to hub {hub} ({hub_ip}:{hub_port})...") # Create P2P network service instance for sending join request - sys.path.insert(0, '/opt/aitbc/apps/blockchain-node/src') from aitbc_chain.p2p_network import P2PNetworkService # Create a minimal P2P service just for sending the join request - p2p_service = P2PNetworkService(local_address, local_port, node_id, []) + p2p_service = P2PNetworkService( + local_address, + local_port, + node_id, + "", + island_id=island_id, + island_name=island_name, + is_hub=is_hub, + island_chain_id=chain_id or chain_settings.island_chain_id or chain_settings.chain_id, + ) # Send join request async def send_join(): @@ -586,9 +604,9 @@ def join(ctx, island_id, island_name, chain_id, hub, is_hub): # If registering as hub if is_hub: - info("Registering as hub...") + click.echo("Registering as hub...") # Hub registration would happen here via the hub register command - info("Run 'aitbc node hub register' to complete hub registration") + click.echo("Run 'aitbc node hub register' to complete hub registration") else: error("Failed to join island - no response from hub") raise click.Abort()