From 5bc18d684ca8e99ea5a9ca928349669b7593c88e Mon Sep 17 00:00:00 2001 From: oib Date: Sat, 28 Feb 2026 23:19:02 +0100 Subject: [PATCH] feat: add phase 2 decentralized memory python services and tests --- .../src/app/domain/decentralized_memory.py | 56 +++++ .../src/app/domain/federated_learning.py | 122 ++++++++++ .../src/app/schemas/decentralized_memory.py | 29 +++ .../src/app/schemas/federated_learning.py | 37 ++++ .../src/app/services/federated_learning.py | 208 ++++++++++++++++++ .../src/app/services/ipfs_storage_adapter.py | 158 +++++++++++++ .../app/services/zk_memory_verification.py | 115 ++++++++++ .../tests/test_federated_learning.py | 129 +++++++++++ .../tests/test_ipfs_storage_adapter.py | 108 +++++++++ .../tests/test_zk_memory_verification.py | 92 ++++++++ 10 files changed, 1054 insertions(+) create mode 100644 apps/coordinator-api/src/app/domain/decentralized_memory.py create mode 100644 apps/coordinator-api/src/app/domain/federated_learning.py create mode 100644 apps/coordinator-api/src/app/schemas/decentralized_memory.py create mode 100644 apps/coordinator-api/src/app/schemas/federated_learning.py create mode 100644 apps/coordinator-api/src/app/services/federated_learning.py create mode 100644 apps/coordinator-api/src/app/services/ipfs_storage_adapter.py create mode 100644 apps/coordinator-api/src/app/services/zk_memory_verification.py create mode 100644 apps/coordinator-api/tests/test_federated_learning.py create mode 100644 apps/coordinator-api/tests/test_ipfs_storage_adapter.py create mode 100644 apps/coordinator-api/tests/test_zk_memory_verification.py diff --git a/apps/coordinator-api/src/app/domain/decentralized_memory.py b/apps/coordinator-api/src/app/domain/decentralized_memory.py new file mode 100644 index 00000000..c670271f --- /dev/null +++ b/apps/coordinator-api/src/app/domain/decentralized_memory.py @@ -0,0 +1,56 @@ +""" +Decentralized Memory Domain Models + +Domain models for managing agent memory and knowledge graphs on IPFS/Filecoin. +""" + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Dict, Optional, List +from uuid import uuid4 + +from sqlalchemy import Column, JSON +from sqlmodel import Field, SQLModel, Relationship + +class MemoryType(str, Enum): + VECTOR_DB = "vector_db" + KNOWLEDGE_GRAPH = "knowledge_graph" + POLICY_WEIGHTS = "policy_weights" + EPISODIC = "episodic" + +class StorageStatus(str, Enum): + PENDING = "pending" # Upload to IPFS pending + UPLOADED = "uploaded" # Available on IPFS + PINNED = "pinned" # Pinned on Filecoin/Pinata + ANCHORED = "anchored" # CID written to blockchain + FAILED = "failed" # Upload failed + +class AgentMemoryNode(SQLModel, table=True): + """Represents a chunk of memory or knowledge stored on decentralized storage""" + __tablename__ = "agent_memory_node" + + id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True) + agent_id: str = Field(index=True) + memory_type: MemoryType = Field(index=True) + + # Decentralized Storage Identifiers + cid: Optional[str] = Field(default=None, index=True) # IPFS Content Identifier + size_bytes: Optional[int] = Field(default=None) + + # Encryption and Security + is_encrypted: bool = Field(default=True) + encryption_key_id: Optional[str] = Field(default=None) # Reference to KMS or Lit Protocol + zk_proof_hash: Optional[str] = Field(default=None) # Hash of the ZK proof verifying content validity + + status: StorageStatus = Field(default=StorageStatus.PENDING, index=True) + + metadata: Dict[str, str] = Field(default_factory=dict, sa_column=Column(JSON)) + tags: List[str] = Field(default_factory=list, sa_column=Column(JSON)) + + # Blockchain Anchoring + anchor_tx_hash: Optional[str] = Field(default=None) + + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) diff --git a/apps/coordinator-api/src/app/domain/federated_learning.py b/apps/coordinator-api/src/app/domain/federated_learning.py new file mode 100644 index 00000000..daabe6f4 --- /dev/null +++ b/apps/coordinator-api/src/app/domain/federated_learning.py @@ -0,0 +1,122 @@ +""" +Federated Learning Domain Models + +Domain models for managing cross-agent knowledge sharing and collaborative model training. +""" + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Dict, List, Optional +from uuid import uuid4 + +from sqlalchemy import Column, JSON +from sqlmodel import Field, SQLModel, Relationship + +class TrainingStatus(str, Enum): + INITIALIZED = "initiated" + GATHERING_PARTICIPANTS = "gathering_participants" + TRAINING = "training" + AGGREGATING = "aggregating" + COMPLETED = "completed" + FAILED = "failed" + +class ParticipantStatus(str, Enum): + INVITED = "invited" + JOINED = "joined" + TRAINING = "training" + SUBMITTED = "submitted" + DROPPED = "dropped" + +class FederatedLearningSession(SQLModel, table=True): + """Represents a collaborative training session across multiple agents""" + __tablename__ = "federated_learning_session" + + id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True) + initiator_agent_id: str = Field(index=True) + task_description: str = Field() + model_architecture_cid: str = Field() # IPFS CID pointing to model structure definition + initial_weights_cid: Optional[str] = Field(default=None) # Optional starting point + + target_participants: int = Field(default=3) + current_round: int = Field(default=0) + total_rounds: int = Field(default=10) + + aggregation_strategy: str = Field(default="fedavg") # e.g. fedavg, fedprox + min_participants_per_round: int = Field(default=2) + + reward_pool_amount: float = Field(default=0.0) # Total AITBC allocated to reward participants + + status: TrainingStatus = Field(default=TrainingStatus.INITIALIZED, index=True) + + global_model_cid: Optional[str] = Field(default=None) # Final aggregated model + + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + # Relationships + participants: List["TrainingParticipant"] = Relationship(back_populates="session") + rounds: List["TrainingRound"] = Relationship(back_populates="session") + +class TrainingParticipant(SQLModel, table=True): + """An agent participating in a federated learning session""" + __tablename__ = "training_participant" + + id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True) + session_id: str = Field(foreign_key="federated_learning_session.id", index=True) + agent_id: str = Field(index=True) + + status: ParticipantStatus = Field(default=ParticipantStatus.JOINED, index=True) + data_samples_count: int = Field(default=0) # Claimed number of local samples used + compute_power_committed: float = Field(default=0.0) # TFLOPS + + reputation_score_at_join: float = Field(default=0.0) + earned_reward: float = Field(default=0.0) + + joined_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + # Relationships + session: FederatedLearningSession = Relationship(back_populates="participants") + +class TrainingRound(SQLModel, table=True): + """A specific round of federated learning""" + __tablename__ = "training_round" + + id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True) + session_id: str = Field(foreign_key="federated_learning_session.id", index=True) + round_number: int = Field() + + status: str = Field(default="pending") # pending, active, aggregating, completed + + starting_model_cid: str = Field() # Global model weights at start of round + aggregated_model_cid: Optional[str] = Field(default=None) # Resulting weights after round + + metrics: Dict[str, float] = Field(default_factory=dict, sa_column=Column(JSON)) # e.g. loss, accuracy + + started_at: datetime = Field(default_factory=datetime.utcnow) + completed_at: Optional[datetime] = Field(default=None) + + # Relationships + session: FederatedLearningSession = Relationship(back_populates="rounds") + updates: List["LocalModelUpdate"] = Relationship(back_populates="round") + +class LocalModelUpdate(SQLModel, table=True): + """A local model update submitted by a participant for a specific round""" + __tablename__ = "local_model_update" + + id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True) + round_id: str = Field(foreign_key="training_round.id", index=True) + participant_agent_id: str = Field(index=True) + + weights_cid: str = Field() # IPFS CID of the locally trained weights + zk_proof_hash: Optional[str] = Field(default=None) # Proof that training was executed correctly + + is_aggregated: bool = Field(default=False) + rejected_reason: Optional[str] = Field(default=None) # e.g. "outlier", "failed zk verification" + + submitted_at: datetime = Field(default_factory=datetime.utcnow) + + # Relationships + round: TrainingRound = Relationship(back_populates="updates") diff --git a/apps/coordinator-api/src/app/schemas/decentralized_memory.py b/apps/coordinator-api/src/app/schemas/decentralized_memory.py new file mode 100644 index 00000000..104f9909 --- /dev/null +++ b/apps/coordinator-api/src/app/schemas/decentralized_memory.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel, Field +from typing import Optional, Dict, List +from .decentralized_memory import MemoryType, StorageStatus + +class MemoryNodeCreate(BaseModel): + agent_id: str + memory_type: MemoryType + is_encrypted: bool = True + metadata: Dict[str, str] = Field(default_factory=dict) + tags: List[str] = Field(default_factory=list) + +class MemoryNodeResponse(BaseModel): + id: str + agent_id: str + memory_type: MemoryType + cid: Optional[str] + size_bytes: Optional[int] + is_encrypted: bool + status: StorageStatus + metadata: Dict[str, str] + tags: List[str] + + class Config: + orm_mode = True + +class MemoryQueryRequest(BaseModel): + agent_id: str + memory_type: Optional[MemoryType] = None + tags: Optional[List[str]] = None diff --git a/apps/coordinator-api/src/app/schemas/federated_learning.py b/apps/coordinator-api/src/app/schemas/federated_learning.py new file mode 100644 index 00000000..4a8c034f --- /dev/null +++ b/apps/coordinator-api/src/app/schemas/federated_learning.py @@ -0,0 +1,37 @@ +from pydantic import BaseModel +from typing import Optional, Dict +from .federated_learning import TrainingStatus + +class FederatedSessionCreate(BaseModel): + initiator_agent_id: str + task_description: str + model_architecture_cid: str + initial_weights_cid: Optional[str] = None + target_participants: int = 3 + total_rounds: int = 10 + aggregation_strategy: str = "fedavg" + min_participants_per_round: int = 2 + reward_pool_amount: float = 0.0 + +class FederatedSessionResponse(BaseModel): + id: str + initiator_agent_id: str + task_description: str + target_participants: int + current_round: int + total_rounds: int + status: TrainingStatus + global_model_cid: Optional[str] + + class Config: + orm_mode = True + +class JoinSessionRequest(BaseModel): + agent_id: str + compute_power_committed: float + +class SubmitUpdateRequest(BaseModel): + agent_id: str + weights_cid: str + zk_proof_hash: Optional[str] = None + data_samples_count: int diff --git a/apps/coordinator-api/src/app/services/federated_learning.py b/apps/coordinator-api/src/app/services/federated_learning.py new file mode 100644 index 00000000..5106a6a9 --- /dev/null +++ b/apps/coordinator-api/src/app/services/federated_learning.py @@ -0,0 +1,208 @@ +""" +Federated Learning Service + +Service for managing cross-agent knowledge sharing and collaborative model training. +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import List, Optional + +from sqlmodel import Session, select +from fastapi import HTTPException + +from ..domain.federated_learning import ( + FederatedLearningSession, TrainingParticipant, TrainingRound, + LocalModelUpdate, TrainingStatus, ParticipantStatus +) +from ..schemas.federated_learning import ( + FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest +) +from ..blockchain.contract_interactions import ContractInteractionService + +logger = logging.getLogger(__name__) + +class FederatedLearningService: + def __init__( + self, + session: Session, + contract_service: ContractInteractionService + ): + self.session = session + self.contract_service = contract_service + + async def create_session(self, request: FederatedSessionCreate) -> FederatedLearningSession: + """Create a new federated learning session""" + + session = FederatedLearningSession( + initiator_agent_id=request.initiator_agent_id, + task_description=request.task_description, + model_architecture_cid=request.model_architecture_cid, + initial_weights_cid=request.initial_weights_cid, + target_participants=request.target_participants, + total_rounds=request.total_rounds, + aggregation_strategy=request.aggregation_strategy, + min_participants_per_round=request.min_participants_per_round, + reward_pool_amount=request.reward_pool_amount, + status=TrainingStatus.GATHERING_PARTICIPANTS + ) + + self.session.add(session) + self.session.commit() + self.session.refresh(session) + + logger.info(f"Created Federated Learning Session {session.id} by {request.initiator_agent_id}") + return session + + async def join_session(self, session_id: str, request: JoinSessionRequest) -> TrainingParticipant: + """Allow an agent to join an active session""" + + fl_session = self.session.get(FederatedLearningSession, session_id) + if not fl_session: + raise HTTPException(status_code=404, detail="Session not found") + + if fl_session.status != TrainingStatus.GATHERING_PARTICIPANTS: + raise HTTPException(status_code=400, detail="Session is not currently accepting participants") + + # Check if already joined + existing = self.session.exec( + select(TrainingParticipant).where( + TrainingParticipant.session_id == session_id, + TrainingParticipant.agent_id == request.agent_id + ) + ).first() + + if existing: + raise HTTPException(status_code=400, detail="Agent already joined this session") + + # In reality, fetch reputation from blockchain/service + mock_reputation = 95.0 + + participant = TrainingParticipant( + session_id=session_id, + agent_id=request.agent_id, + compute_power_committed=request.compute_power_committed, + reputation_score_at_join=mock_reputation, + status=ParticipantStatus.JOINED + ) + + self.session.add(participant) + self.session.commit() + self.session.refresh(participant) + + # Check if we have enough participants to start + current_count = len(fl_session.participants) + 1 # +1 for the newly added but not refreshed one + if current_count >= fl_session.target_participants: + await self._start_training(fl_session) + + return participant + + async def _start_training(self, fl_session: FederatedLearningSession): + """Internal method to transition from gathering to active training""" + fl_session.status = TrainingStatus.TRAINING + fl_session.current_round = 1 + + # Start Round 1 + round1 = TrainingRound( + session_id=fl_session.id, + round_number=1, + status="active", + starting_model_cid=fl_session.initial_weights_cid or fl_session.model_architecture_cid + ) + + self.session.add(round1) + self.session.commit() + logger.info(f"Started training for session {fl_session.id}, Round 1 active.") + + async def submit_local_update(self, session_id: str, round_id: str, request: SubmitUpdateRequest) -> LocalModelUpdate: + """Participant submits their locally trained model weights""" + + fl_session = self.session.get(FederatedLearningSession, session_id) + current_round = self.session.get(TrainingRound, round_id) + + if not fl_session or not current_round: + raise HTTPException(status_code=404, detail="Session or Round not found") + + if fl_session.status != TrainingStatus.TRAINING or current_round.status != "active": + raise HTTPException(status_code=400, detail="Round is not currently active") + + participant = self.session.exec( + select(TrainingParticipant).where( + TrainingParticipant.session_id == session_id, + TrainingParticipant.agent_id == request.agent_id + ) + ).first() + + if not participant: + raise HTTPException(status_code=403, detail="Agent is not a participant in this session") + + update = LocalModelUpdate( + round_id=round_id, + participant_agent_id=request.agent_id, + weights_cid=request.weights_cid, + zk_proof_hash=request.zk_proof_hash + ) + + participant.data_samples_count += request.data_samples_count + participant.status = ParticipantStatus.SUBMITTED + + self.session.add(update) + self.session.commit() + self.session.refresh(update) + + # Check if we should trigger aggregation + updates_count = len(current_round.updates) + 1 + if updates_count >= fl_session.min_participants_per_round: + # Note: In a real system, this might be triggered asynchronously via a Celery task + await self._aggregate_round(fl_session, current_round) + + return update + + async def _aggregate_round(self, fl_session: FederatedLearningSession, current_round: TrainingRound): + """Mock aggregation process""" + current_round.status = "aggregating" + fl_session.status = TrainingStatus.AGGREGATING + self.session.commit() + + # Mocking the actual heavy ML aggregation that would happen elsewhere + logger.info(f"Aggregating {len(current_round.updates)} updates for round {current_round.round_number}") + + # Assume successful aggregation creates a new global CID + import hashlib + import time + mock_hash = hashlib.md5(str(time.time()).encode()).hexdigest() + new_global_cid = f"bafy_aggregated_{mock_hash[:20]}" + + current_round.aggregated_model_cid = new_global_cid + current_round.status = "completed" + current_round.completed_at = datetime.utcnow() + current_round.metrics = {"loss": 0.5 - (current_round.round_number * 0.05), "accuracy": 0.7 + (current_round.round_number * 0.02)} + + if fl_session.current_round >= fl_session.total_rounds: + fl_session.status = TrainingStatus.COMPLETED + fl_session.global_model_cid = new_global_cid + logger.info(f"Federated Learning Session {fl_session.id} fully completed.") + # Here we would handle reward distribution via smart contracts + else: + fl_session.current_round += 1 + fl_session.status = TrainingStatus.TRAINING + + # Start next round + next_round = TrainingRound( + session_id=fl_session.id, + round_number=fl_session.current_round, + status="active", + starting_model_cid=new_global_cid + ) + self.session.add(next_round) + + # Reset participant statuses + for p in fl_session.participants: + if p.status == ParticipantStatus.SUBMITTED: + p.status = ParticipantStatus.TRAINING + + logger.info(f"Session {fl_session.id} progressing to Round {fl_session.current_round}") + + self.session.commit() diff --git a/apps/coordinator-api/src/app/services/ipfs_storage_adapter.py b/apps/coordinator-api/src/app/services/ipfs_storage_adapter.py new file mode 100644 index 00000000..aaa7d93c --- /dev/null +++ b/apps/coordinator-api/src/app/services/ipfs_storage_adapter.py @@ -0,0 +1,158 @@ +""" +IPFS Storage Adapter Service + +Service for offloading agent vector databases and knowledge graphs to IPFS/Filecoin. +""" + +from __future__ import annotations + +import logging +import json +import hashlib +from typing import List, Optional, Dict, Any + +from sqlmodel import Session, select +from fastapi import HTTPException + +from ..domain.decentralized_memory import AgentMemoryNode, MemoryType, StorageStatus +from ..schemas.decentralized_memory import MemoryNodeCreate +from ..blockchain.contract_interactions import ContractInteractionService + +# In a real environment, this would use a library like ipfshttpclient or a service like Pinata/Web3.Storage +# For this implementation, we will mock the interactions to demonstrate the architecture. + +logger = logging.getLogger(__name__) + +class IPFSAdapterService: + def __init__( + self, + session: Session, + contract_service: ContractInteractionService, + ipfs_gateway_url: str = "http://127.0.0.1:5001/api/v0", + pinning_service_token: Optional[str] = None + ): + self.session = session + self.contract_service = contract_service + self.ipfs_gateway_url = ipfs_gateway_url + self.pinning_service_token = pinning_service_token + + async def _mock_ipfs_upload(self, data: bytes) -> str: + """Mock function to simulate IPFS CID generation (v1 format CID simulation)""" + # Using sha256 to simulate content hashing + hash_val = hashlib.sha256(data).hexdigest() + # Mocking a CIDv1 base32 string format (bafy...) + return f"bafybeig{hash_val[:40]}" + + async def store_memory( + self, + request: MemoryNodeCreate, + raw_data: bytes, + zk_proof_hash: Optional[str] = None + ) -> AgentMemoryNode: + """ + Upload raw memory data (e.g. serialized vector DB or JSON knowledge graph) to IPFS + and create a tracking record. + """ + # 1. Create initial record + node = AgentMemoryNode( + agent_id=request.agent_id, + memory_type=request.memory_type, + is_encrypted=request.is_encrypted, + metadata=request.metadata, + tags=request.tags, + size_bytes=len(raw_data), + status=StorageStatus.PENDING, + zk_proof_hash=zk_proof_hash + ) + self.session.add(node) + self.session.commit() + self.session.refresh(node) + + try: + # 2. Upload to IPFS (Mocked) + logger.info(f"Uploading {len(raw_data)} bytes to IPFS for agent {request.agent_id}") + cid = await self._mock_ipfs_upload(raw_data) + + node.cid = cid + node.status = StorageStatus.UPLOADED + + # 3. Pin to Filecoin/Pinning service (Mocked) + if self.pinning_service_token: + logger.info(f"Pinning CID {cid} to persistent storage") + node.status = StorageStatus.PINNED + + self.session.commit() + self.session.refresh(node) + return node + + except Exception as e: + logger.error(f"Failed to store memory node {node.id}: {str(e)}") + node.status = StorageStatus.FAILED + self.session.commit() + raise HTTPException(status_code=500, detail="Failed to upload data to decentralized storage") + + async def get_memory_nodes( + self, + agent_id: str, + memory_type: Optional[MemoryType] = None, + tags: Optional[List[str]] = None + ) -> List[AgentMemoryNode]: + """Retrieve metadata for an agent's stored memory nodes""" + query = select(AgentMemoryNode).where(AgentMemoryNode.agent_id == agent_id) + + if memory_type: + query = query.where(AgentMemoryNode.memory_type == memory_type) + + # Execute query and filter by tags in Python (since SQLite JSON JSON_CONTAINS is complex via pure SQLAlchemy without specific dialects) + results = self.session.exec(query).all() + + if tags and len(tags) > 0: + filtered_results = [] + for r in results: + if all(tag in r.tags for tag in tags): + filtered_results.append(r) + return filtered_results + + return results + + async def anchor_to_blockchain(self, node_id: str) -> AgentMemoryNode: + """ + Anchor a specific IPFS CID to the agent's smart contract profile to ensure data lineage. + """ + node = self.session.get(AgentMemoryNode, node_id) + if not node: + raise HTTPException(status_code=404, detail="Memory node not found") + + if not node.cid: + raise HTTPException(status_code=400, detail="Cannot anchor node without CID") + + if node.status == StorageStatus.ANCHORED: + return node + + try: + # Mocking the smart contract call to AgentMemory.sol + # tx_hash = await self.contract_service.anchor_agent_memory(node.agent_id, node.cid, node.zk_proof_hash) + tx_hash = "0x" + hashlib.md5(f"{node.id}{node.cid}".encode()).hexdigest() + + node.anchor_tx_hash = tx_hash + node.status = StorageStatus.ANCHORED + self.session.commit() + self.session.refresh(node) + + logger.info(f"Anchored memory {node_id} (CID: {node.cid}) to blockchain. Tx: {tx_hash}") + return node + + except Exception as e: + logger.error(f"Failed to anchor memory node {node_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to anchor CID to blockchain") + + async def retrieve_memory(self, node_id: str) -> bytes: + """Retrieve the raw data from IPFS""" + node = self.session.get(AgentMemoryNode, node_id) + if not node or not node.cid: + raise HTTPException(status_code=404, detail="Memory node or CID not found") + + # Mocking retrieval + logger.info(f"Retrieving CID {node.cid} from IPFS network") + mock_data = b"{\"mock\": \"data\", \"info\": \"This represents decrypted vector db or KG data\"}" + return mock_data diff --git a/apps/coordinator-api/src/app/services/zk_memory_verification.py b/apps/coordinator-api/src/app/services/zk_memory_verification.py new file mode 100644 index 00000000..e7378c61 --- /dev/null +++ b/apps/coordinator-api/src/app/services/zk_memory_verification.py @@ -0,0 +1,115 @@ +""" +ZK-Proof Memory Verification Service + +Service for generating and verifying Zero-Knowledge proofs for decentralized memory retrieval. +Ensures that data retrieved from IPFS matches the anchored state on the blockchain +without revealing the contents of the data itself. +""" + +from __future__ import annotations + +import logging +import hashlib +import json +from typing import Dict, Any, Optional, Tuple + +from fastapi import HTTPException +from sqlmodel import Session + +from ..domain.decentralized_memory import AgentMemoryNode +from ..blockchain.contract_interactions import ContractInteractionService + +logger = logging.getLogger(__name__) + +class ZKMemoryVerificationService: + def __init__( + self, + session: Session, + contract_service: ContractInteractionService + ): + self.session = session + self.contract_service = contract_service + + async def generate_memory_proof(self, node_id: str, raw_data: bytes) -> Tuple[str, str]: + """ + Generate a Zero-Knowledge proof that the given raw data corresponds to + the structural integrity and properties required by the system, + and compute its hash for on-chain anchoring. + + Returns: + Tuple[str, str]: (zk_proof_payload, zk_proof_hash) + """ + node = self.session.get(AgentMemoryNode, node_id) + if not node: + raise HTTPException(status_code=404, detail="Memory node not found") + + # In a real ZK system (like snarkjs or circom), we would: + # 1. Compile the raw data into circuit inputs. + # 2. Run the witness generator. + # 3. Generate the proof. + + # Mocking ZK Proof generation + logger.info(f"Generating ZK proof for memory node {node_id}") + + # We simulate a proof by creating a structured JSON string + data_hash = hashlib.sha256(raw_data).hexdigest() + + mock_proof = { + "pi_a": ["mock_pi_a_1", "mock_pi_a_2", "mock_pi_a_3"], + "pi_b": [["mock_pi_b_1", "mock_pi_b_2"], ["mock_pi_b_3", "mock_pi_b_4"]], + "pi_c": ["mock_pi_c_1", "mock_pi_c_2", "mock_pi_c_3"], + "protocol": "groth16", + "curve": "bn128", + "publicSignals": [data_hash, node.agent_id] + } + + proof_payload = json.dumps(mock_proof) + + # The proof hash is what gets stored on-chain + proof_hash = "0x" + hashlib.sha256(proof_payload.encode()).hexdigest() + + return proof_payload, proof_hash + + async def verify_retrieved_memory( + self, + node_id: str, + retrieved_data: bytes, + proof_payload: str + ) -> bool: + """ + Verify that the retrieved data matches the on-chain anchored ZK proof. + """ + node = self.session.get(AgentMemoryNode, node_id) + if not node: + raise HTTPException(status_code=404, detail="Memory node not found") + + if not node.zk_proof_hash: + raise HTTPException(status_code=400, detail="Memory node does not have an anchored ZK proof") + + logger.info(f"Verifying ZK proof for retrieved memory {node_id}") + + try: + # 1. Verify the provided proof payload matches the on-chain hash + calculated_hash = "0x" + hashlib.sha256(proof_payload.encode()).hexdigest() + if calculated_hash != node.zk_proof_hash: + logger.error("Proof payload hash does not match anchored hash") + return False + + # 2. Verify the proof against the retrieved data (Circuit verification) + # In a real system, we might verify this locally or query the smart contract + + # Local mock verification + proof_data = json.loads(proof_payload) + data_hash = hashlib.sha256(retrieved_data).hexdigest() + + # Check if the public signals match the data we retrieved + if proof_data.get("publicSignals", [])[0] != data_hash: + logger.error("Public signals in proof do not match retrieved data hash") + return False + + logger.info("ZK Memory Verification Successful") + return True + + except Exception as e: + logger.error(f"Error during ZK memory verification: {str(e)}") + return False diff --git a/apps/coordinator-api/tests/test_federated_learning.py b/apps/coordinator-api/tests/test_federated_learning.py new file mode 100644 index 00000000..fbf917da --- /dev/null +++ b/apps/coordinator-api/tests/test_federated_learning.py @@ -0,0 +1,129 @@ +import pytest +from unittest.mock import AsyncMock + +from sqlmodel import Session, create_engine, SQLModel +from sqlmodel.pool import StaticPool +from fastapi import HTTPException + +from app.services.federated_learning import FederatedLearningService +from app.domain.federated_learning import TrainingStatus, ParticipantStatus +from app.schemas.federated_learning import FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest + +@pytest.fixture +def test_db(): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + session = Session(engine) + yield session + session.close() + +@pytest.fixture +def mock_contract_service(): + return AsyncMock() + +@pytest.fixture +def fl_service(test_db, mock_contract_service): + return FederatedLearningService( + session=test_db, + contract_service=mock_contract_service + ) + +@pytest.mark.asyncio +async def test_create_session(fl_service): + req = FederatedSessionCreate( + initiator_agent_id="agent-admin", + task_description="Train LLM on financial data", + model_architecture_cid="bafy_arch_123", + target_participants=2, + total_rounds=2, + min_participants_per_round=2 + ) + + session = await fl_service.create_session(req) + assert session.status == TrainingStatus.GATHERING_PARTICIPANTS + assert session.initiator_agent_id == "agent-admin" + +@pytest.mark.asyncio +async def test_join_session_and_start(fl_service): + req = FederatedSessionCreate( + initiator_agent_id="agent-admin", + task_description="Train LLM on financial data", + model_architecture_cid="bafy_arch_123", + target_participants=2, + total_rounds=2, + min_participants_per_round=2 + ) + session = await fl_service.create_session(req) + + # Agent 1 joins + p1 = await fl_service.join_session( + session.id, + JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0) + ) + assert p1.status == ParticipantStatus.JOINED + assert session.status == TrainingStatus.GATHERING_PARTICIPANTS + + # Agent 2 joins, triggers start + p2 = await fl_service.join_session( + session.id, + JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0) + ) + + # Needs refresh + fl_service.session.refresh(session) + + assert session.status == TrainingStatus.TRAINING + assert session.current_round == 1 + assert len(session.rounds) == 1 + assert session.rounds[0].status == "active" + +@pytest.mark.asyncio +async def test_submit_updates_and_aggregate(fl_service): + # Setup + req = FederatedSessionCreate( + initiator_agent_id="agent-admin", + task_description="Train LLM on financial data", + model_architecture_cid="bafy_arch_123", + target_participants=2, + total_rounds=1, # Only 1 round for quick test + min_participants_per_round=2 + ) + session = await fl_service.create_session(req) + await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0)) + await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0)) + + fl_service.session.refresh(session) + round1 = session.rounds[0] + + # Agent 1 submits + u1 = await fl_service.submit_local_update( + session.id, + round1.id, + SubmitUpdateRequest(agent_id="agent-1", weights_cid="bafy_w1", data_samples_count=1000) + ) + assert u1.weights_cid == "bafy_w1" + + fl_service.session.refresh(session) + fl_service.session.refresh(round1) + + # Not aggregated yet + assert round1.status == "active" + + # Agent 2 submits, triggers aggregation and completion since total_rounds=1 + u2 = await fl_service.submit_local_update( + session.id, + round1.id, + SubmitUpdateRequest(agent_id="agent-2", weights_cid="bafy_w2", data_samples_count=1500) + ) + + fl_service.session.refresh(session) + fl_service.session.refresh(round1) + + assert round1.status == "completed" + assert session.status == TrainingStatus.COMPLETED + assert session.global_model_cid is not None + assert session.global_model_cid.startswith("bafy_aggregated_") diff --git a/apps/coordinator-api/tests/test_ipfs_storage_adapter.py b/apps/coordinator-api/tests/test_ipfs_storage_adapter.py new file mode 100644 index 00000000..4d72c1ae --- /dev/null +++ b/apps/coordinator-api/tests/test_ipfs_storage_adapter.py @@ -0,0 +1,108 @@ +import pytest +from unittest.mock import AsyncMock + +from sqlmodel import Session, create_engine, SQLModel +from sqlmodel.pool import StaticPool +from fastapi import HTTPException + +from app.services.ipfs_storage_adapter import IPFSAdapterService +from app.domain.decentralized_memory import MemoryType, StorageStatus +from app.schemas.decentralized_memory import MemoryNodeCreate + +@pytest.fixture +def test_db(): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + session = Session(engine) + yield session + session.close() + +@pytest.fixture +def mock_contract_service(): + return AsyncMock() + +@pytest.fixture +def storage_service(test_db, mock_contract_service): + return IPFSAdapterService( + session=test_db, + contract_service=mock_contract_service, + pinning_service_token="mock_token" + ) + +@pytest.mark.asyncio +async def test_store_memory(storage_service): + request = MemoryNodeCreate( + agent_id="agent-007", + memory_type=MemoryType.VECTOR_DB, + tags=["training", "batch1"] + ) + + raw_data = b"mock_vector_embeddings_data" + + node = await storage_service.store_memory(request, raw_data, zk_proof_hash="0xabc123") + + assert node.agent_id == "agent-007" + assert node.memory_type == MemoryType.VECTOR_DB + assert node.cid is not None + assert node.cid.startswith("bafy") + assert node.size_bytes == len(raw_data) + assert node.status == StorageStatus.PINNED + assert node.zk_proof_hash == "0xabc123" + +@pytest.mark.asyncio +async def test_get_memory_nodes(storage_service): + # Store multiple + await storage_service.store_memory( + MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB, tags=["v1"]), + b"data1" + ) + await storage_service.store_memory( + MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH, tags=["v1"]), + b"data2" + ) + await storage_service.store_memory( + MemoryNodeCreate(agent_id="agent-008", memory_type=MemoryType.VECTOR_DB), + b"data3" + ) + + # Get all for agent-007 + nodes = await storage_service.get_memory_nodes("agent-007") + assert len(nodes) == 2 + + # Filter by type + nodes_kg = await storage_service.get_memory_nodes("agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH) + assert len(nodes_kg) == 1 + assert nodes_kg[0].memory_type == MemoryType.KNOWLEDGE_GRAPH + + # Filter by tag + nodes_tag = await storage_service.get_memory_nodes("agent-007", tags=["v1"]) + assert len(nodes_tag) == 2 + +@pytest.mark.asyncio +async def test_anchor_to_blockchain(storage_service): + node = await storage_service.store_memory( + MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB), + b"data1" + ) + + assert node.anchor_tx_hash is None + + anchored_node = await storage_service.anchor_to_blockchain(node.id) + + assert anchored_node.status == StorageStatus.ANCHORED + assert anchored_node.anchor_tx_hash is not None + +@pytest.mark.asyncio +async def test_retrieve_memory(storage_service): + node = await storage_service.store_memory( + MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB), + b"data1" + ) + + data = await storage_service.retrieve_memory(node.id) + assert isinstance(data, bytes) + assert b"mock" in data diff --git a/apps/coordinator-api/tests/test_zk_memory_verification.py b/apps/coordinator-api/tests/test_zk_memory_verification.py new file mode 100644 index 00000000..5e1acae0 --- /dev/null +++ b/apps/coordinator-api/tests/test_zk_memory_verification.py @@ -0,0 +1,92 @@ +import pytest +from unittest.mock import AsyncMock + +from sqlmodel import Session, create_engine, SQLModel +from sqlmodel.pool import StaticPool + +from app.services.zk_memory_verification import ZKMemoryVerificationService +from app.domain.decentralized_memory import AgentMemoryNode, MemoryType + +@pytest.fixture +def test_db(): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + session = Session(engine) + yield session + session.close() + +@pytest.fixture +def mock_contract_service(): + return AsyncMock() + +@pytest.fixture +def zk_service(test_db, mock_contract_service): + return ZKMemoryVerificationService( + session=test_db, + contract_service=mock_contract_service + ) + +@pytest.mark.asyncio +async def test_generate_memory_proof(zk_service, test_db): + node = AgentMemoryNode( + agent_id="agent-zk", + memory_type=MemoryType.VECTOR_DB + ) + test_db.add(node) + test_db.commit() + test_db.refresh(node) + + raw_data = b"secret_vector_data" + + proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data) + + assert proof_payload is not None + assert proof_hash.startswith("0x") + assert "groth16" in proof_payload + +@pytest.mark.asyncio +async def test_verify_retrieved_memory_success(zk_service, test_db): + node = AgentMemoryNode( + agent_id="agent-zk", + memory_type=MemoryType.VECTOR_DB + ) + test_db.add(node) + test_db.commit() + test_db.refresh(node) + + raw_data = b"secret_vector_data" + proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data) + + # Simulate anchoring + node.zk_proof_hash = proof_hash + test_db.commit() + + # Verify + is_valid = await zk_service.verify_retrieved_memory(node.id, raw_data, proof_payload) + assert is_valid is True + +@pytest.mark.asyncio +async def test_verify_retrieved_memory_tampered_data(zk_service, test_db): + node = AgentMemoryNode( + agent_id="agent-zk", + memory_type=MemoryType.VECTOR_DB + ) + test_db.add(node) + test_db.commit() + test_db.refresh(node) + + raw_data = b"secret_vector_data" + proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data) + + node.zk_proof_hash = proof_hash + test_db.commit() + + # Tamper with data + tampered_data = b"secret_vector_data_modified" + + is_valid = await zk_service.verify_retrieved_memory(node.id, tampered_data, proof_payload) + assert is_valid is False