feat: add phase 2 decentralized memory python services and tests
This commit is contained in:
129
apps/coordinator-api/tests/test_federated_learning.py
Normal file
129
apps/coordinator-api/tests/test_federated_learning.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from sqlmodel import Session, create_engine, SQLModel
|
||||
from sqlmodel.pool import StaticPool
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.services.federated_learning import FederatedLearningService
|
||||
from app.domain.federated_learning import TrainingStatus, ParticipantStatus
|
||||
from app.schemas.federated_learning import FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
session = Session(engine)
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_contract_service():
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def fl_service(test_db, mock_contract_service):
|
||||
return FederatedLearningService(
|
||||
session=test_db,
|
||||
contract_service=mock_contract_service
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(fl_service):
|
||||
req = FederatedSessionCreate(
|
||||
initiator_agent_id="agent-admin",
|
||||
task_description="Train LLM on financial data",
|
||||
model_architecture_cid="bafy_arch_123",
|
||||
target_participants=2,
|
||||
total_rounds=2,
|
||||
min_participants_per_round=2
|
||||
)
|
||||
|
||||
session = await fl_service.create_session(req)
|
||||
assert session.status == TrainingStatus.GATHERING_PARTICIPANTS
|
||||
assert session.initiator_agent_id == "agent-admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_session_and_start(fl_service):
|
||||
req = FederatedSessionCreate(
|
||||
initiator_agent_id="agent-admin",
|
||||
task_description="Train LLM on financial data",
|
||||
model_architecture_cid="bafy_arch_123",
|
||||
target_participants=2,
|
||||
total_rounds=2,
|
||||
min_participants_per_round=2
|
||||
)
|
||||
session = await fl_service.create_session(req)
|
||||
|
||||
# Agent 1 joins
|
||||
p1 = await fl_service.join_session(
|
||||
session.id,
|
||||
JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0)
|
||||
)
|
||||
assert p1.status == ParticipantStatus.JOINED
|
||||
assert session.status == TrainingStatus.GATHERING_PARTICIPANTS
|
||||
|
||||
# Agent 2 joins, triggers start
|
||||
p2 = await fl_service.join_session(
|
||||
session.id,
|
||||
JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0)
|
||||
)
|
||||
|
||||
# Needs refresh
|
||||
fl_service.session.refresh(session)
|
||||
|
||||
assert session.status == TrainingStatus.TRAINING
|
||||
assert session.current_round == 1
|
||||
assert len(session.rounds) == 1
|
||||
assert session.rounds[0].status == "active"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_updates_and_aggregate(fl_service):
|
||||
# Setup
|
||||
req = FederatedSessionCreate(
|
||||
initiator_agent_id="agent-admin",
|
||||
task_description="Train LLM on financial data",
|
||||
model_architecture_cid="bafy_arch_123",
|
||||
target_participants=2,
|
||||
total_rounds=1, # Only 1 round for quick test
|
||||
min_participants_per_round=2
|
||||
)
|
||||
session = await fl_service.create_session(req)
|
||||
await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0))
|
||||
await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0))
|
||||
|
||||
fl_service.session.refresh(session)
|
||||
round1 = session.rounds[0]
|
||||
|
||||
# Agent 1 submits
|
||||
u1 = await fl_service.submit_local_update(
|
||||
session.id,
|
||||
round1.id,
|
||||
SubmitUpdateRequest(agent_id="agent-1", weights_cid="bafy_w1", data_samples_count=1000)
|
||||
)
|
||||
assert u1.weights_cid == "bafy_w1"
|
||||
|
||||
fl_service.session.refresh(session)
|
||||
fl_service.session.refresh(round1)
|
||||
|
||||
# Not aggregated yet
|
||||
assert round1.status == "active"
|
||||
|
||||
# Agent 2 submits, triggers aggregation and completion since total_rounds=1
|
||||
u2 = await fl_service.submit_local_update(
|
||||
session.id,
|
||||
round1.id,
|
||||
SubmitUpdateRequest(agent_id="agent-2", weights_cid="bafy_w2", data_samples_count=1500)
|
||||
)
|
||||
|
||||
fl_service.session.refresh(session)
|
||||
fl_service.session.refresh(round1)
|
||||
|
||||
assert round1.status == "completed"
|
||||
assert session.status == TrainingStatus.COMPLETED
|
||||
assert session.global_model_cid is not None
|
||||
assert session.global_model_cid.startswith("bafy_aggregated_")
|
||||
108
apps/coordinator-api/tests/test_ipfs_storage_adapter.py
Normal file
108
apps/coordinator-api/tests/test_ipfs_storage_adapter.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from sqlmodel import Session, create_engine, SQLModel
|
||||
from sqlmodel.pool import StaticPool
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.services.ipfs_storage_adapter import IPFSAdapterService
|
||||
from app.domain.decentralized_memory import MemoryType, StorageStatus
|
||||
from app.schemas.decentralized_memory import MemoryNodeCreate
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
session = Session(engine)
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_contract_service():
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def storage_service(test_db, mock_contract_service):
|
||||
return IPFSAdapterService(
|
||||
session=test_db,
|
||||
contract_service=mock_contract_service,
|
||||
pinning_service_token="mock_token"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_memory(storage_service):
|
||||
request = MemoryNodeCreate(
|
||||
agent_id="agent-007",
|
||||
memory_type=MemoryType.VECTOR_DB,
|
||||
tags=["training", "batch1"]
|
||||
)
|
||||
|
||||
raw_data = b"mock_vector_embeddings_data"
|
||||
|
||||
node = await storage_service.store_memory(request, raw_data, zk_proof_hash="0xabc123")
|
||||
|
||||
assert node.agent_id == "agent-007"
|
||||
assert node.memory_type == MemoryType.VECTOR_DB
|
||||
assert node.cid is not None
|
||||
assert node.cid.startswith("bafy")
|
||||
assert node.size_bytes == len(raw_data)
|
||||
assert node.status == StorageStatus.PINNED
|
||||
assert node.zk_proof_hash == "0xabc123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_nodes(storage_service):
|
||||
# Store multiple
|
||||
await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB, tags=["v1"]),
|
||||
b"data1"
|
||||
)
|
||||
await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH, tags=["v1"]),
|
||||
b"data2"
|
||||
)
|
||||
await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-008", memory_type=MemoryType.VECTOR_DB),
|
||||
b"data3"
|
||||
)
|
||||
|
||||
# Get all for agent-007
|
||||
nodes = await storage_service.get_memory_nodes("agent-007")
|
||||
assert len(nodes) == 2
|
||||
|
||||
# Filter by type
|
||||
nodes_kg = await storage_service.get_memory_nodes("agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH)
|
||||
assert len(nodes_kg) == 1
|
||||
assert nodes_kg[0].memory_type == MemoryType.KNOWLEDGE_GRAPH
|
||||
|
||||
# Filter by tag
|
||||
nodes_tag = await storage_service.get_memory_nodes("agent-007", tags=["v1"])
|
||||
assert len(nodes_tag) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anchor_to_blockchain(storage_service):
|
||||
node = await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB),
|
||||
b"data1"
|
||||
)
|
||||
|
||||
assert node.anchor_tx_hash is None
|
||||
|
||||
anchored_node = await storage_service.anchor_to_blockchain(node.id)
|
||||
|
||||
assert anchored_node.status == StorageStatus.ANCHORED
|
||||
assert anchored_node.anchor_tx_hash is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_memory(storage_service):
|
||||
node = await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB),
|
||||
b"data1"
|
||||
)
|
||||
|
||||
data = await storage_service.retrieve_memory(node.id)
|
||||
assert isinstance(data, bytes)
|
||||
assert b"mock" in data
|
||||
92
apps/coordinator-api/tests/test_zk_memory_verification.py
Normal file
92
apps/coordinator-api/tests/test_zk_memory_verification.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from sqlmodel import Session, create_engine, SQLModel
|
||||
from sqlmodel.pool import StaticPool
|
||||
|
||||
from app.services.zk_memory_verification import ZKMemoryVerificationService
|
||||
from app.domain.decentralized_memory import AgentMemoryNode, MemoryType
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
session = Session(engine)
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_contract_service():
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def zk_service(test_db, mock_contract_service):
|
||||
return ZKMemoryVerificationService(
|
||||
session=test_db,
|
||||
contract_service=mock_contract_service
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_memory_proof(zk_service, test_db):
|
||||
node = AgentMemoryNode(
|
||||
agent_id="agent-zk",
|
||||
memory_type=MemoryType.VECTOR_DB
|
||||
)
|
||||
test_db.add(node)
|
||||
test_db.commit()
|
||||
test_db.refresh(node)
|
||||
|
||||
raw_data = b"secret_vector_data"
|
||||
|
||||
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
|
||||
|
||||
assert proof_payload is not None
|
||||
assert proof_hash.startswith("0x")
|
||||
assert "groth16" in proof_payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_retrieved_memory_success(zk_service, test_db):
|
||||
node = AgentMemoryNode(
|
||||
agent_id="agent-zk",
|
||||
memory_type=MemoryType.VECTOR_DB
|
||||
)
|
||||
test_db.add(node)
|
||||
test_db.commit()
|
||||
test_db.refresh(node)
|
||||
|
||||
raw_data = b"secret_vector_data"
|
||||
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
|
||||
|
||||
# Simulate anchoring
|
||||
node.zk_proof_hash = proof_hash
|
||||
test_db.commit()
|
||||
|
||||
# Verify
|
||||
is_valid = await zk_service.verify_retrieved_memory(node.id, raw_data, proof_payload)
|
||||
assert is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_retrieved_memory_tampered_data(zk_service, test_db):
|
||||
node = AgentMemoryNode(
|
||||
agent_id="agent-zk",
|
||||
memory_type=MemoryType.VECTOR_DB
|
||||
)
|
||||
test_db.add(node)
|
||||
test_db.commit()
|
||||
test_db.refresh(node)
|
||||
|
||||
raw_data = b"secret_vector_data"
|
||||
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
|
||||
|
||||
node.zk_proof_hash = proof_hash
|
||||
test_db.commit()
|
||||
|
||||
# Tamper with data
|
||||
tampered_data = b"secret_vector_data_modified"
|
||||
|
||||
is_valid = await zk_service.verify_retrieved_memory(node.id, tampered_data, proof_payload)
|
||||
assert is_valid is False
|
||||
Reference in New Issue
Block a user