feat: add phase 2 decentralized memory python services and tests

This commit is contained in:
oib
2026-02-28 23:19:02 +01:00
parent c63259ef2c
commit 5bc18d684c
10 changed files with 1054 additions and 0 deletions

View File

@@ -0,0 +1,208 @@
"""
Federated Learning Service
Service for managing cross-agent knowledge sharing and collaborative model training.
"""
from __future__ import annotations
import logging
from datetime import datetime
from typing import List, Optional
from sqlmodel import Session, select
from fastapi import HTTPException
from ..domain.federated_learning import (
FederatedLearningSession, TrainingParticipant, TrainingRound,
LocalModelUpdate, TrainingStatus, ParticipantStatus
)
from ..schemas.federated_learning import (
FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest
)
from ..blockchain.contract_interactions import ContractInteractionService
logger = logging.getLogger(__name__)
class FederatedLearningService:
def __init__(
self,
session: Session,
contract_service: ContractInteractionService
):
self.session = session
self.contract_service = contract_service
async def create_session(self, request: FederatedSessionCreate) -> FederatedLearningSession:
"""Create a new federated learning session"""
session = FederatedLearningSession(
initiator_agent_id=request.initiator_agent_id,
task_description=request.task_description,
model_architecture_cid=request.model_architecture_cid,
initial_weights_cid=request.initial_weights_cid,
target_participants=request.target_participants,
total_rounds=request.total_rounds,
aggregation_strategy=request.aggregation_strategy,
min_participants_per_round=request.min_participants_per_round,
reward_pool_amount=request.reward_pool_amount,
status=TrainingStatus.GATHERING_PARTICIPANTS
)
self.session.add(session)
self.session.commit()
self.session.refresh(session)
logger.info(f"Created Federated Learning Session {session.id} by {request.initiator_agent_id}")
return session
async def join_session(self, session_id: str, request: JoinSessionRequest) -> TrainingParticipant:
"""Allow an agent to join an active session"""
fl_session = self.session.get(FederatedLearningSession, session_id)
if not fl_session:
raise HTTPException(status_code=404, detail="Session not found")
if fl_session.status != TrainingStatus.GATHERING_PARTICIPANTS:
raise HTTPException(status_code=400, detail="Session is not currently accepting participants")
# Check if already joined
existing = self.session.exec(
select(TrainingParticipant).where(
TrainingParticipant.session_id == session_id,
TrainingParticipant.agent_id == request.agent_id
)
).first()
if existing:
raise HTTPException(status_code=400, detail="Agent already joined this session")
# In reality, fetch reputation from blockchain/service
mock_reputation = 95.0
participant = TrainingParticipant(
session_id=session_id,
agent_id=request.agent_id,
compute_power_committed=request.compute_power_committed,
reputation_score_at_join=mock_reputation,
status=ParticipantStatus.JOINED
)
self.session.add(participant)
self.session.commit()
self.session.refresh(participant)
# Check if we have enough participants to start
current_count = len(fl_session.participants) + 1 # +1 for the newly added but not refreshed one
if current_count >= fl_session.target_participants:
await self._start_training(fl_session)
return participant
async def _start_training(self, fl_session: FederatedLearningSession):
"""Internal method to transition from gathering to active training"""
fl_session.status = TrainingStatus.TRAINING
fl_session.current_round = 1
# Start Round 1
round1 = TrainingRound(
session_id=fl_session.id,
round_number=1,
status="active",
starting_model_cid=fl_session.initial_weights_cid or fl_session.model_architecture_cid
)
self.session.add(round1)
self.session.commit()
logger.info(f"Started training for session {fl_session.id}, Round 1 active.")
async def submit_local_update(self, session_id: str, round_id: str, request: SubmitUpdateRequest) -> LocalModelUpdate:
"""Participant submits their locally trained model weights"""
fl_session = self.session.get(FederatedLearningSession, session_id)
current_round = self.session.get(TrainingRound, round_id)
if not fl_session or not current_round:
raise HTTPException(status_code=404, detail="Session or Round not found")
if fl_session.status != TrainingStatus.TRAINING or current_round.status != "active":
raise HTTPException(status_code=400, detail="Round is not currently active")
participant = self.session.exec(
select(TrainingParticipant).where(
TrainingParticipant.session_id == session_id,
TrainingParticipant.agent_id == request.agent_id
)
).first()
if not participant:
raise HTTPException(status_code=403, detail="Agent is not a participant in this session")
update = LocalModelUpdate(
round_id=round_id,
participant_agent_id=request.agent_id,
weights_cid=request.weights_cid,
zk_proof_hash=request.zk_proof_hash
)
participant.data_samples_count += request.data_samples_count
participant.status = ParticipantStatus.SUBMITTED
self.session.add(update)
self.session.commit()
self.session.refresh(update)
# Check if we should trigger aggregation
updates_count = len(current_round.updates) + 1
if updates_count >= fl_session.min_participants_per_round:
# Note: In a real system, this might be triggered asynchronously via a Celery task
await self._aggregate_round(fl_session, current_round)
return update
async def _aggregate_round(self, fl_session: FederatedLearningSession, current_round: TrainingRound):
"""Mock aggregation process"""
current_round.status = "aggregating"
fl_session.status = TrainingStatus.AGGREGATING
self.session.commit()
# Mocking the actual heavy ML aggregation that would happen elsewhere
logger.info(f"Aggregating {len(current_round.updates)} updates for round {current_round.round_number}")
# Assume successful aggregation creates a new global CID
import hashlib
import time
mock_hash = hashlib.md5(str(time.time()).encode()).hexdigest()
new_global_cid = f"bafy_aggregated_{mock_hash[:20]}"
current_round.aggregated_model_cid = new_global_cid
current_round.status = "completed"
current_round.completed_at = datetime.utcnow()
current_round.metrics = {"loss": 0.5 - (current_round.round_number * 0.05), "accuracy": 0.7 + (current_round.round_number * 0.02)}
if fl_session.current_round >= fl_session.total_rounds:
fl_session.status = TrainingStatus.COMPLETED
fl_session.global_model_cid = new_global_cid
logger.info(f"Federated Learning Session {fl_session.id} fully completed.")
# Here we would handle reward distribution via smart contracts
else:
fl_session.current_round += 1
fl_session.status = TrainingStatus.TRAINING
# Start next round
next_round = TrainingRound(
session_id=fl_session.id,
round_number=fl_session.current_round,
status="active",
starting_model_cid=new_global_cid
)
self.session.add(next_round)
# Reset participant statuses
for p in fl_session.participants:
if p.status == ParticipantStatus.SUBMITTED:
p.status = ParticipantStatus.TRAINING
logger.info(f"Session {fl_session.id} progressing to Round {fl_session.current_round}")
self.session.commit()

View File

@@ -0,0 +1,158 @@
"""
IPFS Storage Adapter Service
Service for offloading agent vector databases and knowledge graphs to IPFS/Filecoin.
"""
from __future__ import annotations
import logging
import json
import hashlib
from typing import List, Optional, Dict, Any
from sqlmodel import Session, select
from fastapi import HTTPException
from ..domain.decentralized_memory import AgentMemoryNode, MemoryType, StorageStatus
from ..schemas.decentralized_memory import MemoryNodeCreate
from ..blockchain.contract_interactions import ContractInteractionService
# In a real environment, this would use a library like ipfshttpclient or a service like Pinata/Web3.Storage
# For this implementation, we will mock the interactions to demonstrate the architecture.
logger = logging.getLogger(__name__)
class IPFSAdapterService:
def __init__(
self,
session: Session,
contract_service: ContractInteractionService,
ipfs_gateway_url: str = "http://127.0.0.1:5001/api/v0",
pinning_service_token: Optional[str] = None
):
self.session = session
self.contract_service = contract_service
self.ipfs_gateway_url = ipfs_gateway_url
self.pinning_service_token = pinning_service_token
async def _mock_ipfs_upload(self, data: bytes) -> str:
"""Mock function to simulate IPFS CID generation (v1 format CID simulation)"""
# Using sha256 to simulate content hashing
hash_val = hashlib.sha256(data).hexdigest()
# Mocking a CIDv1 base32 string format (bafy...)
return f"bafybeig{hash_val[:40]}"
async def store_memory(
self,
request: MemoryNodeCreate,
raw_data: bytes,
zk_proof_hash: Optional[str] = None
) -> AgentMemoryNode:
"""
Upload raw memory data (e.g. serialized vector DB or JSON knowledge graph) to IPFS
and create a tracking record.
"""
# 1. Create initial record
node = AgentMemoryNode(
agent_id=request.agent_id,
memory_type=request.memory_type,
is_encrypted=request.is_encrypted,
metadata=request.metadata,
tags=request.tags,
size_bytes=len(raw_data),
status=StorageStatus.PENDING,
zk_proof_hash=zk_proof_hash
)
self.session.add(node)
self.session.commit()
self.session.refresh(node)
try:
# 2. Upload to IPFS (Mocked)
logger.info(f"Uploading {len(raw_data)} bytes to IPFS for agent {request.agent_id}")
cid = await self._mock_ipfs_upload(raw_data)
node.cid = cid
node.status = StorageStatus.UPLOADED
# 3. Pin to Filecoin/Pinning service (Mocked)
if self.pinning_service_token:
logger.info(f"Pinning CID {cid} to persistent storage")
node.status = StorageStatus.PINNED
self.session.commit()
self.session.refresh(node)
return node
except Exception as e:
logger.error(f"Failed to store memory node {node.id}: {str(e)}")
node.status = StorageStatus.FAILED
self.session.commit()
raise HTTPException(status_code=500, detail="Failed to upload data to decentralized storage")
async def get_memory_nodes(
self,
agent_id: str,
memory_type: Optional[MemoryType] = None,
tags: Optional[List[str]] = None
) -> List[AgentMemoryNode]:
"""Retrieve metadata for an agent's stored memory nodes"""
query = select(AgentMemoryNode).where(AgentMemoryNode.agent_id == agent_id)
if memory_type:
query = query.where(AgentMemoryNode.memory_type == memory_type)
# Execute query and filter by tags in Python (since SQLite JSON JSON_CONTAINS is complex via pure SQLAlchemy without specific dialects)
results = self.session.exec(query).all()
if tags and len(tags) > 0:
filtered_results = []
for r in results:
if all(tag in r.tags for tag in tags):
filtered_results.append(r)
return filtered_results
return results
async def anchor_to_blockchain(self, node_id: str) -> AgentMemoryNode:
"""
Anchor a specific IPFS CID to the agent's smart contract profile to ensure data lineage.
"""
node = self.session.get(AgentMemoryNode, node_id)
if not node:
raise HTTPException(status_code=404, detail="Memory node not found")
if not node.cid:
raise HTTPException(status_code=400, detail="Cannot anchor node without CID")
if node.status == StorageStatus.ANCHORED:
return node
try:
# Mocking the smart contract call to AgentMemory.sol
# tx_hash = await self.contract_service.anchor_agent_memory(node.agent_id, node.cid, node.zk_proof_hash)
tx_hash = "0x" + hashlib.md5(f"{node.id}{node.cid}".encode()).hexdigest()
node.anchor_tx_hash = tx_hash
node.status = StorageStatus.ANCHORED
self.session.commit()
self.session.refresh(node)
logger.info(f"Anchored memory {node_id} (CID: {node.cid}) to blockchain. Tx: {tx_hash}")
return node
except Exception as e:
logger.error(f"Failed to anchor memory node {node_id}: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to anchor CID to blockchain")
async def retrieve_memory(self, node_id: str) -> bytes:
"""Retrieve the raw data from IPFS"""
node = self.session.get(AgentMemoryNode, node_id)
if not node or not node.cid:
raise HTTPException(status_code=404, detail="Memory node or CID not found")
# Mocking retrieval
logger.info(f"Retrieving CID {node.cid} from IPFS network")
mock_data = b"{\"mock\": \"data\", \"info\": \"This represents decrypted vector db or KG data\"}"
return mock_data

View File

@@ -0,0 +1,115 @@
"""
ZK-Proof Memory Verification Service
Service for generating and verifying Zero-Knowledge proofs for decentralized memory retrieval.
Ensures that data retrieved from IPFS matches the anchored state on the blockchain
without revealing the contents of the data itself.
"""
from __future__ import annotations
import logging
import hashlib
import json
from typing import Dict, Any, Optional, Tuple
from fastapi import HTTPException
from sqlmodel import Session
from ..domain.decentralized_memory import AgentMemoryNode
from ..blockchain.contract_interactions import ContractInteractionService
logger = logging.getLogger(__name__)
class ZKMemoryVerificationService:
def __init__(
self,
session: Session,
contract_service: ContractInteractionService
):
self.session = session
self.contract_service = contract_service
async def generate_memory_proof(self, node_id: str, raw_data: bytes) -> Tuple[str, str]:
"""
Generate a Zero-Knowledge proof that the given raw data corresponds to
the structural integrity and properties required by the system,
and compute its hash for on-chain anchoring.
Returns:
Tuple[str, str]: (zk_proof_payload, zk_proof_hash)
"""
node = self.session.get(AgentMemoryNode, node_id)
if not node:
raise HTTPException(status_code=404, detail="Memory node not found")
# In a real ZK system (like snarkjs or circom), we would:
# 1. Compile the raw data into circuit inputs.
# 2. Run the witness generator.
# 3. Generate the proof.
# Mocking ZK Proof generation
logger.info(f"Generating ZK proof for memory node {node_id}")
# We simulate a proof by creating a structured JSON string
data_hash = hashlib.sha256(raw_data).hexdigest()
mock_proof = {
"pi_a": ["mock_pi_a_1", "mock_pi_a_2", "mock_pi_a_3"],
"pi_b": [["mock_pi_b_1", "mock_pi_b_2"], ["mock_pi_b_3", "mock_pi_b_4"]],
"pi_c": ["mock_pi_c_1", "mock_pi_c_2", "mock_pi_c_3"],
"protocol": "groth16",
"curve": "bn128",
"publicSignals": [data_hash, node.agent_id]
}
proof_payload = json.dumps(mock_proof)
# The proof hash is what gets stored on-chain
proof_hash = "0x" + hashlib.sha256(proof_payload.encode()).hexdigest()
return proof_payload, proof_hash
async def verify_retrieved_memory(
self,
node_id: str,
retrieved_data: bytes,
proof_payload: str
) -> bool:
"""
Verify that the retrieved data matches the on-chain anchored ZK proof.
"""
node = self.session.get(AgentMemoryNode, node_id)
if not node:
raise HTTPException(status_code=404, detail="Memory node not found")
if not node.zk_proof_hash:
raise HTTPException(status_code=400, detail="Memory node does not have an anchored ZK proof")
logger.info(f"Verifying ZK proof for retrieved memory {node_id}")
try:
# 1. Verify the provided proof payload matches the on-chain hash
calculated_hash = "0x" + hashlib.sha256(proof_payload.encode()).hexdigest()
if calculated_hash != node.zk_proof_hash:
logger.error("Proof payload hash does not match anchored hash")
return False
# 2. Verify the proof against the retrieved data (Circuit verification)
# In a real system, we might verify this locally or query the smart contract
# Local mock verification
proof_data = json.loads(proof_payload)
data_hash = hashlib.sha256(retrieved_data).hexdigest()
# Check if the public signals match the data we retrieved
if proof_data.get("publicSignals", [])[0] != data_hash:
logger.error("Public signals in proof do not match retrieved data hash")
return False
logger.info("ZK Memory Verification Successful")
return True
except Exception as e:
logger.error(f"Error during ZK memory verification: {str(e)}")
return False