feat: add phase 2 decentralized memory python services and tests
This commit is contained in:
56
apps/coordinator-api/src/app/domain/decentralized_memory.py
Normal file
56
apps/coordinator-api/src/app/domain/decentralized_memory.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Decentralized Memory Domain Models
|
||||
|
||||
Domain models for managing agent memory and knowledge graphs on IPFS/Filecoin.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, List
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Column, JSON
|
||||
from sqlmodel import Field, SQLModel, Relationship
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
VECTOR_DB = "vector_db"
|
||||
KNOWLEDGE_GRAPH = "knowledge_graph"
|
||||
POLICY_WEIGHTS = "policy_weights"
|
||||
EPISODIC = "episodic"
|
||||
|
||||
class StorageStatus(str, Enum):
|
||||
PENDING = "pending" # Upload to IPFS pending
|
||||
UPLOADED = "uploaded" # Available on IPFS
|
||||
PINNED = "pinned" # Pinned on Filecoin/Pinata
|
||||
ANCHORED = "anchored" # CID written to blockchain
|
||||
FAILED = "failed" # Upload failed
|
||||
|
||||
class AgentMemoryNode(SQLModel, table=True):
|
||||
"""Represents a chunk of memory or knowledge stored on decentralized storage"""
|
||||
__tablename__ = "agent_memory_node"
|
||||
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
|
||||
agent_id: str = Field(index=True)
|
||||
memory_type: MemoryType = Field(index=True)
|
||||
|
||||
# Decentralized Storage Identifiers
|
||||
cid: Optional[str] = Field(default=None, index=True) # IPFS Content Identifier
|
||||
size_bytes: Optional[int] = Field(default=None)
|
||||
|
||||
# Encryption and Security
|
||||
is_encrypted: bool = Field(default=True)
|
||||
encryption_key_id: Optional[str] = Field(default=None) # Reference to KMS or Lit Protocol
|
||||
zk_proof_hash: Optional[str] = Field(default=None) # Hash of the ZK proof verifying content validity
|
||||
|
||||
status: StorageStatus = Field(default=StorageStatus.PENDING, index=True)
|
||||
|
||||
metadata: Dict[str, str] = Field(default_factory=dict, sa_column=Column(JSON))
|
||||
tags: List[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
|
||||
# Blockchain Anchoring
|
||||
anchor_tx_hash: Optional[str] = Field(default=None)
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
122
apps/coordinator-api/src/app/domain/federated_learning.py
Normal file
122
apps/coordinator-api/src/app/domain/federated_learning.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Federated Learning Domain Models
|
||||
|
||||
Domain models for managing cross-agent knowledge sharing and collaborative model training.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Column, JSON
|
||||
from sqlmodel import Field, SQLModel, Relationship
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
INITIALIZED = "initiated"
|
||||
GATHERING_PARTICIPANTS = "gathering_participants"
|
||||
TRAINING = "training"
|
||||
AGGREGATING = "aggregating"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
class ParticipantStatus(str, Enum):
|
||||
INVITED = "invited"
|
||||
JOINED = "joined"
|
||||
TRAINING = "training"
|
||||
SUBMITTED = "submitted"
|
||||
DROPPED = "dropped"
|
||||
|
||||
class FederatedLearningSession(SQLModel, table=True):
|
||||
"""Represents a collaborative training session across multiple agents"""
|
||||
__tablename__ = "federated_learning_session"
|
||||
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
|
||||
initiator_agent_id: str = Field(index=True)
|
||||
task_description: str = Field()
|
||||
model_architecture_cid: str = Field() # IPFS CID pointing to model structure definition
|
||||
initial_weights_cid: Optional[str] = Field(default=None) # Optional starting point
|
||||
|
||||
target_participants: int = Field(default=3)
|
||||
current_round: int = Field(default=0)
|
||||
total_rounds: int = Field(default=10)
|
||||
|
||||
aggregation_strategy: str = Field(default="fedavg") # e.g. fedavg, fedprox
|
||||
min_participants_per_round: int = Field(default=2)
|
||||
|
||||
reward_pool_amount: float = Field(default=0.0) # Total AITBC allocated to reward participants
|
||||
|
||||
status: TrainingStatus = Field(default=TrainingStatus.INITIALIZED, index=True)
|
||||
|
||||
global_model_cid: Optional[str] = Field(default=None) # Final aggregated model
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
participants: List["TrainingParticipant"] = Relationship(back_populates="session")
|
||||
rounds: List["TrainingRound"] = Relationship(back_populates="session")
|
||||
|
||||
class TrainingParticipant(SQLModel, table=True):
|
||||
"""An agent participating in a federated learning session"""
|
||||
__tablename__ = "training_participant"
|
||||
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
|
||||
session_id: str = Field(foreign_key="federated_learning_session.id", index=True)
|
||||
agent_id: str = Field(index=True)
|
||||
|
||||
status: ParticipantStatus = Field(default=ParticipantStatus.JOINED, index=True)
|
||||
data_samples_count: int = Field(default=0) # Claimed number of local samples used
|
||||
compute_power_committed: float = Field(default=0.0) # TFLOPS
|
||||
|
||||
reputation_score_at_join: float = Field(default=0.0)
|
||||
earned_reward: float = Field(default=0.0)
|
||||
|
||||
joined_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
session: FederatedLearningSession = Relationship(back_populates="participants")
|
||||
|
||||
class TrainingRound(SQLModel, table=True):
|
||||
"""A specific round of federated learning"""
|
||||
__tablename__ = "training_round"
|
||||
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
|
||||
session_id: str = Field(foreign_key="federated_learning_session.id", index=True)
|
||||
round_number: int = Field()
|
||||
|
||||
status: str = Field(default="pending") # pending, active, aggregating, completed
|
||||
|
||||
starting_model_cid: str = Field() # Global model weights at start of round
|
||||
aggregated_model_cid: Optional[str] = Field(default=None) # Resulting weights after round
|
||||
|
||||
metrics: Dict[str, float] = Field(default_factory=dict, sa_column=Column(JSON)) # e.g. loss, accuracy
|
||||
|
||||
started_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
completed_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
# Relationships
|
||||
session: FederatedLearningSession = Relationship(back_populates="rounds")
|
||||
updates: List["LocalModelUpdate"] = Relationship(back_populates="round")
|
||||
|
||||
class LocalModelUpdate(SQLModel, table=True):
|
||||
"""A local model update submitted by a participant for a specific round"""
|
||||
__tablename__ = "local_model_update"
|
||||
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True)
|
||||
round_id: str = Field(foreign_key="training_round.id", index=True)
|
||||
participant_agent_id: str = Field(index=True)
|
||||
|
||||
weights_cid: str = Field() # IPFS CID of the locally trained weights
|
||||
zk_proof_hash: Optional[str] = Field(default=None) # Proof that training was executed correctly
|
||||
|
||||
is_aggregated: bool = Field(default=False)
|
||||
rejected_reason: Optional[str] = Field(default=None) # e.g. "outlier", "failed zk verification"
|
||||
|
||||
submitted_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
round: TrainingRound = Relationship(back_populates="updates")
|
||||
29
apps/coordinator-api/src/app/schemas/decentralized_memory.py
Normal file
29
apps/coordinator-api/src/app/schemas/decentralized_memory.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, List
|
||||
from .decentralized_memory import MemoryType, StorageStatus
|
||||
|
||||
class MemoryNodeCreate(BaseModel):
|
||||
agent_id: str
|
||||
memory_type: MemoryType
|
||||
is_encrypted: bool = True
|
||||
metadata: Dict[str, str] = Field(default_factory=dict)
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
|
||||
class MemoryNodeResponse(BaseModel):
|
||||
id: str
|
||||
agent_id: str
|
||||
memory_type: MemoryType
|
||||
cid: Optional[str]
|
||||
size_bytes: Optional[int]
|
||||
is_encrypted: bool
|
||||
status: StorageStatus
|
||||
metadata: Dict[str, str]
|
||||
tags: List[str]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class MemoryQueryRequest(BaseModel):
|
||||
agent_id: str
|
||||
memory_type: Optional[MemoryType] = None
|
||||
tags: Optional[List[str]] = None
|
||||
37
apps/coordinator-api/src/app/schemas/federated_learning.py
Normal file
37
apps/coordinator-api/src/app/schemas/federated_learning.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict
|
||||
from .federated_learning import TrainingStatus
|
||||
|
||||
class FederatedSessionCreate(BaseModel):
|
||||
initiator_agent_id: str
|
||||
task_description: str
|
||||
model_architecture_cid: str
|
||||
initial_weights_cid: Optional[str] = None
|
||||
target_participants: int = 3
|
||||
total_rounds: int = 10
|
||||
aggregation_strategy: str = "fedavg"
|
||||
min_participants_per_round: int = 2
|
||||
reward_pool_amount: float = 0.0
|
||||
|
||||
class FederatedSessionResponse(BaseModel):
|
||||
id: str
|
||||
initiator_agent_id: str
|
||||
task_description: str
|
||||
target_participants: int
|
||||
current_round: int
|
||||
total_rounds: int
|
||||
status: TrainingStatus
|
||||
global_model_cid: Optional[str]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class JoinSessionRequest(BaseModel):
|
||||
agent_id: str
|
||||
compute_power_committed: float
|
||||
|
||||
class SubmitUpdateRequest(BaseModel):
|
||||
agent_id: str
|
||||
weights_cid: str
|
||||
zk_proof_hash: Optional[str] = None
|
||||
data_samples_count: int
|
||||
208
apps/coordinator-api/src/app/services/federated_learning.py
Normal file
208
apps/coordinator-api/src/app/services/federated_learning.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Federated Learning Service
|
||||
|
||||
Service for managing cross-agent knowledge sharing and collaborative model training.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ..domain.federated_learning import (
|
||||
FederatedLearningSession, TrainingParticipant, TrainingRound,
|
||||
LocalModelUpdate, TrainingStatus, ParticipantStatus
|
||||
)
|
||||
from ..schemas.federated_learning import (
|
||||
FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest
|
||||
)
|
||||
from ..blockchain.contract_interactions import ContractInteractionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FederatedLearningService:
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
contract_service: ContractInteractionService
|
||||
):
|
||||
self.session = session
|
||||
self.contract_service = contract_service
|
||||
|
||||
async def create_session(self, request: FederatedSessionCreate) -> FederatedLearningSession:
|
||||
"""Create a new federated learning session"""
|
||||
|
||||
session = FederatedLearningSession(
|
||||
initiator_agent_id=request.initiator_agent_id,
|
||||
task_description=request.task_description,
|
||||
model_architecture_cid=request.model_architecture_cid,
|
||||
initial_weights_cid=request.initial_weights_cid,
|
||||
target_participants=request.target_participants,
|
||||
total_rounds=request.total_rounds,
|
||||
aggregation_strategy=request.aggregation_strategy,
|
||||
min_participants_per_round=request.min_participants_per_round,
|
||||
reward_pool_amount=request.reward_pool_amount,
|
||||
status=TrainingStatus.GATHERING_PARTICIPANTS
|
||||
)
|
||||
|
||||
self.session.add(session)
|
||||
self.session.commit()
|
||||
self.session.refresh(session)
|
||||
|
||||
logger.info(f"Created Federated Learning Session {session.id} by {request.initiator_agent_id}")
|
||||
return session
|
||||
|
||||
async def join_session(self, session_id: str, request: JoinSessionRequest) -> TrainingParticipant:
|
||||
"""Allow an agent to join an active session"""
|
||||
|
||||
fl_session = self.session.get(FederatedLearningSession, session_id)
|
||||
if not fl_session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if fl_session.status != TrainingStatus.GATHERING_PARTICIPANTS:
|
||||
raise HTTPException(status_code=400, detail="Session is not currently accepting participants")
|
||||
|
||||
# Check if already joined
|
||||
existing = self.session.exec(
|
||||
select(TrainingParticipant).where(
|
||||
TrainingParticipant.session_id == session_id,
|
||||
TrainingParticipant.agent_id == request.agent_id
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Agent already joined this session")
|
||||
|
||||
# In reality, fetch reputation from blockchain/service
|
||||
mock_reputation = 95.0
|
||||
|
||||
participant = TrainingParticipant(
|
||||
session_id=session_id,
|
||||
agent_id=request.agent_id,
|
||||
compute_power_committed=request.compute_power_committed,
|
||||
reputation_score_at_join=mock_reputation,
|
||||
status=ParticipantStatus.JOINED
|
||||
)
|
||||
|
||||
self.session.add(participant)
|
||||
self.session.commit()
|
||||
self.session.refresh(participant)
|
||||
|
||||
# Check if we have enough participants to start
|
||||
current_count = len(fl_session.participants) + 1 # +1 for the newly added but not refreshed one
|
||||
if current_count >= fl_session.target_participants:
|
||||
await self._start_training(fl_session)
|
||||
|
||||
return participant
|
||||
|
||||
async def _start_training(self, fl_session: FederatedLearningSession):
|
||||
"""Internal method to transition from gathering to active training"""
|
||||
fl_session.status = TrainingStatus.TRAINING
|
||||
fl_session.current_round = 1
|
||||
|
||||
# Start Round 1
|
||||
round1 = TrainingRound(
|
||||
session_id=fl_session.id,
|
||||
round_number=1,
|
||||
status="active",
|
||||
starting_model_cid=fl_session.initial_weights_cid or fl_session.model_architecture_cid
|
||||
)
|
||||
|
||||
self.session.add(round1)
|
||||
self.session.commit()
|
||||
logger.info(f"Started training for session {fl_session.id}, Round 1 active.")
|
||||
|
||||
async def submit_local_update(self, session_id: str, round_id: str, request: SubmitUpdateRequest) -> LocalModelUpdate:
|
||||
"""Participant submits their locally trained model weights"""
|
||||
|
||||
fl_session = self.session.get(FederatedLearningSession, session_id)
|
||||
current_round = self.session.get(TrainingRound, round_id)
|
||||
|
||||
if not fl_session or not current_round:
|
||||
raise HTTPException(status_code=404, detail="Session or Round not found")
|
||||
|
||||
if fl_session.status != TrainingStatus.TRAINING or current_round.status != "active":
|
||||
raise HTTPException(status_code=400, detail="Round is not currently active")
|
||||
|
||||
participant = self.session.exec(
|
||||
select(TrainingParticipant).where(
|
||||
TrainingParticipant.session_id == session_id,
|
||||
TrainingParticipant.agent_id == request.agent_id
|
||||
)
|
||||
).first()
|
||||
|
||||
if not participant:
|
||||
raise HTTPException(status_code=403, detail="Agent is not a participant in this session")
|
||||
|
||||
update = LocalModelUpdate(
|
||||
round_id=round_id,
|
||||
participant_agent_id=request.agent_id,
|
||||
weights_cid=request.weights_cid,
|
||||
zk_proof_hash=request.zk_proof_hash
|
||||
)
|
||||
|
||||
participant.data_samples_count += request.data_samples_count
|
||||
participant.status = ParticipantStatus.SUBMITTED
|
||||
|
||||
self.session.add(update)
|
||||
self.session.commit()
|
||||
self.session.refresh(update)
|
||||
|
||||
# Check if we should trigger aggregation
|
||||
updates_count = len(current_round.updates) + 1
|
||||
if updates_count >= fl_session.min_participants_per_round:
|
||||
# Note: In a real system, this might be triggered asynchronously via a Celery task
|
||||
await self._aggregate_round(fl_session, current_round)
|
||||
|
||||
return update
|
||||
|
||||
async def _aggregate_round(self, fl_session: FederatedLearningSession, current_round: TrainingRound):
|
||||
"""Mock aggregation process"""
|
||||
current_round.status = "aggregating"
|
||||
fl_session.status = TrainingStatus.AGGREGATING
|
||||
self.session.commit()
|
||||
|
||||
# Mocking the actual heavy ML aggregation that would happen elsewhere
|
||||
logger.info(f"Aggregating {len(current_round.updates)} updates for round {current_round.round_number}")
|
||||
|
||||
# Assume successful aggregation creates a new global CID
|
||||
import hashlib
|
||||
import time
|
||||
mock_hash = hashlib.md5(str(time.time()).encode()).hexdigest()
|
||||
new_global_cid = f"bafy_aggregated_{mock_hash[:20]}"
|
||||
|
||||
current_round.aggregated_model_cid = new_global_cid
|
||||
current_round.status = "completed"
|
||||
current_round.completed_at = datetime.utcnow()
|
||||
current_round.metrics = {"loss": 0.5 - (current_round.round_number * 0.05), "accuracy": 0.7 + (current_round.round_number * 0.02)}
|
||||
|
||||
if fl_session.current_round >= fl_session.total_rounds:
|
||||
fl_session.status = TrainingStatus.COMPLETED
|
||||
fl_session.global_model_cid = new_global_cid
|
||||
logger.info(f"Federated Learning Session {fl_session.id} fully completed.")
|
||||
# Here we would handle reward distribution via smart contracts
|
||||
else:
|
||||
fl_session.current_round += 1
|
||||
fl_session.status = TrainingStatus.TRAINING
|
||||
|
||||
# Start next round
|
||||
next_round = TrainingRound(
|
||||
session_id=fl_session.id,
|
||||
round_number=fl_session.current_round,
|
||||
status="active",
|
||||
starting_model_cid=new_global_cid
|
||||
)
|
||||
self.session.add(next_round)
|
||||
|
||||
# Reset participant statuses
|
||||
for p in fl_session.participants:
|
||||
if p.status == ParticipantStatus.SUBMITTED:
|
||||
p.status = ParticipantStatus.TRAINING
|
||||
|
||||
logger.info(f"Session {fl_session.id} progressing to Round {fl_session.current_round}")
|
||||
|
||||
self.session.commit()
|
||||
158
apps/coordinator-api/src/app/services/ipfs_storage_adapter.py
Normal file
158
apps/coordinator-api/src/app/services/ipfs_storage_adapter.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
IPFS Storage Adapter Service
|
||||
|
||||
Service for offloading agent vector databases and knowledge graphs to IPFS/Filecoin.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import json
|
||||
import hashlib
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ..domain.decentralized_memory import AgentMemoryNode, MemoryType, StorageStatus
|
||||
from ..schemas.decentralized_memory import MemoryNodeCreate
|
||||
from ..blockchain.contract_interactions import ContractInteractionService
|
||||
|
||||
# In a real environment, this would use a library like ipfshttpclient or a service like Pinata/Web3.Storage
|
||||
# For this implementation, we will mock the interactions to demonstrate the architecture.
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class IPFSAdapterService:
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
contract_service: ContractInteractionService,
|
||||
ipfs_gateway_url: str = "http://127.0.0.1:5001/api/v0",
|
||||
pinning_service_token: Optional[str] = None
|
||||
):
|
||||
self.session = session
|
||||
self.contract_service = contract_service
|
||||
self.ipfs_gateway_url = ipfs_gateway_url
|
||||
self.pinning_service_token = pinning_service_token
|
||||
|
||||
async def _mock_ipfs_upload(self, data: bytes) -> str:
|
||||
"""Mock function to simulate IPFS CID generation (v1 format CID simulation)"""
|
||||
# Using sha256 to simulate content hashing
|
||||
hash_val = hashlib.sha256(data).hexdigest()
|
||||
# Mocking a CIDv1 base32 string format (bafy...)
|
||||
return f"bafybeig{hash_val[:40]}"
|
||||
|
||||
async def store_memory(
|
||||
self,
|
||||
request: MemoryNodeCreate,
|
||||
raw_data: bytes,
|
||||
zk_proof_hash: Optional[str] = None
|
||||
) -> AgentMemoryNode:
|
||||
"""
|
||||
Upload raw memory data (e.g. serialized vector DB or JSON knowledge graph) to IPFS
|
||||
and create a tracking record.
|
||||
"""
|
||||
# 1. Create initial record
|
||||
node = AgentMemoryNode(
|
||||
agent_id=request.agent_id,
|
||||
memory_type=request.memory_type,
|
||||
is_encrypted=request.is_encrypted,
|
||||
metadata=request.metadata,
|
||||
tags=request.tags,
|
||||
size_bytes=len(raw_data),
|
||||
status=StorageStatus.PENDING,
|
||||
zk_proof_hash=zk_proof_hash
|
||||
)
|
||||
self.session.add(node)
|
||||
self.session.commit()
|
||||
self.session.refresh(node)
|
||||
|
||||
try:
|
||||
# 2. Upload to IPFS (Mocked)
|
||||
logger.info(f"Uploading {len(raw_data)} bytes to IPFS for agent {request.agent_id}")
|
||||
cid = await self._mock_ipfs_upload(raw_data)
|
||||
|
||||
node.cid = cid
|
||||
node.status = StorageStatus.UPLOADED
|
||||
|
||||
# 3. Pin to Filecoin/Pinning service (Mocked)
|
||||
if self.pinning_service_token:
|
||||
logger.info(f"Pinning CID {cid} to persistent storage")
|
||||
node.status = StorageStatus.PINNED
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(node)
|
||||
return node
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store memory node {node.id}: {str(e)}")
|
||||
node.status = StorageStatus.FAILED
|
||||
self.session.commit()
|
||||
raise HTTPException(status_code=500, detail="Failed to upload data to decentralized storage")
|
||||
|
||||
async def get_memory_nodes(
|
||||
self,
|
||||
agent_id: str,
|
||||
memory_type: Optional[MemoryType] = None,
|
||||
tags: Optional[List[str]] = None
|
||||
) -> List[AgentMemoryNode]:
|
||||
"""Retrieve metadata for an agent's stored memory nodes"""
|
||||
query = select(AgentMemoryNode).where(AgentMemoryNode.agent_id == agent_id)
|
||||
|
||||
if memory_type:
|
||||
query = query.where(AgentMemoryNode.memory_type == memory_type)
|
||||
|
||||
# Execute query and filter by tags in Python (since SQLite JSON JSON_CONTAINS is complex via pure SQLAlchemy without specific dialects)
|
||||
results = self.session.exec(query).all()
|
||||
|
||||
if tags and len(tags) > 0:
|
||||
filtered_results = []
|
||||
for r in results:
|
||||
if all(tag in r.tags for tag in tags):
|
||||
filtered_results.append(r)
|
||||
return filtered_results
|
||||
|
||||
return results
|
||||
|
||||
async def anchor_to_blockchain(self, node_id: str) -> AgentMemoryNode:
|
||||
"""
|
||||
Anchor a specific IPFS CID to the agent's smart contract profile to ensure data lineage.
|
||||
"""
|
||||
node = self.session.get(AgentMemoryNode, node_id)
|
||||
if not node:
|
||||
raise HTTPException(status_code=404, detail="Memory node not found")
|
||||
|
||||
if not node.cid:
|
||||
raise HTTPException(status_code=400, detail="Cannot anchor node without CID")
|
||||
|
||||
if node.status == StorageStatus.ANCHORED:
|
||||
return node
|
||||
|
||||
try:
|
||||
# Mocking the smart contract call to AgentMemory.sol
|
||||
# tx_hash = await self.contract_service.anchor_agent_memory(node.agent_id, node.cid, node.zk_proof_hash)
|
||||
tx_hash = "0x" + hashlib.md5(f"{node.id}{node.cid}".encode()).hexdigest()
|
||||
|
||||
node.anchor_tx_hash = tx_hash
|
||||
node.status = StorageStatus.ANCHORED
|
||||
self.session.commit()
|
||||
self.session.refresh(node)
|
||||
|
||||
logger.info(f"Anchored memory {node_id} (CID: {node.cid}) to blockchain. Tx: {tx_hash}")
|
||||
return node
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to anchor memory node {node_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to anchor CID to blockchain")
|
||||
|
||||
async def retrieve_memory(self, node_id: str) -> bytes:
|
||||
"""Retrieve the raw data from IPFS"""
|
||||
node = self.session.get(AgentMemoryNode, node_id)
|
||||
if not node or not node.cid:
|
||||
raise HTTPException(status_code=404, detail="Memory node or CID not found")
|
||||
|
||||
# Mocking retrieval
|
||||
logger.info(f"Retrieving CID {node.cid} from IPFS network")
|
||||
mock_data = b"{\"mock\": \"data\", \"info\": \"This represents decrypted vector db or KG data\"}"
|
||||
return mock_data
|
||||
115
apps/coordinator-api/src/app/services/zk_memory_verification.py
Normal file
115
apps/coordinator-api/src/app/services/zk_memory_verification.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
ZK-Proof Memory Verification Service
|
||||
|
||||
Service for generating and verifying Zero-Knowledge proofs for decentralized memory retrieval.
|
||||
Ensures that data retrieved from IPFS matches the anchored state on the blockchain
|
||||
without revealing the contents of the data itself.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel import Session
|
||||
|
||||
from ..domain.decentralized_memory import AgentMemoryNode
|
||||
from ..blockchain.contract_interactions import ContractInteractionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ZKMemoryVerificationService:
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
contract_service: ContractInteractionService
|
||||
):
|
||||
self.session = session
|
||||
self.contract_service = contract_service
|
||||
|
||||
async def generate_memory_proof(self, node_id: str, raw_data: bytes) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate a Zero-Knowledge proof that the given raw data corresponds to
|
||||
the structural integrity and properties required by the system,
|
||||
and compute its hash for on-chain anchoring.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (zk_proof_payload, zk_proof_hash)
|
||||
"""
|
||||
node = self.session.get(AgentMemoryNode, node_id)
|
||||
if not node:
|
||||
raise HTTPException(status_code=404, detail="Memory node not found")
|
||||
|
||||
# In a real ZK system (like snarkjs or circom), we would:
|
||||
# 1. Compile the raw data into circuit inputs.
|
||||
# 2. Run the witness generator.
|
||||
# 3. Generate the proof.
|
||||
|
||||
# Mocking ZK Proof generation
|
||||
logger.info(f"Generating ZK proof for memory node {node_id}")
|
||||
|
||||
# We simulate a proof by creating a structured JSON string
|
||||
data_hash = hashlib.sha256(raw_data).hexdigest()
|
||||
|
||||
mock_proof = {
|
||||
"pi_a": ["mock_pi_a_1", "mock_pi_a_2", "mock_pi_a_3"],
|
||||
"pi_b": [["mock_pi_b_1", "mock_pi_b_2"], ["mock_pi_b_3", "mock_pi_b_4"]],
|
||||
"pi_c": ["mock_pi_c_1", "mock_pi_c_2", "mock_pi_c_3"],
|
||||
"protocol": "groth16",
|
||||
"curve": "bn128",
|
||||
"publicSignals": [data_hash, node.agent_id]
|
||||
}
|
||||
|
||||
proof_payload = json.dumps(mock_proof)
|
||||
|
||||
# The proof hash is what gets stored on-chain
|
||||
proof_hash = "0x" + hashlib.sha256(proof_payload.encode()).hexdigest()
|
||||
|
||||
return proof_payload, proof_hash
|
||||
|
||||
async def verify_retrieved_memory(
|
||||
self,
|
||||
node_id: str,
|
||||
retrieved_data: bytes,
|
||||
proof_payload: str
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that the retrieved data matches the on-chain anchored ZK proof.
|
||||
"""
|
||||
node = self.session.get(AgentMemoryNode, node_id)
|
||||
if not node:
|
||||
raise HTTPException(status_code=404, detail="Memory node not found")
|
||||
|
||||
if not node.zk_proof_hash:
|
||||
raise HTTPException(status_code=400, detail="Memory node does not have an anchored ZK proof")
|
||||
|
||||
logger.info(f"Verifying ZK proof for retrieved memory {node_id}")
|
||||
|
||||
try:
|
||||
# 1. Verify the provided proof payload matches the on-chain hash
|
||||
calculated_hash = "0x" + hashlib.sha256(proof_payload.encode()).hexdigest()
|
||||
if calculated_hash != node.zk_proof_hash:
|
||||
logger.error("Proof payload hash does not match anchored hash")
|
||||
return False
|
||||
|
||||
# 2. Verify the proof against the retrieved data (Circuit verification)
|
||||
# In a real system, we might verify this locally or query the smart contract
|
||||
|
||||
# Local mock verification
|
||||
proof_data = json.loads(proof_payload)
|
||||
data_hash = hashlib.sha256(retrieved_data).hexdigest()
|
||||
|
||||
# Check if the public signals match the data we retrieved
|
||||
if proof_data.get("publicSignals", [])[0] != data_hash:
|
||||
logger.error("Public signals in proof do not match retrieved data hash")
|
||||
return False
|
||||
|
||||
logger.info("ZK Memory Verification Successful")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ZK memory verification: {str(e)}")
|
||||
return False
|
||||
129
apps/coordinator-api/tests/test_federated_learning.py
Normal file
129
apps/coordinator-api/tests/test_federated_learning.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from sqlmodel import Session, create_engine, SQLModel
|
||||
from sqlmodel.pool import StaticPool
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.services.federated_learning import FederatedLearningService
|
||||
from app.domain.federated_learning import TrainingStatus, ParticipantStatus
|
||||
from app.schemas.federated_learning import FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
session = Session(engine)
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_contract_service():
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def fl_service(test_db, mock_contract_service):
|
||||
return FederatedLearningService(
|
||||
session=test_db,
|
||||
contract_service=mock_contract_service
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(fl_service):
|
||||
req = FederatedSessionCreate(
|
||||
initiator_agent_id="agent-admin",
|
||||
task_description="Train LLM on financial data",
|
||||
model_architecture_cid="bafy_arch_123",
|
||||
target_participants=2,
|
||||
total_rounds=2,
|
||||
min_participants_per_round=2
|
||||
)
|
||||
|
||||
session = await fl_service.create_session(req)
|
||||
assert session.status == TrainingStatus.GATHERING_PARTICIPANTS
|
||||
assert session.initiator_agent_id == "agent-admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_session_and_start(fl_service):
|
||||
req = FederatedSessionCreate(
|
||||
initiator_agent_id="agent-admin",
|
||||
task_description="Train LLM on financial data",
|
||||
model_architecture_cid="bafy_arch_123",
|
||||
target_participants=2,
|
||||
total_rounds=2,
|
||||
min_participants_per_round=2
|
||||
)
|
||||
session = await fl_service.create_session(req)
|
||||
|
||||
# Agent 1 joins
|
||||
p1 = await fl_service.join_session(
|
||||
session.id,
|
||||
JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0)
|
||||
)
|
||||
assert p1.status == ParticipantStatus.JOINED
|
||||
assert session.status == TrainingStatus.GATHERING_PARTICIPANTS
|
||||
|
||||
# Agent 2 joins, triggers start
|
||||
p2 = await fl_service.join_session(
|
||||
session.id,
|
||||
JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0)
|
||||
)
|
||||
|
||||
# Needs refresh
|
||||
fl_service.session.refresh(session)
|
||||
|
||||
assert session.status == TrainingStatus.TRAINING
|
||||
assert session.current_round == 1
|
||||
assert len(session.rounds) == 1
|
||||
assert session.rounds[0].status == "active"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_updates_and_aggregate(fl_service):
|
||||
# Setup
|
||||
req = FederatedSessionCreate(
|
||||
initiator_agent_id="agent-admin",
|
||||
task_description="Train LLM on financial data",
|
||||
model_architecture_cid="bafy_arch_123",
|
||||
target_participants=2,
|
||||
total_rounds=1, # Only 1 round for quick test
|
||||
min_participants_per_round=2
|
||||
)
|
||||
session = await fl_service.create_session(req)
|
||||
await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-1", compute_power_committed=10.0))
|
||||
await fl_service.join_session(session.id, JoinSessionRequest(agent_id="agent-2", compute_power_committed=15.0))
|
||||
|
||||
fl_service.session.refresh(session)
|
||||
round1 = session.rounds[0]
|
||||
|
||||
# Agent 1 submits
|
||||
u1 = await fl_service.submit_local_update(
|
||||
session.id,
|
||||
round1.id,
|
||||
SubmitUpdateRequest(agent_id="agent-1", weights_cid="bafy_w1", data_samples_count=1000)
|
||||
)
|
||||
assert u1.weights_cid == "bafy_w1"
|
||||
|
||||
fl_service.session.refresh(session)
|
||||
fl_service.session.refresh(round1)
|
||||
|
||||
# Not aggregated yet
|
||||
assert round1.status == "active"
|
||||
|
||||
# Agent 2 submits, triggers aggregation and completion since total_rounds=1
|
||||
u2 = await fl_service.submit_local_update(
|
||||
session.id,
|
||||
round1.id,
|
||||
SubmitUpdateRequest(agent_id="agent-2", weights_cid="bafy_w2", data_samples_count=1500)
|
||||
)
|
||||
|
||||
fl_service.session.refresh(session)
|
||||
fl_service.session.refresh(round1)
|
||||
|
||||
assert round1.status == "completed"
|
||||
assert session.status == TrainingStatus.COMPLETED
|
||||
assert session.global_model_cid is not None
|
||||
assert session.global_model_cid.startswith("bafy_aggregated_")
|
||||
108
apps/coordinator-api/tests/test_ipfs_storage_adapter.py
Normal file
108
apps/coordinator-api/tests/test_ipfs_storage_adapter.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from sqlmodel import Session, create_engine, SQLModel
|
||||
from sqlmodel.pool import StaticPool
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.services.ipfs_storage_adapter import IPFSAdapterService
|
||||
from app.domain.decentralized_memory import MemoryType, StorageStatus
|
||||
from app.schemas.decentralized_memory import MemoryNodeCreate
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
session = Session(engine)
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_contract_service():
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def storage_service(test_db, mock_contract_service):
|
||||
return IPFSAdapterService(
|
||||
session=test_db,
|
||||
contract_service=mock_contract_service,
|
||||
pinning_service_token="mock_token"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_memory(storage_service):
|
||||
request = MemoryNodeCreate(
|
||||
agent_id="agent-007",
|
||||
memory_type=MemoryType.VECTOR_DB,
|
||||
tags=["training", "batch1"]
|
||||
)
|
||||
|
||||
raw_data = b"mock_vector_embeddings_data"
|
||||
|
||||
node = await storage_service.store_memory(request, raw_data, zk_proof_hash="0xabc123")
|
||||
|
||||
assert node.agent_id == "agent-007"
|
||||
assert node.memory_type == MemoryType.VECTOR_DB
|
||||
assert node.cid is not None
|
||||
assert node.cid.startswith("bafy")
|
||||
assert node.size_bytes == len(raw_data)
|
||||
assert node.status == StorageStatus.PINNED
|
||||
assert node.zk_proof_hash == "0xabc123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_nodes(storage_service):
|
||||
# Store multiple
|
||||
await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB, tags=["v1"]),
|
||||
b"data1"
|
||||
)
|
||||
await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH, tags=["v1"]),
|
||||
b"data2"
|
||||
)
|
||||
await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-008", memory_type=MemoryType.VECTOR_DB),
|
||||
b"data3"
|
||||
)
|
||||
|
||||
# Get all for agent-007
|
||||
nodes = await storage_service.get_memory_nodes("agent-007")
|
||||
assert len(nodes) == 2
|
||||
|
||||
# Filter by type
|
||||
nodes_kg = await storage_service.get_memory_nodes("agent-007", memory_type=MemoryType.KNOWLEDGE_GRAPH)
|
||||
assert len(nodes_kg) == 1
|
||||
assert nodes_kg[0].memory_type == MemoryType.KNOWLEDGE_GRAPH
|
||||
|
||||
# Filter by tag
|
||||
nodes_tag = await storage_service.get_memory_nodes("agent-007", tags=["v1"])
|
||||
assert len(nodes_tag) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anchor_to_blockchain(storage_service):
|
||||
node = await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB),
|
||||
b"data1"
|
||||
)
|
||||
|
||||
assert node.anchor_tx_hash is None
|
||||
|
||||
anchored_node = await storage_service.anchor_to_blockchain(node.id)
|
||||
|
||||
assert anchored_node.status == StorageStatus.ANCHORED
|
||||
assert anchored_node.anchor_tx_hash is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_memory(storage_service):
|
||||
node = await storage_service.store_memory(
|
||||
MemoryNodeCreate(agent_id="agent-007", memory_type=MemoryType.VECTOR_DB),
|
||||
b"data1"
|
||||
)
|
||||
|
||||
data = await storage_service.retrieve_memory(node.id)
|
||||
assert isinstance(data, bytes)
|
||||
assert b"mock" in data
|
||||
92
apps/coordinator-api/tests/test_zk_memory_verification.py
Normal file
92
apps/coordinator-api/tests/test_zk_memory_verification.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from sqlmodel import Session, create_engine, SQLModel
|
||||
from sqlmodel.pool import StaticPool
|
||||
|
||||
from app.services.zk_memory_verification import ZKMemoryVerificationService
|
||||
from app.domain.decentralized_memory import AgentMemoryNode, MemoryType
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
session = Session(engine)
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_contract_service():
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def zk_service(test_db, mock_contract_service):
|
||||
return ZKMemoryVerificationService(
|
||||
session=test_db,
|
||||
contract_service=mock_contract_service
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_memory_proof(zk_service, test_db):
|
||||
node = AgentMemoryNode(
|
||||
agent_id="agent-zk",
|
||||
memory_type=MemoryType.VECTOR_DB
|
||||
)
|
||||
test_db.add(node)
|
||||
test_db.commit()
|
||||
test_db.refresh(node)
|
||||
|
||||
raw_data = b"secret_vector_data"
|
||||
|
||||
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
|
||||
|
||||
assert proof_payload is not None
|
||||
assert proof_hash.startswith("0x")
|
||||
assert "groth16" in proof_payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_retrieved_memory_success(zk_service, test_db):
|
||||
node = AgentMemoryNode(
|
||||
agent_id="agent-zk",
|
||||
memory_type=MemoryType.VECTOR_DB
|
||||
)
|
||||
test_db.add(node)
|
||||
test_db.commit()
|
||||
test_db.refresh(node)
|
||||
|
||||
raw_data = b"secret_vector_data"
|
||||
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
|
||||
|
||||
# Simulate anchoring
|
||||
node.zk_proof_hash = proof_hash
|
||||
test_db.commit()
|
||||
|
||||
# Verify
|
||||
is_valid = await zk_service.verify_retrieved_memory(node.id, raw_data, proof_payload)
|
||||
assert is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_retrieved_memory_tampered_data(zk_service, test_db):
|
||||
node = AgentMemoryNode(
|
||||
agent_id="agent-zk",
|
||||
memory_type=MemoryType.VECTOR_DB
|
||||
)
|
||||
test_db.add(node)
|
||||
test_db.commit()
|
||||
test_db.refresh(node)
|
||||
|
||||
raw_data = b"secret_vector_data"
|
||||
proof_payload, proof_hash = await zk_service.generate_memory_proof(node.id, raw_data)
|
||||
|
||||
node.zk_proof_hash = proof_hash
|
||||
test_db.commit()
|
||||
|
||||
# Tamper with data
|
||||
tampered_data = b"secret_vector_data_modified"
|
||||
|
||||
is_valid = await zk_service.verify_retrieved_memory(node.id, tampered_data, proof_payload)
|
||||
assert is_valid is False
|
||||
Reference in New Issue
Block a user