Add hash conflict detection and cleanup across chains in import operations and improve database initialization
Some checks failed
Some checks failed
- Add hash conflict detection in import_block to delete existing blocks with same hash - Add hash conflict cleanup in import_chain before importing blocks - Add logging for hash conflict deletions showing affected chains - Add WAL file permission setting in init_db for .db-shm and .db-wal files - Add test_import_chain_clears_hash_conflicts_across_chains to verify cross-chain hash cleanup
This commit is contained in:
@@ -86,9 +86,16 @@ def init_db() -> None:
|
|||||||
# If tables already exist, that's okay
|
# If tables already exist, that's okay
|
||||||
if "already exists" not in str(e):
|
if "already exists" not in str(e):
|
||||||
raise
|
raise
|
||||||
# Set restrictive file permissions on database file
|
# Set restrictive file permissions on database file and WAL files
|
||||||
if settings.db_path.exists():
|
if settings.db_path.exists():
|
||||||
os.chmod(settings.db_path, stat.S_IRUSR | stat.S_IWUSR) # Read/write for owner only
|
os.chmod(settings.db_path, stat.S_IRUSR | stat.S_IWUSR) # Read/write for owner only
|
||||||
|
# Also set permissions on WAL files if they exist
|
||||||
|
wal_shm = settings.db_path.with_suffix('.db-shm')
|
||||||
|
wal_wal = settings.db_path.with_suffix('.db-wal')
|
||||||
|
if wal_shm.exists():
|
||||||
|
os.chmod(wal_shm, stat.S_IRUSR | stat.S_IWUSR)
|
||||||
|
if wal_wal.exists():
|
||||||
|
os.chmod(wal_wal, stat.S_IRUSR | stat.S_IWUSR)
|
||||||
|
|
||||||
# Restricted engine access - only for internal use
|
# Restricted engine access - only for internal use
|
||||||
def get_engine():
|
def get_engine():
|
||||||
|
|||||||
@@ -647,11 +647,23 @@ async def import_block(block_data: dict) -> Dict[str, Any]:
|
|||||||
timestamp = datetime.utcnow()
|
timestamp = datetime.utcnow()
|
||||||
|
|
||||||
with session_scope() as session:
|
with session_scope() as session:
|
||||||
|
# Check for hash conflicts across chains
|
||||||
|
block_hash = block_data["hash"]
|
||||||
|
existing_block = session.execute(
|
||||||
|
select(Block).where(Block.hash == block_hash)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing_block:
|
||||||
|
# Delete existing block with conflicting hash
|
||||||
|
_logger.warning(f"Deleting existing block with conflicting hash {block_hash} from chain {existing_block[0].chain_id}")
|
||||||
|
session.execute(delete(Block).where(Block.hash == block_hash))
|
||||||
|
session.commit()
|
||||||
|
|
||||||
# Create block
|
# Create block
|
||||||
block = Block(
|
block = Block(
|
||||||
chain_id=chain_id,
|
chain_id=chain_id,
|
||||||
height=block_data["height"],
|
height=block_data["height"],
|
||||||
hash=block_data["hash"],
|
hash=block_hash,
|
||||||
parent_hash=block_data["parent_hash"],
|
parent_hash=block_data["parent_hash"],
|
||||||
proposer=block_data["proposer"],
|
proposer=block_data["proposer"],
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
@@ -859,6 +871,19 @@ async def import_chain(import_data: dict) -> Dict[str, Any]:
|
|||||||
session.execute(delete(Account).where(Account.chain_id == chain_id))
|
session.execute(delete(Account).where(Account.chain_id == chain_id))
|
||||||
_logger.info(f"Clearing existing blocks for chain {chain_id}")
|
_logger.info(f"Clearing existing blocks for chain {chain_id}")
|
||||||
session.execute(delete(Block).where(Block.chain_id == chain_id))
|
session.execute(delete(Block).where(Block.chain_id == chain_id))
|
||||||
|
|
||||||
|
import_hashes = {block_data["hash"] for block_data in unique_blocks}
|
||||||
|
if import_hashes:
|
||||||
|
hash_conflict_result = session.execute(
|
||||||
|
select(Block.hash, Block.chain_id)
|
||||||
|
.where(Block.hash.in_(import_hashes))
|
||||||
|
)
|
||||||
|
hash_conflicts = hash_conflict_result.all()
|
||||||
|
if hash_conflicts:
|
||||||
|
conflict_chains = {chain_id for _, chain_id in hash_conflicts}
|
||||||
|
_logger.warning(f"Clearing {len(hash_conflicts)} blocks with conflicting hashes across chains: {conflict_chains}")
|
||||||
|
session.execute(delete(Block).where(Block.hash.in_(import_hashes)))
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
session.expire_all()
|
session.expire_all()
|
||||||
|
|
||||||
|
|||||||
@@ -238,3 +238,105 @@ async def test_import_chain_dedupes_duplicate_heights_and_preserves_transaction_
|
|||||||
assert len(chain_a_transactions) == 1
|
assert len(chain_a_transactions) == 1
|
||||||
assert chain_a_transactions[0].tx_hash == _hex("incoming-tx-1")
|
assert chain_a_transactions[0].tx_hash == _hex("incoming-tx-1")
|
||||||
assert chain_a_transactions[0].timestamp == "2026-01-02T00:00:02"
|
assert chain_a_transactions[0].timestamp == "2026-01-02T00:00:02"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_import_chain_clears_hash_conflicts_across_chains(isolated_engine):
|
||||||
|
"""Test that import-chain clears blocks with conflicting hashes across different chains."""
|
||||||
|
from aitbc_chain.rpc import router as rpc_router
|
||||||
|
from aitbc_chain.database import get_engine
|
||||||
|
|
||||||
|
with Session(isolated_engine) as session:
|
||||||
|
session.add(
|
||||||
|
Block(
|
||||||
|
chain_id="chain-a",
|
||||||
|
height=0,
|
||||||
|
hash=_hex("chain-a-block-0"),
|
||||||
|
parent_hash="0x00",
|
||||||
|
proposer="node-a",
|
||||||
|
timestamp=datetime(2026, 1, 1, 0, 0, 0),
|
||||||
|
tx_count=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.add(
|
||||||
|
Block(
|
||||||
|
chain_id="chain-a",
|
||||||
|
height=1,
|
||||||
|
hash=_hex("chain-a-block-1"),
|
||||||
|
parent_hash=_hex("chain-a-block-0"),
|
||||||
|
proposer="node-a",
|
||||||
|
timestamp=datetime(2026, 1, 1, 0, 0, 1),
|
||||||
|
tx_count=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.add(
|
||||||
|
Block(
|
||||||
|
chain_id="chain-b",
|
||||||
|
height=0,
|
||||||
|
hash=_hex("chain-b-block-0"),
|
||||||
|
parent_hash="0x00",
|
||||||
|
proposer="node-b",
|
||||||
|
timestamp=datetime(2026, 1, 1, 0, 0, 0),
|
||||||
|
tx_count=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.add(
|
||||||
|
Block(
|
||||||
|
chain_id="chain-b",
|
||||||
|
height=1,
|
||||||
|
hash=_hex("chain-b-block-1"),
|
||||||
|
parent_hash=_hex("chain-b-block-0"),
|
||||||
|
proposer="node-b",
|
||||||
|
timestamp=datetime(2026, 1, 1, 0, 0, 1),
|
||||||
|
tx_count=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
with Session(isolated_engine) as session:
|
||||||
|
chain_a_blocks = session.exec(
|
||||||
|
select(Block).where(Block.chain_id == "chain-a").order_by(Block.height)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
conflicting_hash = chain_a_blocks[0].hash
|
||||||
|
|
||||||
|
import_payload = {
|
||||||
|
"chain_id": "chain-c",
|
||||||
|
"blocks": [
|
||||||
|
{
|
||||||
|
"chain_id": "chain-c",
|
||||||
|
"height": 0,
|
||||||
|
"hash": conflicting_hash,
|
||||||
|
"parent_hash": _hex("parent-0"),
|
||||||
|
"proposer": _hex("proposer-0"),
|
||||||
|
"timestamp": "2026-01-01T00:00:00",
|
||||||
|
"tx_count": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chain_id": "chain-c",
|
||||||
|
"height": 1,
|
||||||
|
"hash": _hex("chain-c-block-1"),
|
||||||
|
"parent_hash": conflicting_hash,
|
||||||
|
"proposer": _hex("proposer-1"),
|
||||||
|
"timestamp": "2026-01-01T00:00:01",
|
||||||
|
"tx_count": 0,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await rpc_router.import_chain(import_payload)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["imported_blocks"] == 2
|
||||||
|
|
||||||
|
with Session(isolated_engine) as session:
|
||||||
|
chain_c_blocks = session.exec(
|
||||||
|
select(Block).where(Block.chain_id == "chain-c").order_by(Block.height)
|
||||||
|
).all()
|
||||||
|
chain_a_blocks_after = session.exec(
|
||||||
|
select(Block).where(Block.chain_id == "chain-a").order_by(Block.height)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
assert [block.height for block in chain_c_blocks] == [0, 1]
|
||||||
|
assert chain_c_blocks[0].hash == conflicting_hash
|
||||||
|
assert len(chain_a_blocks_after) == 1
|
||||||
|
assert chain_a_blocks_after[0].height == 1
|
||||||
|
|||||||
@@ -11,10 +11,21 @@ import urllib.parse
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
# Database setup
|
# Database setup
|
||||||
def init_db():
|
def get_db_path():
|
||||||
"""Initialize SQLite database"""
|
"""Get database path and ensure directory exists"""
|
||||||
import os
|
import os
|
||||||
db_path = os.getenv("EXCHANGE_DATABASE_URL", "sqlite:////var/lib/aitbc/data/exchange/exchange.db").replace("sqlite:///", "")
|
db_path = os.getenv("EXCHANGE_DATABASE_URL", "sqlite:////var/lib/aitbc/data/exchange/exchange.db").replace("sqlite:///", "")
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
db_dir = os.path.dirname(db_path)
|
||||||
|
if db_dir and not os.path.exists(db_dir):
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
"""Initialize SQLite database"""
|
||||||
|
db_path = get_db_path()
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -61,8 +72,7 @@ def init_db():
|
|||||||
|
|
||||||
def create_mock_trades():
|
def create_mock_trades():
|
||||||
"""Create some mock trades"""
|
"""Create some mock trades"""
|
||||||
import os
|
db_path = get_db_path()
|
||||||
db_path = os.getenv("EXCHANGE_DATABASE_URL", "sqlite:////var/lib/aitbc/data/exchange/exchange.db").replace("sqlite:///", "")
|
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -121,8 +131,7 @@ class ExchangeAPIHandler(BaseHTTPRequestHandler):
|
|||||||
query = urllib.parse.parse_qs(parsed.query)
|
query = urllib.parse.parse_qs(parsed.query)
|
||||||
limit = int(query.get('limit', [20])[0])
|
limit = int(query.get('limit', [20])[0])
|
||||||
|
|
||||||
import os
|
db_path = get_db_path()
|
||||||
db_path = os.getenv("EXCHANGE_DATABASE_URL", "sqlite:////var/lib/aitbc/data/exchange/exchange.db").replace("sqlite:///", "")
|
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -149,8 +158,7 @@ class ExchangeAPIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def get_orderbook(self):
|
def get_orderbook(self):
|
||||||
"""Get order book"""
|
"""Get order book"""
|
||||||
import os
|
db_path = get_db_path()
|
||||||
db_path = os.getenv("EXCHANGE_DATABASE_URL", "sqlite:////var/lib/aitbc/data/exchange/exchange.db").replace("sqlite:///", "")
|
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -255,8 +263,7 @@ class ExchangeAPIHandler(BaseHTTPRequestHandler):
|
|||||||
# Store order in local database for orderbook
|
# Store order in local database for orderbook
|
||||||
total = amount * price
|
total = amount * price
|
||||||
|
|
||||||
import os
|
db_path = get_db_path()
|
||||||
db_path = os.getenv("EXCHANGE_DATABASE_URL", "sqlite:////var/lib/aitbc/data/exchange/exchange.db").replace("sqlite:///", "")
|
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -297,8 +304,7 @@ class ExchangeAPIHandler(BaseHTTPRequestHandler):
|
|||||||
# Fallback to database-only if blockchain is down
|
# Fallback to database-only if blockchain is down
|
||||||
total = amount * price
|
total = amount * price
|
||||||
|
|
||||||
import os
|
db_path = get_db_path()
|
||||||
db_path = os.getenv("EXCHANGE_DATABASE_URL", "sqlite:////var/lib/aitbc/data/exchange/exchange.db").replace("sqlite:///", "")
|
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user