fix: stabilize multichain hub and follower sync flow
Some checks failed
CLI Tests / test-cli (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled

This commit is contained in:
aitbc
2026-04-13 14:31:23 +02:00
parent d72945f20c
commit bc96e47b8f
9 changed files with 469 additions and 234 deletions

View File

@@ -17,6 +17,7 @@ from ..mempool import get_mempool
from ..metrics import metrics_registry
from ..models import Account, Block, Receipt, Transaction
from ..logger import get_logger
from ..sync import ChainSync
from ..contracts.agent_messaging_contract import messaging_contract
_logger = get_logger(__name__)
@@ -28,16 +29,18 @@ _last_import_time = 0
_import_lock = asyncio.Lock()
# Global variable to store the PoA proposer
_poa_proposer = None
_poa_proposers: Dict[str, Any] = {}
def set_poa_proposer(proposer):
def set_poa_proposer(proposer, chain_id: str = None):
"""Set the global PoA proposer instance"""
global _poa_proposer
_poa_proposer = proposer
if chain_id is None:
chain_id = getattr(getattr(proposer, "_config", None), "chain_id", None) or get_chain_id(None)
_poa_proposers[chain_id] = proposer
def get_poa_proposer():
def get_poa_proposer(chain_id: str = None):
"""Get the global PoA proposer instance"""
return _poa_proposer
chain_id = get_chain_id(chain_id)
return _poa_proposers.get(chain_id)
def get_chain_id(chain_id: str = None) -> str:
"""Get chain_id from parameter or use default from settings"""
@@ -46,6 +49,117 @@ def get_chain_id(chain_id: str = None) -> str:
return settings.chain_id
return chain_id
def get_supported_chains() -> List[str]:
from ..config import settings
chains = [chain.strip() for chain in settings.supported_chains.split(",") if chain.strip()]
if not chains and settings.chain_id:
return [settings.chain_id]
return chains
def _normalize_transaction_data(tx_data: Dict[str, Any], chain_id: str) -> Dict[str, Any]:
sender = tx_data.get("from")
recipient = tx_data.get("to")
if not isinstance(sender, str) or not sender.strip():
raise ValueError("transaction.from is required")
if not isinstance(recipient, str) or not recipient.strip():
raise ValueError("transaction.to is required")
try:
amount = int(tx_data["amount"])
except KeyError as exc:
raise ValueError("transaction.amount is required") from exc
except (TypeError, ValueError) as exc:
raise ValueError("transaction.amount must be an integer") from exc
try:
fee = int(tx_data.get("fee", 10))
except (TypeError, ValueError) as exc:
raise ValueError("transaction.fee must be an integer") from exc
try:
nonce = int(tx_data.get("nonce", 0))
except (TypeError, ValueError) as exc:
raise ValueError("transaction.nonce must be an integer") from exc
if amount < 0:
raise ValueError("transaction.amount must be non-negative")
if fee < 0:
raise ValueError("transaction.fee must be non-negative")
if nonce < 0:
raise ValueError("transaction.nonce must be non-negative")
payload = tx_data.get("payload", "0x")
if payload is None:
payload = "0x"
return {
"chain_id": chain_id,
"from": sender.strip(),
"to": recipient.strip(),
"amount": amount,
"fee": fee,
"nonce": nonce,
"payload": payload,
"signature": tx_data.get("signature") or tx_data.get("sig"),
}
def _validate_transaction_admission(tx_data: Dict[str, Any], mempool: Any) -> None:
from ..mempool import compute_tx_hash
chain_id = tx_data["chain_id"]
supported_chains = get_supported_chains()
if not chain_id:
raise ValueError("transaction.chain_id is required")
if supported_chains and chain_id not in supported_chains:
raise ValueError(f"unsupported chain_id '{chain_id}'. Supported chains: {supported_chains}")
tx_hash = compute_tx_hash(tx_data)
with session_scope() as session:
sender_account = session.get(Account, (chain_id, tx_data["from"]))
if sender_account is None:
raise ValueError(f"sender account not found on chain '{chain_id}'")
total_cost = tx_data["amount"] + tx_data["fee"]
if sender_account.balance < total_cost:
raise ValueError(
f"insufficient balance for sender '{tx_data['from']}' on chain '{chain_id}': has {sender_account.balance}, needs {total_cost}"
)
if tx_data["nonce"] != sender_account.nonce:
raise ValueError(
f"invalid nonce for sender '{tx_data['from']}' on chain '{chain_id}': expected {sender_account.nonce}, got {tx_data['nonce']}"
)
existing_tx = session.exec(
select(Transaction)
.where(Transaction.chain_id == chain_id)
.where(Transaction.tx_hash == tx_hash)
).first()
if existing_tx is not None:
raise ValueError(f"transaction '{tx_hash}' is already confirmed on chain '{chain_id}'")
existing_nonce = session.exec(
select(Transaction)
.where(Transaction.chain_id == chain_id)
.where(Transaction.sender == tx_data["from"])
.where(Transaction.nonce == tx_data["nonce"])
).first()
if existing_nonce is not None:
raise ValueError(
f"sender '{tx_data['from']}' already used nonce {tx_data['nonce']} on chain '{chain_id}'"
)
pending_txs = mempool.list_transactions(chain_id=chain_id)
if any(pending_tx.tx_hash == tx_hash for pending_tx in pending_txs):
raise ValueError(f"transaction '{tx_hash}' is already pending on chain '{chain_id}'")
if any(
pending_tx.content.get("from") == tx_data["from"] and pending_tx.content.get("nonce") == tx_data["nonce"]
for pending_tx in pending_txs
):
raise ValueError(
f"sender '{tx_data['from']}' already has pending nonce {tx_data['nonce']} on chain '{chain_id}'"
)
def _serialize_receipt(receipt: Receipt) -> Dict[str, Any]:
return {
@@ -118,17 +232,25 @@ async def get_head(chain_id: str = None) -> Dict[str, Any]:
@router.get("/blocks/{height}", summary="Get block by height")
async def get_block(height: int) -> Dict[str, Any]:
async def get_block(height: int, chain_id: str = None) -> Dict[str, Any]:
"""Get block by height"""
chain_id = get_chain_id(chain_id)
metrics_registry.increment("rpc_get_block_total")
start = time.perf_counter()
with session_scope() as session:
block = session.exec(select(Block).where(Block.height == height)).first()
block = session.exec(
select(Block).where(Block.chain_id == chain_id).where(Block.height == height)
).first()
if block is None:
metrics_registry.increment("rpc_get_block_not_found_total")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="block not found")
metrics_registry.increment("rpc_get_block_success_total")
txs = session.exec(select(Transaction).where(Transaction.block_height == height)).all()
txs = session.exec(
select(Transaction)
.where(Transaction.chain_id == chain_id)
.where(Transaction.block_height == height)
).all()
tx_list = []
for tx in txs:
t = dict(tx.payload) if tx.payload else {}
@@ -137,6 +259,7 @@ async def get_block(height: int) -> Dict[str, Any]:
metrics_registry.observe("rpc_get_block_duration_seconds", time.perf_counter() - start)
return {
"chain_id": block.chain_id,
"height": block.height,
"hash": block.hash,
"parent_hash": block.parent_hash,
@@ -152,26 +275,16 @@ async def get_block(height: int) -> Dict[str, Any]:
async def submit_transaction(tx_data: dict) -> Dict[str, Any]:
"""Submit a new transaction to the mempool"""
from ..mempool import get_mempool
from ..models import Transaction
try:
mempool = get_mempool()
# Create transaction data as dictionary
tx_data_dict = {
"chain_id": tx_data.get("chain_id", "ait-mainnet"),
"from": tx_data["from"],
"to": tx_data["to"],
"amount": tx_data["amount"],
"fee": tx_data.get("fee", 10),
"nonce": tx_data.get("nonce", 0),
"payload": tx_data.get("payload", "0x"),
"signature": tx_data.get("signature")
}
# Add to mempool
tx_hash = mempool.add(tx_data_dict)
chain_id = tx_data.get("chain_id") or get_chain_id(None)
tx_data_dict = _normalize_transaction_data(tx_data, chain_id)
_validate_transaction_admission(tx_data_dict, mempool)
tx_hash = mempool.add(tx_data_dict, chain_id=chain_id)
return {
"success": True,
"transaction_hash": tx_hash,
@@ -282,7 +395,7 @@ async def query_transactions(
@router.get("/blocks-range", summary="Get blocks in height range")
async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = True) -> Dict[str, Any]:
async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = True, chain_id: str = None) -> Dict[str, Any]:
"""Get blocks in a height range
Args:
@@ -291,12 +404,12 @@ async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = Tru
include_tx: Whether to include transaction data (default: True)
"""
with session_scope() as session:
from ..config import settings as cfg
from ..models import Transaction
chain_id = get_chain_id(chain_id)
blocks = session.exec(
select(Block).where(
Block.chain_id == cfg.chain_id,
Block.chain_id == chain_id,
Block.height >= start,
Block.height <= end,
).order_by(Block.height.asc())
@@ -317,7 +430,9 @@ async def get_blocks_range(start: int = 0, end: int = 10, include_tx: bool = Tru
if include_tx:
# Fetch transactions for this block
txs = session.exec(
select(Transaction).where(Transaction.block_height == b.height)
select(Transaction)
.where(Transaction.chain_id == chain_id)
.where(Transaction.block_height == b.height)
).all()
block_data["transactions"] = [tx.model_dump() for tx in txs]
@@ -424,59 +539,53 @@ async def import_block(block_data: dict) -> Dict[str, Any]:
_last_import_time = time.time()
with session_scope() as session:
# Convert timestamp string to datetime if needed
timestamp = block_data.get("timestamp")
if isinstance(timestamp, str):
try:
timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
except ValueError:
# Fallback to current time if parsing fails
timestamp = datetime.utcnow()
elif timestamp is None:
chain_id = block_data.get("chain_id") or block_data.get("chainId") or get_chain_id(None)
timestamp = block_data.get("timestamp")
if isinstance(timestamp, str):
try:
timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
except ValueError:
timestamp = datetime.utcnow()
# Extract height from either 'number' or 'height' field
height = block_data.get("number") or block_data.get("height")
if height is None:
raise ValueError("Block height is required")
# Check if block already exists to prevent duplicates
existing = session.execute(
select(Block).where(Block.height == int(height))
).scalar_one_or_none()
if existing:
return {
"success": True,
"block_number": existing.height,
"block_hash": existing.hash,
"message": "Block already exists"
}
# Create block from data
block = Block(
chain_id=block_data.get("chainId", "ait-mainnet"),
height=int(height),
hash=block_data.get("hash"),
parent_hash=block_data.get("parentHash", ""),
proposer=block_data.get("miner", ""),
timestamp=timestamp,
tx_count=len(block_data.get("transactions", [])),
state_root=block_data.get("stateRoot"),
block_metadata=json.dumps(block_data)
)
session.add(block)
session.commit()
_logger.info(f"Successfully imported block {block.height}")
elif timestamp is None:
timestamp = datetime.utcnow()
height = block_data.get("number") or block_data.get("height")
if height is None:
raise ValueError("Block height is required")
transactions = block_data.get("transactions", [])
normalized_block = {
"chain_id": chain_id,
"height": int(height),
"hash": block_data.get("hash"),
"parent_hash": block_data.get("parent_hash") or block_data.get("parentHash", ""),
"proposer": block_data.get("proposer") or block_data.get("miner", ""),
"timestamp": timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp,
"tx_count": block_data.get("tx_count", len(transactions)),
"state_root": block_data.get("state_root") or block_data.get("stateRoot"),
}
from ..config import settings as cfg
sync = ChainSync(
session_factory=session_scope,
chain_id=chain_id,
validate_signatures=cfg.sync_validate_signatures,
)
result = sync.import_block(normalized_block, transactions=transactions)
if result.accepted:
_logger.info(f"Successfully imported block {result.height}")
metrics_registry.increment("blocks_imported_total")
return {
"success": True,
"block_number": block.height,
"block_hash": block.hash
}
return {
"success": result.accepted,
"accepted": result.accepted,
"block_number": result.height,
"block_hash": result.block_hash,
"chain_id": chain_id,
"reason": result.reason,
}
except Exception as e:
_logger.error(f"Failed to import block: {e}")