feat: add phase 2 decentralized memory python services and tests

This commit is contained in:
oib
2026-02-28 23:19:02 +01:00
parent c63259ef2c
commit 5bc18d684c
10 changed files with 1054 additions and 0 deletions

View File

@@ -0,0 +1,56 @@
"""
Decentralized Memory Domain Models
Domain models for managing agent memory and knowledge graphs on IPFS/Filecoin.
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Dict, Optional, List
from uuid import uuid4
from sqlalchemy import Column, JSON
from sqlmodel import Field, SQLModel, Relationship
class MemoryType(str, Enum):
VECTOR_DB = "vector_db"
KNOWLEDGE_GRAPH = "knowledge_graph"
POLICY_WEIGHTS = "policy_weights"
EPISODIC = "episodic"
class StorageStatus(str, Enum):
PENDING = "pending" # Upload to IPFS pending
UPLOADED = "uploaded" # Available on IPFS
PINNED = "pinned" # Pinned on Filecoin/Pinata
ANCHORED = "anchored" # CID written to blockchain
FAILED = "failed" # Upload failed
class AgentMemoryNode(SQLModel, table=True):
"""Represents a chunk of memory or knowledge stored on decentralized storage"""
__tablename__ = "agent_memory_node"
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
agent_id: str = Field(index=True)
memory_type: MemoryType = Field(index=True)
# Decentralized Storage Identifiers
cid: Optional[str] = Field(default=None, index=True) # IPFS Content Identifier
size_bytes: Optional[int] = Field(default=None)
# Encryption and Security
is_encrypted: bool = Field(default=True)
encryption_key_id: Optional[str] = Field(default=None) # Reference to KMS or Lit Protocol
zk_proof_hash: Optional[str] = Field(default=None) # Hash of the ZK proof verifying content validity
status: StorageStatus = Field(default=StorageStatus.PENDING, index=True)
metadata: Dict[str, str] = Field(default_factory=dict, sa_column=Column(JSON))
tags: List[str] = Field(default_factory=list, sa_column=Column(JSON))
# Blockchain Anchoring
anchor_tx_hash: Optional[str] = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -0,0 +1,122 @@
"""
Federated Learning Domain Models
Domain models for managing cross-agent knowledge sharing and collaborative model training.
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from uuid import uuid4
from sqlalchemy import Column, JSON
from sqlmodel import Field, SQLModel, Relationship
class TrainingStatus(str, Enum):
INITIALIZED = "initiated"
GATHERING_PARTICIPANTS = "gathering_participants"
TRAINING = "training"
AGGREGATING = "aggregating"
COMPLETED = "completed"
FAILED = "failed"
class ParticipantStatus(str, Enum):
INVITED = "invited"
JOINED = "joined"
TRAINING = "training"
SUBMITTED = "submitted"
DROPPED = "dropped"
class FederatedLearningSession(SQLModel, table=True):
"""Represents a collaborative training session across multiple agents"""
__tablename__ = "federated_learning_session"
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
initiator_agent_id: str = Field(index=True)
task_description: str = Field()
model_architecture_cid: str = Field() # IPFS CID pointing to model structure definition
initial_weights_cid: Optional[str] = Field(default=None) # Optional starting point
target_participants: int = Field(default=3)
current_round: int = Field(default=0)
total_rounds: int = Field(default=10)
aggregation_strategy: str = Field(default="fedavg") # e.g. fedavg, fedprox
min_participants_per_round: int = Field(default=2)
reward_pool_amount: float = Field(default=0.0) # Total AITBC allocated to reward participants
status: TrainingStatus = Field(default=TrainingStatus.INITIALIZED, index=True)
global_model_cid: Optional[str] = Field(default=None) # Final aggregated model
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# Relationships
participants: List["TrainingParticipant"] = Relationship(back_populates="session")
rounds: List["TrainingRound"] = Relationship(back_populates="session")
class TrainingParticipant(SQLModel, table=True):
"""An agent participating in a federated learning session"""
__tablename__ = "training_participant"
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
session_id: str = Field(foreign_key="federated_learning_session.id", index=True)
agent_id: str = Field(index=True)
status: ParticipantStatus = Field(default=ParticipantStatus.JOINED, index=True)
data_samples_count: int = Field(default=0) # Claimed number of local samples used
compute_power_committed: float = Field(default=0.0) # TFLOPS
reputation_score_at_join: float = Field(default=0.0)
earned_reward: float = Field(default=0.0)
joined_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# Relationships
session: FederatedLearningSession = Relationship(back_populates="participants")
class TrainingRound(SQLModel, table=True):
"""A specific round of federated learning"""
__tablename__ = "training_round"
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
session_id: str = Field(foreign_key="federated_learning_session.id", index=True)
round_number: int = Field()
status: str = Field(default="pending") # pending, active, aggregating, completed
starting_model_cid: str = Field() # Global model weights at start of round
aggregated_model_cid: Optional[str] = Field(default=None) # Resulting weights after round
metrics: Dict[str, float] = Field(default_factory=dict, sa_column=Column(JSON)) # e.g. loss, accuracy
started_at: datetime = Field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = Field(default=None)
# Relationships
session: FederatedLearningSession = Relationship(back_populates="rounds")
updates: List["LocalModelUpdate"] = Relationship(back_populates="round")
class LocalModelUpdate(SQLModel, table=True):
"""A local model update submitted by a participant for a specific round"""
__tablename__ = "local_model_update"
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
round_id: str = Field(foreign_key="training_round.id", index=True)
participant_agent_id: str = Field(index=True)
weights_cid: str = Field() # IPFS CID of the locally trained weights
zk_proof_hash: Optional[str] = Field(default=None) # Proof that training was executed correctly
is_aggregated: bool = Field(default=False)
rejected_reason: Optional[str] = Field(default=None) # e.g. "outlier", "failed zk verification"
submitted_at: datetime = Field(default_factory=datetime.utcnow)
# Relationships
round: TrainingRound = Relationship(back_populates="updates")

View File

@@ -0,0 +1,29 @@
from pydantic import BaseModel, Field
from typing import Optional, Dict, List
from .decentralized_memory import MemoryType, StorageStatus
class MemoryNodeCreate(BaseModel):
agent_id: str
memory_type: MemoryType
is_encrypted: bool = True
metadata: Dict[str, str] = Field(default_factory=dict)
tags: List[str] = Field(default_factory=list)
class MemoryNodeResponse(BaseModel):
id: str
agent_id: str
memory_type: MemoryType
cid: Optional[str]
size_bytes: Optional[int]
is_encrypted: bool
status: StorageStatus
metadata: Dict[str, str]
tags: List[str]
class Config:
orm_mode = True
class MemoryQueryRequest(BaseModel):
agent_id: str
memory_type: Optional[MemoryType] = None
tags: Optional[List[str]] = None

View File

@@ -0,0 +1,37 @@
from pydantic import BaseModel
from typing import Optional, Dict
from .federated_learning import TrainingStatus
class FederatedSessionCreate(BaseModel):
initiator_agent_id: str
task_description: str
model_architecture_cid: str
initial_weights_cid: Optional[str] = None
target_participants: int = 3
total_rounds: int = 10
aggregation_strategy: str = "fedavg"
min_participants_per_round: int = 2
reward_pool_amount: float = 0.0
class FederatedSessionResponse(BaseModel):
id: str
initiator_agent_id: str
task_description: str
target_participants: int
current_round: int
total_rounds: int
status: TrainingStatus
global_model_cid: Optional[str]
class Config:
orm_mode = True
class JoinSessionRequest(BaseModel):
agent_id: str
compute_power_committed: float
class SubmitUpdateRequest(BaseModel):
agent_id: str
weights_cid: str
zk_proof_hash: Optional[str] = None
data_samples_count: int

View File

@@ -0,0 +1,208 @@
"""
Federated Learning Service
Service for managing cross-agent knowledge sharing and collaborative model training.
"""
from __future__ import annotations
import logging
from datetime import datetime
from typing import List, Optional
from sqlmodel import Session, select
from fastapi import HTTPException
from ..domain.federated_learning import (
FederatedLearningSession, TrainingParticipant, TrainingRound,
LocalModelUpdate, TrainingStatus, ParticipantStatus
)
from ..schemas.federated_learning import (
FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest
)
from ..blockchain.contract_interactions import ContractInteractionService
logger = logging.getLogger(__name__)
class FederatedLearningService:
def __init__(
self,
session: Session,
contract_service: ContractInteractionService
):
self.session = session
self.contract_service = contract_service
async def create_session(self, request: FederatedSessionCreate) -> FederatedLearningSession:
"""Create a new federated learning session"""
session = FederatedLearningSession(
initiator_agent_id=request.initiator_agent_id,
task_description=request.task_description,
model_architecture_cid=request.model_architecture_cid,
initial_weights_cid=request.initial_weights_cid,
target_participants=request.target_participants,
total_rounds=request.total_rounds,
aggregation_strategy=request.aggregation_strategy,
min_participants_per_round=request.min_participants_per_round,
reward_pool_amount=request.reward_pool_amount,
status=TrainingStatus.GATHERING_PARTICIPANTS
)
self.session.add(session)
self.session.commit()
self.session.refresh(session)
logger.info(f"Created Federated Learning Session {session.id} by {request.initiator_agent_id}")
return session
async def join_session(self, session_id: str, request: JoinSessionRequest) -> TrainingParticipant:
"""Allow an agent to join an active session"""
fl_session = self.session.get(FederatedLearningSession, session_id)
if not fl_session:
raise HTTPException(status_code=404, detail="Session not found")
if fl_session.status != TrainingStatus.GATHERING_PARTICIPANTS:
raise HTTPException(status_code=400, detail="Session is not currently accepting participants")
# Check if already joined
existing = self.session.exec(
select(TrainingParticipant).where(
TrainingParticipant.session_id == session_id,
TrainingParticipant.agent_id == request.agent_id
)
).first()
if existing:
raise HTTPException(status_code=400, detail="Agent already joined this session")
# In reality, fetch reputation from blockchain/service
mock_reputation = 95.0
participant = TrainingParticipant(
session_id=session_id,
agent_id=request.agent_id,
compute_power_committed=request.compute_power_committed,
reputation_score_at_join=mock_reputation,
status=ParticipantStatus.JOINED
)
self.session.add(participant)
self.session.commit()
self.session.refresh(participant)
# Check if we have enough participants to start
current_count = len(fl_session.participants) + 1 # +1 for the newly added but not refreshed one
if current_count >= fl_session.target_participants:
await self._start_training(fl_session)
return participant
async def _start_training(self, fl_session: FederatedLearningSession):
"""Internal method to transition from gathering to active training"""
fl_session.status = TrainingStatus.TRAINING
fl_session.current_round = 1
# Start Round 1
round1 = TrainingRound(
session_id=fl_session.id,
round_number=1,
status="active",
starting_model_cid=fl_session.initial_weights_cid or fl_session.model_architecture_cid
)
self.session.add(round1)
self.session.commit()
logger.info(f"Started training for session {fl_session.id}, Round 1 active.")
async def submit_local_update(self, session_id: str, round_id: str, request: SubmitUpdateRequest) -> LocalModelUpdate:
"""Participant submits their locally trained model weights"""
fl_session = self.session.get(FederatedLearningSession, session_id)
current_round = self.session.get(TrainingRound, round_id)
if not fl_session or not current_round:
raise HTTPException(status_code=404, detail="Session or Round not found")
if fl_session.status != TrainingStatus.TRAINING or current_round.status != "active":
raise HTTPException(status_code=400, detail="Round is not currently active")
participant = self.session.exec(
select(TrainingParticipant).where(
TrainingParticipant.session_id == session_id,
TrainingParticipant.agent_id == request.agent_id
)
).first()
if not participant:
raise HTTPException(status_code=403, detail="Agent is not a participant in this session")
update = LocalModelUpdate(
round_id=round_id,
participant_agent_id=request.agent_id,
weights_cid=request.weights_cid,
zk_proof_hash=request.zk_proof_hash
)
participant.data_samples_count += request.data_samples_count
participant.status = ParticipantStatus.SUBMITTED
self.session.add(update)
self.session.commit()
self.session.refresh(update)
# Check if we should trigger aggregation
updates_count = len(current_round.updates) + 1
if updates_count >= fl_session.min_participants_per_round:
# Note: In a real system, this might be triggered asynchronously via a Celery task
await self._aggregate_round(fl_session, current_round)
return update
async def _aggregate_round(self, fl_session: FederatedLearningSession, current_round: TrainingRound):
"""Mock aggregation process"""
current_round.status = "aggregating"
fl_session.status = TrainingStatus.AGGREGATING
self.session.commit()
# Mocking the actual heavy ML aggregation that would happen elsewhere
logger.info(f"Aggregating {len(current_round.updates)} updates for round {current_round.round_number}")
# Assume successful aggregation creates a new global CID
import hashlib
import time
mock_hash = hashlib.md5(str(time.time()).encode()).hexdigest()
new_global_cid = f"bafy_aggregated_{mock_hash[:20]}"
current_round.aggregated_model_cid = new_global_cid
current_round.status = "completed"
current_round.completed_at = datetime.utcnow()
current_round.metrics = {"loss": 0.5 - (current_round.round_number * 0.05), "accuracy": 0.7 + (current_round.round_number * 0.02)}
if fl_session.current_round >= fl_session.total_rounds:
fl_session.status = TrainingStatus.COMPLETED
fl_session.global_model_cid = new_global_cid
logger.info(f"Federated Learning Session {fl_session.id} fully completed.")
# Here we would handle reward distribution via smart contracts
else:
fl_session.current_round += 1
fl_session.status = TrainingStatus.TRAINING
# Start next round
next_round = TrainingRound(
session_id=fl_session.id,
round_number=fl_session.current_round,
status="active",
starting_model_cid=new_global_cid
)
self.session.add(next_round)
# Reset participant statuses
for p in fl_session.participants:
if p.status == ParticipantStatus.SUBMITTED:
p.status = ParticipantStatus.TRAINING
logger.info(f"Session {fl_session.id} progressing to Round {fl_session.current_round}")
self.session.commit()

View File

@@ -0,0 +1,158 @@
"""
IPFS Storage Adapter Service
Service for offloading agent vector databases and knowledge graphs to IPFS/Filecoin.
"""
from __future__ import annotations
import logging
import json
import hashlib
from typing import List, Optional, Dict, Any
from sqlmodel import Session, select
from fastapi import HTTPException
from ..domain.decentralized_memory import AgentMemoryNode, MemoryType, StorageStatus
from ..schemas.decentralized_memory import MemoryNodeCreate
from ..blockchain.contract_interactions import ContractInteractionService
# In a real environment, this would use a library like ipfshttpclient or a service like Pinata/Web3.Storage
# For this implementation, we will mock the interactions to demonstrate the architecture.
logger = logging.getLogger(__name__)
class IPFSAdapterService:
def __init__(
self,
session: Session,
contract_service: ContractInteractionService,
ipfs_gateway_url: str = "http://127.0.0.1:5001/api/v0",
pinning_service_token: Optional[str] = None
):
self.session = session
self.contract_service = contract_service
self.ipfs_gateway_url = ipfs_gateway_url
self.pinning_service_token = pinning_service_token
async def _mock_ipfs_upload(self, data: bytes) -> str:
"""Mock function to simulate IPFS CID generation (v1 format CID simulation)"""
# Using sha256 to simulate content hashing
hash_val = hashlib.sha256(data).hexdigest()
# Mocking a CIDv1 base32 string format (bafy...)
return f"bafybeig{hash_val[:40]}"
async def store_memory(
self,
request: MemoryNodeCreate,
raw_data: bytes,
zk_proof_hash: Optional[str] = None
) -> AgentMemoryNode:
"""
Upload raw memory data (e.g. serialized vector DB or JSON knowledge graph) to IPFS
and create a tracking record.
"""
# 1. Create initial record
node = AgentMemoryNode(
agent_id=request.agent_id,
memory_type=request.memory_type,
is_encrypted=request.is_encrypted,
metadata=request.metadata,
tags=request.tags,
size_bytes=len(raw_data),
status=StorageStatus.PENDING,
zk_proof_hash=zk_proof_hash
)
self.session.add(node)
self.session.commit()
self.session.refresh(node)
try:
# 2. Upload to IPFS (Mocked)
logger.info(f"Uploading {len(raw_data)} bytes to IPFS for agent {request.agent_id}")
cid = await self._mock_ipfs_upload(raw_data)
node.cid = cid
node.status = StorageStatus.UPLOADED
# 3. Pin to Filecoin/Pinning service (Mocked)
if self.pinning_service_token:
logger.info(f"Pinning CID {cid} to persistent storage")
node.status = StorageStatus.PINNED
self.session.commit()
self.session.refresh(node)
return node
except Exception as e:
logger.error(f"Failed to store memory node {node.id}: {str(e)}")
node.status = StorageStatus.FAILED
self.session.commit()
raise HTTPException(status_code=500, detail="Failed to upload data to decentralized storage")
async def get_memory_nodes(
self,
agent_id: str,
memory_type: Optional[MemoryType] = None,
tags: Optional[List[str]] = None
) -> List[AgentMemoryNode]:
"""Retrieve metadata for an agent's stored memory nodes"""
query = select(AgentMemoryNode).where(AgentMemoryNode.agent_id == agent_id)
if memory_type:
query = query.where(AgentMemoryNode.memory_type == memory_type)
# Execute query and filter by tags in Python (since SQLite JSON JSON_CONTAINS is complex via pure SQLAlchemy without specific dialects)
results = self.session.exec(query).all()
if tags and len(tags) > 0:
filtered_results = []
for r in results:
if all(tag in r.tags for tag in tags):
filtered_results.append(r)
return filtered_results
return results
async def anchor_to_blockchain(self, node_id: str) -> AgentMemoryNode:
"""
Anchor a specific IPFS CID to the agent's smart contract profile to ensure data lineage.
"""
node = self.session.get(AgentMemoryNode, node_id)
if not node:
raise HTTPException(status_code=404, detail="Memory node not found")
if not node.cid:
raise HTTPException(status_code=400, detail="Cannot anchor node without CID")
if node.status == StorageStatus.ANCHORED:
return node
try:
# Mocking the smart contract call to AgentMemory.sol
# tx_hash = await self.contract_service.anchor_agent_memory(node.agent_id, node.cid, node.zk_proof_hash)
tx_hash = "0x" + hashlib.md5(f"{node.id}{node.cid}".encode()).hexdigest()
node.anchor_tx_hash = tx_hash
node.status = StorageStatus.ANCHORED
self.session.commit()
self.session.refresh(node)
logger.info(f"Anchored memory {node_id} (CID: {node.cid}) to blockchain. Tx: {tx_hash}")
return node
except Exception as e:
logger.error(f"Failed to anchor memory node {node_id}: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to anchor CID to blockchain")
async def retrieve_memory(self, node_id: str) -> bytes:
"""Retrieve the raw data from IPFS"""
node = self.session.get(AgentMemoryNode, node_id)
if not node or not node.cid:
raise HTTPException(status_code=404, detail="Memory node or CID not found")
# Mocking retrieval
logger.info(f"Retrieving CID {node.cid} from IPFS network")
mock_data = b"{\"mock\": \"data\", \"info\": \"This represents decrypted vector db or KG data\"}"
return mock_data

View File

@@ -0,0 +1,115 @@
"""
ZK-Proof Memory Verification Service
Service for generating and verifying Zero-Knowledge proofs for decentralized memory retrieval.
Ensures that data retrieved from IPFS matches the anchored state on the blockchain
without revealing the contents of the data itself.
"""
from __future__ import annotations
import logging
import hashlib
import json
from typing import Dict, Any, Optional, Tuple
from fastapi import HTTPException
from sqlmodel import Session
from ..domain.decentralized_memory import AgentMemoryNode
from ..blockchain.contract_interactions import ContractInteractionService
logger = logging.getLogger(__name__)
class ZKMemoryVerificationService:
def __init__(
self,
session: Session,
contract_service: ContractInteractionService
):
self.session = session
self.contract_service = contract_service
async def generate_memory_proof(self, node_id: str, raw_data: bytes) -> Tuple[str, str]:
"""
Generate a Zero-Knowledge proof that the given raw data corresponds to
the structural integrity and properties required by the system,
and compute its hash for on-chain anchoring.
Returns:
Tuple[str, str]: (zk_proof_payload, zk_proof_hash)
"""
node = self.session.get(AgentMemoryNode, node_id)
if not node:
raise HTTPException(status_code=404, detail="Memory node not found")
# In a real ZK system (like snarkjs or circom), we would:
# 1. Compile the raw data into circuit inputs.
# 2. Run the witness generator.
# 3. Generate the proof.
# Mocking ZK Proof generation
logger.info(f"Generating ZK proof for memory node {node_id}")
# We simulate a proof by creating a structured JSON string
data_hash = hashlib.sha256(raw_data).hexdigest()
mock_proof = {
"pi_a": ["mock_pi_a_1", "mock_pi_a_2", "mock_pi_a_3"],
"pi_b": [["mock_pi_b_1", "mock_pi_b_2"], ["mock_pi_b_3", "mock_pi_b_4"]],
"pi_c": ["mock_pi_c_1", "mock_pi_c_2", "mock_pi_c_3"],
"protocol": "groth16",
"curve": "bn128",
"publicSignals": [data_hash, node.agent_id]
}
proof_payload = json.dumps(mock_proof)
# The proof hash is what gets stored on-chain
proof_hash = "0x" + hashlib.sha256(proof_payload.encode()).hexdigest()
return proof_payload, proof_hash
async def verify_retrieved_memory(
self,
node_id: str,
retrieved_data: bytes,
proof_payload: str
) -> bool:
"""
Verify that the retrieved data matches the on-chain anchored ZK proof.
"""
node = self.session.get(AgentMemoryNode, node_id)
if not node:
raise HTTPException(status_code=404, detail="Memory node not found")
if not node.zk_proof_hash:
raise HTTPException(status_code=400, detail="Memory node does not have an anchored ZK proof")
logger.info(f"Verifying ZK proof for retrieved memory {node_id}")
try:
# 1. Verify the provided proof payload matches the on-chain hash
calculated_hash = "0x" + hashlib.sha256(proof_payload.encode()).hexdigest()
if calculated_hash != node.zk_proof_hash:
logger.error("Proof payload hash does not match anchored hash")
return False
# 2. Verify the proof against the retrieved data (Circuit verification)
# In a real system, we might verify this locally or query the smart contract
# Local mock verification
proof_data = json.loads(proof_payload)
data_hash = hashlib.sha256(retrieved_data).hexdigest()
# Check if the public signals match the data we retrieved
if proof_data.get("publicSignals", [])[0] != data_hash:
logger.error("Public signals in proof do not match retrieved data hash")
return False
logger.info("ZK Memory Verification Successful")
return True
except Exception as e:
logger.error(f"Error during ZK memory verification: {str(e)}")
return False

View File

@@ -0,0 +1,129 @@
import pytest
from unittest.mock import AsyncMock
from sqlmodel import Session, create_engine, SQLModel
from sqlmodel.pool import StaticPool
from fastapi import HTTPException
from app.services.federated_learning import FederatedLearningService
from app.domain.federated_learning import TrainingStatus, ParticipantStatus
from app.schemas.federated_learning import FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest
@pytest.fixture
def test_db():
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
session = Session(engine)
yield session
session.close()
@pytest.fixture
def mock_contract_service():
return AsyncMock()
@pytest.fixture
def fl_service(test_db, mock_contract_service):
return FederatedLearningService(
session=test_db,
contract_service=mock_contract_service
)
@pytest.mark.asyncio
async def test_create_session(fl_service):
req = FederatedSessionCreate(
initiator_agent_id="agent-admin",
task_description="Train LLM on financial data",
model_architecture_cid="bafy_arch_123",
target_participants=2,
total_rounds=2,
min_participants_per_round=2
)
session = await fl_service.create_session(req)
assert session.status == TrainingStatus.GATHERING_PARTICIPANTS
assert session.initiator_agent_id == "agent-admin"
@pytest.mark.asyncio
async def test_join_session_and_start(fl_service):
req = FederatedSessionCreate(
initiator_agent_id="agent-admin",
task_description="Train LLM on financial data",
model_architecture_cid="bafy_arch_123",
target_participants=2,
total_rounds=2,
min_participants_per_round=2
)
session = await fl_service.create_session(req)
# Agent 1 joins
p1 = await fl_service.join_session(
session.id,
JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0)
)
assert p1.status == ParticipantStatus.JOINED
assert session.status == TrainingStatus.GATHERING_PARTICIPANTS
# Agent 2 joins, triggers start
p2 = await fl_service.join_session(
session.id,
JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0)
)
# Needs refresh
fl_service.session.refresh(session)
assert session.status == TrainingStatus.TRAINING
assert session.current_round == 1
assert len(session.rounds) == 1
assert session.rounds[0].status == "active"
@pytest.mark.asyncio
async def test_submit_updates_and_aggregate(fl_service):
# Setup
req = FederatedSessionCreate(
initiator_agent_id="agent-admin",
task_description="Train LLM on financial data",
model_architecture_cid="bafy_arch_123",
target_participants=2,
total_rounds=1, # Only 1 round for quick test
min_participants_per_round=2
)
session = await fl_service.create_session(req)
await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0))
await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0))
fl_service.session.refresh(session)
round1 = session.rounds[0]
# Agent 1 submits
u1 = await fl_service.submit_local_update(
session.id,
round1.id,
SubmitUpdateRequest(agent_id="agent-1", weights_cid="bafy_w1", data_samples_count=1000)
)
assert u1.weights_cid == "bafy_w1"
fl_service.session.refresh(session)
fl_service.session.refresh(round1)
# Not aggregated yet
assert round1.status == "active"
# Agent 2 submits, triggers aggregation and completion since total_rounds=1
u2 = await fl_service.submit_local_update(
session.id,
round1.id,
SubmitUpdateRequest(agent_id="agent-2", weights_cid="bafy_w2", data_samples_count=1500)
)
fl_service.session.refresh(session)
fl_service.session.refresh(round1)
assert round1.status == "completed"
assert session.status == TrainingStatus.COMPLETED
assert session.global_model_cid is not None
assert session.global_model_cid.startswith("bafy_aggregated_")

View File

@@ -0,0 +1,108 @@
import pytest
from unittest.mock import AsyncMock
from sqlmodel import Session, create_engine, SQLModel
from sqlmodel.pool import StaticPool
from fastapi import HTTPException
from app.services.ipfs_storage_adapter import IPFSAdapterService
from app.domain.decentralized_memory import MemoryType, StorageStatus
from app.schemas.decentralized_memory import MemoryNodeCreate
@pytest.fixture
def test_db():
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
session = Session(engine)
yield session
session.close()
@pytest.fixture
def mock_contract_service():
return AsyncMock()
@pytest.fixture
def storage_service(test_db, mock_contract_service):
return IPFSAdapterService(
session=test_db,
contract_service=mock_contract_service,
pinning_service_token="mock_token"
)
@pytest.mark.asyncio
async def test_store_memory(storage_service):
request = MemoryNodeCreate(
agent_id="agent-007",
memory_type=MemoryType.VECTOR_DB,
tags=["training", "batch1"]
)
raw_data = b"mock_vector_embeddings_data"
node = await storage_service.store_memory(request, raw_data, zk_proof_hash="0xabc123")
assert node.agent_id == "agent-007"
assert node.memory_type == MemoryType.VECTOR_DB
assert node.cid is not None
assert node.cid.startswith("bafy")
assert node.size_bytes == len(raw_data)
assert node.status == StorageStatus.PINNED
assert node.zk_proof_hash == "0xabc123"
@pytest.mark.asyncio
async def test_get_memory_nodes(storage_service):
# Store multiple
await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB, tags=["v1"]),
b"data1"
)
await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH, tags=["v1"]),
b"data2"
)
await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-008", memory_type=MemoryType.VECTOR_DB),
b"data3"
)
# Get all for agent-007
nodes = await storage_service.get_memory_nodes("agent-007")
assert len(nodes) == 2
# Filter by type
nodes_kg = await storage_service.get_memory_nodes("agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH)
assert len(nodes_kg) == 1
assert nodes_kg[0].memory_type == MemoryType.KNOWLEDGE_GRAPH
# Filter by tag
nodes_tag = await storage_service.get_memory_nodes("agent-007", tags=["v1"])
assert len(nodes_tag) == 2
@pytest.mark.asyncio
async def test_anchor_to_blockchain(storage_service):
node = await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB),
b"data1"
)
assert node.anchor_tx_hash is None
anchored_node = await storage_service.anchor_to_blockchain(node.id)
assert anchored_node.status == StorageStatus.ANCHORED
assert anchored_node.anchor_tx_hash is not None
@pytest.mark.asyncio
async def test_retrieve_memory(storage_service):
node = await storage_service.store_memory(
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB),
b"data1"
)
data = await storage_service.retrieve_memory(node.id)
assert isinstance(data, bytes)
assert b"mock" in data

View File

@@ -0,0 +1,92 @@
import pytest
from unittest.mock import AsyncMock
from sqlmodel import Session, create_engine, SQLModel
from sqlmodel.pool import StaticPool
from app.services.zk_memory_verification import ZKMemoryVerificationService
from app.domain.decentralized_memory import AgentMemoryNode, MemoryType
@pytest.fixture
def test_db():
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
session = Session(engine)
yield session
session.close()
@pytest.fixture
def mock_contract_service():
return AsyncMock()
@pytest.fixture
def zk_service(test_db, mock_contract_service):
return ZKMemoryVerificationService(
session=test_db,
contract_service=mock_contract_service
)
@pytest.mark.asyncio
async def test_generate_memory_proof(zk_service, test_db):
node = AgentMemoryNode(
agent_id="agent-zk",
memory_type=MemoryType.VECTOR_DB
)
test_db.add(node)
test_db.commit()
test_db.refresh(node)
raw_data = b"secret_vector_data"
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
assert proof_payload is not None
assert proof_hash.startswith("0x")
assert "groth16" in proof_payload
@pytest.mark.asyncio
async def test_verify_retrieved_memory_success(zk_service, test_db):
node = AgentMemoryNode(
agent_id="agent-zk",
memory_type=MemoryType.VECTOR_DB
)
test_db.add(node)
test_db.commit()
test_db.refresh(node)
raw_data = b"secret_vector_data"
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
# Simulate anchoring
node.zk_proof_hash = proof_hash
test_db.commit()
# Verify
is_valid = await zk_service.verify_retrieved_memory(node.id, raw_data, proof_payload)
assert is_valid is True
@pytest.mark.asyncio
async def test_verify_retrieved_memory_tampered_data(zk_service, test_db):
node = AgentMemoryNode(
agent_id="agent-zk",
memory_type=MemoryType.VECTOR_DB
)
test_db.add(node)
test_db.commit()
test_db.refresh(node)
raw_data = b"secret_vector_data"
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
node.zk_proof_hash = proof_hash
test_db.commit()
# Tamper with data
tampered_data = b"secret_vector_data_modified"
is_valid = await zk_service.verify_retrieved_memory(node.id, tampered_data, proof_payload)
assert is_valid is False