feat: add phase 2 decentralized memory python services and tests

This commit is contained in:
oib
2026-02-28 23:19:02 +01:00
parent c63259ef2c
commit 5bc18d684c
10 changed files with 1054 additions and 0 deletions

View File

@@ -0,0 +1,129 @@
import pytest
from unittest.mock import AsyncMock
from sqlmodel import Session, create_engine, SQLModel
from sqlmodel.pool import StaticPool
from fastapi import HTTPException
from app.services.federated_learning import FederatedLearningService
from app.domain.federated_learning import TrainingStatus, ParticipantStatus
from app.schemas.federated_learning import FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest
@pytest.fixture
def test_db():
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
session = Session(engine)
yield session
session.close()
@pytest.fixture
def mock_contract_service():
return AsyncMock()
@pytest.fixture
def fl_service(test_db, mock_contract_service):
return FederatedLearningService(
session=test_db,
contract_service=mock_contract_service
)
@pytest.mark.asyncio
async def test_create_session(fl_service):
req = FederatedSessionCreate(
initiator_agent_id="agent-admin",
task_description="Train LLM on financial data",
model_architecture_cid="bafy_arch_123",
target_participants=2,
total_rounds=2,
min_participants_per_round=2
)
session = await fl_service.create_session(req)
assert session.status == TrainingStatus.GATHERING_PARTICIPANTS
assert session.initiator_agent_id == "agent-admin"
@pytest.mark.asyncio
async def test_join_session_and_start(fl_service):
req = FederatedSessionCreate(
initiator_agent_id="agent-admin",
task_description="Train LLM on financial data",
model_architecture_cid="bafy_arch_123",
target_participants=2,
total_rounds=2,
min_participants_per_round=2
)
session = await fl_service.create_session(req)
# Agent 1 joins
p1 = await fl_service.join_session(
session.id,
JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0)
)
assert p1.status == ParticipantStatus.JOINED
assert session.status == TrainingStatus.GATHERING_PARTICIPANTS
# Agent 2 joins, triggers start
p2 = await fl_service.join_session(
session.id,
JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0)
)
# Needs refresh
fl_service.session.refresh(session)
assert session.status == TrainingStatus.TRAINING
assert session.current_round == 1
assert len(session.rounds) == 1
assert session.rounds[0].status == "active"
@pytest.mark.asyncio
async def test_submit_updates_and_aggregate(fl_service):
# Setup
req = FederatedSessionCreate(
initiator_agent_id="agent-admin",
task_description="Train LLM on financial data",
model_architecture_cid="bafy_arch_123",
target_participants=2,
total_rounds=1, # Only 1 round for quick test
min_participants_per_round=2
)
session = await fl_service.create_session(req)
await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0))
await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0))
fl_service.session.refresh(session)
round1 = session.rounds[0]
# Agent 1 submits
u1 = await fl_service.submit_local_update(
session.id,
round1.id,
SubmitUpdateRequest(agent_id="agent-1", weights_cid="bafy_w1", data_samples_count=1000)
)
assert u1.weights_cid == "bafy_w1"
fl_service.session.refresh(session)
fl_service.session.refresh(round1)
# Not aggregated yet
assert round1.status == "active"
# Agent 2 submits, triggers aggregation and completion since total_rounds=1
u2 = await fl_service.submit_local_update(
session.id,
round1.id,
SubmitUpdateRequest(agent_id="agent-2", weights_cid="bafy_w2", data_samples_count=1500)
)
fl_service.session.refresh(session)
fl_service.session.refresh(round1)
assert round1.status == "completed"
assert session.status == TrainingStatus.COMPLETED
assert session.global_model_cid is not None
assert session.global_model_cid.startswith("bafy_aggregated_")

View File

@@ -0,0 +1,108 @@
import pytest
from unittest.mock import AsyncMock
from sqlmodel import Session, create_engine, SQLModel
from sqlmodel.pool import StaticPool
from fastapi import HTTPException
from app.services.ipfs_storage_adapter import IPFSAdapterService
from app.domain.decentralized_memory import MemoryType, StorageStatus
from app.schemas.decentralized_memory import MemoryNodeCreate
@pytest.fixture
def test_db():
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
session = Session(engine)
yield session
session.close()
@pytest.fixture
def mock_contract_service():
return AsyncMock()
@pytest.fixture
def storage_service(test_db, mock_contract_service):
return IPFSAdapterService(
session=test_db,
contract_service=mock_contract_service,
pinning_service_token="mock_token"
)
@pytest.mark.asyncio
async def test_store_memory(storage_service):
request = MemoryNodeCreate(
agent_id="agent-007",
memory_type=MemoryType.VECTOR_DB,
tags=["training", "batch1"]
)
raw_data = b"mock_vector_embeddings_data"
node = await storage_service.store_memory(request, raw_data, zk_proof_hash="0xabc123")
assert node.agent_id == "agent-007"
assert node.memory_type == MemoryType.VECTOR_DB
assert node.cid is not None
assert node.cid.startswith("bafy")
assert node.size_bytes == len(raw_data)
assert node.status == StorageStatus.PINNED
assert node.zk_proof_hash == "0xabc123"
@pytest.mark.asyncio
async def test_get_memory_nodes(storage_service):
# Store multiple
await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB, tags=["v1"]),
b"data1"
)
await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH, tags=["v1"]),
b"data2"
)
await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-008", memory_type=MemoryType.VECTOR_DB),
b"data3"
)
# Get all for agent-007
nodes = await storage_service.get_memory_nodes("agent-007")
assert len(nodes) == 2
# Filter by type
nodes_kg = await storage_service.get_memory_nodes("agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH)
assert len(nodes_kg) == 1
assert nodes_kg[0].memory_type == MemoryType.KNOWLEDGE_GRAPH
# Filter by tag
nodes_tag = await storage_service.get_memory_nodes("agent-007", tags=["v1"])
assert len(nodes_tag) == 2
@pytest.mark.asyncio
async def test_anchor_to_blockchain(storage_service):
node = await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB),
b"data1"
)
assert node.anchor_tx_hash is None
anchored_node = await storage_service.anchor_to_blockchain(node.id)
assert anchored_node.status == StorageStatus.ANCHORED
assert anchored_node.anchor_tx_hash is not None
@pytest.mark.asyncio
async def test_retrieve_memory(storage_service):
node = await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB),
b"data1"
)
data = await storage_service.retrieve_memory(node.id)
assert isinstance(data, bytes)
assert b"mock" in data

View File

@@ -0,0 +1,92 @@
import pytest
from unittest.mock import AsyncMock
from sqlmodel import Session, create_engine, SQLModel
from sqlmodel.pool import StaticPool
from app.services.zk_memory_verification import ZKMemoryVerificationService
from app.domain.decentralized_memory import AgentMemoryNode, MemoryType
@pytest.fixture
def test_db():
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
session = Session(engine)
yield session
session.close()
@pytest.fixture
def mock_contract_service():
return AsyncMock()
@pytest.fixture
def zk_service(test_db, mock_contract_service):
return ZKMemoryVerificationService(
session=test_db,
contract_service=mock_contract_service
)
@pytest.mark.asyncio
async def test_generate_memory_proof(zk_service, test_db):
node = AgentMemoryNode(
agent_id="agent-zk",
memory_type=MemoryType.VECTOR_DB
)
test_db.add(node)
test_db.commit()
test_db.refresh(node)
raw_data = b"secret_vector_data"
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
assert proof_payload is not None
assert proof_hash.startswith("0x")
assert "groth16" in proof_payload
@pytest.mark.asyncio
async def test_verify_retrieved_memory_success(zk_service, test_db):
node = AgentMemoryNode(
agent_id="agent-zk",
memory_type=MemoryType.VECTOR_DB
)
test_db.add(node)
test_db.commit()
test_db.refresh(node)
raw_data = b"secret_vector_data"
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
# Simulate anchoring
node.zk_proof_hash = proof_hash
test_db.commit()
# Verify
is_valid = await zk_service.verify_retrieved_memory(node.id, raw_data, proof_payload)
assert is_valid is True
@pytest.mark.asyncio
async def test_verify_retrieved_memory_tampered_data(zk_service, test_db):
node = AgentMemoryNode(
agent_id="agent-zk",
memory_type=MemoryType.VECTOR_DB
)
test_db.add(node)
test_db.commit()
test_db.refresh(node)
raw_data = b"secret_vector_data"
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
node.zk_proof_hash = proof_hash
test_db.commit()
# Tamper with data
tampered_data = b"secret_vector_data_modified"
is_valid = await zk_service.verify_retrieved_memory(node.id, tampered_data, proof_payload)
assert is_valid is False