chore(security): enhance environment configuration, CI workflows, and wallet daemon with security improvements
- Restructure .env.example with security-focused documentation, service-specific environment file references, and AWS Secrets Manager integration - Update CLI tests workflow to single Python 3.13 version, add pytest-mock dependency, and consolidate test execution with coverage - Add comprehensive security validation to package publishing workflow with manual approval gates, secret scanning, and release
This commit is contained in:
@@ -0,0 +1,584 @@
|
||||
"""
|
||||
AITBC Agent Wallet Security Implementation
|
||||
|
||||
This module implements the security layer for autonomous agent wallets,
|
||||
integrating the guardian contract to prevent unlimited spending in case
|
||||
of agent compromise.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from eth_account import Account
|
||||
from eth_utils import to_checksum_address
|
||||
|
||||
from .guardian_contract import (
|
||||
GuardianContract,
|
||||
SpendingLimit,
|
||||
TimeLockConfig,
|
||||
GuardianConfig,
|
||||
create_guardian_contract,
|
||||
CONSERVATIVE_CONFIG,
|
||||
AGGRESSIVE_CONFIG,
|
||||
HIGH_SECURITY_CONFIG
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSecurityProfile:
|
||||
"""Security profile for an agent"""
|
||||
agent_address: str
|
||||
security_level: str # "conservative", "aggressive", "high_security"
|
||||
guardian_addresses: List[str]
|
||||
custom_limits: Optional[Dict] = None
|
||||
enabled: bool = True
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class AgentWalletSecurity:
|
||||
"""
|
||||
Security manager for autonomous agent wallets
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_profiles: Dict[str, AgentSecurityProfile] = {}
|
||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
||||
self.security_events: List[Dict] = []
|
||||
|
||||
# Default configurations
|
||||
self.configurations = {
|
||||
"conservative": CONSERVATIVE_CONFIG,
|
||||
"aggressive": AGGRESSIVE_CONFIG,
|
||||
"high_security": HIGH_SECURITY_CONFIG
|
||||
}
|
||||
|
||||
def register_agent(self,
|
||||
agent_address: str,
|
||||
security_level: str = "conservative",
|
||||
guardian_addresses: List[str] = None,
|
||||
custom_limits: Dict = None) -> Dict:
|
||||
"""
|
||||
Register an agent for security protection
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
security_level: Security level (conservative, aggressive, high_security)
|
||||
guardian_addresses: List of guardian addresses for recovery
|
||||
custom_limits: Custom spending limits (overrides security_level)
|
||||
|
||||
Returns:
|
||||
Registration result
|
||||
"""
|
||||
try:
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
|
||||
if agent_address in self.agent_profiles:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "Agent already registered"
|
||||
}
|
||||
|
||||
# Validate security level
|
||||
if security_level not in self.configurations:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Invalid security level: {security_level}"
|
||||
}
|
||||
|
||||
# Default guardians if none provided
|
||||
if guardian_addresses is None:
|
||||
guardian_addresses = [agent_address] # Self-guardian (should be overridden)
|
||||
|
||||
# Validate guardian addresses
|
||||
guardian_addresses = [to_checksum_address(addr) for addr in guardian_addresses]
|
||||
|
||||
# Create security profile
|
||||
profile = AgentSecurityProfile(
|
||||
agent_address=agent_address,
|
||||
security_level=security_level,
|
||||
guardian_addresses=guardian_addresses,
|
||||
custom_limits=custom_limits
|
||||
)
|
||||
|
||||
# Create guardian contract
|
||||
config = self.configurations[security_level]
|
||||
if custom_limits:
|
||||
config.update(custom_limits)
|
||||
|
||||
guardian_contract = create_guardian_contract(
|
||||
agent_address=agent_address,
|
||||
guardians=guardian_addresses,
|
||||
**config
|
||||
)
|
||||
|
||||
# Store profile and contract
|
||||
self.agent_profiles[agent_address] = profile
|
||||
self.guardian_contracts[agent_address] = guardian_contract
|
||||
|
||||
# Log security event
|
||||
self._log_security_event(
|
||||
event_type="agent_registered",
|
||||
agent_address=agent_address,
|
||||
security_level=security_level,
|
||||
guardian_count=len(guardian_addresses)
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "registered",
|
||||
"agent_address": agent_address,
|
||||
"security_level": security_level,
|
||||
"guardian_addresses": guardian_addresses,
|
||||
"limits": guardian_contract.config.limits,
|
||||
"time_lock_threshold": guardian_contract.config.time_lock.threshold,
|
||||
"registered_at": profile.created_at.isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Registration failed: {str(e)}"
|
||||
}
|
||||
|
||||
def protect_transaction(self,
|
||||
agent_address: str,
|
||||
to_address: str,
|
||||
amount: int,
|
||||
data: str = "") -> Dict:
|
||||
"""
|
||||
Protect a transaction with guardian contract
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
to_address: Recipient address
|
||||
amount: Amount to transfer
|
||||
data: Transaction data
|
||||
|
||||
Returns:
|
||||
Protection result
|
||||
"""
|
||||
try:
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
|
||||
# Check if agent is registered
|
||||
if agent_address not in self.agent_profiles:
|
||||
return {
|
||||
"status": "unprotected",
|
||||
"reason": "Agent not registered for security protection",
|
||||
"suggestion": "Register agent with register_agent() first"
|
||||
}
|
||||
|
||||
# Check if protection is enabled
|
||||
profile = self.agent_profiles[agent_address]
|
||||
if not profile.enabled:
|
||||
return {
|
||||
"status": "unprotected",
|
||||
"reason": "Security protection disabled for this agent"
|
||||
}
|
||||
|
||||
# Get guardian contract
|
||||
guardian_contract = self.guardian_contracts[agent_address]
|
||||
|
||||
# Initiate transaction protection
|
||||
result = guardian_contract.initiate_transaction(to_address, amount, data)
|
||||
|
||||
# Log security event
|
||||
self._log_security_event(
|
||||
event_type="transaction_protected",
|
||||
agent_address=agent_address,
|
||||
to_address=to_address,
|
||||
amount=amount,
|
||||
protection_status=result["status"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Transaction protection failed: {str(e)}"
|
||||
}
|
||||
|
||||
def execute_protected_transaction(self,
|
||||
agent_address: str,
|
||||
operation_id: str,
|
||||
signature: str) -> Dict:
|
||||
"""
|
||||
Execute a previously protected transaction
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
operation_id: Operation ID from protection
|
||||
signature: Transaction signature
|
||||
|
||||
Returns:
|
||||
Execution result
|
||||
"""
|
||||
try:
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
|
||||
if agent_address not in self.guardian_contracts:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "Agent not registered"
|
||||
}
|
||||
|
||||
guardian_contract = self.guardian_contracts[agent_address]
|
||||
result = guardian_contract.execute_transaction(operation_id, signature)
|
||||
|
||||
# Log security event
|
||||
if result["status"] == "executed":
|
||||
self._log_security_event(
|
||||
event_type="transaction_executed",
|
||||
agent_address=agent_address,
|
||||
operation_id=operation_id,
|
||||
transaction_hash=result.get("transaction_hash")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Transaction execution failed: {str(e)}"
|
||||
}
|
||||
|
||||
def emergency_pause_agent(self, agent_address: str, guardian_address: str) -> Dict:
|
||||
"""
|
||||
Emergency pause an agent's operations
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
guardian_address: Guardian address initiating pause
|
||||
|
||||
Returns:
|
||||
Pause result
|
||||
"""
|
||||
try:
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
guardian_address = to_checksum_address(guardian_address)
|
||||
|
||||
if agent_address not in self.guardian_contracts:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "Agent not registered"
|
||||
}
|
||||
|
||||
guardian_contract = self.guardian_contracts[agent_address]
|
||||
result = guardian_contract.emergency_pause(guardian_address)
|
||||
|
||||
# Log security event
|
||||
if result["status"] == "paused":
|
||||
self._log_security_event(
|
||||
event_type="emergency_pause",
|
||||
agent_address=agent_address,
|
||||
guardian_address=guardian_address
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Emergency pause failed: {str(e)}"
|
||||
}
|
||||
|
||||
def update_agent_security(self,
|
||||
agent_address: str,
|
||||
new_limits: Dict,
|
||||
guardian_address: str) -> Dict:
|
||||
"""
|
||||
Update security limits for an agent
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
new_limits: New spending limits
|
||||
guardian_address: Guardian address making the change
|
||||
|
||||
Returns:
|
||||
Update result
|
||||
"""
|
||||
try:
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
guardian_address = to_checksum_address(guardian_address)
|
||||
|
||||
if agent_address not in self.guardian_contracts:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "Agent not registered"
|
||||
}
|
||||
|
||||
guardian_contract = self.guardian_contracts[agent_address]
|
||||
|
||||
# Create new spending limits
|
||||
limits = SpendingLimit(
|
||||
per_transaction=new_limits.get("per_transaction", 1000),
|
||||
per_hour=new_limits.get("per_hour", 5000),
|
||||
per_day=new_limits.get("per_day", 20000),
|
||||
per_week=new_limits.get("per_week", 100000)
|
||||
)
|
||||
|
||||
result = guardian_contract.update_limits(limits, guardian_address)
|
||||
|
||||
# Log security event
|
||||
if result["status"] == "updated":
|
||||
self._log_security_event(
|
||||
event_type="security_limits_updated",
|
||||
agent_address=agent_address,
|
||||
guardian_address=guardian_address,
|
||||
new_limits=new_limits
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Security update failed: {str(e)}"
|
||||
}
|
||||
|
||||
def get_agent_security_status(self, agent_address: str) -> Dict:
|
||||
"""
|
||||
Get security status for an agent
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
|
||||
Returns:
|
||||
Security status
|
||||
"""
|
||||
try:
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
|
||||
if agent_address not in self.agent_profiles:
|
||||
return {
|
||||
"status": "not_registered",
|
||||
"message": "Agent not registered for security protection"
|
||||
}
|
||||
|
||||
profile = self.agent_profiles[agent_address]
|
||||
guardian_contract = self.guardian_contracts[agent_address]
|
||||
|
||||
return {
|
||||
"status": "protected",
|
||||
"agent_address": agent_address,
|
||||
"security_level": profile.security_level,
|
||||
"enabled": profile.enabled,
|
||||
"guardian_addresses": profile.guardian_addresses,
|
||||
"registered_at": profile.created_at.isoformat(),
|
||||
"spending_status": guardian_contract.get_spending_status(),
|
||||
"pending_operations": guardian_contract.get_pending_operations(),
|
||||
"recent_activity": guardian_contract.get_operation_history(10)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Status check failed: {str(e)}"
|
||||
}
|
||||
|
||||
def list_protected_agents(self) -> List[Dict]:
|
||||
"""List all protected agents"""
|
||||
agents = []
|
||||
|
||||
for agent_address, profile in self.agent_profiles.items():
|
||||
guardian_contract = self.guardian_contracts[agent_address]
|
||||
|
||||
agents.append({
|
||||
"agent_address": agent_address,
|
||||
"security_level": profile.security_level,
|
||||
"enabled": profile.enabled,
|
||||
"guardian_count": len(profile.guardian_addresses),
|
||||
"pending_operations": len(guardian_contract.pending_operations),
|
||||
"paused": guardian_contract.paused,
|
||||
"emergency_mode": guardian_contract.emergency_mode,
|
||||
"registered_at": profile.created_at.isoformat()
|
||||
})
|
||||
|
||||
return sorted(agents, key=lambda x: x["registered_at"], reverse=True)
|
||||
|
||||
def get_security_events(self, agent_address: str = None, limit: int = 50) -> List[Dict]:
|
||||
"""
|
||||
Get security events
|
||||
|
||||
Args:
|
||||
agent_address: Filter by agent address (optional)
|
||||
limit: Maximum number of events
|
||||
|
||||
Returns:
|
||||
Security events
|
||||
"""
|
||||
events = self.security_events
|
||||
|
||||
if agent_address:
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
events = [e for e in events if e.get("agent_address") == agent_address]
|
||||
|
||||
return sorted(events, key=lambda x: x["timestamp"], reverse=True)[:limit]
|
||||
|
||||
def _log_security_event(self, **kwargs):
|
||||
"""Log a security event"""
|
||||
event = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
**kwargs
|
||||
}
|
||||
self.security_events.append(event)
|
||||
|
||||
def disable_agent_protection(self, agent_address: str, guardian_address: str) -> Dict:
|
||||
"""
|
||||
Disable protection for an agent (guardian only)
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
guardian_address: Guardian address
|
||||
|
||||
Returns:
|
||||
Disable result
|
||||
"""
|
||||
try:
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
guardian_address = to_checksum_address(guardian_address)
|
||||
|
||||
if agent_address not in self.agent_profiles:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "Agent not registered"
|
||||
}
|
||||
|
||||
profile = self.agent_profiles[agent_address]
|
||||
|
||||
if guardian_address not in profile.guardian_addresses:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "Not authorized: not a guardian"
|
||||
}
|
||||
|
||||
profile.enabled = False
|
||||
|
||||
# Log security event
|
||||
self._log_security_event(
|
||||
event_type="protection_disabled",
|
||||
agent_address=agent_address,
|
||||
guardian_address=guardian_address
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "disabled",
|
||||
"agent_address": agent_address,
|
||||
"disabled_at": datetime.utcnow().isoformat(),
|
||||
"guardian": guardian_address
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Disable protection failed: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# Global security manager instance
|
||||
agent_wallet_security = AgentWalletSecurity()
|
||||
|
||||
|
||||
# Convenience functions for common operations
|
||||
def register_agent_for_protection(agent_address: str,
|
||||
security_level: str = "conservative",
|
||||
guardians: List[str] = None) -> Dict:
|
||||
"""Register an agent for security protection"""
|
||||
return agent_wallet_security.register_agent(
|
||||
agent_address=agent_address,
|
||||
security_level=security_level,
|
||||
guardian_addresses=guardians
|
||||
)
|
||||
|
||||
|
||||
def protect_agent_transaction(agent_address: str,
|
||||
to_address: str,
|
||||
amount: int,
|
||||
data: str = "") -> Dict:
|
||||
"""Protect a transaction for an agent"""
|
||||
return agent_wallet_security.protect_transaction(
|
||||
agent_address=agent_address,
|
||||
to_address=to_address,
|
||||
amount=amount,
|
||||
data=data
|
||||
)
|
||||
|
||||
|
||||
def get_agent_security_summary(agent_address: str) -> Dict:
|
||||
"""Get security summary for an agent"""
|
||||
return agent_wallet_security.get_agent_security_status(agent_address)
|
||||
|
||||
|
||||
# Security audit and monitoring functions
|
||||
def generate_security_report() -> Dict:
|
||||
"""Generate comprehensive security report"""
|
||||
protected_agents = agent_wallet_security.list_protected_agents()
|
||||
|
||||
total_agents = len(protected_agents)
|
||||
active_agents = len([a for a in protected_agents if a["enabled"]])
|
||||
paused_agents = len([a for a in protected_agents if a["paused"]])
|
||||
emergency_agents = len([a for a in protected_agents if a["emergency_mode"]])
|
||||
|
||||
recent_events = agent_wallet_security.get_security_events(limit=20)
|
||||
|
||||
return {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"summary": {
|
||||
"total_protected_agents": total_agents,
|
||||
"active_agents": active_agents,
|
||||
"paused_agents": paused_agents,
|
||||
"emergency_mode_agents": emergency_agents,
|
||||
"protection_coverage": f"{(active_agents / total_agents * 100):.1f}%" if total_agents > 0 else "0%"
|
||||
},
|
||||
"agents": protected_agents,
|
||||
"recent_security_events": recent_events,
|
||||
"security_levels": {
|
||||
level: len([a for a in protected_agents if a["security_level"] == level])
|
||||
for level in ["conservative", "aggressive", "high_security"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def detect_suspicious_activity(agent_address: str, hours: int = 24) -> Dict:
|
||||
"""Detect suspicious activity for an agent"""
|
||||
status = agent_wallet_security.get_agent_security_status(agent_address)
|
||||
|
||||
if status["status"] != "protected":
|
||||
return {
|
||||
"status": "not_protected",
|
||||
"suspicious_activity": False
|
||||
}
|
||||
|
||||
spending_status = status["spending_status"]
|
||||
recent_events = agent_wallet_security.get_security_events(agent_address, limit=50)
|
||||
|
||||
# Suspicious patterns
|
||||
suspicious_patterns = []
|
||||
|
||||
# Check for rapid spending
|
||||
if spending_status["spent"]["current_hour"] > spending_status["current_limits"]["per_hour"] * 0.8:
|
||||
suspicious_patterns.append("High hourly spending rate")
|
||||
|
||||
# Check for many small transactions (potential dust attack)
|
||||
recent_tx_count = len([e for e in recent_events if e["event_type"] == "transaction_executed"])
|
||||
if recent_tx_count > 20:
|
||||
suspicious_patterns.append("High transaction frequency")
|
||||
|
||||
# Check for emergency pauses
|
||||
recent_pauses = len([e for e in recent_events if e["event_type"] == "emergency_pause"])
|
||||
if recent_pauses > 0:
|
||||
suspicious_patterns.append("Recent emergency pauses detected")
|
||||
|
||||
return {
|
||||
"status": "analyzed",
|
||||
"agent_address": agent_address,
|
||||
"suspicious_activity": len(suspicious_patterns) > 0,
|
||||
"suspicious_patterns": suspicious_patterns,
|
||||
"analysis_period_hours": hours,
|
||||
"analyzed_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Fixed Guardian Configuration with Proper Guardian Setup
|
||||
Addresses the critical vulnerability where guardian lists were empty
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from eth_account import Account
|
||||
from eth_utils import to_checksum_address, keccak
|
||||
|
||||
from .guardian_contract import (
|
||||
SpendingLimit,
|
||||
TimeLockConfig,
|
||||
GuardianConfig,
|
||||
GuardianContract
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardianSetup:
|
||||
"""Guardian setup configuration"""
|
||||
primary_guardian: str # Main guardian address
|
||||
backup_guardians: List[str] # Backup guardian addresses
|
||||
multisig_threshold: int # Number of signatures required
|
||||
emergency_contacts: List[str] # Additional emergency contacts
|
||||
|
||||
|
||||
class SecureGuardianManager:
|
||||
"""
|
||||
Secure guardian management with proper initialization
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.guardian_registrations: Dict[str, GuardianSetup] = {}
|
||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
||||
|
||||
def create_guardian_setup(
|
||||
self,
|
||||
agent_address: str,
|
||||
owner_address: str,
|
||||
security_level: str = "conservative",
|
||||
custom_guardians: Optional[List[str]] = None
|
||||
) -> GuardianSetup:
|
||||
"""
|
||||
Create a proper guardian setup for an agent
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
owner_address: Owner of the agent
|
||||
security_level: Security level (conservative, aggressive, high_security)
|
||||
custom_guardians: Optional custom guardian addresses
|
||||
|
||||
Returns:
|
||||
Guardian setup configuration
|
||||
"""
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
owner_address = to_checksum_address(owner_address)
|
||||
|
||||
# Determine guardian requirements based on security level
|
||||
if security_level == "conservative":
|
||||
required_guardians = 3
|
||||
multisig_threshold = 2
|
||||
elif security_level == "aggressive":
|
||||
required_guardians = 2
|
||||
multisig_threshold = 2
|
||||
elif security_level == "high_security":
|
||||
required_guardians = 5
|
||||
multisig_threshold = 3
|
||||
else:
|
||||
raise ValueError(f"Invalid security level: {security_level}")
|
||||
|
||||
# Build guardian list
|
||||
guardians = []
|
||||
|
||||
# Always include the owner as primary guardian
|
||||
guardians.append(owner_address)
|
||||
|
||||
# Add custom guardians if provided
|
||||
if custom_guardians:
|
||||
for guardian in custom_guardians:
|
||||
guardian = to_checksum_address(guardian)
|
||||
if guardian not in guardians:
|
||||
guardians.append(guardian)
|
||||
|
||||
# Generate backup guardians if needed
|
||||
while len(guardians) < required_guardians:
|
||||
# Generate a deterministic backup guardian based on agent address
|
||||
# In production, these would be trusted service addresses
|
||||
backup_index = len(guardians) - 1 # -1 because owner is already included
|
||||
backup_guardian = self._generate_backup_guardian(agent_address, backup_index)
|
||||
|
||||
if backup_guardian not in guardians:
|
||||
guardians.append(backup_guardian)
|
||||
|
||||
# Create setup
|
||||
setup = GuardianSetup(
|
||||
primary_guardian=owner_address,
|
||||
backup_guardians=[g for g in guardians if g != owner_address],
|
||||
multisig_threshold=multisig_threshold,
|
||||
emergency_contacts=guardians.copy()
|
||||
)
|
||||
|
||||
self.guardian_registrations[agent_address] = setup
|
||||
|
||||
return setup
|
||||
|
||||
def _generate_backup_guardian(self, agent_address: str, index: int) -> str:
|
||||
"""
|
||||
Generate deterministic backup guardian address
|
||||
|
||||
In production, these would be pre-registered trusted guardian addresses
|
||||
"""
|
||||
# Create a deterministic address based on agent address and index
|
||||
seed = f"{agent_address}_{index}_backup_guardian"
|
||||
hash_result = keccak(seed.encode())
|
||||
|
||||
# Use the hash to generate a valid address
|
||||
address_bytes = hash_result[-20:] # Take last 20 bytes
|
||||
address = "0x" + address_bytes.hex()
|
||||
|
||||
return to_checksum_address(address)
|
||||
|
||||
def create_secure_guardian_contract(
|
||||
self,
|
||||
agent_address: str,
|
||||
security_level: str = "conservative",
|
||||
custom_guardians: Optional[List[str]] = None
|
||||
) -> GuardianContract:
|
||||
"""
|
||||
Create a guardian contract with proper guardian configuration
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
security_level: Security level
|
||||
custom_guardians: Optional custom guardian addresses
|
||||
|
||||
Returns:
|
||||
Configured guardian contract
|
||||
"""
|
||||
# Create guardian setup
|
||||
setup = self.create_guardian_setup(
|
||||
agent_address=agent_address,
|
||||
owner_address=agent_address, # Agent is its own owner initially
|
||||
security_level=security_level,
|
||||
custom_guardians=custom_guardians
|
||||
)
|
||||
|
||||
# Get security configuration
|
||||
config = self._get_security_config(security_level, setup)
|
||||
|
||||
# Create contract
|
||||
contract = GuardianContract(agent_address, config)
|
||||
|
||||
# Store contract
|
||||
self.guardian_contracts[agent_address] = contract
|
||||
|
||||
return contract
|
||||
|
||||
def _get_security_config(self, security_level: str, setup: GuardianSetup) -> GuardianConfig:
|
||||
"""Get security configuration with proper guardian list"""
|
||||
|
||||
# Build guardian list
|
||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
||||
|
||||
if security_level == "conservative":
|
||||
return GuardianConfig(
|
||||
limits=SpendingLimit(
|
||||
per_transaction=1000,
|
||||
per_hour=5000,
|
||||
per_day=20000,
|
||||
per_week=100000
|
||||
),
|
||||
time_lock=TimeLockConfig(
|
||||
threshold=5000,
|
||||
delay_hours=24,
|
||||
max_delay_hours=168
|
||||
),
|
||||
guardians=all_guardians,
|
||||
pause_enabled=True,
|
||||
emergency_mode=False,
|
||||
multisig_threshold=setup.multisig_threshold
|
||||
)
|
||||
|
||||
elif security_level == "aggressive":
|
||||
return GuardianConfig(
|
||||
limits=SpendingLimit(
|
||||
per_transaction=5000,
|
||||
per_hour=25000,
|
||||
per_day=100000,
|
||||
per_week=500000
|
||||
),
|
||||
time_lock=TimeLockConfig(
|
||||
threshold=20000,
|
||||
delay_hours=12,
|
||||
max_delay_hours=72
|
||||
),
|
||||
guardians=all_guardians,
|
||||
pause_enabled=True,
|
||||
emergency_mode=False,
|
||||
multisig_threshold=setup.multisig_threshold
|
||||
)
|
||||
|
||||
elif security_level == "high_security":
|
||||
return GuardianConfig(
|
||||
limits=SpendingLimit(
|
||||
per_transaction=500,
|
||||
per_hour=2000,
|
||||
per_day=8000,
|
||||
per_week=40000
|
||||
),
|
||||
time_lock=TimeLockConfig(
|
||||
threshold=2000,
|
||||
delay_hours=48,
|
||||
max_delay_hours=168
|
||||
),
|
||||
guardians=all_guardians,
|
||||
pause_enabled=True,
|
||||
emergency_mode=False,
|
||||
multisig_threshold=setup.multisig_threshold
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid security level: {security_level}")
|
||||
|
||||
def test_emergency_pause(self, agent_address: str, guardian_address: str) -> Dict:
|
||||
"""
|
||||
Test emergency pause functionality
|
||||
|
||||
Args:
|
||||
agent_address: Agent address
|
||||
guardian_address: Guardian attempting pause
|
||||
|
||||
Returns:
|
||||
Test result
|
||||
"""
|
||||
if agent_address not in self.guardian_contracts:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "Agent not registered"
|
||||
}
|
||||
|
||||
contract = self.guardian_contracts[agent_address]
|
||||
return contract.emergency_pause(guardian_address)
|
||||
|
||||
def verify_guardian_authorization(self, agent_address: str, guardian_address: str) -> bool:
|
||||
"""
|
||||
Verify if a guardian is authorized for an agent
|
||||
|
||||
Args:
|
||||
agent_address: Agent address
|
||||
guardian_address: Guardian address to verify
|
||||
|
||||
Returns:
|
||||
True if guardian is authorized
|
||||
"""
|
||||
if agent_address not in self.guardian_registrations:
|
||||
return False
|
||||
|
||||
setup = self.guardian_registrations[agent_address]
|
||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
||||
|
||||
return to_checksum_address(guardian_address) in [
|
||||
to_checksum_address(g) for g in all_guardians
|
||||
]
|
||||
|
||||
def get_guardian_summary(self, agent_address: str) -> Dict:
|
||||
"""
|
||||
Get guardian setup summary for an agent
|
||||
|
||||
Args:
|
||||
agent_address: Agent address
|
||||
|
||||
Returns:
|
||||
Guardian summary
|
||||
"""
|
||||
if agent_address not in self.guardian_registrations:
|
||||
return {"error": "Agent not registered"}
|
||||
|
||||
setup = self.guardian_registrations[agent_address]
|
||||
contract = self.guardian_contracts.get(agent_address)
|
||||
|
||||
return {
|
||||
"agent_address": agent_address,
|
||||
"primary_guardian": setup.primary_guardian,
|
||||
"backup_guardians": setup.backup_guardians,
|
||||
"total_guardians": len(setup.backup_guardians) + 1,
|
||||
"multisig_threshold": setup.multisig_threshold,
|
||||
"emergency_contacts": setup.emergency_contacts,
|
||||
"contract_status": contract.get_spending_status() if contract else None,
|
||||
"pause_functional": contract is not None and len(setup.backup_guardians) > 0
|
||||
}
|
||||
|
||||
|
||||
# Fixed security configurations with proper guardians
|
||||
def get_fixed_conservative_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
||||
"""Get fixed conservative configuration with proper guardians"""
|
||||
return GuardianConfig(
|
||||
limits=SpendingLimit(
|
||||
per_transaction=1000,
|
||||
per_hour=5000,
|
||||
per_day=20000,
|
||||
per_week=100000
|
||||
),
|
||||
time_lock=TimeLockConfig(
|
||||
threshold=5000,
|
||||
delay_hours=24,
|
||||
max_delay_hours=168
|
||||
),
|
||||
guardians=[owner_address], # At least the owner
|
||||
pause_enabled=True,
|
||||
emergency_mode=False
|
||||
)
|
||||
|
||||
|
||||
def get_fixed_aggressive_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
||||
"""Get fixed aggressive configuration with proper guardians"""
|
||||
return GuardianConfig(
|
||||
limits=SpendingLimit(
|
||||
per_transaction=5000,
|
||||
per_hour=25000,
|
||||
per_day=100000,
|
||||
per_week=500000
|
||||
),
|
||||
time_lock=TimeLockConfig(
|
||||
threshold=20000,
|
||||
delay_hours=12,
|
||||
max_delay_hours=72
|
||||
),
|
||||
guardians=[owner_address], # At least the owner
|
||||
pause_enabled=True,
|
||||
emergency_mode=False
|
||||
)
|
||||
|
||||
|
||||
def get_fixed_high_security_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
||||
"""Get fixed high security configuration with proper guardians"""
|
||||
return GuardianConfig(
|
||||
limits=SpendingLimit(
|
||||
per_transaction=500,
|
||||
per_hour=2000,
|
||||
per_day=8000,
|
||||
per_week=40000
|
||||
),
|
||||
time_lock=TimeLockConfig(
|
||||
threshold=2000,
|
||||
delay_hours=48,
|
||||
max_delay_hours=168
|
||||
),
|
||||
guardians=[owner_address], # At least the owner
|
||||
pause_enabled=True,
|
||||
emergency_mode=False
|
||||
)
|
||||
|
||||
|
||||
# Global secure guardian manager
|
||||
secure_guardian_manager = SecureGuardianManager()
|
||||
|
||||
|
||||
# Convenience function for secure agent registration
|
||||
def register_agent_with_guardians(
|
||||
agent_address: str,
|
||||
owner_address: str,
|
||||
security_level: str = "conservative",
|
||||
custom_guardians: Optional[List[str]] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Register an agent with proper guardian configuration
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
owner_address: Owner address
|
||||
security_level: Security level
|
||||
custom_guardians: Optional custom guardians
|
||||
|
||||
Returns:
|
||||
Registration result
|
||||
"""
|
||||
try:
|
||||
# Create secure guardian contract
|
||||
contract = secure_guardian_manager.create_secure_guardian_contract(
|
||||
agent_address=agent_address,
|
||||
security_level=security_level,
|
||||
custom_guardians=custom_guardians
|
||||
)
|
||||
|
||||
# Get guardian summary
|
||||
summary = secure_guardian_manager.get_guardian_summary(agent_address)
|
||||
|
||||
return {
|
||||
"status": "registered",
|
||||
"agent_address": agent_address,
|
||||
"security_level": security_level,
|
||||
"guardian_count": summary["total_guardians"],
|
||||
"multisig_threshold": summary["multisig_threshold"],
|
||||
"pause_functional": summary["pause_functional"],
|
||||
"registered_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Registration failed: {str(e)}"
|
||||
}
|
||||
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
AITBC Guardian Contract - Spending Limit Protection for Agent Wallets
|
||||
|
||||
This contract implements a spending limit guardian that protects autonomous agent
|
||||
wallets from unlimited spending in case of compromise. It provides:
|
||||
- Per-transaction spending limits
|
||||
- Per-period (daily/hourly) spending caps
|
||||
- Time-lock for large withdrawals
|
||||
- Emergency pause functionality
|
||||
- Multi-signature recovery for critical operations
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from eth_account import Account
|
||||
from eth_utils import to_checksum_address, keccak
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpendingLimit:
|
||||
"""Spending limit configuration"""
|
||||
per_transaction: int # Maximum per transaction
|
||||
per_hour: int # Maximum per hour
|
||||
per_day: int # Maximum per day
|
||||
per_week: int # Maximum per week
|
||||
|
||||
@dataclass
|
||||
class TimeLockConfig:
|
||||
"""Time lock configuration for large withdrawals"""
|
||||
threshold: int # Amount that triggers time lock
|
||||
delay_hours: int # Delay period in hours
|
||||
max_delay_hours: int # Maximum delay period
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardianConfig:
|
||||
"""Complete guardian configuration"""
|
||||
limits: SpendingLimit
|
||||
time_lock: TimeLockConfig
|
||||
guardians: List[str] # Guardian addresses for recovery
|
||||
pause_enabled: bool = True
|
||||
emergency_mode: bool = False
|
||||
|
||||
|
||||
class GuardianContract:
|
||||
"""
|
||||
Guardian contract implementation for agent wallet protection
|
||||
"""
|
||||
|
||||
def __init__(self, agent_address: str, config: GuardianConfig):
|
||||
self.agent_address = to_checksum_address(agent_address)
|
||||
self.config = config
|
||||
self.spending_history: List[Dict] = []
|
||||
self.pending_operations: Dict[str, Dict] = {}
|
||||
self.paused = False
|
||||
self.emergency_mode = False
|
||||
|
||||
# Contract state
|
||||
self.nonce = 0
|
||||
self.guardian_approvals: Dict[str, bool] = {}
|
||||
|
||||
def _get_period_key(self, timestamp: datetime, period: str) -> str:
|
||||
"""Generate period key for spending tracking"""
|
||||
if period == "hour":
|
||||
return timestamp.strftime("%Y-%m-%d-%H")
|
||||
elif period == "day":
|
||||
return timestamp.strftime("%Y-%m-%d")
|
||||
elif period == "week":
|
||||
# Get week number (Monday as first day)
|
||||
week_num = timestamp.isocalendar()[1]
|
||||
return f"{timestamp.year}-W{week_num:02d}"
|
||||
else:
|
||||
raise ValueError(f"Invalid period: {period}")
|
||||
|
||||
def _get_spent_in_period(self, period: str, timestamp: datetime = None) -> int:
|
||||
"""Calculate total spent in given period"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
period_key = self._get_period_key(timestamp, period)
|
||||
|
||||
total = 0
|
||||
for record in self.spending_history:
|
||||
record_time = datetime.fromisoformat(record["timestamp"])
|
||||
record_period = self._get_period_key(record_time, period)
|
||||
|
||||
if record_period == period_key and record["status"] == "completed":
|
||||
total += record["amount"]
|
||||
|
||||
return total
|
||||
|
||||
def _check_spending_limits(self, amount: int, timestamp: datetime = None) -> Tuple[bool, str]:
|
||||
"""Check if amount exceeds spending limits"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
# Check per-transaction limit
|
||||
if amount > self.config.limits.per_transaction:
|
||||
return False, f"Amount {amount} exceeds per-transaction limit {self.config.limits.per_transaction}"
|
||||
|
||||
# Check per-hour limit
|
||||
spent_hour = self._get_spent_in_period("hour", timestamp)
|
||||
if spent_hour + amount > self.config.limits.per_hour:
|
||||
return False, f"Hourly spending {spent_hour + amount} would exceed limit {self.config.limits.per_hour}"
|
||||
|
||||
# Check per-day limit
|
||||
spent_day = self._get_spent_in_period("day", timestamp)
|
||||
if spent_day + amount > self.config.limits.per_day:
|
||||
return False, f"Daily spending {spent_day + amount} would exceed limit {self.config.limits.per_day}"
|
||||
|
||||
# Check per-week limit
|
||||
spent_week = self._get_spent_in_period("week", timestamp)
|
||||
if spent_week + amount > self.config.limits.per_week:
|
||||
return False, f"Weekly spending {spent_week + amount} would exceed limit {self.config.limits.per_week}"
|
||||
|
||||
return True, "Spending limits check passed"
|
||||
|
||||
def _requires_time_lock(self, amount: int) -> bool:
|
||||
"""Check if amount requires time lock"""
|
||||
return amount >= self.config.time_lock.threshold
|
||||
|
||||
def _create_operation_hash(self, operation: Dict) -> str:
|
||||
"""Create hash for operation identification"""
|
||||
operation_str = json.dumps(operation, sort_keys=True)
|
||||
return keccak(operation_str.encode()).hex()
|
||||
|
||||
def initiate_transaction(self, to_address: str, amount: int, data: str = "") -> Dict:
|
||||
"""
|
||||
Initiate a transaction with guardian protection
|
||||
|
||||
Args:
|
||||
to_address: Recipient address
|
||||
amount: Amount to transfer
|
||||
data: Transaction data (optional)
|
||||
|
||||
Returns:
|
||||
Operation result with status and details
|
||||
"""
|
||||
# Check if paused
|
||||
if self.paused:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"reason": "Guardian contract is paused",
|
||||
"operation_id": None
|
||||
}
|
||||
|
||||
# Check emergency mode
|
||||
if self.emergency_mode:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"reason": "Emergency mode activated",
|
||||
"operation_id": None
|
||||
}
|
||||
|
||||
# Validate address
|
||||
try:
|
||||
to_address = to_checksum_address(to_address)
|
||||
except:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"reason": "Invalid recipient address",
|
||||
"operation_id": None
|
||||
}
|
||||
|
||||
# Check spending limits
|
||||
limits_ok, limits_reason = self._check_spending_limits(amount)
|
||||
if not limits_ok:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"reason": limits_reason,
|
||||
"operation_id": None
|
||||
}
|
||||
|
||||
# Create operation
|
||||
operation = {
|
||||
"type": "transaction",
|
||||
"to": to_address,
|
||||
"amount": amount,
|
||||
"data": data,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"nonce": self.nonce,
|
||||
"status": "pending"
|
||||
}
|
||||
|
||||
operation_id = self._create_operation_hash(operation)
|
||||
operation["operation_id"] = operation_id
|
||||
|
||||
# Check if time lock is required
|
||||
if self._requires_time_lock(amount):
|
||||
unlock_time = datetime.utcnow() + timedelta(hours=self.config.time_lock.delay_hours)
|
||||
operation["unlock_time"] = unlock_time.isoformat()
|
||||
operation["status"] = "time_locked"
|
||||
|
||||
# Store for later execution
|
||||
self.pending_operations[operation_id] = operation
|
||||
|
||||
return {
|
||||
"status": "time_locked",
|
||||
"operation_id": operation_id,
|
||||
"unlock_time": unlock_time.isoformat(),
|
||||
"delay_hours": self.config.time_lock.delay_hours,
|
||||
"message": f"Transaction requires {self.config.time_lock.delay_hours}h time lock"
|
||||
}
|
||||
|
||||
# Immediate execution for smaller amounts
|
||||
self.pending_operations[operation_id] = operation
|
||||
|
||||
return {
|
||||
"status": "approved",
|
||||
"operation_id": operation_id,
|
||||
"message": "Transaction approved for execution"
|
||||
}
|
||||
|
||||
def execute_transaction(self, operation_id: str, signature: str) -> Dict:
|
||||
"""
|
||||
Execute a previously approved transaction
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID from initiate_transaction
|
||||
signature: Transaction signature from agent
|
||||
|
||||
Returns:
|
||||
Execution result
|
||||
"""
|
||||
if operation_id not in self.pending_operations:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "Operation not found"
|
||||
}
|
||||
|
||||
operation = self.pending_operations[operation_id]
|
||||
|
||||
# Check if operation is time locked
|
||||
if operation["status"] == "time_locked":
|
||||
unlock_time = datetime.fromisoformat(operation["unlock_time"])
|
||||
if datetime.utcnow() < unlock_time:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Operation locked until {unlock_time.isoformat()}"
|
||||
}
|
||||
|
||||
operation["status"] = "ready"
|
||||
|
||||
# Verify signature (simplified - in production, use proper verification)
|
||||
try:
|
||||
# In production, verify the signature matches the agent address
|
||||
# For now, we'll assume signature is valid
|
||||
pass
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": f"Invalid signature: {str(e)}"
|
||||
}
|
||||
|
||||
# Record the transaction
|
||||
record = {
|
||||
"operation_id": operation_id,
|
||||
"to": operation["to"],
|
||||
"amount": operation["amount"],
|
||||
"data": operation.get("data", ""),
|
||||
"timestamp": operation["timestamp"],
|
||||
"executed_at": datetime.utcnow().isoformat(),
|
||||
"status": "completed",
|
||||
"nonce": operation["nonce"]
|
||||
}
|
||||
|
||||
self.spending_history.append(record)
|
||||
self.nonce += 1
|
||||
|
||||
# Remove from pending
|
||||
del self.pending_operations[operation_id]
|
||||
|
||||
return {
|
||||
"status": "executed",
|
||||
"operation_id": operation_id,
|
||||
"transaction_hash": f"0x{keccak(f'{operation_id}{signature}'.encode()).hex()}",
|
||||
"executed_at": record["executed_at"]
|
||||
}
|
||||
|
||||
def emergency_pause(self, guardian_address: str) -> Dict:
|
||||
"""
|
||||
Emergency pause function (guardian only)
|
||||
|
||||
Args:
|
||||
guardian_address: Address of guardian initiating pause
|
||||
|
||||
Returns:
|
||||
Pause result
|
||||
"""
|
||||
if guardian_address not in self.config.guardians:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"reason": "Not authorized: guardian address not recognized"
|
||||
}
|
||||
|
||||
self.paused = True
|
||||
self.emergency_mode = True
|
||||
|
||||
return {
|
||||
"status": "paused",
|
||||
"paused_at": datetime.utcnow().isoformat(),
|
||||
"guardian": guardian_address,
|
||||
"message": "Emergency pause activated - all operations halted"
|
||||
}
|
||||
|
||||
def emergency_unpause(self, guardian_signatures: List[str]) -> Dict:
|
||||
"""
|
||||
Emergency unpause function (requires multiple guardian signatures)
|
||||
|
||||
Args:
|
||||
guardian_signatures: Signatures from required guardians
|
||||
|
||||
Returns:
|
||||
Unpause result
|
||||
"""
|
||||
# In production, verify all guardian signatures
|
||||
required_signatures = len(self.config.guardians)
|
||||
if len(guardian_signatures) < required_signatures:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"reason": f"Requires {required_signatures} guardian signatures, got {len(guardian_signatures)}"
|
||||
}
|
||||
|
||||
# Verify signatures (simplified)
|
||||
# In production, verify each signature matches a guardian address
|
||||
|
||||
self.paused = False
|
||||
self.emergency_mode = False
|
||||
|
||||
return {
|
||||
"status": "unpaused",
|
||||
"unpaused_at": datetime.utcnow().isoformat(),
|
||||
"message": "Emergency pause lifted - operations resumed"
|
||||
}
|
||||
|
||||
def update_limits(self, new_limits: SpendingLimit, guardian_address: str) -> Dict:
|
||||
"""
|
||||
Update spending limits (guardian only)
|
||||
|
||||
Args:
|
||||
new_limits: New spending limits
|
||||
guardian_address: Address of guardian making the change
|
||||
|
||||
Returns:
|
||||
Update result
|
||||
"""
|
||||
if guardian_address not in self.config.guardians:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"reason": "Not authorized: guardian address not recognized"
|
||||
}
|
||||
|
||||
old_limits = self.config.limits
|
||||
self.config.limits = new_limits
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
"old_limits": old_limits,
|
||||
"new_limits": new_limits,
|
||||
"updated_at": datetime.utcnow().isoformat(),
|
||||
"guardian": guardian_address
|
||||
}
|
||||
|
||||
def get_spending_status(self) -> Dict:
|
||||
"""Get current spending status and limits"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
return {
|
||||
"agent_address": self.agent_address,
|
||||
"current_limits": self.config.limits,
|
||||
"spent": {
|
||||
"current_hour": self._get_spent_in_period("hour", now),
|
||||
"current_day": self._get_spent_in_period("day", now),
|
||||
"current_week": self._get_spent_in_period("week", now)
|
||||
},
|
||||
"remaining": {
|
||||
"current_hour": self.config.limits.per_hour - self._get_spent_in_period("hour", now),
|
||||
"current_day": self.config.limits.per_day - self._get_spent_in_period("day", now),
|
||||
"current_week": self.config.limits.per_week - self._get_spent_in_period("week", now)
|
||||
},
|
||||
"pending_operations": len(self.pending_operations),
|
||||
"paused": self.paused,
|
||||
"emergency_mode": self.emergency_mode,
|
||||
"nonce": self.nonce
|
||||
}
|
||||
|
||||
def get_operation_history(self, limit: int = 50) -> List[Dict]:
|
||||
"""Get operation history"""
|
||||
return sorted(self.spending_history, key=lambda x: x["timestamp"], reverse=True)[:limit]
|
||||
|
||||
def get_pending_operations(self) -> List[Dict]:
|
||||
"""Get all pending operations"""
|
||||
return list(self.pending_operations.values())
|
||||
|
||||
|
||||
# Factory function for creating guardian contracts
|
||||
def create_guardian_contract(
|
||||
agent_address: str,
|
||||
per_transaction: int = 1000,
|
||||
per_hour: int = 5000,
|
||||
per_day: int = 20000,
|
||||
per_week: int = 100000,
|
||||
time_lock_threshold: int = 10000,
|
||||
time_lock_delay: int = 24,
|
||||
guardians: List[str] = None
|
||||
) -> GuardianContract:
|
||||
"""
|
||||
Create a guardian contract with default security parameters
|
||||
|
||||
Args:
|
||||
agent_address: The agent wallet address to protect
|
||||
per_transaction: Maximum amount per transaction
|
||||
per_hour: Maximum amount per hour
|
||||
per_day: Maximum amount per day
|
||||
per_week: Maximum amount per week
|
||||
time_lock_threshold: Amount that triggers time lock
|
||||
time_lock_delay: Time lock delay in hours
|
||||
guardians: List of guardian addresses
|
||||
|
||||
Returns:
|
||||
Configured GuardianContract instance
|
||||
"""
|
||||
if guardians is None:
|
||||
# Default to using the agent address as guardian (should be overridden)
|
||||
guardians = [agent_address]
|
||||
|
||||
limits = SpendingLimit(
|
||||
per_transaction=per_transaction,
|
||||
per_hour=per_hour,
|
||||
per_day=per_day,
|
||||
per_week=per_week
|
||||
)
|
||||
|
||||
time_lock = TimeLockConfig(
|
||||
threshold=time_lock_threshold,
|
||||
delay_hours=time_lock_delay,
|
||||
max_delay_hours=168 # 1 week max
|
||||
)
|
||||
|
||||
config = GuardianConfig(
|
||||
limits=limits,
|
||||
time_lock=time_lock,
|
||||
guardians=[to_checksum_address(g) for g in guardians]
|
||||
)
|
||||
|
||||
return GuardianContract(agent_address, config)
|
||||
|
||||
|
||||
# Example usage and security configurations
|
||||
CONSERVATIVE_CONFIG = {
|
||||
"per_transaction": 100, # $100 per transaction
|
||||
"per_hour": 500, # $500 per hour
|
||||
"per_day": 2000, # $2,000 per day
|
||||
"per_week": 10000, # $10,000 per week
|
||||
"time_lock_threshold": 1000, # Time lock over $1,000
|
||||
"time_lock_delay": 24 # 24 hour delay
|
||||
}
|
||||
|
||||
AGGRESSIVE_CONFIG = {
|
||||
"per_transaction": 1000, # $1,000 per transaction
|
||||
"per_hour": 5000, # $5,000 per hour
|
||||
"per_day": 20000, # $20,000 per day
|
||||
"per_week": 100000, # $100,000 per week
|
||||
"time_lock_threshold": 10000, # Time lock over $10,000
|
||||
"time_lock_delay": 12 # 12 hour delay
|
||||
}
|
||||
|
||||
HIGH_SECURITY_CONFIG = {
|
||||
"per_transaction": 50, # $50 per transaction
|
||||
"per_hour": 200, # $200 per hour
|
||||
"per_day": 1000, # $1,000 per day
|
||||
"per_week": 5000, # $5,000 per week
|
||||
"time_lock_threshold": 500, # Time lock over $500
|
||||
"time_lock_delay": 48 # 48 hour delay
|
||||
}
|
||||
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
Persistent Spending Tracker - Database-Backed Security
|
||||
Fixes the critical vulnerability where spending limits were lost on restart
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import create_engine, Column, String, Integer, Float, DateTime, Index
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from eth_utils import to_checksum_address
|
||||
import json
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class SpendingRecord(Base):
|
||||
"""Database model for spending tracking"""
|
||||
__tablename__ = "spending_records"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
agent_address = Column(String, index=True)
|
||||
period_type = Column(String, index=True) # hour, day, week
|
||||
period_key = Column(String, index=True)
|
||||
amount = Column(Float)
|
||||
transaction_hash = Column(String)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Composite indexes for performance
|
||||
__table_args__ = (
|
||||
Index('idx_agent_period', 'agent_address', 'period_type', 'period_key'),
|
||||
Index('idx_timestamp', 'timestamp'),
|
||||
)
|
||||
|
||||
|
||||
class SpendingLimit(Base):
|
||||
"""Database model for spending limits"""
|
||||
__tablename__ = "spending_limits"
|
||||
|
||||
agent_address = Column(String, primary_key=True)
|
||||
per_transaction = Column(Float)
|
||||
per_hour = Column(Float)
|
||||
per_day = Column(Float)
|
||||
per_week = Column(Float)
|
||||
time_lock_threshold = Column(Float)
|
||||
time_lock_delay_hours = Column(Integer)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_by = Column(String) # Guardian who updated
|
||||
|
||||
|
||||
class GuardianAuthorization(Base):
|
||||
"""Database model for guardian authorizations"""
|
||||
__tablename__ = "guardian_authorizations"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
agent_address = Column(String, index=True)
|
||||
guardian_address = Column(String, index=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
added_at = Column(DateTime, default=datetime.utcnow)
|
||||
added_by = Column(String)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpendingCheckResult:
|
||||
"""Result of spending limit check"""
|
||||
allowed: bool
|
||||
reason: str
|
||||
current_spent: Dict[str, float]
|
||||
remaining: Dict[str, float]
|
||||
requires_time_lock: bool
|
||||
time_lock_until: Optional[datetime] = None
|
||||
|
||||
|
||||
class PersistentSpendingTracker:
|
||||
"""
|
||||
Database-backed spending tracker that survives restarts
|
||||
"""
|
||||
|
||||
def __init__(self, database_url: str = "sqlite:///spending_tracker.db"):
|
||||
self.engine = create_engine(database_url)
|
||||
Base.metadata.create_all(self.engine)
|
||||
self.SessionLocal = sessionmaker(bind=self.engine)
|
||||
|
||||
def get_session(self) -> Session:
|
||||
"""Get database session"""
|
||||
return self.SessionLocal()
|
||||
|
||||
def _get_period_key(self, timestamp: datetime, period: str) -> str:
|
||||
"""Generate period key for spending tracking"""
|
||||
if period == "hour":
|
||||
return timestamp.strftime("%Y-%m-%d-%H")
|
||||
elif period == "day":
|
||||
return timestamp.strftime("%Y-%m-%d")
|
||||
elif period == "week":
|
||||
# Get week number (Monday as first day)
|
||||
week_num = timestamp.isocalendar()[1]
|
||||
return f"{timestamp.year}-W{week_num:02d}"
|
||||
else:
|
||||
raise ValueError(f"Invalid period: {period}")
|
||||
|
||||
def get_spent_in_period(self, agent_address: str, period: str, timestamp: datetime = None) -> float:
|
||||
"""
|
||||
Get total spent in given period from database
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
period: Period type (hour, day, week)
|
||||
timestamp: Timestamp to check (default: now)
|
||||
|
||||
Returns:
|
||||
Total amount spent in period
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
period_key = self._get_period_key(timestamp, period)
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
|
||||
with self.get_session() as session:
|
||||
total = session.query(SpendingRecord).filter(
|
||||
SpendingRecord.agent_address == agent_address,
|
||||
SpendingRecord.period_type == period,
|
||||
SpendingRecord.period_key == period_key
|
||||
).with_entities(SpendingRecord.amount).all()
|
||||
|
||||
return sum(record.amount for record in total)
|
||||
|
||||
def record_spending(self, agent_address: str, amount: float, transaction_hash: str, timestamp: datetime = None) -> bool:
|
||||
"""
|
||||
Record a spending transaction in the database
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
amount: Amount spent
|
||||
transaction_hash: Transaction hash
|
||||
timestamp: Transaction timestamp (default: now)
|
||||
|
||||
Returns:
|
||||
True if recorded successfully
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
# Record for all periods
|
||||
periods = ["hour", "day", "week"]
|
||||
|
||||
for period in periods:
|
||||
period_key = self._get_period_key(timestamp, period)
|
||||
|
||||
record = SpendingRecord(
|
||||
id=f"{transaction_hash}_{period}",
|
||||
agent_address=agent_address,
|
||||
period_type=period,
|
||||
period_key=period_key,
|
||||
amount=amount,
|
||||
transaction_hash=transaction_hash,
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
session.add(record)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to record spending: {e}")
|
||||
return False
|
||||
|
||||
def check_spending_limits(self, agent_address: str, amount: float, timestamp: datetime = None) -> SpendingCheckResult:
|
||||
"""
|
||||
Check if amount exceeds spending limits using persistent data
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
amount: Amount to check
|
||||
timestamp: Timestamp for check (default: now)
|
||||
|
||||
Returns:
|
||||
Spending check result
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
|
||||
# Get spending limits from database
|
||||
with self.get_session() as session:
|
||||
limits = session.query(SpendingLimit).filter(
|
||||
SpendingLimit.agent_address == agent_address
|
||||
).first()
|
||||
|
||||
if not limits:
|
||||
# Default limits if not set
|
||||
limits = SpendingLimit(
|
||||
agent_address=agent_address,
|
||||
per_transaction=1000.0,
|
||||
per_hour=5000.0,
|
||||
per_day=20000.0,
|
||||
per_week=100000.0,
|
||||
time_lock_threshold=5000.0,
|
||||
time_lock_delay_hours=24
|
||||
)
|
||||
session.add(limits)
|
||||
session.commit()
|
||||
|
||||
# Check each limit
|
||||
current_spent = {}
|
||||
remaining = {}
|
||||
|
||||
# Per-transaction limit
|
||||
if amount > limits.per_transaction:
|
||||
return SpendingCheckResult(
|
||||
allowed=False,
|
||||
reason=f"Amount {amount} exceeds per-transaction limit {limits.per_transaction}",
|
||||
current_spent=current_spent,
|
||||
remaining=remaining,
|
||||
requires_time_lock=False
|
||||
)
|
||||
|
||||
# Per-hour limit
|
||||
spent_hour = self.get_spent_in_period(agent_address, "hour", timestamp)
|
||||
current_spent["hour"] = spent_hour
|
||||
remaining["hour"] = limits.per_hour - spent_hour
|
||||
|
||||
if spent_hour + amount > limits.per_hour:
|
||||
return SpendingCheckResult(
|
||||
allowed=False,
|
||||
reason=f"Hourly spending {spent_hour + amount} would exceed limit {limits.per_hour}",
|
||||
current_spent=current_spent,
|
||||
remaining=remaining,
|
||||
requires_time_lock=False
|
||||
)
|
||||
|
||||
# Per-day limit
|
||||
spent_day = self.get_spent_in_period(agent_address, "day", timestamp)
|
||||
current_spent["day"] = spent_day
|
||||
remaining["day"] = limits.per_day - spent_day
|
||||
|
||||
if spent_day + amount > limits.per_day:
|
||||
return SpendingCheckResult(
|
||||
allowed=False,
|
||||
reason=f"Daily spending {spent_day + amount} would exceed limit {limits.per_day}",
|
||||
current_spent=current_spent,
|
||||
remaining=remaining,
|
||||
requires_time_lock=False
|
||||
)
|
||||
|
||||
# Per-week limit
|
||||
spent_week = self.get_spent_in_period(agent_address, "week", timestamp)
|
||||
current_spent["week"] = spent_week
|
||||
remaining["week"] = limits.per_week - spent_week
|
||||
|
||||
if spent_week + amount > limits.per_week:
|
||||
return SpendingCheckResult(
|
||||
allowed=False,
|
||||
reason=f"Weekly spending {spent_week + amount} would exceed limit {limits.per_week}",
|
||||
current_spent=current_spent,
|
||||
remaining=remaining,
|
||||
requires_time_lock=False
|
||||
)
|
||||
|
||||
# Check time lock requirement
|
||||
requires_time_lock = amount >= limits.time_lock_threshold
|
||||
time_lock_until = None
|
||||
|
||||
if requires_time_lock:
|
||||
time_lock_until = timestamp + timedelta(hours=limits.time_lock_delay_hours)
|
||||
|
||||
return SpendingCheckResult(
|
||||
allowed=True,
|
||||
reason="Spending limits check passed",
|
||||
current_spent=current_spent,
|
||||
remaining=remaining,
|
||||
requires_time_lock=requires_time_lock,
|
||||
time_lock_until=time_lock_until
|
||||
)
|
||||
|
||||
def update_spending_limits(self, agent_address: str, new_limits: Dict, guardian_address: str) -> bool:
|
||||
"""
|
||||
Update spending limits for an agent
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
new_limits: New spending limits
|
||||
guardian_address: Guardian making the change
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
guardian_address = to_checksum_address(guardian_address)
|
||||
|
||||
# Verify guardian authorization
|
||||
if not self.is_guardian_authorized(agent_address, guardian_address):
|
||||
return False
|
||||
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
limits = session.query(SpendingLimit).filter(
|
||||
SpendingLimit.agent_address == agent_address
|
||||
).first()
|
||||
|
||||
if limits:
|
||||
limits.per_transaction = new_limits.get("per_transaction", limits.per_transaction)
|
||||
limits.per_hour = new_limits.get("per_hour", limits.per_hour)
|
||||
limits.per_day = new_limits.get("per_day", limits.per_day)
|
||||
limits.per_week = new_limits.get("per_week", limits.per_week)
|
||||
limits.time_lock_threshold = new_limits.get("time_lock_threshold", limits.time_lock_threshold)
|
||||
limits.time_lock_delay_hours = new_limits.get("time_lock_delay_hours", limits.time_lock_delay_hours)
|
||||
limits.updated_at = datetime.utcnow()
|
||||
limits.updated_by = guardian_address
|
||||
else:
|
||||
limits = SpendingLimit(
|
||||
agent_address=agent_address,
|
||||
per_transaction=new_limits.get("per_transaction", 1000.0),
|
||||
per_hour=new_limits.get("per_hour", 5000.0),
|
||||
per_day=new_limits.get("per_day", 20000.0),
|
||||
per_week=new_limits.get("per_week", 100000.0),
|
||||
time_lock_threshold=new_limits.get("time_lock_threshold", 5000.0),
|
||||
time_lock_delay_hours=new_limits.get("time_lock_delay_hours", 24),
|
||||
updated_at=datetime.utcnow(),
|
||||
updated_by=guardian_address
|
||||
)
|
||||
session.add(limits)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to update spending limits: {e}")
|
||||
return False
|
||||
|
||||
def add_guardian(self, agent_address: str, guardian_address: str, added_by: str) -> bool:
|
||||
"""
|
||||
Add a guardian for an agent
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
guardian_address: Guardian address
|
||||
added_by: Who added this guardian
|
||||
|
||||
Returns:
|
||||
True if added successfully
|
||||
"""
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
guardian_address = to_checksum_address(guardian_address)
|
||||
added_by = to_checksum_address(added_by)
|
||||
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
# Check if already exists
|
||||
existing = session.query(GuardianAuthorization).filter(
|
||||
GuardianAuthorization.agent_address == agent_address,
|
||||
GuardianAuthorization.guardian_address == guardian_address
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.is_active = True
|
||||
existing.added_at = datetime.utcnow()
|
||||
existing.added_by = added_by
|
||||
else:
|
||||
auth = GuardianAuthorization(
|
||||
id=f"{agent_address}_{guardian_address}",
|
||||
agent_address=agent_address,
|
||||
guardian_address=guardian_address,
|
||||
is_active=True,
|
||||
added_at=datetime.utcnow(),
|
||||
added_by=added_by
|
||||
)
|
||||
session.add(auth)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to add guardian: {e}")
|
||||
return False
|
||||
|
||||
def is_guardian_authorized(self, agent_address: str, guardian_address: str) -> bool:
|
||||
"""
|
||||
Check if a guardian is authorized for an agent
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
guardian_address: Guardian address
|
||||
|
||||
Returns:
|
||||
True if authorized
|
||||
"""
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
guardian_address = to_checksum_address(guardian_address)
|
||||
|
||||
with self.get_session() as session:
|
||||
auth = session.query(GuardianAuthorization).filter(
|
||||
GuardianAuthorization.agent_address == agent_address,
|
||||
GuardianAuthorization.guardian_address == guardian_address,
|
||||
GuardianAuthorization.is_active == True
|
||||
).first()
|
||||
|
||||
return auth is not None
|
||||
|
||||
def get_spending_summary(self, agent_address: str) -> Dict:
|
||||
"""
|
||||
Get comprehensive spending summary for an agent
|
||||
|
||||
Args:
|
||||
agent_address: Agent wallet address
|
||||
|
||||
Returns:
|
||||
Spending summary
|
||||
"""
|
||||
agent_address = to_checksum_address(agent_address)
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Get current spending
|
||||
current_spent = {
|
||||
"hour": self.get_spent_in_period(agent_address, "hour", now),
|
||||
"day": self.get_spent_in_period(agent_address, "day", now),
|
||||
"week": self.get_spent_in_period(agent_address, "week", now)
|
||||
}
|
||||
|
||||
# Get limits
|
||||
with self.get_session() as session:
|
||||
limits = session.query(SpendingLimit).filter(
|
||||
SpendingLimit.agent_address == agent_address
|
||||
).first()
|
||||
|
||||
if not limits:
|
||||
return {"error": "No spending limits set"}
|
||||
|
||||
# Calculate remaining
|
||||
remaining = {
|
||||
"hour": limits.per_hour - current_spent["hour"],
|
||||
"day": limits.per_day - current_spent["day"],
|
||||
"week": limits.per_week - current_spent["week"]
|
||||
}
|
||||
|
||||
# Get authorized guardians
|
||||
with self.get_session() as session:
|
||||
guardians = session.query(GuardianAuthorization).filter(
|
||||
GuardianAuthorization.agent_address == agent_address,
|
||||
GuardianAuthorization.is_active == True
|
||||
).all()
|
||||
|
||||
return {
|
||||
"agent_address": agent_address,
|
||||
"current_spending": current_spent,
|
||||
"remaining_spending": remaining,
|
||||
"limits": {
|
||||
"per_transaction": limits.per_transaction,
|
||||
"per_hour": limits.per_hour,
|
||||
"per_day": limits.per_day,
|
||||
"per_week": limits.per_week
|
||||
},
|
||||
"time_lock": {
|
||||
"threshold": limits.time_lock_threshold,
|
||||
"delay_hours": limits.time_lock_delay_hours
|
||||
},
|
||||
"authorized_guardians": [g.guardian_address for g in guardians],
|
||||
"last_updated": limits.updated_at.isoformat() if limits.updated_at else None
|
||||
}
|
||||
|
||||
|
||||
# Global persistent tracker instance
|
||||
persistent_tracker = PersistentSpendingTracker()
|
||||
@@ -0,0 +1,706 @@
|
||||
"""
|
||||
Multi-Modal WebSocket Fusion Service
|
||||
|
||||
Advanced WebSocket stream architecture for multi-modal fusion with
|
||||
per-stream backpressure handling and GPU provider flow control.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Any, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from uuid import uuid4
|
||||
|
||||
from aitbc.logging import get_logger
|
||||
from .websocket_stream_manager import (
|
||||
WebSocketStreamManager, StreamConfig, MessageType,
|
||||
stream_manager, WebSocketStream
|
||||
)
|
||||
from .gpu_multimodal import GPUMultimodalProcessor
|
||||
from .multi_modal_fusion import MultiModalFusionService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FusionStreamType(Enum):
|
||||
"""Types of fusion streams"""
|
||||
VISUAL = "visual"
|
||||
TEXT = "text"
|
||||
AUDIO = "audio"
|
||||
SENSOR = "sensor"
|
||||
CONTROL = "control"
|
||||
METRICS = "metrics"
|
||||
|
||||
|
||||
class GPUProviderStatus(Enum):
|
||||
"""GPU provider status"""
|
||||
AVAILABLE = "available"
|
||||
BUSY = "busy"
|
||||
SLOW = "slow"
|
||||
OVERLOADED = "overloaded"
|
||||
OFFLINE = "offline"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusionStreamConfig:
|
||||
"""Configuration for fusion streams"""
|
||||
stream_type: FusionStreamType
|
||||
max_queue_size: int = 500
|
||||
gpu_timeout: float = 2.0
|
||||
fusion_timeout: float = 5.0
|
||||
batch_size: int = 8
|
||||
enable_gpu_acceleration: bool = True
|
||||
priority: int = 1 # Higher number = higher priority
|
||||
|
||||
def to_stream_config(self) -> StreamConfig:
|
||||
"""Convert to WebSocket stream config"""
|
||||
return StreamConfig(
|
||||
max_queue_size=self.max_queue_size,
|
||||
send_timeout=self.fusion_timeout,
|
||||
heartbeat_interval=30.0,
|
||||
slow_consumer_threshold=0.5,
|
||||
backpressure_threshold=0.7,
|
||||
drop_bulk_threshold=0.85,
|
||||
enable_compression=True,
|
||||
priority_send=True
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusionData:
|
||||
"""Multi-modal fusion data"""
|
||||
stream_id: str
|
||||
stream_type: FusionStreamType
|
||||
data: Any
|
||||
timestamp: float
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
requires_gpu: bool = False
|
||||
processing_priority: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUProviderMetrics:
|
||||
"""GPU provider performance metrics"""
|
||||
provider_id: str
|
||||
status: GPUProviderStatus
|
||||
avg_processing_time: float
|
||||
queue_size: int
|
||||
gpu_utilization: float
|
||||
memory_usage: float
|
||||
error_rate: float
|
||||
last_update: float
|
||||
|
||||
|
||||
class GPUProviderFlowControl:
|
||||
"""Flow control for GPU providers"""
|
||||
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
self.metrics = GPUProviderMetrics(
|
||||
provider_id=provider_id,
|
||||
status=GPUProviderStatus.AVAILABLE,
|
||||
avg_processing_time=0.0,
|
||||
queue_size=0,
|
||||
gpu_utilization=0.0,
|
||||
memory_usage=0.0,
|
||||
error_rate=0.0,
|
||||
last_update=time.time()
|
||||
)
|
||||
|
||||
# Flow control queues
|
||||
self.input_queue = asyncio.Queue(maxsize=100)
|
||||
self.output_queue = asyncio.Queue(maxsize=100)
|
||||
self.control_queue = asyncio.Queue(maxsize=50)
|
||||
|
||||
# Flow control parameters
|
||||
self.max_concurrent_requests = 4
|
||||
self.current_requests = 0
|
||||
self.slow_threshold = 2.0 # seconds
|
||||
self.overload_threshold = 0.8 # queue fill ratio
|
||||
|
||||
# Performance tracking
|
||||
self.request_times = []
|
||||
self.error_count = 0
|
||||
self.total_requests = 0
|
||||
|
||||
# Flow control task
|
||||
self._flow_control_task = None
|
||||
self._running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start flow control"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._flow_control_task = asyncio.create_task(self._flow_control_loop())
|
||||
logger.info(f"GPU provider flow control started: {self.provider_id}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop flow control"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._flow_control_task:
|
||||
self._flow_control_task.cancel()
|
||||
try:
|
||||
await self._flow_control_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info(f"GPU provider flow control stopped: {self.provider_id}")
|
||||
|
||||
async def submit_request(self, data: FusionData) -> Optional[str]:
|
||||
"""Submit request with flow control"""
|
||||
if not self._running:
|
||||
return None
|
||||
|
||||
# Check provider status
|
||||
if self.metrics.status == GPUProviderStatus.OFFLINE:
|
||||
logger.warning(f"GPU provider {self.provider_id} is offline")
|
||||
return None
|
||||
|
||||
# Check backpressure
|
||||
if self.input_queue.qsize() / self.input_queue.maxsize > self.overload_threshold:
|
||||
self.metrics.status = GPUProviderStatus.OVERLOADED
|
||||
logger.warning(f"GPU provider {self.provider_id} is overloaded")
|
||||
return None
|
||||
|
||||
# Submit request
|
||||
request_id = str(uuid4())
|
||||
request_data = {
|
||||
"request_id": request_id,
|
||||
"data": data,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.input_queue.put(request_data),
|
||||
timeout=1.0
|
||||
)
|
||||
return request_id
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Request timeout for GPU provider {self.provider_id}")
|
||||
return None
|
||||
|
||||
async def get_result(self, request_id: str, timeout: float = 5.0) -> Optional[Any]:
|
||||
"""Get processing result"""
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
# Check output queue
|
||||
result = await asyncio.wait_for(
|
||||
self.output_queue.get(),
|
||||
timeout=0.1
|
||||
)
|
||||
|
||||
if result.get("request_id") == request_id:
|
||||
return result.get("data")
|
||||
|
||||
# Put back if not our result
|
||||
await self.output_queue.put(result)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
async def _flow_control_loop(self):
|
||||
"""Main flow control loop"""
|
||||
while self._running:
|
||||
try:
|
||||
# Get next request
|
||||
request_data = await asyncio.wait_for(
|
||||
self.input_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
# Check concurrent request limit
|
||||
if self.current_requests >= self.max_concurrent_requests:
|
||||
# Re-queue request
|
||||
await self.input_queue.put(request_data)
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Process request
|
||||
self.current_requests += 1
|
||||
self.total_requests += 1
|
||||
|
||||
asyncio.create_task(self._process_request(request_data))
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Flow control error for {self.provider_id}: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _process_request(self, request_data: Dict[str, Any]):
|
||||
"""Process individual request"""
|
||||
request_id = request_data["request_id"]
|
||||
data: FusionData = request_data["data"]
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Simulate GPU processing
|
||||
if data.requires_gpu:
|
||||
# Simulate GPU processing time
|
||||
processing_time = np.random.uniform(0.5, 3.0)
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
# Simulate GPU result
|
||||
result = {
|
||||
"processed_data": f"gpu_processed_{data.stream_type}",
|
||||
"processing_time": processing_time,
|
||||
"gpu_utilization": np.random.uniform(0.3, 0.9),
|
||||
"memory_usage": np.random.uniform(0.4, 0.8)
|
||||
}
|
||||
else:
|
||||
# CPU processing
|
||||
processing_time = np.random.uniform(0.1, 0.5)
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
result = {
|
||||
"processed_data": f"cpu_processed_{data.stream_type}",
|
||||
"processing_time": processing_time
|
||||
}
|
||||
|
||||
# Update metrics
|
||||
actual_time = time.time() - start_time
|
||||
self._update_metrics(actual_time, success=True)
|
||||
|
||||
# Send result
|
||||
await self.output_queue.put({
|
||||
"request_id": request_id,
|
||||
"data": result,
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request processing error for {self.provider_id}: {e}")
|
||||
self._update_metrics(time.time() - start_time, success=False)
|
||||
|
||||
# Send error result
|
||||
await self.output_queue.put({
|
||||
"request_id": request_id,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
finally:
|
||||
self.current_requests -= 1
|
||||
|
||||
def _update_metrics(self, processing_time: float, success: bool):
|
||||
"""Update provider metrics"""
|
||||
# Update processing time
|
||||
self.request_times.append(processing_time)
|
||||
if len(self.request_times) > 100:
|
||||
self.request_times.pop(0)
|
||||
|
||||
self.metrics.avg_processing_time = np.mean(self.request_times)
|
||||
|
||||
# Update error rate
|
||||
if not success:
|
||||
self.error_count += 1
|
||||
|
||||
self.metrics.error_rate = self.error_count / max(self.total_requests, 1)
|
||||
|
||||
# Update queue sizes
|
||||
self.metrics.queue_size = self.input_queue.qsize()
|
||||
|
||||
# Update status
|
||||
if self.metrics.error_rate > 0.1:
|
||||
self.metrics.status = GPUProviderStatus.OFFLINE
|
||||
elif self.metrics.avg_processing_time > self.slow_threshold:
|
||||
self.metrics.status = GPUProviderStatus.SLOW
|
||||
elif self.metrics.queue_size > self.input_queue.maxsize * 0.8:
|
||||
self.metrics.status = GPUProviderStatus.OVERLOADED
|
||||
elif self.current_requests >= self.max_concurrent_requests:
|
||||
self.metrics.status = GPUProviderStatus.BUSY
|
||||
else:
|
||||
self.metrics.status = GPUProviderStatus.AVAILABLE
|
||||
|
||||
self.metrics.last_update = time.time()
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get provider metrics"""
|
||||
return {
|
||||
"provider_id": self.provider_id,
|
||||
"status": self.metrics.status.value,
|
||||
"avg_processing_time": self.metrics.avg_processing_time,
|
||||
"queue_size": self.metrics.queue_size,
|
||||
"current_requests": self.current_requests,
|
||||
"max_concurrent_requests": self.max_concurrent_requests,
|
||||
"error_rate": self.metrics.error_rate,
|
||||
"total_requests": self.total_requests,
|
||||
"last_update": self.metrics.last_update
|
||||
}
|
||||
|
||||
|
||||
class MultiModalWebSocketFusion:
|
||||
"""Multi-modal fusion service with WebSocket streaming and backpressure control"""
|
||||
|
||||
def __init__(self):
|
||||
self.stream_manager = stream_manager
|
||||
self.fusion_service = None # Will be injected
|
||||
self.gpu_providers: Dict[str, GPUProviderFlowControl] = {}
|
||||
|
||||
# Fusion streams
|
||||
self.fusion_streams: Dict[str, FusionStreamConfig] = {}
|
||||
self.active_fusions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Performance metrics
|
||||
self.fusion_metrics = {
|
||||
"total_fusions": 0,
|
||||
"successful_fusions": 0,
|
||||
"failed_fusions": 0,
|
||||
"avg_fusion_time": 0.0,
|
||||
"gpu_utilization": 0.0,
|
||||
"memory_usage": 0.0
|
||||
}
|
||||
|
||||
# Backpressure control
|
||||
self.backpressure_enabled = True
|
||||
self.global_queue_size = 0
|
||||
self.max_global_queue_size = 10000
|
||||
|
||||
# Running state
|
||||
self._running = False
|
||||
self._monitor_task = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the fusion service"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Start stream manager
|
||||
await self.stream_manager.start()
|
||||
|
||||
# Initialize GPU providers
|
||||
await self._initialize_gpu_providers()
|
||||
|
||||
# Start monitoring
|
||||
self._monitor_task = asyncio.create_task(self._monitor_loop())
|
||||
|
||||
logger.info("Multi-Modal WebSocket Fusion started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the fusion service"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Stop GPU providers
|
||||
for provider in self.gpu_providers.values():
|
||||
await provider.stop()
|
||||
|
||||
# Stop stream manager
|
||||
await self.stream_manager.stop()
|
||||
|
||||
# Stop monitoring
|
||||
if self._monitor_task:
|
||||
self._monitor_task.cancel()
|
||||
try:
|
||||
await self._monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Multi-Modal WebSocket Fusion stopped")
|
||||
|
||||
async def register_fusion_stream(self, stream_id: str, config: FusionStreamConfig):
|
||||
"""Register a fusion stream"""
|
||||
self.fusion_streams[stream_id] = config
|
||||
logger.info(f"Registered fusion stream: {stream_id} ({config.stream_type.value})")
|
||||
|
||||
async def handle_websocket_connection(self, websocket, stream_id: str,
|
||||
stream_type: FusionStreamType):
|
||||
"""Handle WebSocket connection for fusion stream"""
|
||||
config = FusionStreamConfig(
|
||||
stream_type=stream_type,
|
||||
max_queue_size=500,
|
||||
gpu_timeout=2.0,
|
||||
fusion_timeout=5.0
|
||||
)
|
||||
|
||||
async with self.stream_manager.manage_stream(websocket, config.to_stream_config()) as stream:
|
||||
logger.info(f"Fusion stream connected: {stream_id} ({stream_type.value})")
|
||||
|
||||
try:
|
||||
# Handle incoming messages
|
||||
async for message in websocket:
|
||||
await self._handle_stream_message(stream_id, stream_type, message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fusion stream {stream_id}: {e}")
|
||||
|
||||
async def _handle_stream_message(self, stream_id: str, stream_type: FusionStreamType,
|
||||
message: str):
|
||||
"""Handle incoming stream message"""
|
||||
try:
|
||||
data = json.loads(message)
|
||||
|
||||
# Create fusion data
|
||||
fusion_data = FusionData(
|
||||
stream_id=stream_id,
|
||||
stream_type=stream_type,
|
||||
data=data.get("data"),
|
||||
timestamp=time.time(),
|
||||
metadata=data.get("metadata", {}),
|
||||
requires_gpu=data.get("requires_gpu", False),
|
||||
processing_priority=data.get("priority", 1)
|
||||
)
|
||||
|
||||
# Submit to GPU provider if needed
|
||||
if fusion_data.requires_gpu:
|
||||
await self._submit_to_gpu_provider(fusion_data)
|
||||
else:
|
||||
await self._process_cpu_fusion(fusion_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling stream message: {e}")
|
||||
|
||||
async def _submit_to_gpu_provider(self, fusion_data: FusionData):
|
||||
"""Submit fusion data to GPU provider"""
|
||||
# Select best GPU provider
|
||||
provider_id = await self._select_gpu_provider(fusion_data)
|
||||
|
||||
if not provider_id:
|
||||
logger.warning("No available GPU providers")
|
||||
await self._handle_fusion_error(fusion_data, "No GPU providers available")
|
||||
return
|
||||
|
||||
provider = self.gpu_providers[provider_id]
|
||||
|
||||
# Submit request
|
||||
request_id = await provider.submit_request(fusion_data)
|
||||
|
||||
if not request_id:
|
||||
await self._handle_fusion_error(fusion_data, "GPU provider overloaded")
|
||||
return
|
||||
|
||||
# Wait for result
|
||||
result = await provider.get_result(request_id, timeout=5.0)
|
||||
|
||||
if result and "error" not in result:
|
||||
await self._handle_fusion_result(fusion_data, result)
|
||||
else:
|
||||
error = result.get("error", "Unknown error") if result else "Timeout"
|
||||
await self._handle_fusion_error(fusion_data, error)
|
||||
|
||||
async def _process_cpu_fusion(self, fusion_data: FusionData):
|
||||
"""Process fusion data on CPU"""
|
||||
try:
|
||||
# Simulate CPU fusion processing
|
||||
processing_time = np.random.uniform(0.1, 0.5)
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
result = {
|
||||
"processed_data": f"cpu_fused_{fusion_data.stream_type}",
|
||||
"processing_time": processing_time,
|
||||
"fusion_type": "cpu"
|
||||
}
|
||||
|
||||
await self._handle_fusion_result(fusion_data, result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CPU fusion error: {e}")
|
||||
await self._handle_fusion_error(fusion_data, str(e))
|
||||
|
||||
async def _handle_fusion_result(self, fusion_data: FusionData, result: Dict[str, Any]):
|
||||
"""Handle successful fusion result"""
|
||||
# Update metrics
|
||||
self.fusion_metrics["total_fusions"] += 1
|
||||
self.fusion_metrics["successful_fusions"] += 1
|
||||
|
||||
# Broadcast result
|
||||
broadcast_data = {
|
||||
"type": "fusion_result",
|
||||
"stream_id": fusion_data.stream_id,
|
||||
"stream_type": fusion_data.stream_type.value,
|
||||
"result": result,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
await self.stream_manager.broadcast_to_all(broadcast_data, MessageType.IMPORTANT)
|
||||
|
||||
logger.info(f"Fusion completed for {fusion_data.stream_id}")
|
||||
|
||||
async def _handle_fusion_error(self, fusion_data: FusionData, error: str):
|
||||
"""Handle fusion error"""
|
||||
# Update metrics
|
||||
self.fusion_metrics["total_fusions"] += 1
|
||||
self.fusion_metrics["failed_fusions"] += 1
|
||||
|
||||
# Broadcast error
|
||||
error_data = {
|
||||
"type": "fusion_error",
|
||||
"stream_id": fusion_data.stream_id,
|
||||
"stream_type": fusion_data.stream_type.value,
|
||||
"error": error,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
await self.stream_manager.broadcast_to_all(error_data, MessageType.CRITICAL)
|
||||
|
||||
logger.error(f"Fusion error for {fusion_data.stream_id}: {error}")
|
||||
|
||||
async def _select_gpu_provider(self, fusion_data: FusionData) -> Optional[str]:
|
||||
"""Select best GPU provider based on load and performance"""
|
||||
available_providers = []
|
||||
|
||||
for provider_id, provider in self.gpu_providers.items():
|
||||
metrics = provider.get_metrics()
|
||||
|
||||
# Check if provider is available
|
||||
if metrics["status"] == GPUProviderStatus.AVAILABLE.value:
|
||||
available_providers.append((provider_id, metrics))
|
||||
|
||||
if not available_providers:
|
||||
return None
|
||||
|
||||
# Select provider with lowest queue size and processing time
|
||||
best_provider = min(
|
||||
available_providers,
|
||||
key=lambda x: (x[1]["queue_size"], x[1]["avg_processing_time"])
|
||||
)
|
||||
|
||||
return best_provider[0]
|
||||
|
||||
async def _initialize_gpu_providers(self):
|
||||
"""Initialize GPU providers"""
|
||||
# Create mock GPU providers
|
||||
provider_configs = [
|
||||
{"provider_id": "gpu_1", "max_concurrent": 4},
|
||||
{"provider_id": "gpu_2", "max_concurrent": 2},
|
||||
{"provider_id": "gpu_3", "max_concurrent": 6}
|
||||
]
|
||||
|
||||
for config in provider_configs:
|
||||
provider = GPUProviderFlowControl(config["provider_id"])
|
||||
provider.max_concurrent_requests = config["max_concurrent"]
|
||||
await provider.start()
|
||||
self.gpu_providers[config["provider_id"]] = provider
|
||||
|
||||
logger.info(f"Initialized {len(self.gpu_providers)} GPU providers")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""Monitor system performance and backpressure"""
|
||||
while self._running:
|
||||
try:
|
||||
# Update global metrics
|
||||
await self._update_global_metrics()
|
||||
|
||||
# Check backpressure
|
||||
if self.backpressure_enabled:
|
||||
await self._check_backpressure()
|
||||
|
||||
# Monitor GPU providers
|
||||
await self._monitor_gpu_providers()
|
||||
|
||||
# Sleep
|
||||
await asyncio.sleep(10) # Monitor every 10 seconds
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Monitor loop error: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _update_global_metrics(self):
|
||||
"""Update global performance metrics"""
|
||||
# Get stream manager metrics
|
||||
manager_metrics = self.stream_manager.get_manager_metrics()
|
||||
|
||||
# Update global queue size
|
||||
self.global_queue_size = manager_metrics["total_queue_size"]
|
||||
|
||||
# Calculate GPU utilization
|
||||
total_gpu_util = 0
|
||||
total_memory = 0
|
||||
active_providers = 0
|
||||
|
||||
for provider in self.gpu_providers.values():
|
||||
metrics = provider.get_metrics()
|
||||
if metrics["status"] != GPUProviderStatus.OFFLINE.value:
|
||||
total_gpu_util += metrics.get("gpu_utilization", 0)
|
||||
total_memory += metrics.get("memory_usage", 0)
|
||||
active_providers += 1
|
||||
|
||||
if active_providers > 0:
|
||||
self.fusion_metrics["gpu_utilization"] = total_gpu_util / active_providers
|
||||
self.fusion_metrics["memory_usage"] = total_memory / active_providers
|
||||
|
||||
async def _check_backpressure(self):
|
||||
"""Check and handle backpressure"""
|
||||
if self.global_queue_size > self.max_global_queue_size * 0.8:
|
||||
logger.warning("High backpressure detected, applying flow control")
|
||||
|
||||
# Get slow streams
|
||||
slow_streams = self.stream_manager.get_slow_streams(threshold=0.8)
|
||||
|
||||
# Handle slow streams
|
||||
for stream_id in slow_streams:
|
||||
await self.stream_manager.handle_slow_consumer(stream_id, "throttle")
|
||||
|
||||
async def _monitor_gpu_providers(self):
|
||||
"""Monitor GPU provider health"""
|
||||
for provider_id, provider in self.gpu_providers.items():
|
||||
metrics = provider.get_metrics()
|
||||
|
||||
# Check for unhealthy providers
|
||||
if metrics["status"] == GPUProviderStatus.OFFLINE.value:
|
||||
logger.warning(f"GPU provider {provider_id} is offline")
|
||||
|
||||
elif metrics["error_rate"] > 0.1:
|
||||
logger.warning(f"GPU provider {provider_id} has high error rate: {metrics['error_rate']}")
|
||||
|
||||
elif metrics["avg_processing_time"] > 5.0:
|
||||
logger.warning(f"GPU provider {provider_id} is slow: {metrics['avg_processing_time']}s")
|
||||
|
||||
def get_comprehensive_metrics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive system metrics"""
|
||||
# Get stream manager metrics
|
||||
stream_metrics = self.stream_manager.get_manager_metrics()
|
||||
|
||||
# Get GPU provider metrics
|
||||
gpu_metrics = {}
|
||||
for provider_id, provider in self.gpu_providers.items():
|
||||
gpu_metrics[provider_id] = provider.get_metrics()
|
||||
|
||||
# Get fusion metrics
|
||||
fusion_metrics = self.fusion_metrics.copy()
|
||||
|
||||
# Calculate success rate
|
||||
if fusion_metrics["total_fusions"] > 0:
|
||||
fusion_metrics["success_rate"] = (
|
||||
fusion_metrics["successful_fusions"] / fusion_metrics["total_fusions"]
|
||||
)
|
||||
else:
|
||||
fusion_metrics["success_rate"] = 0.0
|
||||
|
||||
return {
|
||||
"timestamp": time.time(),
|
||||
"system_status": "running" if self._running else "stopped",
|
||||
"backpressure_enabled": self.backpressure_enabled,
|
||||
"global_queue_size": self.global_queue_size,
|
||||
"max_global_queue_size": self.max_global_queue_size,
|
||||
"stream_metrics": stream_metrics,
|
||||
"gpu_metrics": gpu_metrics,
|
||||
"fusion_metrics": fusion_metrics,
|
||||
"active_fusion_streams": len(self.fusion_streams),
|
||||
"registered_gpu_providers": len(self.gpu_providers)
|
||||
}
|
||||
|
||||
|
||||
# Global fusion service instance
|
||||
multimodal_fusion_service = MultiModalWebSocketFusion()
|
||||
420
apps/coordinator-api/src/app/services/secure_wallet_service.py
Normal file
420
apps/coordinator-api/src/app/services/secure_wallet_service.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Secure Wallet Service - Fixed Version
|
||||
Implements proper Ethereum cryptography and secure key storage
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Dict
|
||||
from sqlalchemy import select
|
||||
from sqlmodel import Session
|
||||
from datetime import datetime
|
||||
import secrets
|
||||
|
||||
from ..domain.wallet import (
|
||||
AgentWallet, NetworkConfig, TokenBalance, WalletTransaction,
|
||||
WalletType, TransactionStatus
|
||||
)
|
||||
from ..schemas.wallet import WalletCreate, TransactionRequest
|
||||
from ..blockchain.contract_interactions import ContractInteractionService
|
||||
|
||||
# Import our fixed crypto utilities
|
||||
from .wallet_crypto import (
|
||||
generate_ethereum_keypair,
|
||||
verify_keypair_consistency,
|
||||
encrypt_private_key,
|
||||
decrypt_private_key,
|
||||
validate_private_key_format,
|
||||
create_secure_wallet,
|
||||
recover_wallet
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecureWalletService:
|
||||
"""Secure wallet service with proper cryptography and key management"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
contract_service: ContractInteractionService
|
||||
):
|
||||
self.session = session
|
||||
self.contract_service = contract_service
|
||||
|
||||
async def create_wallet(self, request: WalletCreate, encryption_password: str) -> AgentWallet:
|
||||
"""
|
||||
Create a new wallet with proper security
|
||||
|
||||
Args:
|
||||
request: Wallet creation request
|
||||
encryption_password: Strong password for private key encryption
|
||||
|
||||
Returns:
|
||||
Created wallet record
|
||||
|
||||
Raises:
|
||||
ValueError: If password is weak or wallet already exists
|
||||
"""
|
||||
# Validate password strength
|
||||
from ..utils.security import validate_password_strength
|
||||
password_validation = validate_password_strength(encryption_password)
|
||||
|
||||
if not password_validation["is_acceptable"]:
|
||||
raise ValueError(
|
||||
f"Password too weak: {', '.join(password_validation['issues'])}"
|
||||
)
|
||||
|
||||
# Check if agent already has an active wallet of this type
|
||||
existing = self.session.exec(
|
||||
select(AgentWallet).where(
|
||||
AgentWallet.agent_id == request.agent_id,
|
||||
AgentWallet.wallet_type == request.wallet_type,
|
||||
AgentWallet.is_active == True
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise ValueError(f"Agent {request.agent_id} already has an active {request.wallet_type} wallet")
|
||||
|
||||
try:
|
||||
# Generate proper Ethereum keypair
|
||||
private_key, public_key, address = generate_ethereum_keypair()
|
||||
|
||||
# Verify keypair consistency
|
||||
if not verify_keypair_consistency(private_key, address):
|
||||
raise RuntimeError("Keypair generation failed consistency check")
|
||||
|
||||
# Encrypt private key securely
|
||||
encrypted_data = encrypt_private_key(private_key, encryption_password)
|
||||
|
||||
# Create wallet record
|
||||
wallet = AgentWallet(
|
||||
agent_id=request.agent_id,
|
||||
address=address,
|
||||
public_key=public_key,
|
||||
wallet_type=request.wallet_type,
|
||||
metadata=request.metadata,
|
||||
encrypted_private_key=encrypted_data,
|
||||
encryption_version="1.0",
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(wallet)
|
||||
self.session.commit()
|
||||
self.session.refresh(wallet)
|
||||
|
||||
logger.info(f"Created secure wallet {wallet.address} for agent {request.agent_id}")
|
||||
return wallet
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create secure wallet: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
async def get_wallet_by_agent(self, agent_id: str) -> List[AgentWallet]:
|
||||
"""Retrieve all active wallets for an agent"""
|
||||
return self.session.exec(
|
||||
select(AgentWallet).where(
|
||||
AgentWallet.agent_id == agent_id,
|
||||
AgentWallet.is_active == True
|
||||
)
|
||||
).all()
|
||||
|
||||
async def get_wallet_with_private_key(
|
||||
self,
|
||||
wallet_id: int,
|
||||
encryption_password: str
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Get wallet with decrypted private key (for signing operations)
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
encryption_password: Password for decryption
|
||||
|
||||
Returns:
|
||||
Wallet keys including private key
|
||||
|
||||
Raises:
|
||||
ValueError: If decryption fails or wallet not found
|
||||
"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
raise ValueError("Wallet not found")
|
||||
|
||||
if not wallet.is_active:
|
||||
raise ValueError("Wallet is not active")
|
||||
|
||||
try:
|
||||
# Decrypt private key
|
||||
if isinstance(wallet.encrypted_private_key, dict):
|
||||
# New format
|
||||
keys = recover_wallet(wallet.encrypted_private_key, encryption_password)
|
||||
else:
|
||||
# Legacy format - cannot decrypt securely
|
||||
raise ValueError(
|
||||
"Wallet uses legacy encryption format. "
|
||||
"Please migrate to secure encryption."
|
||||
)
|
||||
|
||||
return {
|
||||
"wallet_id": wallet_id,
|
||||
"address": wallet.address,
|
||||
"private_key": keys["private_key"],
|
||||
"public_key": keys["public_key"],
|
||||
"agent_id": wallet.agent_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt wallet {wallet_id}: {e}")
|
||||
raise ValueError(f"Failed to access wallet: {str(e)}")
|
||||
|
||||
async def verify_wallet_integrity(self, wallet_id: int) -> Dict[str, bool]:
|
||||
"""
|
||||
Verify wallet cryptographic integrity
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
|
||||
Returns:
|
||||
Integrity check results
|
||||
"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
return {"exists": False}
|
||||
|
||||
results = {
|
||||
"exists": True,
|
||||
"active": wallet.is_active,
|
||||
"has_encrypted_key": bool(wallet.encrypted_private_key),
|
||||
"address_format_valid": False,
|
||||
"public_key_present": bool(wallet.public_key)
|
||||
}
|
||||
|
||||
# Validate address format
|
||||
try:
|
||||
from eth_utils import to_checksum_address
|
||||
to_checksum_address(wallet.address)
|
||||
results["address_format_valid"] = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if we can verify the keypair consistency
|
||||
# (We can't do this without the password, but we can check the format)
|
||||
if wallet.public_key and wallet.encrypted_private_key:
|
||||
results["has_keypair_data"] = True
|
||||
|
||||
return results
|
||||
|
||||
async def migrate_wallet_encryption(
|
||||
self,
|
||||
wallet_id: int,
|
||||
old_password: str,
|
||||
new_password: str
|
||||
) -> AgentWallet:
|
||||
"""
|
||||
Migrate wallet from old encryption to new secure encryption
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
old_password: Current password
|
||||
new_password: New strong password
|
||||
|
||||
Returns:
|
||||
Updated wallet
|
||||
"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
raise ValueError("Wallet not found")
|
||||
|
||||
try:
|
||||
# Get current private key
|
||||
current_keys = await self.get_wallet_with_private_key(wallet_id, old_password)
|
||||
|
||||
# Validate new password
|
||||
from ..utils.security import validate_password_strength
|
||||
password_validation = validate_password_strength(new_password)
|
||||
|
||||
if not password_validation["is_acceptable"]:
|
||||
raise ValueError(
|
||||
f"New password too weak: {', '.join(password_validation['issues'])}"
|
||||
)
|
||||
|
||||
# Re-encrypt with new password
|
||||
new_encrypted_data = encrypt_private_key(current_keys["private_key"], new_password)
|
||||
|
||||
# Update wallet
|
||||
wallet.encrypted_private_key = new_encrypted_data
|
||||
wallet.encryption_version = "1.0"
|
||||
wallet.updated_at = datetime.utcnow()
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(wallet)
|
||||
|
||||
logger.info(f"Migrated wallet {wallet_id} to secure encryption")
|
||||
return wallet
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate wallet {wallet_id}: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
async def get_balances(self, wallet_id: int) -> List[TokenBalance]:
|
||||
"""Get all tracked balances for a wallet"""
|
||||
return self.session.exec(
|
||||
select(TokenBalance).where(TokenBalance.wallet_id == wallet_id)
|
||||
).all()
|
||||
|
||||
async def update_balance(self, wallet_id: int, chain_id: int, token_address: str, balance: float) -> TokenBalance:
|
||||
"""Update a specific token balance for a wallet"""
|
||||
record = self.session.exec(
|
||||
select(TokenBalance).where(
|
||||
TokenBalance.wallet_id == wallet_id,
|
||||
TokenBalance.chain_id == chain_id,
|
||||
TokenBalance.token_address == token_address
|
||||
)
|
||||
).first()
|
||||
|
||||
if record:
|
||||
record.balance = balance
|
||||
record.updated_at = datetime.utcnow()
|
||||
else:
|
||||
record = TokenBalance(
|
||||
wallet_id=wallet_id,
|
||||
chain_id=chain_id,
|
||||
token_address=token_address,
|
||||
balance=balance,
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
self.session.add(record)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(record)
|
||||
return record
|
||||
|
||||
async def create_transaction(
|
||||
self,
|
||||
wallet_id: int,
|
||||
request: TransactionRequest,
|
||||
encryption_password: str
|
||||
) -> WalletTransaction:
|
||||
"""
|
||||
Create a transaction with proper signing
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
request: Transaction request
|
||||
encryption_password: Password for private key access
|
||||
|
||||
Returns:
|
||||
Created transaction record
|
||||
"""
|
||||
# Get wallet keys
|
||||
wallet_keys = await self.get_wallet_with_private_key(wallet_id, encryption_password)
|
||||
|
||||
# Create transaction record
|
||||
transaction = WalletTransaction(
|
||||
wallet_id=wallet_id,
|
||||
to_address=request.to_address,
|
||||
amount=request.amount,
|
||||
token_address=request.token_address,
|
||||
chain_id=request.chain_id,
|
||||
data=request.data or "",
|
||||
status=TransactionStatus.PENDING,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(transaction)
|
||||
self.session.commit()
|
||||
self.session.refresh(transaction)
|
||||
|
||||
# TODO: Implement actual blockchain transaction signing and submission
|
||||
# This would use the private_key to sign the transaction
|
||||
|
||||
logger.info(f"Created transaction {transaction.id} for wallet {wallet_id}")
|
||||
return transaction
|
||||
|
||||
async def deactivate_wallet(self, wallet_id: int, reason: str = "User request") -> bool:
|
||||
"""Deactivate a wallet"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
return False
|
||||
|
||||
wallet.is_active = False
|
||||
wallet.updated_at = datetime.utcnow()
|
||||
wallet.deactivation_reason = reason
|
||||
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"Deactivated wallet {wallet_id}: {reason}")
|
||||
return True
|
||||
|
||||
async def get_wallet_security_audit(self, wallet_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive security audit for a wallet
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
|
||||
Returns:
|
||||
Security audit results
|
||||
"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
return {"error": "Wallet not found"}
|
||||
|
||||
audit = {
|
||||
"wallet_id": wallet_id,
|
||||
"agent_id": wallet.agent_id,
|
||||
"address": wallet.address,
|
||||
"is_active": wallet.is_active,
|
||||
"encryption_version": getattr(wallet, 'encryption_version', 'unknown'),
|
||||
"created_at": wallet.created_at.isoformat() if wallet.created_at else None,
|
||||
"updated_at": wallet.updated_at.isoformat() if wallet.updated_at else None
|
||||
}
|
||||
|
||||
# Check encryption security
|
||||
if isinstance(wallet.encrypted_private_key, dict):
|
||||
audit["encryption_secure"] = True
|
||||
audit["encryption_algorithm"] = wallet.encrypted_private_key.get("algorithm")
|
||||
audit["encryption_iterations"] = wallet.encrypted_private_key.get("iterations")
|
||||
else:
|
||||
audit["encryption_secure"] = False
|
||||
audit["encryption_issues"] = ["Uses legacy or broken encryption"]
|
||||
|
||||
# Check address format
|
||||
try:
|
||||
from eth_utils import to_checksum_address
|
||||
to_checksum_address(wallet.address)
|
||||
audit["address_valid"] = True
|
||||
except:
|
||||
audit["address_valid"] = False
|
||||
audit["address_issues"] = ["Invalid Ethereum address format"]
|
||||
|
||||
# Check keypair data
|
||||
audit["has_public_key"] = bool(wallet.public_key)
|
||||
audit["has_encrypted_private_key"] = bool(wallet.encrypted_private_key)
|
||||
|
||||
# Overall security score
|
||||
security_score = 0
|
||||
if audit["encryption_secure"]:
|
||||
security_score += 40
|
||||
if audit["address_valid"]:
|
||||
security_score += 30
|
||||
if audit["has_public_key"]:
|
||||
security_score += 15
|
||||
if audit["has_encrypted_private_key"]:
|
||||
security_score += 15
|
||||
|
||||
audit["security_score"] = security_score
|
||||
audit["security_level"] = (
|
||||
"Excellent" if security_score >= 90 else
|
||||
"Good" if security_score >= 70 else
|
||||
"Fair" if security_score >= 50 else
|
||||
"Poor"
|
||||
)
|
||||
|
||||
return audit
|
||||
238
apps/coordinator-api/src/app/services/wallet_crypto.py
Normal file
238
apps/coordinator-api/src/app/services/wallet_crypto.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Secure Cryptographic Operations for Agent Wallets
|
||||
Fixed implementation using proper Ethereum cryptography
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from typing import Tuple, Dict, Any
|
||||
from eth_account import Account
|
||||
from eth_utils import to_checksum_address
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
|
||||
def generate_ethereum_keypair() -> Tuple[str, str, str]:
|
||||
"""
|
||||
Generate proper Ethereum keypair using secp256k1
|
||||
|
||||
Returns:
|
||||
Tuple of (private_key, public_key, address)
|
||||
"""
|
||||
# Use eth_account which properly implements secp256k1
|
||||
account = Account.create()
|
||||
|
||||
private_key = account.key.hex()
|
||||
public_key = account._private_key.public_key.to_hex()
|
||||
address = account.address
|
||||
|
||||
return private_key, public_key, address
|
||||
|
||||
|
||||
def verify_keypair_consistency(private_key: str, expected_address: str) -> bool:
|
||||
"""
|
||||
Verify that a private key generates the expected address
|
||||
|
||||
Args:
|
||||
private_key: 32-byte private key hex
|
||||
expected_address: Expected Ethereum address
|
||||
|
||||
Returns:
|
||||
True if keypair is consistent
|
||||
"""
|
||||
try:
|
||||
account = Account.from_key(private_key)
|
||||
return to_checksum_address(account.address) == to_checksum_address(expected_address)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def derive_secure_key(password: str, salt: bytes = None) -> bytes:
|
||||
"""
|
||||
Derive secure encryption key using PBKDF2
|
||||
|
||||
Args:
|
||||
password: User password
|
||||
salt: Optional salt (generated if not provided)
|
||||
|
||||
Returns:
|
||||
Tuple of (key, salt) for storage
|
||||
"""
|
||||
if salt is None:
|
||||
salt = secrets.token_bytes(32)
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=600_000, # OWASP recommended minimum
|
||||
)
|
||||
|
||||
key = kdf.derive(password.encode())
|
||||
return base64.urlsafe_b64encode(key), salt
|
||||
|
||||
|
||||
def encrypt_private_key(private_key: str, password: str) -> Dict[str, str]:
|
||||
"""
|
||||
Encrypt private key with proper KDF and Fernet
|
||||
|
||||
Args:
|
||||
private_key: 32-byte private key hex
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
Dict with encrypted data and salt
|
||||
"""
|
||||
# Derive encryption key
|
||||
fernet_key, salt = derive_secure_key(password)
|
||||
|
||||
# Encrypt
|
||||
f = Fernet(fernet_key)
|
||||
encrypted = f.encrypt(private_key.encode())
|
||||
|
||||
return {
|
||||
"encrypted_key": encrypted.decode(),
|
||||
"salt": base64.b64encode(salt).decode(),
|
||||
"algorithm": "PBKDF2-SHA256-Fernet",
|
||||
"iterations": 600_000
|
||||
}
|
||||
|
||||
|
||||
def decrypt_private_key(encrypted_data: Dict[str, str], password: str) -> str:
|
||||
"""
|
||||
Decrypt private key with proper verification
|
||||
|
||||
Args:
|
||||
encrypted_data: Dict with encrypted key and salt
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
Decrypted private key
|
||||
|
||||
Raises:
|
||||
ValueError: If decryption fails
|
||||
"""
|
||||
try:
|
||||
# Extract salt and encrypted key
|
||||
salt = base64.b64decode(encrypted_data["salt"])
|
||||
encrypted_key = encrypted_data["encrypted_key"].encode()
|
||||
|
||||
# Derive same key
|
||||
fernet_key, _ = derive_secure_key(password, salt)
|
||||
|
||||
# Decrypt
|
||||
f = Fernet(fernet_key)
|
||||
decrypted = f.decrypt(encrypted_key)
|
||||
|
||||
return decrypted.decode()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decrypt private key: {str(e)}")
|
||||
|
||||
|
||||
def validate_private_key_format(private_key: str) -> bool:
|
||||
"""
|
||||
Validate private key format
|
||||
|
||||
Args:
|
||||
private_key: Private key to validate
|
||||
|
||||
Returns:
|
||||
True if format is valid
|
||||
"""
|
||||
try:
|
||||
# Remove 0x prefix if present
|
||||
if private_key.startswith("0x"):
|
||||
private_key = private_key[2:]
|
||||
|
||||
# Check length (32 bytes = 64 hex chars)
|
||||
if len(private_key) != 64:
|
||||
return False
|
||||
|
||||
# Check if valid hex
|
||||
int(private_key, 16)
|
||||
|
||||
# Try to create account to verify it's a valid secp256k1 key
|
||||
Account.from_key("0x" + private_key)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Security configuration constants
|
||||
class SecurityConfig:
|
||||
"""Security configuration constants"""
|
||||
|
||||
# PBKDF2 settings
|
||||
PBKDF2_ITERATIONS = 600_000
|
||||
PBKDF2_ALGORITHM = hashes.SHA256
|
||||
SALT_LENGTH = 32
|
||||
|
||||
# Fernet settings
|
||||
FERNET_KEY_LENGTH = 32
|
||||
|
||||
# Validation
|
||||
PRIVATE_KEY_LENGTH = 64 # 32 bytes in hex
|
||||
ADDRESS_LENGTH = 40 # 20 bytes in hex (without 0x)
|
||||
|
||||
|
||||
# Backward compatibility wrapper for existing code
|
||||
def create_secure_wallet(agent_id: str, password: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a wallet with proper security
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
password: Strong password for encryption
|
||||
|
||||
Returns:
|
||||
Wallet data with encrypted private key
|
||||
"""
|
||||
# Generate proper keypair
|
||||
private_key, public_key, address = generate_ethereum_keypair()
|
||||
|
||||
# Validate consistency
|
||||
if not verify_keypair_consistency(private_key, address):
|
||||
raise RuntimeError("Keypair generation failed consistency check")
|
||||
|
||||
# Encrypt private key
|
||||
encrypted_data = encrypt_private_key(private_key, password)
|
||||
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"address": address,
|
||||
"public_key": public_key,
|
||||
"encrypted_private_key": encrypted_data,
|
||||
"created_at": secrets.token_hex(16), # For tracking
|
||||
"version": "1.0"
|
||||
}
|
||||
|
||||
|
||||
def recover_wallet(encrypted_data: Dict[str, str], password: str) -> Dict[str, str]:
|
||||
"""
|
||||
Recover wallet from encrypted data
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted wallet data
|
||||
password: Password for decryption
|
||||
|
||||
Returns:
|
||||
Wallet keys
|
||||
"""
|
||||
# Decrypt private key
|
||||
private_key = decrypt_private_key(encrypted_data, password)
|
||||
|
||||
# Validate format
|
||||
if not validate_private_key_format(private_key):
|
||||
raise ValueError("Decrypted private key has invalid format")
|
||||
|
||||
# Derive address and public key to verify
|
||||
account = Account.from_key("0x" + private_key)
|
||||
|
||||
return {
|
||||
"private_key": private_key,
|
||||
"public_key": account._private_key.public_key.to_hex(),
|
||||
"address": account.address
|
||||
}
|
||||
@@ -0,0 +1,641 @@
|
||||
"""
|
||||
WebSocket Stream Manager with Backpressure Control
|
||||
|
||||
Advanced WebSocket stream architecture with per-stream flow control,
|
||||
bounded queues, and event loop protection for multi-modal fusion.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import weakref
|
||||
from typing import Dict, List, Optional, Any, Callable, Set, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from collections import deque
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import websockets
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from aitbc.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StreamStatus(Enum):
|
||||
"""Stream connection status"""
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
SLOW_CONSUMER = "slow_consumer"
|
||||
BACKPRESSURE = "backpressure"
|
||||
DISCONNECTED = "disconnected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Message types for stream classification"""
|
||||
CRITICAL = "critical" # High priority, must deliver
|
||||
IMPORTANT = "important" # Normal priority
|
||||
BULK = "bulk" # Low priority, can be dropped
|
||||
CONTROL = "control" # Stream control messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamMessage:
|
||||
"""Message with priority and metadata"""
|
||||
data: Any
|
||||
message_type: MessageType
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.message_id,
|
||||
"type": self.message_type.value,
|
||||
"timestamp": self.timestamp,
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamMetrics:
|
||||
"""Metrics for stream performance monitoring"""
|
||||
messages_sent: int = 0
|
||||
messages_dropped: int = 0
|
||||
bytes_sent: int = 0
|
||||
last_send_time: float = 0
|
||||
avg_send_time: float = 0
|
||||
queue_size: int = 0
|
||||
backpressure_events: int = 0
|
||||
slow_consumer_events: int = 0
|
||||
|
||||
def update_send_metrics(self, send_time: float, message_size: int):
|
||||
"""Update send performance metrics"""
|
||||
self.messages_sent += 1
|
||||
self.bytes_sent += message_size
|
||||
self.last_send_time = time.time()
|
||||
|
||||
# Update average send time
|
||||
if self.messages_sent == 1:
|
||||
self.avg_send_time = send_time
|
||||
else:
|
||||
self.avg_send_time = (self.avg_send_time * (self.messages_sent - 1) + send_time) / self.messages_sent
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamConfig:
|
||||
"""Configuration for individual streams"""
|
||||
max_queue_size: int = 1000
|
||||
send_timeout: float = 5.0
|
||||
heartbeat_interval: float = 30.0
|
||||
slow_consumer_threshold: float = 0.5 # seconds
|
||||
backpressure_threshold: float = 0.8 # queue fill ratio
|
||||
drop_bulk_threshold: float = 0.9 # queue fill ratio for bulk messages
|
||||
enable_compression: bool = True
|
||||
priority_send: bool = True
|
||||
|
||||
|
||||
class BoundedMessageQueue:
|
||||
"""Bounded queue with priority and backpressure handling"""
|
||||
|
||||
def __init__(self, max_size: int = 1000):
|
||||
self.max_size = max_size
|
||||
self.queues = {
|
||||
MessageType.CRITICAL: deque(maxlen=max_size // 4),
|
||||
MessageType.IMPORTANT: deque(maxlen=max_size // 2),
|
||||
MessageType.BULK: deque(maxlen=max_size // 4),
|
||||
MessageType.CONTROL: deque(maxlen=100) # Small control queue
|
||||
}
|
||||
self.total_size = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def put(self, message: StreamMessage) -> bool:
|
||||
"""Add message to queue with backpressure handling"""
|
||||
async with self._lock:
|
||||
# Check if we're at capacity
|
||||
if self.total_size >= self.max_size:
|
||||
# Drop bulk messages first
|
||||
if message.message_type == MessageType.BULK:
|
||||
return False
|
||||
|
||||
# Drop oldest important messages if critical
|
||||
if message.message_type == MessageType.IMPORTANT:
|
||||
if self.queues[MessageType.IMPORTANT]:
|
||||
self.queues[MessageType.IMPORTANT].popleft()
|
||||
self.total_size -= 1
|
||||
else:
|
||||
return False
|
||||
|
||||
# Always allow critical messages (drop oldest if needed)
|
||||
if message.message_type == MessageType.CRITICAL:
|
||||
if self.queues[MessageType.CRITICAL]:
|
||||
self.queues[MessageType.CRITICAL].popleft()
|
||||
self.total_size -= 1
|
||||
|
||||
self.queues[message.message_type].append(message)
|
||||
self.total_size += 1
|
||||
return True
|
||||
|
||||
async def get(self) -> Optional[StreamMessage]:
|
||||
"""Get next message by priority"""
|
||||
async with self._lock:
|
||||
# Priority order: CONTROL > CRITICAL > IMPORTANT > BULK
|
||||
for message_type in [MessageType.CONTROL, MessageType.CRITICAL,
|
||||
MessageType.IMPORTANT, MessageType.BULK]:
|
||||
if self.queues[message_type]:
|
||||
message = self.queues[message_type].popleft()
|
||||
self.total_size -= 1
|
||||
return message
|
||||
return None
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get total queue size"""
|
||||
return self.total_size
|
||||
|
||||
def fill_ratio(self) -> float:
|
||||
"""Get queue fill ratio"""
|
||||
return self.total_size / self.max_size
|
||||
|
||||
|
||||
class WebSocketStream:
|
||||
"""Individual WebSocket stream with backpressure control"""
|
||||
|
||||
def __init__(self, websocket: WebSocketServerProtocol,
|
||||
stream_id: str, config: StreamConfig):
|
||||
self.websocket = websocket
|
||||
self.stream_id = stream_id
|
||||
self.config = config
|
||||
self.status = StreamStatus.CONNECTING
|
||||
self.queue = BoundedMessageQueue(config.max_queue_size)
|
||||
self.metrics = StreamMetrics()
|
||||
self.last_heartbeat = time.time()
|
||||
self.slow_consumer_count = 0
|
||||
|
||||
# Event loop protection
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._sender_task = None
|
||||
self._heartbeat_task = None
|
||||
self._running = False
|
||||
|
||||
# Weak reference for cleanup
|
||||
self._finalizer = weakref.finalize(self, self._cleanup)
|
||||
|
||||
async def start(self):
|
||||
"""Start stream processing"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self.status = StreamStatus.CONNECTED
|
||||
|
||||
# Start sender task
|
||||
self._sender_task = asyncio.create_task(self._sender_loop())
|
||||
|
||||
# Start heartbeat task
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
logger.info(f"Stream {self.stream_id} started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop stream processing"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self.status = StreamStatus.DISCONNECTED
|
||||
|
||||
# Cancel tasks
|
||||
if self._sender_task:
|
||||
self._sender_task.cancel()
|
||||
try:
|
||||
await self._sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
await self._heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info(f"Stream {self.stream_id} stopped")
|
||||
|
||||
async def send_message(self, data: Any, message_type: MessageType = MessageType.IMPORTANT) -> bool:
|
||||
"""Send message with backpressure handling"""
|
||||
if not self._running:
|
||||
return False
|
||||
|
||||
message = StreamMessage(data=data, message_type=message_type)
|
||||
|
||||
# Check backpressure
|
||||
queue_ratio = self.queue.fill_ratio()
|
||||
if queue_ratio > self.config.backpressure_threshold:
|
||||
self.status = StreamStatus.BACKPRESSURE
|
||||
self.metrics.backpressure_events += 1
|
||||
|
||||
# Drop bulk messages under backpressure
|
||||
if message_type == MessageType.BULK and queue_ratio > self.config.drop_bulk_threshold:
|
||||
self.metrics.messages_dropped += 1
|
||||
return False
|
||||
|
||||
# Add to queue
|
||||
success = await self.queue.put(message)
|
||||
if not success:
|
||||
self.metrics.messages_dropped += 1
|
||||
|
||||
return success
|
||||
|
||||
async def _sender_loop(self):
|
||||
"""Main sender loop with backpressure control"""
|
||||
while self._running:
|
||||
try:
|
||||
# Get next message
|
||||
message = await self.queue.get()
|
||||
if message is None:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Send with timeout and backpressure protection
|
||||
start_time = time.time()
|
||||
success = await self._send_with_backpressure(message)
|
||||
send_time = time.time() - start_time
|
||||
|
||||
if success:
|
||||
message_size = len(json.dumps(message.to_dict()).encode())
|
||||
self.metrics.update_send_metrics(send_time, message_size)
|
||||
else:
|
||||
# Retry logic
|
||||
message.retry_count += 1
|
||||
if message.retry_count < message.max_retries:
|
||||
await self.queue.put(message)
|
||||
else:
|
||||
self.metrics.messages_dropped += 1
|
||||
logger.warning(f"Message {message.message_id} dropped after max retries")
|
||||
|
||||
# Check for slow consumer
|
||||
if send_time > self.config.slow_consumer_threshold:
|
||||
self.slow_consumer_count += 1
|
||||
self.metrics.slow_consumer_events += 1
|
||||
|
||||
if self.slow_consumer_count > 5: # Threshold for slow consumer detection
|
||||
self.status = StreamStatus.SLOW_CONSUMER
|
||||
logger.warning(f"Stream {self.stream_id} detected as slow consumer")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sender loop for stream {self.stream_id}: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _send_with_backpressure(self, message: StreamMessage) -> bool:
|
||||
"""Send message with backpressure and timeout protection"""
|
||||
try:
|
||||
async with self._send_lock:
|
||||
# Use asyncio.wait_for for timeout protection
|
||||
message_data = message.to_dict()
|
||||
|
||||
if self.config.enable_compression:
|
||||
# Compress large messages
|
||||
message_str = json.dumps(message_data, separators=(',', ':'))
|
||||
if len(message_str) > 1024: # Compress messages > 1KB
|
||||
message_data['_compressed'] = True
|
||||
message_str = json.dumps(message_data, separators=(',', ':'))
|
||||
else:
|
||||
message_str = json.dumps(message_data)
|
||||
|
||||
# Send with timeout
|
||||
await asyncio.wait_for(
|
||||
self.websocket.send(message_str),
|
||||
timeout=self.config.send_timeout
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Send timeout for stream {self.stream_id}")
|
||||
return False
|
||||
except ConnectionClosed:
|
||||
logger.info(f"Connection closed for stream {self.stream_id}")
|
||||
await self.stop()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Send error for stream {self.stream_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""Heartbeat loop for connection health monitoring"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.config.heartbeat_interval)
|
||||
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
# Send heartbeat
|
||||
heartbeat_msg = {
|
||||
"type": "heartbeat",
|
||||
"timestamp": time.time(),
|
||||
"stream_id": self.stream_id,
|
||||
"queue_size": self.queue.size(),
|
||||
"status": self.status.value
|
||||
}
|
||||
|
||||
await self.send_message(heartbeat_msg, MessageType.CONTROL)
|
||||
self.last_heartbeat = time.time()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Heartbeat error for stream {self.stream_id}: {e}")
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get stream metrics"""
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"status": self.status.value,
|
||||
"queue_size": self.queue.size(),
|
||||
"queue_fill_ratio": self.queue.fill_ratio(),
|
||||
"messages_sent": self.metrics.messages_sent,
|
||||
"messages_dropped": self.metrics.messages_dropped,
|
||||
"bytes_sent": self.metrics.bytes_sent,
|
||||
"avg_send_time": self.metrics.avg_send_time,
|
||||
"backpressure_events": self.metrics.backpressure_events,
|
||||
"slow_consumer_events": self.metrics.slow_consumer_events,
|
||||
"last_heartbeat": self.last_heartbeat
|
||||
}
|
||||
|
||||
def _cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self._running:
|
||||
# This should be called by garbage collector
|
||||
logger.warning(f"Stream {self.stream_id} cleanup called while running")
|
||||
|
||||
|
||||
class WebSocketStreamManager:
|
||||
"""Manages multiple WebSocket streams with backpressure control"""
|
||||
|
||||
def __init__(self, default_config: Optional[StreamConfig] = None):
|
||||
self.default_config = default_config or StreamConfig()
|
||||
self.streams: Dict[str, WebSocketStream] = {}
|
||||
self.stream_configs: Dict[str, StreamConfig] = {}
|
||||
|
||||
# Global metrics
|
||||
self.total_connections = 0
|
||||
self.total_messages_sent = 0
|
||||
self.total_messages_dropped = 0
|
||||
|
||||
# Event loop protection
|
||||
self._manager_lock = asyncio.Lock()
|
||||
self._cleanup_task = None
|
||||
self._running = False
|
||||
|
||||
# Message broadcasting
|
||||
self._broadcast_queue = asyncio.Queue(maxsize=10000)
|
||||
self._broadcast_task = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the stream manager"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Start cleanup task
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
# Start broadcast task
|
||||
self._broadcast_task = asyncio.create_task(self._broadcast_loop())
|
||||
|
||||
logger.info("WebSocket Stream Manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the stream manager"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Stop all streams
|
||||
streams_to_stop = list(self.streams.values())
|
||||
for stream in streams_to_stop:
|
||||
await stream.stop()
|
||||
|
||||
# Cancel tasks
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._broadcast_task:
|
||||
self._broadcast_task.cancel()
|
||||
try:
|
||||
await self._broadcast_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("WebSocket Stream Manager stopped")
|
||||
|
||||
async def manage_stream(self, websocket: WebSocketServerProtocol,
|
||||
config: Optional[StreamConfig] = None):
|
||||
"""Context manager for stream lifecycle"""
|
||||
stream_id = str(uuid.uuid4())
|
||||
stream_config = config or self.default_config
|
||||
|
||||
stream = None
|
||||
try:
|
||||
# Create and start stream
|
||||
stream = WebSocketStream(websocket, stream_id, stream_config)
|
||||
await stream.start()
|
||||
|
||||
async with self._manager_lock:
|
||||
self.streams[stream_id] = stream
|
||||
self.stream_configs[stream_id] = stream_config
|
||||
self.total_connections += 1
|
||||
|
||||
logger.info(f"Stream {stream_id} added to manager")
|
||||
|
||||
yield stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error managing stream {stream_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup stream
|
||||
if stream and stream_id in self.streams:
|
||||
await stream.stop()
|
||||
|
||||
async with self._manager_lock:
|
||||
del self.streams[stream_id]
|
||||
if stream_id in self.stream_configs:
|
||||
del self.stream_configs[stream_id]
|
||||
self.total_connections -= 1
|
||||
|
||||
logger.info(f"Stream {stream_id} removed from manager")
|
||||
|
||||
async def broadcast_to_all(self, data: Any, message_type: MessageType = MessageType.IMPORTANT):
|
||||
"""Broadcast message to all streams"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._broadcast_queue.put((data, message_type))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("Broadcast queue full, dropping message")
|
||||
self.total_messages_dropped += 1
|
||||
|
||||
async def broadcast_to_stream(self, stream_id: str, data: Any,
|
||||
message_type: MessageType = MessageType.IMPORTANT):
|
||||
"""Send message to specific stream"""
|
||||
async with self._manager_lock:
|
||||
stream = self.streams.get(stream_id)
|
||||
if stream:
|
||||
await stream.send_message(data, message_type)
|
||||
|
||||
async def _broadcast_loop(self):
|
||||
"""Broadcast messages to all streams"""
|
||||
while self._running:
|
||||
try:
|
||||
# Get broadcast message
|
||||
data, message_type = await self._broadcast_queue.get()
|
||||
|
||||
# Send to all streams concurrently
|
||||
tasks = []
|
||||
async with self._manager_lock:
|
||||
streams = list(self.streams.values())
|
||||
|
||||
for stream in streams:
|
||||
task = asyncio.create_task(
|
||||
stream.send_message(data, message_type)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all sends (with timeout)
|
||||
if tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True),
|
||||
timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Broadcast timeout, some streams may be slow")
|
||||
|
||||
self.total_messages_sent += 1
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in broadcast loop: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Cleanup disconnected streams"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(60) # Cleanup every minute
|
||||
|
||||
disconnected_streams = []
|
||||
async with self._manager_lock:
|
||||
for stream_id, stream in self.streams.items():
|
||||
if stream.status == StreamStatus.DISCONNECTED:
|
||||
disconnected_streams.append(stream_id)
|
||||
|
||||
# Remove disconnected streams
|
||||
for stream_id in disconnected_streams:
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
await stream.stop()
|
||||
del self.streams[stream_id]
|
||||
if stream_id in self.stream_configs:
|
||||
del self.stream_configs[stream_id]
|
||||
self.total_connections -= 1
|
||||
logger.info(f"Cleaned up disconnected stream {stream_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
|
||||
async def get_manager_metrics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive manager metrics"""
|
||||
async with self._manager_lock:
|
||||
stream_metrics = []
|
||||
for stream in self.streams.values():
|
||||
stream_metrics.append(stream.get_metrics())
|
||||
|
||||
# Calculate aggregate metrics
|
||||
total_queue_size = sum(m["queue_size"] for m in stream_metrics)
|
||||
total_messages_sent = sum(m["messages_sent"] for m in stream_metrics)
|
||||
total_messages_dropped = sum(m["messages_dropped"] for m in stream_metrics)
|
||||
total_bytes_sent = sum(m["bytes_sent"] for m in stream_metrics)
|
||||
|
||||
# Status distribution
|
||||
status_counts = {}
|
||||
for stream in self.streams.values():
|
||||
status = stream.status.value
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
|
||||
return {
|
||||
"manager_status": "running" if self._running else "stopped",
|
||||
"total_connections": self.total_connections,
|
||||
"active_streams": len(self.streams),
|
||||
"total_queue_size": total_queue_size,
|
||||
"total_messages_sent": total_messages_sent,
|
||||
"total_messages_dropped": total_messages_dropped,
|
||||
"total_bytes_sent": total_bytes_sent,
|
||||
"broadcast_queue_size": self._broadcast_queue.qsize(),
|
||||
"stream_status_distribution": status_counts,
|
||||
"stream_metrics": stream_metrics
|
||||
}
|
||||
|
||||
async def update_stream_config(self, stream_id: str, config: StreamConfig):
|
||||
"""Update configuration for specific stream"""
|
||||
async with self._manager_lock:
|
||||
if stream_id in self.streams:
|
||||
self.stream_configs[stream_id] = config
|
||||
# Stream will use new config on next send
|
||||
logger.info(f"Updated config for stream {stream_id}")
|
||||
|
||||
def get_slow_streams(self, threshold: float = 0.8) -> List[str]:
|
||||
"""Get streams with high queue fill ratios"""
|
||||
slow_streams = []
|
||||
for stream_id, stream in self.streams.items():
|
||||
if stream.queue.fill_ratio() > threshold:
|
||||
slow_streams.append(stream_id)
|
||||
return slow_streams
|
||||
|
||||
async def handle_slow_consumer(self, stream_id: str, action: str = "warn"):
|
||||
"""Handle slow consumer streams"""
|
||||
async with self._manager_lock:
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
return
|
||||
|
||||
if action == "warn":
|
||||
logger.warning(f"Slow consumer detected: {stream_id}")
|
||||
await stream.send_message(
|
||||
{"warning": "Slow consumer detected", "stream_id": stream_id},
|
||||
MessageType.CONTROL
|
||||
)
|
||||
elif action == "throttle":
|
||||
# Reduce queue size for slow consumer
|
||||
new_config = StreamConfig(
|
||||
max_queue_size=stream.config.max_queue_size // 2,
|
||||
send_timeout=stream.config.send_timeout * 2
|
||||
)
|
||||
await self.update_stream_config(stream_id, new_config)
|
||||
logger.info(f"Throttled slow consumer: {stream_id}")
|
||||
elif action == "disconnect":
|
||||
logger.warning(f"Disconnecting slow consumer: {stream_id}")
|
||||
await stream.stop()
|
||||
|
||||
|
||||
# Global stream manager instance
|
||||
stream_manager = WebSocketStreamManager()
|
||||
@@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends
|
||||
|
||||
from .deps import get_receipt_service, get_keystore, get_ledger
|
||||
from .models import ReceiptVerificationModel, from_validation_result
|
||||
from .keystore.service import KeystoreService
|
||||
from .keystore.persistent_service import PersistentKeystoreService
|
||||
from .ledger_mock import SQLiteLedgerAdapter
|
||||
from .receipts.service import ReceiptVerifierService
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from .models import (
|
||||
WalletDescriptor,
|
||||
from_validation_result,
|
||||
)
|
||||
from .keystore.service import KeystoreService
|
||||
from .keystore.persistent_service import PersistentKeystoreService
|
||||
from .ledger_mock import SQLiteLedgerAdapter
|
||||
from .receipts.service import ReceiptValidationResult, ReceiptVerifierService
|
||||
from .security import RateLimiter, wipe_buffer
|
||||
@@ -85,7 +85,7 @@ def verify_receipt_history(
|
||||
|
||||
@router.get("/wallets", response_model=WalletListResponse, summary="List wallets")
|
||||
def list_wallets(
|
||||
keystore: KeystoreService = Depends(get_keystore),
|
||||
keystore: PersistentKeystoreService = Depends(get_keystore),
|
||||
ledger: SQLiteLedgerAdapter = Depends(get_ledger),
|
||||
) -> WalletListResponse:
|
||||
descriptors = []
|
||||
@@ -102,7 +102,7 @@ def list_wallets(
|
||||
def create_wallet(
|
||||
request: WalletCreateRequest,
|
||||
http_request: Request,
|
||||
keystore: KeystoreService = Depends(get_keystore),
|
||||
keystore: PersistentKeystoreService = Depends(get_keystore),
|
||||
ledger: SQLiteLedgerAdapter = Depends(get_ledger),
|
||||
) -> WalletCreateResponse:
|
||||
_enforce_limit("wallet-create", http_request)
|
||||
@@ -113,11 +113,13 @@ def create_wallet(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="invalid base64 secret") from exc
|
||||
|
||||
try:
|
||||
ip_address = http_request.client.host if http_request.client else "unknown"
|
||||
record = keystore.create_wallet(
|
||||
wallet_id=request.wallet_id,
|
||||
password=request.password,
|
||||
secret=secret,
|
||||
metadata=request.metadata,
|
||||
ip_address=ip_address
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
@@ -137,16 +139,18 @@ def unlock_wallet(
|
||||
wallet_id: str,
|
||||
request: WalletUnlockRequest,
|
||||
http_request: Request,
|
||||
keystore: KeystoreService = Depends(get_keystore),
|
||||
keystore: PersistentKeystoreService = Depends(get_keystore),
|
||||
ledger: SQLiteLedgerAdapter = Depends(get_ledger),
|
||||
) -> WalletUnlockResponse:
|
||||
_enforce_limit("wallet-unlock", http_request, wallet_id)
|
||||
try:
|
||||
secret = bytearray(keystore.unlock_wallet(wallet_id, request.password))
|
||||
ledger.record_event(wallet_id, "unlocked", {"success": True})
|
||||
ip_address = http_request.client.host if http_request.client else "unknown"
|
||||
secret = bytearray(keystore.unlock_wallet(wallet_id, request.password, ip_address))
|
||||
ledger.record_event(wallet_id, "unlocked", {"success": True, "ip_address": ip_address})
|
||||
logger.info("Unlocked wallet", extra={"wallet_id": wallet_id})
|
||||
except (KeyError, ValueError):
|
||||
ledger.record_event(wallet_id, "unlocked", {"success": False})
|
||||
ip_address = http_request.client.host if http_request.client else "unknown"
|
||||
ledger.record_event(wallet_id, "unlocked", {"success": False, "ip_address": ip_address})
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid credentials")
|
||||
finally:
|
||||
if "secret" in locals():
|
||||
@@ -160,7 +164,7 @@ def sign_payload(
|
||||
wallet_id: str,
|
||||
request: WalletSignRequest,
|
||||
http_request: Request,
|
||||
keystore: KeystoreService = Depends(get_keystore),
|
||||
keystore: PersistentKeystoreService = Depends(get_keystore),
|
||||
ledger: SQLiteLedgerAdapter = Depends(get_ledger),
|
||||
) -> WalletSignResponse:
|
||||
_enforce_limit("wallet-sign", http_request, wallet_id)
|
||||
@@ -170,11 +174,13 @@ def sign_payload(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="invalid base64 message") from exc
|
||||
|
||||
try:
|
||||
signature = keystore.sign_message(wallet_id, request.password, message)
|
||||
ledger.record_event(wallet_id, "sign", {"success": True})
|
||||
ip_address = http_request.client.host if http_request.client else "unknown"
|
||||
signature = keystore.sign_message(wallet_id, request.password, message, ip_address)
|
||||
ledger.record_event(wallet_id, "sign", {"success": True, "ip_address": ip_address})
|
||||
logger.debug("Signed payload", extra={"wallet_id": wallet_id})
|
||||
except (KeyError, ValueError):
|
||||
ledger.record_event(wallet_id, "sign", {"success": False})
|
||||
ip_address = http_request.client.host if http_request.client else "unknown"
|
||||
ledger.record_event(wallet_id, "sign", {"success": False, "ip_address": ip_address})
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid credentials")
|
||||
|
||||
signature_b64 = base64.b64encode(signature).decode()
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import Depends
|
||||
|
||||
from .keystore.service import KeystoreService
|
||||
from .ledger_mock import SQLiteLedgerAdapter
|
||||
from .keystore.persistent_service import PersistentKeystoreService
|
||||
from .receipts.service import ReceiptVerifierService
|
||||
from .settings import Settings, settings
|
||||
|
||||
@@ -22,8 +23,8 @@ def get_receipt_service(config: Settings = Depends(get_settings)) -> ReceiptVeri
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_keystore() -> KeystoreService:
|
||||
return KeystoreService()
|
||||
def get_keystore(config: Settings = Depends(get_settings)) -> PersistentKeystoreService:
|
||||
return PersistentKeystoreService(db_path=config.ledger_db_path.parent / "keystore.db")
|
||||
|
||||
|
||||
def get_ledger(config: Settings = Depends(get_settings)) -> SQLiteLedgerAdapter:
|
||||
|
||||
396
apps/wallet-daemon/src/app/keystore/persistent_service.py
Normal file
396
apps/wallet-daemon/src/app/keystore/persistent_service.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
Persistent Keystore Service - Fixes data loss on restart
|
||||
Replaces the in-memory-only keystore with database persistence
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from secrets import token_bytes
|
||||
|
||||
from nacl.signing import SigningKey
|
||||
|
||||
from ..crypto.encryption import EncryptionSuite, EncryptionError
|
||||
from ..security import validate_password_rules, wipe_buffer
|
||||
|
||||
|
||||
@dataclass
|
||||
class WalletRecord:
|
||||
"""Wallet record with database persistence"""
|
||||
wallet_id: str
|
||||
public_key: str
|
||||
salt: bytes
|
||||
nonce: bytes
|
||||
ciphertext: bytes
|
||||
metadata: Dict[str, str]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class PersistentKeystoreService:
|
||||
"""Persistent keystore with database storage and proper encryption"""
|
||||
|
||||
def __init__(self, db_path: Optional[Path] = None, encryption: Optional[EncryptionSuite] = None) -> None:
|
||||
self.db_path = db_path or Path("./data/keystore.db")
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._encryption = encryption or EncryptionSuite()
|
||||
self._lock = threading.Lock()
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
"""Initialize database schema"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS wallets (
|
||||
wallet_id TEXT PRIMARY KEY,
|
||||
public_key TEXT NOT NULL,
|
||||
salt BLOB NOT NULL,
|
||||
nonce BLOB NOT NULL,
|
||||
ciphertext BLOB NOT NULL,
|
||||
metadata TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS wallet_access_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
wallet_id TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
ip_address TEXT,
|
||||
FOREIGN KEY (wallet_id) REFERENCES wallets (wallet_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Indexes for performance
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_wallets_created_at ON wallets(created_at)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_access_log_wallet_id ON wallet_access_log(wallet_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_access_log_timestamp ON wallet_access_log(timestamp)")
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def list_wallets(self) -> List[str]:
|
||||
"""List all wallet IDs"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
cursor = conn.execute("SELECT wallet_id FROM wallets ORDER BY created_at DESC")
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def list_records(self) -> Iterable[WalletRecord]:
|
||||
"""List all wallet records"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
cursor = conn.execute("""
|
||||
SELECT wallet_id, public_key, salt, nonce, ciphertext, metadata, created_at, updated_at
|
||||
FROM wallets
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
for row in cursor.fetchall():
|
||||
metadata = json.loads(row[5])
|
||||
yield WalletRecord(
|
||||
wallet_id=row[0],
|
||||
public_key=row[1],
|
||||
salt=row[2],
|
||||
nonce=row[3],
|
||||
ciphertext=row[4],
|
||||
metadata=metadata,
|
||||
created_at=row[6],
|
||||
updated_at=row[7]
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_wallet(self, wallet_id: str) -> Optional[WalletRecord]:
|
||||
"""Get wallet record by ID"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
cursor = conn.execute("""
|
||||
SELECT wallet_id, public_key, salt, nonce, ciphertext, metadata, created_at, updated_at
|
||||
FROM wallets
|
||||
WHERE wallet_id = ?
|
||||
""", (wallet_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
metadata = json.loads(row[5])
|
||||
return WalletRecord(
|
||||
wallet_id=row[0],
|
||||
public_key=row[1],
|
||||
salt=row[2],
|
||||
nonce=row[3],
|
||||
ciphertext=row[4],
|
||||
metadata=metadata,
|
||||
created_at=row[6],
|
||||
updated_at=row[7]
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def create_wallet(
|
||||
self,
|
||||
wallet_id: str,
|
||||
password: str,
|
||||
secret: Optional[bytes] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
ip_address: Optional[str] = None
|
||||
) -> WalletRecord:
|
||||
"""Create a new wallet with database persistence"""
|
||||
with self._lock:
|
||||
# Check if wallet already exists
|
||||
if self.get_wallet(wallet_id):
|
||||
raise ValueError("wallet already exists")
|
||||
|
||||
validate_password_rules(password)
|
||||
|
||||
metadata_map = {str(k): str(v) for k, v in (metadata or {}).items()}
|
||||
|
||||
if secret is None:
|
||||
signing_key = SigningKey.generate()
|
||||
secret_bytes = signing_key.encode()
|
||||
else:
|
||||
if len(secret) != SigningKey.seed_size:
|
||||
raise ValueError("secret key must be 32 bytes")
|
||||
secret_bytes = secret
|
||||
signing_key = SigningKey(secret_bytes)
|
||||
|
||||
salt = token_bytes(self._encryption.salt_bytes)
|
||||
nonce = token_bytes(self._encryption.nonce_bytes)
|
||||
ciphertext = self._encryption.encrypt(password=password, plaintext=secret_bytes, salt=salt, nonce=nonce)
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
conn.execute("""
|
||||
INSERT INTO wallets (wallet_id, public_key, salt, nonce, ciphertext, metadata, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
wallet_id,
|
||||
signing_key.verify_key.encode().hex(),
|
||||
salt,
|
||||
nonce,
|
||||
ciphertext,
|
||||
json.dumps(metadata_map),
|
||||
now,
|
||||
now
|
||||
))
|
||||
|
||||
# Log creation
|
||||
conn.execute("""
|
||||
INSERT INTO wallet_access_log (wallet_id, action, timestamp, success, ip_address)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (wallet_id, "created", now, 1, ip_address))
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
record = WalletRecord(
|
||||
wallet_id=wallet_id,
|
||||
public_key=signing_key.verify_key.encode().hex(),
|
||||
salt=salt,
|
||||
nonce=nonce,
|
||||
ciphertext=ciphertext,
|
||||
metadata=metadata_map,
|
||||
created_at=now,
|
||||
updated_at=now
|
||||
)
|
||||
|
||||
return record
|
||||
|
||||
def unlock_wallet(self, wallet_id: str, password: str, ip_address: Optional[str] = None) -> bytes:
|
||||
"""Unlock wallet and return secret key"""
|
||||
record = self.get_wallet(wallet_id)
|
||||
if record is None:
|
||||
self._log_access(wallet_id, "unlock_failed", False, ip_address)
|
||||
raise KeyError("wallet not found")
|
||||
|
||||
try:
|
||||
secret = self._encryption.decrypt(password=password, ciphertext=record.ciphertext, salt=record.salt, nonce=record.nonce)
|
||||
self._log_access(wallet_id, "unlock_success", True, ip_address)
|
||||
return secret
|
||||
except EncryptionError as exc:
|
||||
self._log_access(wallet_id, "unlock_failed", False, ip_address)
|
||||
raise ValueError("failed to decrypt wallet") from exc
|
||||
|
||||
def delete_wallet(self, wallet_id: str) -> bool:
|
||||
"""Delete a wallet and all its access logs"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
# Delete access logs first
|
||||
conn.execute("DELETE FROM wallet_access_log WHERE wallet_id = ?", (wallet_id,))
|
||||
|
||||
# Delete wallet
|
||||
cursor = conn.execute("DELETE FROM wallets WHERE wallet_id = ?", (wallet_id,))
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def sign_message(self, wallet_id: str, password: str, message: bytes, ip_address: Optional[str] = None) -> bytes:
|
||||
"""Sign a message with wallet's private key"""
|
||||
try:
|
||||
secret_bytes = bytearray(self.unlock_wallet(wallet_id, password, ip_address))
|
||||
try:
|
||||
signing_key = SigningKey(bytes(secret_bytes))
|
||||
signed = signing_key.sign(message)
|
||||
self._log_access(wallet_id, "sign_success", True, ip_address)
|
||||
return signed.signature
|
||||
finally:
|
||||
wipe_buffer(secret_bytes)
|
||||
except (KeyError, ValueError) as exc:
|
||||
self._log_access(wallet_id, "sign_failed", False, ip_address)
|
||||
raise
|
||||
|
||||
def update_metadata(self, wallet_id: str, metadata: Dict[str, str]) -> bool:
|
||||
"""Update wallet metadata"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
now = datetime.utcnow().isoformat()
|
||||
metadata_json = json.dumps(metadata)
|
||||
|
||||
cursor = conn.execute("""
|
||||
UPDATE wallets
|
||||
SET metadata = ?, updated_at = ?
|
||||
WHERE wallet_id = ?
|
||||
""", (metadata_json, now, wallet_id))
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _log_access(self, wallet_id: str, action: str, success: bool, ip_address: Optional[str] = None):
|
||||
"""Log wallet access for audit trail"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
now = datetime.utcnow().isoformat()
|
||||
conn.execute("""
|
||||
INSERT INTO wallet_access_log (wallet_id, action, timestamp, success, ip_address)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (wallet_id, action, now, int(success), ip_address))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
# Don't fail the main operation if logging fails
|
||||
pass
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_access_log(self, wallet_id: str, limit: int = 50) -> List[Dict]:
|
||||
"""Get access log for a wallet"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
cursor = conn.execute("""
|
||||
SELECT action, timestamp, success, ip_address
|
||||
FROM wallet_access_log
|
||||
WHERE wallet_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (wallet_id, limit))
|
||||
|
||||
return [
|
||||
{
|
||||
"action": row[0],
|
||||
"timestamp": row[1],
|
||||
"success": bool(row[2]),
|
||||
"ip_address": row[3]
|
||||
}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get keystore statistics"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
# Wallet count
|
||||
wallet_count = conn.execute("SELECT COUNT(*) FROM wallets").fetchone()[0]
|
||||
|
||||
# Recent activity
|
||||
recent_creations = conn.execute("""
|
||||
SELECT COUNT(*) FROM wallets
|
||||
WHERE created_at > datetime('now', '-24 hours')
|
||||
""").fetchone()[0]
|
||||
|
||||
recent_access = conn.execute("""
|
||||
SELECT COUNT(*) FROM wallet_access_log
|
||||
WHERE timestamp > datetime('now', '-24 hours')
|
||||
""").fetchone()[0]
|
||||
|
||||
# Access success rate
|
||||
total_access = conn.execute("SELECT COUNT(*) FROM wallet_access_log").fetchone()[0]
|
||||
successful_access = conn.execute("SELECT COUNT(*) FROM wallet_access_log WHERE success = 1").fetchone()[0]
|
||||
|
||||
success_rate = (successful_access / total_access * 100) if total_access > 0 else 0
|
||||
|
||||
return {
|
||||
"total_wallets": wallet_count,
|
||||
"created_last_24h": recent_creations,
|
||||
"access_last_24h": recent_access,
|
||||
"access_success_rate": round(success_rate, 2),
|
||||
"database_path": str(self.db_path)
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def backup_keystore(self, backup_path: Path) -> bool:
|
||||
"""Create a backup of the keystore database"""
|
||||
try:
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
backup_conn = sqlite3.connect(backup_path)
|
||||
conn.backup(backup_conn)
|
||||
conn.close()
|
||||
backup_conn.close()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def verify_integrity(self) -> Dict[str, Any]:
|
||||
"""Verify database integrity"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
# Run integrity check
|
||||
result = conn.execute("PRAGMA integrity_check").fetchall()
|
||||
|
||||
# Check foreign key constraints
|
||||
fk_check = conn.execute("PRAGMA foreign_key_check").fetchall()
|
||||
|
||||
return {
|
||||
"integrity_check": result,
|
||||
"foreign_key_check": fk_check,
|
||||
"is_valid": len(result) == 1 and result[0][0] == "ok"
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# Import datetime for the module
|
||||
from datetime import datetime
|
||||
283
apps/wallet-daemon/src/app/ledger_mock.py
Normal file
283
apps/wallet-daemon/src/app/ledger_mock.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
SQLite Ledger Adapter for Wallet Daemon
|
||||
Production-ready ledger implementation (replacing missing mock)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class LedgerRecord:
|
||||
"""Ledger record for wallet events"""
|
||||
wallet_id: str
|
||||
event_type: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
success: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class WalletMetadata:
|
||||
"""Wallet metadata stored in ledger"""
|
||||
wallet_id: str
|
||||
public_key: str
|
||||
metadata: Dict[str, str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class SQLiteLedgerAdapter:
|
||||
"""Production-ready SQLite ledger adapter"""
|
||||
|
||||
def __init__(self, db_path: Optional[Path] = None):
|
||||
self.db_path = db_path or Path("./data/wallet_ledger.db")
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._lock = threading.Lock()
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
"""Initialize database schema"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
# Create wallet metadata table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS wallet_metadata (
|
||||
wallet_id TEXT PRIMARY KEY,
|
||||
public_key TEXT NOT NULL,
|
||||
metadata TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# Create events table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS wallet_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
wallet_id TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
data TEXT NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
FOREIGN KEY (wallet_id) REFERENCES wallet_metadata (wallet_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for performance
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_wallet_id ON wallet_events(wallet_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_timestamp ON wallet_events(timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON wallet_events(event_type)")
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def upsert_wallet(self, wallet_id: str, public_key: str, metadata: Dict[str, str]) -> None:
|
||||
"""Insert or update wallet metadata"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
now = datetime.utcnow().isoformat()
|
||||
metadata_json = json.dumps(metadata)
|
||||
|
||||
# Try update first
|
||||
cursor = conn.execute("""
|
||||
UPDATE wallet_metadata
|
||||
SET public_key = ?, metadata = ?, updated_at = ?
|
||||
WHERE wallet_id = ?
|
||||
""", (public_key, metadata_json, now, wallet_id))
|
||||
|
||||
# If no rows updated, insert new
|
||||
if cursor.rowcount == 0:
|
||||
conn.execute("""
|
||||
INSERT INTO wallet_metadata (wallet_id, public_key, metadata, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (wallet_id, public_key, metadata_json, now, now))
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_wallet(self, wallet_id: str) -> Optional[WalletMetadata]:
|
||||
"""Get wallet metadata"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
cursor = conn.execute("""
|
||||
SELECT wallet_id, public_key, metadata, created_at, updated_at
|
||||
FROM wallet_metadata
|
||||
WHERE wallet_id = ?
|
||||
""", (wallet_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
metadata = json.loads(row[2])
|
||||
return WalletMetadata(
|
||||
wallet_id=row[0],
|
||||
public_key=row[1],
|
||||
metadata=metadata,
|
||||
created_at=datetime.fromisoformat(row[3]),
|
||||
updated_at=datetime.fromisoformat(row[4])
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def record_event(self, wallet_id: str, event_type: str, data: Dict[str, Any]) -> None:
|
||||
"""Record a wallet event"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
now = datetime.utcnow().isoformat()
|
||||
data_json = json.dumps(data)
|
||||
success = data.get("success", True)
|
||||
|
||||
conn.execute("""
|
||||
INSERT INTO wallet_events (wallet_id, event_type, timestamp, data, success)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (wallet_id, event_type, now, data_json, int(success)))
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_wallet_events(self, wallet_id: str, limit: int = 50) -> List[LedgerRecord]:
|
||||
"""Get events for a wallet"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
cursor = conn.execute("""
|
||||
SELECT wallet_id, event_type, timestamp, data, success
|
||||
FROM wallet_events
|
||||
WHERE wallet_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (wallet_id, limit))
|
||||
|
||||
events = []
|
||||
for row in cursor.fetchall():
|
||||
data = json.loads(row[3])
|
||||
events.append(LedgerRecord(
|
||||
wallet_id=row[0],
|
||||
event_type=row[1],
|
||||
timestamp=datetime.fromisoformat(row[2]),
|
||||
data=data,
|
||||
success=bool(row[4])
|
||||
))
|
||||
|
||||
return events
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_all_wallets(self) -> List[WalletMetadata]:
|
||||
"""Get all wallets"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
cursor = conn.execute("""
|
||||
SELECT wallet_id, public_key, metadata, created_at, updated_at
|
||||
FROM wallet_metadata
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
wallets = []
|
||||
for row in cursor.fetchall():
|
||||
metadata = json.loads(row[2])
|
||||
wallets.append(WalletMetadata(
|
||||
wallet_id=row[0],
|
||||
public_key=row[1],
|
||||
metadata=metadata,
|
||||
created_at=datetime.fromisoformat(row[3]),
|
||||
updated_at=datetime.fromisoformat(row[4])
|
||||
))
|
||||
|
||||
return wallets
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get ledger statistics"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
# Wallet count
|
||||
wallet_count = conn.execute("SELECT COUNT(*) FROM wallet_metadata").fetchone()[0]
|
||||
|
||||
# Event counts by type
|
||||
event_stats = conn.execute("""
|
||||
SELECT event_type, COUNT(*) as count
|
||||
FROM wallet_events
|
||||
GROUP BY event_type
|
||||
""").fetchall()
|
||||
|
||||
# Recent activity
|
||||
recent_events = conn.execute("""
|
||||
SELECT COUNT(*) FROM wallet_events
|
||||
WHERE timestamp > datetime('now', '-24 hours')
|
||||
""").fetchone()[0]
|
||||
|
||||
return {
|
||||
"total_wallets": wallet_count,
|
||||
"event_breakdown": dict(event_stats),
|
||||
"events_last_24h": recent_events,
|
||||
"database_path": str(self.db_path)
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_wallet(self, wallet_id: str) -> bool:
|
||||
"""Delete a wallet and all its events"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
# Delete events first (foreign key constraint)
|
||||
conn.execute("DELETE FROM wallet_events WHERE wallet_id = ?", (wallet_id,))
|
||||
|
||||
# Delete wallet metadata
|
||||
cursor = conn.execute("DELETE FROM wallet_metadata WHERE wallet_id = ?", (wallet_id,))
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def backup_ledger(self, backup_path: Path) -> bool:
|
||||
"""Create a backup of the ledger database"""
|
||||
try:
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
backup_conn = sqlite3.connect(backup_path)
|
||||
conn.backup(backup_conn)
|
||||
conn.close()
|
||||
backup_conn.close()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def verify_integrity(self) -> Dict[str, Any]:
|
||||
"""Verify database integrity"""
|
||||
with self._lock:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
# Run integrity check
|
||||
result = conn.execute("PRAGMA integrity_check").fetchall()
|
||||
|
||||
# Check foreign key constraints
|
||||
fk_check = conn.execute("PRAGMA foreign_key_check").fetchall()
|
||||
|
||||
return {
|
||||
"integrity_check": result,
|
||||
"foreign_key_check": fk_check,
|
||||
"is_valid": len(result) == 1 and result[0][0] == "ok"
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
161
apps/zk-circuits/Groth16Verifier.sol
Normal file
161
apps/zk-circuits/Groth16Verifier.sol
Normal file
@@ -0,0 +1,161 @@
|
||||
// SPDX-License-Identifier: GPL-3.0
|
||||
/*
|
||||
Copyright 2021 0KIMS association.
|
||||
|
||||
This file is generated with [snarkJS](https://github.com/iden3/snarkjs).
|
||||
|
||||
snarkJS is a free software: you can redistribute it and/or modify it
|
||||
under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
snarkJS is distributed in the hope that it will be useful, but WITHOUT
|
||||
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
|
||||
or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
|
||||
License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with snarkJS. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
pragma solidity >=0.7.0 <0.9.0;
|
||||
|
||||
contract Groth16Verifier {
|
||||
// Scalar field size
|
||||
uint256 constant r = 21888242871839275222246405745257275088548364400416034343698204186575808495617;
|
||||
// Base field size
|
||||
uint256 constant q = 21888242871839275222246405745257275088696311157297823662689037894645226208583;
|
||||
|
||||
// Verification Key data
|
||||
uint256 constant alphax = 17878197547960430839188198659895507284003628546353226099044915418621989763688;
|
||||
uint256 constant alphay = 2414401954608202804440744777004803831246497417525080466014468287036253862429;
|
||||
uint256 constant betax1 = 9712108885154437847450578891476498392461803797234760197580929785758376650650;
|
||||
uint256 constant betax2 = 18272358567695662813397521777636023960648994006030407065408973578488017511142;
|
||||
uint256 constant betay1 = 21680758250979848935332437508266260788381562861496889541922176243649072173633;
|
||||
uint256 constant betay2 = 18113399933881081841371513445282849558527348349073876801631247450598780960185;
|
||||
uint256 constant gammax1 = 11559732032986387107991004021392285783925812861821192530917403151452391805634;
|
||||
uint256 constant gammax2 = 10857046999023057135944570762232829481370756359578518086990519993285655852781;
|
||||
uint256 constant gammay1 = 4082367875863433681332203403145435568316851327593401208105741076214120093531;
|
||||
uint256 constant gammay2 = 8495653923123431417604973247489272438418190587263600148770280649306958101930;
|
||||
uint256 constant deltax1 = 12774548987221780347146542577375964674074290054683884142054120470956957679394;
|
||||
uint256 constant deltax2 = 12165843319937710460660491044309080580686643140898844199182757276079170588931;
|
||||
uint256 constant deltay1 = 5902046582690481723876569491209283634644066206041445880136420948730372505228;
|
||||
uint256 constant deltay2 = 11495780469843451809285048515398120762160136824338528775648991644403497551783;
|
||||
|
||||
|
||||
uint256 constant IC0x = 4148018046519347596812177481784308374584693326254693053110348164627817172095;
|
||||
uint256 constant IC0y = 20730985524054218557052728073337277395061462810058907329882330843946617288874;
|
||||
|
||||
|
||||
// Memory data
|
||||
uint16 constant pVk = 0;
|
||||
uint16 constant pPairing = 128;
|
||||
|
||||
uint16 constant pLastMem = 896;
|
||||
|
||||
function verifyProof(uint[2] calldata _pA, uint[2][2] calldata _pB, uint[2] calldata _pC, uint[0] calldata _pubSignals) public view returns (bool) {
|
||||
assembly {
|
||||
function checkField(v) {
|
||||
if iszero(lt(v, r)) {
|
||||
mstore(0, 0)
|
||||
return(0, 0x20)
|
||||
}
|
||||
}
|
||||
|
||||
// G1 function to multiply a G1 value(x,y) to value in an address
|
||||
function g1_mulAccC(pR, x, y, s) {
|
||||
let success
|
||||
let mIn := mload(0x40)
|
||||
mstore(mIn, x)
|
||||
mstore(add(mIn, 32), y)
|
||||
mstore(add(mIn, 64), s)
|
||||
|
||||
success := staticcall(sub(gas(), 2000), 7, mIn, 96, mIn, 64)
|
||||
|
||||
if iszero(success) {
|
||||
mstore(0, 0)
|
||||
return(0, 0x20)
|
||||
}
|
||||
|
||||
mstore(add(mIn, 64), mload(pR))
|
||||
mstore(add(mIn, 96), mload(add(pR, 32)))
|
||||
|
||||
success := staticcall(sub(gas(), 2000), 6, mIn, 128, pR, 64)
|
||||
|
||||
if iszero(success) {
|
||||
mstore(0, 0)
|
||||
return(0, 0x20)
|
||||
}
|
||||
}
|
||||
|
||||
function checkPairing(pA, pB, pC, pubSignals, pMem) -> isOk {
|
||||
let _pPairing := add(pMem, pPairing)
|
||||
let _pVk := add(pMem, pVk)
|
||||
|
||||
mstore(_pVk, IC0x)
|
||||
mstore(add(_pVk, 32), IC0y)
|
||||
|
||||
// Compute the linear combination vk_x
|
||||
|
||||
|
||||
// -A
|
||||
mstore(_pPairing, calldataload(pA))
|
||||
mstore(add(_pPairing, 32), mod(sub(q, calldataload(add(pA, 32))), q))
|
||||
|
||||
// B
|
||||
mstore(add(_pPairing, 64), calldataload(pB))
|
||||
mstore(add(_pPairing, 96), calldataload(add(pB, 32)))
|
||||
mstore(add(_pPairing, 128), calldataload(add(pB, 64)))
|
||||
mstore(add(_pPairing, 160), calldataload(add(pB, 96)))
|
||||
|
||||
// alpha1
|
||||
mstore(add(_pPairing, 192), alphax)
|
||||
mstore(add(_pPairing, 224), alphay)
|
||||
|
||||
// beta2
|
||||
mstore(add(_pPairing, 256), betax1)
|
||||
mstore(add(_pPairing, 288), betax2)
|
||||
mstore(add(_pPairing, 320), betay1)
|
||||
mstore(add(_pPairing, 352), betay2)
|
||||
|
||||
// vk_x
|
||||
mstore(add(_pPairing, 384), mload(add(pMem, pVk)))
|
||||
mstore(add(_pPairing, 416), mload(add(pMem, add(pVk, 32))))
|
||||
|
||||
|
||||
// gamma2
|
||||
mstore(add(_pPairing, 448), gammax1)
|
||||
mstore(add(_pPairing, 480), gammax2)
|
||||
mstore(add(_pPairing, 512), gammay1)
|
||||
mstore(add(_pPairing, 544), gammay2)
|
||||
|
||||
// C
|
||||
mstore(add(_pPairing, 576), calldataload(pC))
|
||||
mstore(add(_pPairing, 608), calldataload(add(pC, 32)))
|
||||
|
||||
// delta2
|
||||
mstore(add(_pPairing, 640), deltax1)
|
||||
mstore(add(_pPairing, 672), deltax2)
|
||||
mstore(add(_pPairing, 704), deltay1)
|
||||
mstore(add(_pPairing, 736), deltay2)
|
||||
|
||||
|
||||
let success := staticcall(sub(gas(), 2000), 8, _pPairing, 768, _pPairing, 0x20)
|
||||
|
||||
isOk := and(success, mload(_pPairing))
|
||||
}
|
||||
|
||||
let pMem := mload(0x40)
|
||||
mstore(0x40, add(pMem, pLastMem))
|
||||
|
||||
// Validate that all evaluations ∈ F
|
||||
|
||||
|
||||
// Validate all evaluations
|
||||
let isValid := checkPairing(_pA, _pB, _pC, _pubSignals, pMem)
|
||||
|
||||
mstore(0, isValid)
|
||||
return(0, 0x20)
|
||||
}
|
||||
}
|
||||
}
|
||||
BIN
apps/zk-circuits/circuit_0000.zkey
Normal file
BIN
apps/zk-circuits/circuit_0000.zkey
Normal file
Binary file not shown.
BIN
apps/zk-circuits/circuit_0001.zkey
Normal file
BIN
apps/zk-circuits/circuit_0001.zkey
Normal file
Binary file not shown.
135
apps/zk-circuits/modular_ml_components_clean.circom
Normal file
135
apps/zk-circuits/modular_ml_components_clean.circom
Normal file
@@ -0,0 +1,135 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
/*
|
||||
* Modular ML Circuit Components
|
||||
*
|
||||
* Reusable components for machine learning circuits
|
||||
*/
|
||||
|
||||
// Basic parameter update component (gradient descent step)
|
||||
template ParameterUpdate() {
|
||||
signal input current_param;
|
||||
signal input gradient;
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_param;
|
||||
|
||||
// Simple gradient descent: new_param = current_param - learning_rate * gradient
|
||||
new_param <== current_param - learning_rate * gradient;
|
||||
}
|
||||
|
||||
// Vector parameter update component
|
||||
template VectorParameterUpdate(PARAM_COUNT) {
|
||||
signal input current_params[PARAM_COUNT];
|
||||
signal input gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_params[PARAM_COUNT];
|
||||
|
||||
component updates[PARAM_COUNT];
|
||||
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
updates[i] = ParameterUpdate();
|
||||
updates[i].current_param <== current_params[i];
|
||||
updates[i].gradient <== gradients[i];
|
||||
updates[i].learning_rate <== learning_rate;
|
||||
new_params[i] <== updates[i].new_param;
|
||||
}
|
||||
}
|
||||
|
||||
// Simple loss constraint component
|
||||
template LossConstraint() {
|
||||
signal input predicted_loss;
|
||||
signal input actual_loss;
|
||||
signal input tolerance;
|
||||
|
||||
// Constrain that |predicted_loss - actual_loss| <= tolerance
|
||||
signal diff;
|
||||
diff <== predicted_loss - actual_loss;
|
||||
|
||||
// Use absolute value constraint: diff^2 <= tolerance^2
|
||||
signal diff_squared;
|
||||
diff_squared <== diff * diff;
|
||||
|
||||
signal tolerance_squared;
|
||||
tolerance_squared <== tolerance * tolerance;
|
||||
|
||||
// This constraint ensures the loss is within tolerance
|
||||
diff_squared * (1 - diff_squared / tolerance_squared) === 0;
|
||||
}
|
||||
|
||||
// Learning rate validation component
|
||||
template LearningRateValidation() {
|
||||
signal input learning_rate;
|
||||
|
||||
// Removed constraint for optimization - learning rate validation handled externally
|
||||
// This reduces non-linear constraints from 1 to 0 for better proving performance
|
||||
}
|
||||
|
||||
// Training epoch component
|
||||
template TrainingEpoch(PARAM_COUNT) {
|
||||
signal input epoch_params[PARAM_COUNT];
|
||||
signal input epoch_gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output next_epoch_params[PARAM_COUNT];
|
||||
|
||||
component param_update = VectorParameterUpdate(PARAM_COUNT);
|
||||
param_update.current_params <== epoch_params;
|
||||
param_update.gradients <== epoch_gradients;
|
||||
param_update.learning_rate <== learning_rate;
|
||||
next_epoch_params <== param_update.new_params;
|
||||
}
|
||||
|
||||
// Main modular training verification using components
|
||||
template ModularTrainingVerification(PARAM_COUNT, EPOCHS) {
|
||||
signal input initial_parameters[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output final_parameters[PARAM_COUNT];
|
||||
signal output training_complete;
|
||||
|
||||
// Learning rate validation
|
||||
component lr_validator = LearningRateValidation();
|
||||
lr_validator.learning_rate <== learning_rate;
|
||||
|
||||
// Training epochs using modular components
|
||||
signal current_params[EPOCHS + 1][PARAM_COUNT];
|
||||
|
||||
// Initialize
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[0][i] <== initial_parameters[i];
|
||||
}
|
||||
|
||||
// Run training epochs
|
||||
component epochs[EPOCHS];
|
||||
for (var e = 0; e < EPOCHS; e++) {
|
||||
epochs[e] = TrainingEpoch(PARAM_COUNT);
|
||||
|
||||
// Input current parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_params[i] <== current_params[e][i];
|
||||
}
|
||||
|
||||
// Use constant gradients for simplicity (would be computed in real implementation)
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_gradients[i] <== 1; // Constant gradient
|
||||
}
|
||||
|
||||
epochs[e].learning_rate <== learning_rate;
|
||||
|
||||
// Store results
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[e + 1][i] <== epochs[e].next_epoch_params[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Output final parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
final_parameters[i] <== current_params[EPOCHS][i];
|
||||
}
|
||||
|
||||
training_complete <== 1;
|
||||
}
|
||||
|
||||
component main = ModularTrainingVerification(4, 3);
|
||||
135
apps/zk-circuits/modular_ml_components_fixed.circom
Normal file
135
apps/zk-circuits/modular_ml_components_fixed.circom
Normal file
@@ -0,0 +1,135 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
/*
|
||||
* Modular ML Circuit Components
|
||||
*
|
||||
* Reusable components for machine learning circuits
|
||||
*/
|
||||
|
||||
// Basic parameter update component (gradient descent step)
|
||||
template ParameterUpdate() {
|
||||
signal input current_param;
|
||||
signal input gradient;
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_param;
|
||||
|
||||
// Simple gradient descent: new_param = current_param - learning_rate * gradient
|
||||
new_param <== current_param - learning_rate * gradient;
|
||||
}
|
||||
|
||||
// Vector parameter update component
|
||||
template VectorParameterUpdate(PARAM_COUNT) {
|
||||
signal input current_params[PARAM_COUNT];
|
||||
signal input gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_params[PARAM_COUNT];
|
||||
|
||||
component updates[PARAM_COUNT];
|
||||
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
updates[i] = ParameterUpdate();
|
||||
updates[i].current_param <== current_params[i];
|
||||
updates[i].gradient <== gradients[i];
|
||||
updates[i].learning_rate <== learning_rate;
|
||||
new_params[i] <== updates[i].new_param;
|
||||
}
|
||||
}
|
||||
|
||||
// Simple loss constraint component
|
||||
template LossConstraint() {
|
||||
signal input predicted_loss;
|
||||
signal input actual_loss;
|
||||
signal input tolerance;
|
||||
|
||||
// Constrain that |predicted_loss - actual_loss| <= tolerance
|
||||
signal diff;
|
||||
diff <== predicted_loss - actual_loss;
|
||||
|
||||
// Use absolute value constraint: diff^2 <= tolerance^2
|
||||
signal diff_squared;
|
||||
diff_squared <== diff * diff;
|
||||
|
||||
signal tolerance_squared;
|
||||
tolerance_squared <== tolerance * tolerance;
|
||||
|
||||
// This constraint ensures the loss is within tolerance
|
||||
diff_squared * (1 - diff_squared / tolerance_squared) === 0;
|
||||
}
|
||||
|
||||
// Learning rate validation component
|
||||
template LearningRateValidation() {
|
||||
signal input learning_rate;
|
||||
|
||||
// Removed constraint for optimization - learning rate validation handled externally
|
||||
// This reduces non-linear constraints from 1 to 0 for better proving performance
|
||||
}
|
||||
|
||||
// Training epoch component
|
||||
template TrainingEpoch(PARAM_COUNT) {
|
||||
signal input epoch_params[PARAM_COUNT];
|
||||
signal input epoch_gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output next_epoch_params[PARAM_COUNT];
|
||||
|
||||
component param_update = VectorParameterUpdate(PARAM_COUNT);
|
||||
param_update.current_params <== epoch_params;
|
||||
param_update.gradients <== epoch_gradients;
|
||||
param_update.learning_rate <== learning_rate;
|
||||
next_epoch_params <== param_update.new_params;
|
||||
}
|
||||
|
||||
// Main modular training verification using components
|
||||
template ModularTrainingVerification(PARAM_COUNT, EPOCHS) {
|
||||
signal input initial_parameters[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output final_parameters[PARAM_COUNT];
|
||||
signal output training_complete;
|
||||
|
||||
// Learning rate validation
|
||||
component lr_validator = LearningRateValidation();
|
||||
lr_validator.learning_rate <== learning_rate;
|
||||
|
||||
// Training epochs using modular components
|
||||
signal current_params[EPOCHS + 1][PARAM_COUNT];
|
||||
|
||||
// Initialize
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[0][i] <== initial_parameters[i];
|
||||
}
|
||||
|
||||
// Run training epochs
|
||||
component epochs[EPOCHS];
|
||||
for (var e = 0; e < EPOCHS; e++) {
|
||||
epochs[e] = TrainingEpoch(PARAM_COUNT);
|
||||
|
||||
// Input current parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_params[i] <== current_params[e][i];
|
||||
}
|
||||
|
||||
// Use constant gradients for simplicity (would be computed in real implementation)
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_gradients[i] <== 1; // Constant gradient
|
||||
}
|
||||
|
||||
epochs[e].learning_rate <== learning_rate;
|
||||
|
||||
// Store results
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[e + 1][i] <== epochs[e].next_epoch_params[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Output final parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
final_parameters[i] <== current_params[EPOCHS][i];
|
||||
}
|
||||
|
||||
training_complete <== 1;
|
||||
}
|
||||
|
||||
component main = ModularTrainingVerification(4, 3);
|
||||
136
apps/zk-circuits/modular_ml_components_fixed2.circom
Normal file
136
apps/zk-circuits/modular_ml_components_fixed2.circom
Normal file
@@ -0,0 +1,136 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
|
||||
/*
|
||||
* Modular ML Circuit Components
|
||||
*
|
||||
* Reusable components for machine learning circuits
|
||||
*/
|
||||
|
||||
// Basic parameter update component (gradient descent step)
|
||||
template ParameterUpdate() {
|
||||
signal input current_param;
|
||||
signal input gradient;
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_param;
|
||||
|
||||
// Simple gradient descent: new_param = current_param - learning_rate * gradient
|
||||
new_param <== current_param - learning_rate * gradient;
|
||||
}
|
||||
|
||||
// Vector parameter update component
|
||||
template VectorParameterUpdate(PARAM_COUNT) {
|
||||
signal input current_params[PARAM_COUNT];
|
||||
signal input gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_params[PARAM_COUNT];
|
||||
|
||||
component updates[PARAM_COUNT];
|
||||
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
updates[i] = ParameterUpdate();
|
||||
updates[i].current_param <== current_params[i];
|
||||
updates[i].gradient <== gradients[i];
|
||||
updates[i].learning_rate <== learning_rate;
|
||||
new_params[i] <== updates[i].new_param;
|
||||
}
|
||||
}
|
||||
|
||||
// Simple loss constraint component
|
||||
template LossConstraint() {
|
||||
signal input predicted_loss;
|
||||
signal input actual_loss;
|
||||
signal input tolerance;
|
||||
|
||||
// Constrain that |predicted_loss - actual_loss| <= tolerance
|
||||
signal diff;
|
||||
diff <== predicted_loss - actual_loss;
|
||||
|
||||
// Use absolute value constraint: diff^2 <= tolerance^2
|
||||
signal diff_squared;
|
||||
diff_squared <== diff * diff;
|
||||
|
||||
signal tolerance_squared;
|
||||
tolerance_squared <== tolerance * tolerance;
|
||||
|
||||
// This constraint ensures the loss is within tolerance
|
||||
diff_squared * (1 - diff_squared / tolerance_squared) === 0;
|
||||
}
|
||||
|
||||
// Learning rate validation component
|
||||
template LearningRateValidation() {
|
||||
signal input learning_rate;
|
||||
|
||||
// Removed constraint for optimization - learning rate validation handled externally
|
||||
// This reduces non-linear constraints from 1 to 0 for better proving performance
|
||||
}
|
||||
|
||||
// Training epoch component
|
||||
template TrainingEpoch(PARAM_COUNT) {
|
||||
signal input epoch_params[PARAM_COUNT];
|
||||
signal input epoch_gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output next_epoch_params[PARAM_COUNT];
|
||||
|
||||
component param_update = VectorParameterUpdate(PARAM_COUNT);
|
||||
param_update.current_params <== epoch_params;
|
||||
param_update.gradients <== epoch_gradients;
|
||||
param_update.learning_rate <== learning_rate;
|
||||
next_epoch_params <== param_update.new_params;
|
||||
}
|
||||
|
||||
// Main modular training verification using components
|
||||
template ModularTrainingVerification(PARAM_COUNT, EPOCHS) {
|
||||
signal input initial_parameters[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output final_parameters[PARAM_COUNT];
|
||||
signal output training_complete;
|
||||
|
||||
// Learning rate validation
|
||||
component lr_validator = LearningRateValidation();
|
||||
lr_validator.learning_rate <== learning_rate;
|
||||
|
||||
// Training epochs using modular components
|
||||
signal current_params[EPOCHS + 1][PARAM_COUNT];
|
||||
|
||||
// Initialize
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[0][i] <== initial_parameters[i];
|
||||
}
|
||||
|
||||
// Run training epochs
|
||||
component epochs[EPOCHS];
|
||||
for (var e = 0; e < EPOCHS; e++) {
|
||||
epochs[e] = TrainingEpoch(PARAM_COUNT);
|
||||
|
||||
// Input current parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_params[i] <== current_params[e][i];
|
||||
}
|
||||
|
||||
// Use constant gradients for simplicity (would be computed in real implementation)
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_gradients[i] <== 1; // Constant gradient
|
||||
}
|
||||
|
||||
epochs[e].learning_rate <== learning_rate;
|
||||
|
||||
// Store results
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[e + 1][i] <== epochs[e].next_epoch_params[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Output final parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
final_parameters[i] <== current_params[EPOCHS][i];
|
||||
}
|
||||
|
||||
training_complete <== 1;
|
||||
}
|
||||
|
||||
component main = ModularTrainingVerification(4, 3);
|
||||
86
apps/zk-circuits/modular_ml_components_simple.circom
Normal file
86
apps/zk-circuits/modular_ml_components_simple.circom
Normal file
@@ -0,0 +1,86 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
template ParameterUpdate() {
|
||||
signal input current_param;
|
||||
signal input gradient;
|
||||
signal input learning_rate;
|
||||
signal output new_param;
|
||||
new_param <== current_param - learning_rate * gradient;
|
||||
}
|
||||
|
||||
template VectorParameterUpdate(PARAM_COUNT) {
|
||||
signal input current_params[PARAM_COUNT];
|
||||
signal input gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
signal output new_params[PARAM_COUNT];
|
||||
component updates[PARAM_COUNT];
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
updates[i] = ParameterUpdate();
|
||||
updates[i].current_param <== current_params[i];
|
||||
updates[i].gradient <== gradients[i];
|
||||
updates[i].learning_rate <== learning_rate;
|
||||
new_params[i] <== updates[i].new_param;
|
||||
}
|
||||
}
|
||||
|
||||
template LossConstraint() {
|
||||
signal input predicted_loss;
|
||||
signal input actual_loss;
|
||||
signal input tolerance;
|
||||
signal diff;
|
||||
diff <== predicted_loss - actual_loss;
|
||||
signal diff_squared;
|
||||
diff_squared <== diff * diff;
|
||||
signal tolerance_squared;
|
||||
tolerance_squared <== tolerance * tolerance;
|
||||
diff_squared * (1 - diff_squared / tolerance_squared) === 0;
|
||||
}
|
||||
|
||||
template LearningRateValidation() {
|
||||
signal input learning_rate;
|
||||
}
|
||||
|
||||
template TrainingEpoch(PARAM_COUNT) {
|
||||
signal input epoch_params[PARAM_COUNT];
|
||||
signal input epoch_gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
signal output next_epoch_params[PARAM_COUNT];
|
||||
component param_update = VectorParameterUpdate(PARAM_COUNT);
|
||||
param_update.current_params <== epoch_params;
|
||||
param_update.gradients <== epoch_gradients;
|
||||
param_update.learning_rate <== learning_rate;
|
||||
next_epoch_params <== param_update.new_params;
|
||||
}
|
||||
|
||||
template ModularTrainingVerification(PARAM_COUNT, EPOCHS) {
|
||||
signal input initial_parameters[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
signal output final_parameters[PARAM_COUNT];
|
||||
signal output training_complete;
|
||||
component lr_validator = LearningRateValidation();
|
||||
lr_validator.learning_rate <== learning_rate;
|
||||
signal current_params[EPOCHS + 1][PARAM_COUNT];
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[0][i] <== initial_parameters[i];
|
||||
}
|
||||
component epochs[EPOCHS];
|
||||
for (var e = 0; e < EPOCHS; e++) {
|
||||
epochs[e] = TrainingEpoch(PARAM_COUNT);
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_params[i] <== current_params[e][i];
|
||||
}
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_gradients[i] <== 1;
|
||||
}
|
||||
epochs[e].learning_rate <== learning_rate;
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[e + 1][i] <== epochs[e].next_epoch_params[i];
|
||||
}
|
||||
}
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
final_parameters[i] <== current_params[EPOCHS][i];
|
||||
}
|
||||
training_complete <== 1;
|
||||
}
|
||||
|
||||
component main = ModularTrainingVerification(4, 3);
|
||||
135
apps/zk-circuits/modular_ml_components_v2.circom
Normal file
135
apps/zk-circuits/modular_ml_components_v2.circom
Normal file
@@ -0,0 +1,135 @@
|
||||
pragma circom 2.1.0;
|
||||
|
||||
/*
|
||||
* Modular ML Circuit Components
|
||||
*
|
||||
* Reusable components for machine learning circuits
|
||||
*/
|
||||
|
||||
// Basic parameter update component (gradient descent step)
|
||||
template ParameterUpdate() {
|
||||
signal input current_param;
|
||||
signal input gradient;
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_param;
|
||||
|
||||
// Simple gradient descent: new_param = current_param - learning_rate * gradient
|
||||
new_param <== current_param - learning_rate * gradient;
|
||||
}
|
||||
|
||||
// Vector parameter update component
|
||||
template VectorParameterUpdate(PARAM_COUNT) {
|
||||
signal input current_params[PARAM_COUNT];
|
||||
signal input gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_params[PARAM_COUNT];
|
||||
|
||||
component updates[PARAM_COUNT];
|
||||
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
updates[i] = ParameterUpdate();
|
||||
updates[i].current_param <== current_params[i];
|
||||
updates[i].gradient <== gradients[i];
|
||||
updates[i].learning_rate <== learning_rate;
|
||||
new_params[i] <== updates[i].new_param;
|
||||
}
|
||||
}
|
||||
|
||||
// Simple loss constraint component
|
||||
template LossConstraint() {
|
||||
signal input predicted_loss;
|
||||
signal input actual_loss;
|
||||
signal input tolerance;
|
||||
|
||||
// Constrain that |predicted_loss - actual_loss| <= tolerance
|
||||
signal diff;
|
||||
diff <== predicted_loss - actual_loss;
|
||||
|
||||
// Use absolute value constraint: diff^2 <= tolerance^2
|
||||
signal diff_squared;
|
||||
diff_squared <== diff * diff;
|
||||
|
||||
signal tolerance_squared;
|
||||
tolerance_squared <== tolerance * tolerance;
|
||||
|
||||
// This constraint ensures the loss is within tolerance
|
||||
diff_squared * (1 - diff_squared / tolerance_squared) === 0;
|
||||
}
|
||||
|
||||
// Learning rate validation component
|
||||
template LearningRateValidation() {
|
||||
signal input learning_rate;
|
||||
|
||||
// Removed constraint for optimization - learning rate validation handled externally
|
||||
// This reduces non-linear constraints from 1 to 0 for better proving performance
|
||||
}
|
||||
|
||||
// Training epoch component
|
||||
template TrainingEpoch(PARAM_COUNT) {
|
||||
signal input epoch_params[PARAM_COUNT];
|
||||
signal input epoch_gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output next_epoch_params[PARAM_COUNT];
|
||||
|
||||
component param_update = VectorParameterUpdate(PARAM_COUNT);
|
||||
param_update.current_params <== epoch_params;
|
||||
param_update.gradients <== epoch_gradients;
|
||||
param_update.learning_rate <== learning_rate;
|
||||
next_epoch_params <== param_update.new_params;
|
||||
}
|
||||
|
||||
// Main modular training verification using components
|
||||
template ModularTrainingVerification(PARAM_COUNT, EPOCHS) {
|
||||
signal input initial_parameters[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output final_parameters[PARAM_COUNT];
|
||||
signal output training_complete;
|
||||
|
||||
// Learning rate validation
|
||||
component lr_validator = LearningRateValidation();
|
||||
lr_validator.learning_rate <== learning_rate;
|
||||
|
||||
// Training epochs using modular components
|
||||
signal current_params[EPOCHS + 1][PARAM_COUNT];
|
||||
|
||||
// Initialize
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[0][i] <== initial_parameters[i];
|
||||
}
|
||||
|
||||
// Run training epochs
|
||||
component epochs[EPOCHS];
|
||||
for (var e = 0; e < EPOCHS; e++) {
|
||||
epochs[e] = TrainingEpoch(PARAM_COUNT);
|
||||
|
||||
// Input current parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_params[i] <== current_params[e][i];
|
||||
}
|
||||
|
||||
// Use constant gradients for simplicity (would be computed in real implementation)
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_gradients[i] <== 1; // Constant gradient
|
||||
}
|
||||
|
||||
epochs[e].learning_rate <== learning_rate;
|
||||
|
||||
// Store results
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[e + 1][i] <== epochs[e].next_epoch_params[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Output final parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
final_parameters[i] <== current_params[EPOCHS][i];
|
||||
}
|
||||
|
||||
training_complete <== 1;
|
||||
}
|
||||
|
||||
component main = ModularTrainingVerification(4, 3);
|
||||
135
apps/zk-circuits/modular_ml_components_working.circom
Normal file
135
apps/zk-circuits/modular_ml_components_working.circom
Normal file
@@ -0,0 +1,135 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
/*
|
||||
* Modular ML Circuit Components
|
||||
*
|
||||
* Reusable components for machine learning circuits
|
||||
*/
|
||||
|
||||
// Basic parameter update component (gradient descent step)
|
||||
template ParameterUpdate() {
|
||||
signal input current_param;
|
||||
signal input gradient;
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_param;
|
||||
|
||||
// Simple gradient descent: new_param = current_param - learning_rate * gradient
|
||||
new_param <== current_param - learning_rate * gradient;
|
||||
}
|
||||
|
||||
// Vector parameter update component
|
||||
template VectorParameterUpdate(PARAM_COUNT) {
|
||||
signal input current_params[PARAM_COUNT];
|
||||
signal input gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_params[PARAM_COUNT];
|
||||
|
||||
component updates[PARAM_COUNT];
|
||||
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
updates[i] = ParameterUpdate();
|
||||
updates[i].current_param <== current_params[i];
|
||||
updates[i].gradient <== gradients[i];
|
||||
updates[i].learning_rate <== learning_rate;
|
||||
new_params[i] <== updates[i].new_param;
|
||||
}
|
||||
}
|
||||
|
||||
// Simple loss constraint component
|
||||
template LossConstraint() {
|
||||
signal input predicted_loss;
|
||||
signal input actual_loss;
|
||||
signal input tolerance;
|
||||
|
||||
// Constrain that |predicted_loss - actual_loss| <= tolerance
|
||||
signal diff;
|
||||
diff <== predicted_loss - actual_loss;
|
||||
|
||||
// Use absolute value constraint: diff^2 <= tolerance^2
|
||||
signal diff_squared;
|
||||
diff_squared <== diff * diff;
|
||||
|
||||
signal tolerance_squared;
|
||||
tolerance_squared <== tolerance * tolerance;
|
||||
|
||||
// This constraint ensures the loss is within tolerance
|
||||
diff_squared * (1 - diff_squared / tolerance_squared) === 0;
|
||||
}
|
||||
|
||||
// Learning rate validation component
|
||||
template LearningRateValidation() {
|
||||
signal input learning_rate;
|
||||
|
||||
// Removed constraint for optimization - learning rate validation handled externally
|
||||
// This reduces non-linear constraints from 1 to 0 for better proving performance
|
||||
}
|
||||
|
||||
// Training epoch component
|
||||
template TrainingEpoch(PARAM_COUNT) {
|
||||
signal input epoch_params[PARAM_COUNT];
|
||||
signal input epoch_gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output next_epoch_params[PARAM_COUNT];
|
||||
|
||||
component param_update = VectorParameterUpdate(PARAM_COUNT);
|
||||
param_update.current_params <== epoch_params;
|
||||
param_update.gradients <== epoch_gradients;
|
||||
param_update.learning_rate <== learning_rate;
|
||||
next_epoch_params <== param_update.new_params;
|
||||
}
|
||||
|
||||
// Main modular training verification using components
|
||||
template ModularTrainingVerification(PARAM_COUNT, EPOCHS) {
|
||||
signal input initial_parameters[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output final_parameters[PARAM_COUNT];
|
||||
signal output training_complete;
|
||||
|
||||
// Learning rate validation
|
||||
component lr_validator = LearningRateValidation();
|
||||
lr_validator.learning_rate <== learning_rate;
|
||||
|
||||
// Training epochs using modular components
|
||||
signal current_params[EPOCHS + 1][PARAM_COUNT];
|
||||
|
||||
// Initialize
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[0][i] <== initial_parameters[i];
|
||||
}
|
||||
|
||||
// Run training epochs
|
||||
component epochs[EPOCHS];
|
||||
for (var e = 0; e < EPOCHS; e++) {
|
||||
epochs[e] = TrainingEpoch(PARAM_COUNT);
|
||||
|
||||
// Input current parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_params[i] <== current_params[e][i];
|
||||
}
|
||||
|
||||
// Use constant gradients for simplicity (would be computed in real implementation)
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_gradients[i] <== 1; // Constant gradient
|
||||
}
|
||||
|
||||
epochs[e].learning_rate <== learning_rate;
|
||||
|
||||
// Store results
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[e + 1][i] <== epochs[e].next_epoch_params[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Output final parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
final_parameters[i] <== current_params[EPOCHS][i];
|
||||
}
|
||||
|
||||
training_complete <== 1;
|
||||
}
|
||||
|
||||
component main = ModularTrainingVerification(4, 3);
|
||||
BIN
apps/zk-circuits/modular_ml_components_working.r1cs
Normal file
BIN
apps/zk-circuits/modular_ml_components_working.r1cs
Normal file
Binary file not shown.
153
apps/zk-circuits/modular_ml_components_working.sym
Normal file
153
apps/zk-circuits/modular_ml_components_working.sym
Normal file
@@ -0,0 +1,153 @@
|
||||
1,1,4,main.final_parameters[0]
|
||||
2,2,4,main.final_parameters[1]
|
||||
3,3,4,main.final_parameters[2]
|
||||
4,4,4,main.final_parameters[3]
|
||||
5,5,4,main.training_complete
|
||||
6,6,4,main.initial_parameters[0]
|
||||
7,7,4,main.initial_parameters[1]
|
||||
8,8,4,main.initial_parameters[2]
|
||||
9,9,4,main.initial_parameters[3]
|
||||
10,10,4,main.learning_rate
|
||||
11,-1,4,main.current_params[0][0]
|
||||
12,-1,4,main.current_params[0][1]
|
||||
13,-1,4,main.current_params[0][2]
|
||||
14,-1,4,main.current_params[0][3]
|
||||
15,11,4,main.current_params[1][0]
|
||||
16,12,4,main.current_params[1][1]
|
||||
17,13,4,main.current_params[1][2]
|
||||
18,14,4,main.current_params[1][3]
|
||||
19,15,4,main.current_params[2][0]
|
||||
20,16,4,main.current_params[2][1]
|
||||
21,17,4,main.current_params[2][2]
|
||||
22,18,4,main.current_params[2][3]
|
||||
23,-1,4,main.current_params[3][0]
|
||||
24,-1,4,main.current_params[3][1]
|
||||
25,-1,4,main.current_params[3][2]
|
||||
26,-1,4,main.current_params[3][3]
|
||||
27,-1,3,main.epochs[0].next_epoch_params[0]
|
||||
28,-1,3,main.epochs[0].next_epoch_params[1]
|
||||
29,-1,3,main.epochs[0].next_epoch_params[2]
|
||||
30,-1,3,main.epochs[0].next_epoch_params[3]
|
||||
31,-1,3,main.epochs[0].epoch_params[0]
|
||||
32,-1,3,main.epochs[0].epoch_params[1]
|
||||
33,-1,3,main.epochs[0].epoch_params[2]
|
||||
34,-1,3,main.epochs[0].epoch_params[3]
|
||||
35,-1,3,main.epochs[0].epoch_gradients[0]
|
||||
36,-1,3,main.epochs[0].epoch_gradients[1]
|
||||
37,-1,3,main.epochs[0].epoch_gradients[2]
|
||||
38,-1,3,main.epochs[0].epoch_gradients[3]
|
||||
39,-1,3,main.epochs[0].learning_rate
|
||||
40,-1,2,main.epochs[0].param_update.new_params[0]
|
||||
41,-1,2,main.epochs[0].param_update.new_params[1]
|
||||
42,-1,2,main.epochs[0].param_update.new_params[2]
|
||||
43,-1,2,main.epochs[0].param_update.new_params[3]
|
||||
44,-1,2,main.epochs[0].param_update.current_params[0]
|
||||
45,-1,2,main.epochs[0].param_update.current_params[1]
|
||||
46,-1,2,main.epochs[0].param_update.current_params[2]
|
||||
47,-1,2,main.epochs[0].param_update.current_params[3]
|
||||
48,-1,2,main.epochs[0].param_update.gradients[0]
|
||||
49,-1,2,main.epochs[0].param_update.gradients[1]
|
||||
50,-1,2,main.epochs[0].param_update.gradients[2]
|
||||
51,-1,2,main.epochs[0].param_update.gradients[3]
|
||||
52,-1,2,main.epochs[0].param_update.learning_rate
|
||||
53,-1,1,main.epochs[0].param_update.updates[0].new_param
|
||||
54,-1,1,main.epochs[0].param_update.updates[0].current_param
|
||||
55,-1,1,main.epochs[0].param_update.updates[0].gradient
|
||||
56,-1,1,main.epochs[0].param_update.updates[0].learning_rate
|
||||
57,-1,1,main.epochs[0].param_update.updates[1].new_param
|
||||
58,-1,1,main.epochs[0].param_update.updates[1].current_param
|
||||
59,-1,1,main.epochs[0].param_update.updates[1].gradient
|
||||
60,-1,1,main.epochs[0].param_update.updates[1].learning_rate
|
||||
61,-1,1,main.epochs[0].param_update.updates[2].new_param
|
||||
62,-1,1,main.epochs[0].param_update.updates[2].current_param
|
||||
63,-1,1,main.epochs[0].param_update.updates[2].gradient
|
||||
64,-1,1,main.epochs[0].param_update.updates[2].learning_rate
|
||||
65,-1,1,main.epochs[0].param_update.updates[3].new_param
|
||||
66,-1,1,main.epochs[0].param_update.updates[3].current_param
|
||||
67,-1,1,main.epochs[0].param_update.updates[3].gradient
|
||||
68,-1,1,main.epochs[0].param_update.updates[3].learning_rate
|
||||
69,-1,3,main.epochs[1].next_epoch_params[0]
|
||||
70,-1,3,main.epochs[1].next_epoch_params[1]
|
||||
71,-1,3,main.epochs[1].next_epoch_params[2]
|
||||
72,-1,3,main.epochs[1].next_epoch_params[3]
|
||||
73,-1,3,main.epochs[1].epoch_params[0]
|
||||
74,-1,3,main.epochs[1].epoch_params[1]
|
||||
75,-1,3,main.epochs[1].epoch_params[2]
|
||||
76,-1,3,main.epochs[1].epoch_params[3]
|
||||
77,-1,3,main.epochs[1].epoch_gradients[0]
|
||||
78,-1,3,main.epochs[1].epoch_gradients[1]
|
||||
79,-1,3,main.epochs[1].epoch_gradients[2]
|
||||
80,-1,3,main.epochs[1].epoch_gradients[3]
|
||||
81,-1,3,main.epochs[1].learning_rate
|
||||
82,-1,2,main.epochs[1].param_update.new_params[0]
|
||||
83,-1,2,main.epochs[1].param_update.new_params[1]
|
||||
84,-1,2,main.epochs[1].param_update.new_params[2]
|
||||
85,-1,2,main.epochs[1].param_update.new_params[3]
|
||||
86,-1,2,main.epochs[1].param_update.current_params[0]
|
||||
87,-1,2,main.epochs[1].param_update.current_params[1]
|
||||
88,-1,2,main.epochs[1].param_update.current_params[2]
|
||||
89,-1,2,main.epochs[1].param_update.current_params[3]
|
||||
90,-1,2,main.epochs[1].param_update.gradients[0]
|
||||
91,-1,2,main.epochs[1].param_update.gradients[1]
|
||||
92,-1,2,main.epochs[1].param_update.gradients[2]
|
||||
93,-1,2,main.epochs[1].param_update.gradients[3]
|
||||
94,-1,2,main.epochs[1].param_update.learning_rate
|
||||
95,-1,1,main.epochs[1].param_update.updates[0].new_param
|
||||
96,-1,1,main.epochs[1].param_update.updates[0].current_param
|
||||
97,-1,1,main.epochs[1].param_update.updates[0].gradient
|
||||
98,-1,1,main.epochs[1].param_update.updates[0].learning_rate
|
||||
99,-1,1,main.epochs[1].param_update.updates[1].new_param
|
||||
100,-1,1,main.epochs[1].param_update.updates[1].current_param
|
||||
101,-1,1,main.epochs[1].param_update.updates[1].gradient
|
||||
102,-1,1,main.epochs[1].param_update.updates[1].learning_rate
|
||||
103,-1,1,main.epochs[1].param_update.updates[2].new_param
|
||||
104,-1,1,main.epochs[1].param_update.updates[2].current_param
|
||||
105,-1,1,main.epochs[1].param_update.updates[2].gradient
|
||||
106,-1,1,main.epochs[1].param_update.updates[2].learning_rate
|
||||
107,-1,1,main.epochs[1].param_update.updates[3].new_param
|
||||
108,-1,1,main.epochs[1].param_update.updates[3].current_param
|
||||
109,-1,1,main.epochs[1].param_update.updates[3].gradient
|
||||
110,-1,1,main.epochs[1].param_update.updates[3].learning_rate
|
||||
111,-1,3,main.epochs[2].next_epoch_params[0]
|
||||
112,-1,3,main.epochs[2].next_epoch_params[1]
|
||||
113,-1,3,main.epochs[2].next_epoch_params[2]
|
||||
114,-1,3,main.epochs[2].next_epoch_params[3]
|
||||
115,-1,3,main.epochs[2].epoch_params[0]
|
||||
116,-1,3,main.epochs[2].epoch_params[1]
|
||||
117,-1,3,main.epochs[2].epoch_params[2]
|
||||
118,-1,3,main.epochs[2].epoch_params[3]
|
||||
119,-1,3,main.epochs[2].epoch_gradients[0]
|
||||
120,-1,3,main.epochs[2].epoch_gradients[1]
|
||||
121,-1,3,main.epochs[2].epoch_gradients[2]
|
||||
122,-1,3,main.epochs[2].epoch_gradients[3]
|
||||
123,-1,3,main.epochs[2].learning_rate
|
||||
124,-1,2,main.epochs[2].param_update.new_params[0]
|
||||
125,-1,2,main.epochs[2].param_update.new_params[1]
|
||||
126,-1,2,main.epochs[2].param_update.new_params[2]
|
||||
127,-1,2,main.epochs[2].param_update.new_params[3]
|
||||
128,-1,2,main.epochs[2].param_update.current_params[0]
|
||||
129,-1,2,main.epochs[2].param_update.current_params[1]
|
||||
130,-1,2,main.epochs[2].param_update.current_params[2]
|
||||
131,-1,2,main.epochs[2].param_update.current_params[3]
|
||||
132,-1,2,main.epochs[2].param_update.gradients[0]
|
||||
133,-1,2,main.epochs[2].param_update.gradients[1]
|
||||
134,-1,2,main.epochs[2].param_update.gradients[2]
|
||||
135,-1,2,main.epochs[2].param_update.gradients[3]
|
||||
136,-1,2,main.epochs[2].param_update.learning_rate
|
||||
137,-1,1,main.epochs[2].param_update.updates[0].new_param
|
||||
138,-1,1,main.epochs[2].param_update.updates[0].current_param
|
||||
139,-1,1,main.epochs[2].param_update.updates[0].gradient
|
||||
140,-1,1,main.epochs[2].param_update.updates[0].learning_rate
|
||||
141,-1,1,main.epochs[2].param_update.updates[1].new_param
|
||||
142,-1,1,main.epochs[2].param_update.updates[1].current_param
|
||||
143,-1,1,main.epochs[2].param_update.updates[1].gradient
|
||||
144,-1,1,main.epochs[2].param_update.updates[1].learning_rate
|
||||
145,-1,1,main.epochs[2].param_update.updates[2].new_param
|
||||
146,-1,1,main.epochs[2].param_update.updates[2].current_param
|
||||
147,-1,1,main.epochs[2].param_update.updates[2].gradient
|
||||
148,-1,1,main.epochs[2].param_update.updates[2].learning_rate
|
||||
149,-1,1,main.epochs[2].param_update.updates[3].new_param
|
||||
150,-1,1,main.epochs[2].param_update.updates[3].current_param
|
||||
151,-1,1,main.epochs[2].param_update.updates[3].gradient
|
||||
152,-1,1,main.epochs[2].param_update.updates[3].learning_rate
|
||||
153,-1,0,main.lr_validator.learning_rate
|
||||
@@ -0,0 +1,21 @@
|
||||
const wc = require("./witness_calculator.js");
|
||||
const { readFileSync, writeFile } = require("fs");
|
||||
|
||||
if (process.argv.length != 5) {
|
||||
console.log("Usage: node generate_witness.js <file.wasm> <input.json> <output.wtns>");
|
||||
} else {
|
||||
const input = JSON.parse(readFileSync(process.argv[3], "utf8"));
|
||||
|
||||
const buffer = readFileSync(process.argv[2]);
|
||||
wc(buffer).then(async witnessCalculator => {
|
||||
/*
|
||||
const w= await witnessCalculator.calculateWitness(input,0);
|
||||
for (let i=0; i< w.length; i++){
|
||||
console.log(w[i]);
|
||||
}*/
|
||||
const buff= await witnessCalculator.calculateWTNSBin(input,0);
|
||||
writeFile(process.argv[4], buff, function(err) {
|
||||
if (err) throw err;
|
||||
});
|
||||
});
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,381 @@
|
||||
module.exports = async function builder(code, options) {
|
||||
|
||||
options = options || {};
|
||||
|
||||
let wasmModule;
|
||||
try {
|
||||
wasmModule = await WebAssembly.compile(code);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
console.log("\nTry to run circom --c in order to generate c++ code instead\n");
|
||||
throw new Error(err);
|
||||
}
|
||||
|
||||
let wc;
|
||||
|
||||
let errStr = "";
|
||||
let msgStr = "";
|
||||
|
||||
const instance = await WebAssembly.instantiate(wasmModule, {
|
||||
runtime: {
|
||||
exceptionHandler : function(code) {
|
||||
let err;
|
||||
if (code == 1) {
|
||||
err = "Signal not found.\n";
|
||||
} else if (code == 2) {
|
||||
err = "Too many signals set.\n";
|
||||
} else if (code == 3) {
|
||||
err = "Signal already set.\n";
|
||||
} else if (code == 4) {
|
||||
err = "Assert Failed.\n";
|
||||
} else if (code == 5) {
|
||||
err = "Not enough memory.\n";
|
||||
} else if (code == 6) {
|
||||
err = "Input signal array access exceeds the size.\n";
|
||||
} else {
|
||||
err = "Unknown error.\n";
|
||||
}
|
||||
throw new Error(err + errStr);
|
||||
},
|
||||
printErrorMessage : function() {
|
||||
errStr += getMessage() + "\n";
|
||||
// console.error(getMessage());
|
||||
},
|
||||
writeBufferMessage : function() {
|
||||
const msg = getMessage();
|
||||
// Any calls to `log()` will always end with a `\n`, so that's when we print and reset
|
||||
if (msg === "\n") {
|
||||
console.log(msgStr);
|
||||
msgStr = "";
|
||||
} else {
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the message to the message we are creating
|
||||
msgStr += msg;
|
||||
}
|
||||
},
|
||||
showSharedRWMemory : function() {
|
||||
printSharedRWMemory ();
|
||||
}
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
const sanityCheck =
|
||||
options
|
||||
// options &&
|
||||
// (
|
||||
// options.sanityCheck ||
|
||||
// options.logGetSignal ||
|
||||
// options.logSetSignal ||
|
||||
// options.logStartComponent ||
|
||||
// options.logFinishComponent
|
||||
// );
|
||||
|
||||
|
||||
wc = new WitnessCalculator(instance, sanityCheck);
|
||||
return wc;
|
||||
|
||||
function getMessage() {
|
||||
var message = "";
|
||||
var c = instance.exports.getMessageChar();
|
||||
while ( c != 0 ) {
|
||||
message += String.fromCharCode(c);
|
||||
c = instance.exports.getMessageChar();
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
function printSharedRWMemory () {
|
||||
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
|
||||
const arr = new Uint32Array(shared_rw_memory_size);
|
||||
for (let j=0; j<shared_rw_memory_size; j++) {
|
||||
arr[shared_rw_memory_size-1-j] = instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the value to the message we are creating
|
||||
msgStr += (fromArray32(arr).toString());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class WitnessCalculator {
|
||||
constructor(instance, sanityCheck) {
|
||||
this.instance = instance;
|
||||
|
||||
this.version = this.instance.exports.getVersion();
|
||||
this.n32 = this.instance.exports.getFieldNumLen32();
|
||||
|
||||
this.instance.exports.getRawPrime();
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let i=0; i<this.n32; i++) {
|
||||
arr[this.n32-1-i] = this.instance.exports.readSharedRWMemory(i);
|
||||
}
|
||||
this.prime = fromArray32(arr);
|
||||
|
||||
this.witnessSize = this.instance.exports.getWitnessSize();
|
||||
|
||||
this.sanityCheck = sanityCheck;
|
||||
}
|
||||
|
||||
circom_version() {
|
||||
return this.instance.exports.getVersion();
|
||||
}
|
||||
|
||||
async _doCalculateWitness(input_orig, sanityCheck) {
|
||||
//input is assumed to be a map from signals to arrays of bigints
|
||||
this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
|
||||
let prefix = "";
|
||||
var input = new Object();
|
||||
//console.log("Input: ", input_orig);
|
||||
qualify_input(prefix,input_orig,input);
|
||||
//console.log("Input after: ",input);
|
||||
const keys = Object.keys(input);
|
||||
var input_counter = 0;
|
||||
keys.forEach( (k) => {
|
||||
const h = fnvHash(k);
|
||||
const hMSB = parseInt(h.slice(0,8), 16);
|
||||
const hLSB = parseInt(h.slice(8,16), 16);
|
||||
const fArr = flatArray(input[k]);
|
||||
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
|
||||
if (signalSize < 0){
|
||||
throw new Error(`Signal ${k} not found\n`);
|
||||
}
|
||||
if (fArr.length < signalSize) {
|
||||
throw new Error(`Not enough values for input signal ${k}\n`);
|
||||
}
|
||||
if (fArr.length > signalSize) {
|
||||
throw new Error(`Too many values for input signal ${k}\n`);
|
||||
}
|
||||
for (let i=0; i<fArr.length; i++) {
|
||||
const arrFr = toArray32(normalize(fArr[i],this.prime),this.n32)
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
this.instance.exports.writeSharedRWMemory(j,arrFr[this.n32-1-j]);
|
||||
}
|
||||
try {
|
||||
this.instance.exports.setInputSignal(hMSB, hLSB,i);
|
||||
input_counter++;
|
||||
} catch (err) {
|
||||
// console.log(`After adding signal ${i} of ${k}`)
|
||||
throw new Error(err);
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
if (input_counter < this.instance.exports.getInputSize()) {
|
||||
throw new Error(`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`);
|
||||
}
|
||||
}
|
||||
|
||||
async calculateWitness(input, sanityCheck) {
|
||||
|
||||
const w = [];
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
arr[this.n32-1-j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
w.push(fromArray32(arr));
|
||||
}
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
|
||||
async calculateBinWitness(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const pos = i*this.n32;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
|
||||
async calculateWTNSBin(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32+this.n32+11);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
//"wtns"
|
||||
buff[0] = "w".charCodeAt(0)
|
||||
buff[1] = "t".charCodeAt(0)
|
||||
buff[2] = "n".charCodeAt(0)
|
||||
buff[3] = "s".charCodeAt(0)
|
||||
|
||||
//version 2
|
||||
buff32[1] = 2;
|
||||
|
||||
//number of sections: 2
|
||||
buff32[2] = 2;
|
||||
|
||||
//id section 1
|
||||
buff32[3] = 1;
|
||||
|
||||
const n8 = this.n32*4;
|
||||
//id section 1 length in 64bytes
|
||||
const idSection1length = 8 + n8;
|
||||
const idSection1lengthHex = idSection1length.toString(16);
|
||||
buff32[4] = parseInt(idSection1lengthHex.slice(0,8), 16);
|
||||
buff32[5] = parseInt(idSection1lengthHex.slice(8,16), 16);
|
||||
|
||||
//this.n32
|
||||
buff32[6] = n8;
|
||||
|
||||
//prime number
|
||||
this.instance.exports.getRawPrime();
|
||||
|
||||
var pos = 7;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
|
||||
// witness size
|
||||
buff32[pos] = this.witnessSize;
|
||||
pos++;
|
||||
|
||||
//id section 2
|
||||
buff32[pos] = 2;
|
||||
pos++;
|
||||
|
||||
// section 2 length
|
||||
const idSection2length = n8*this.witnessSize;
|
||||
const idSection2lengthHex = idSection2length.toString(16);
|
||||
buff32[pos] = parseInt(idSection2lengthHex.slice(0,8), 16);
|
||||
buff32[pos+1] = parseInt(idSection2lengthHex.slice(8,16), 16);
|
||||
|
||||
pos += 2;
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
function qualify_input_list(prefix,input,input1){
|
||||
if (Array.isArray(input)) {
|
||||
for (let i = 0; i<input.length; i++) {
|
||||
let new_prefix = prefix + "[" + i + "]";
|
||||
qualify_input_list(new_prefix,input[i],input1);
|
||||
}
|
||||
} else {
|
||||
qualify_input(prefix,input,input1);
|
||||
}
|
||||
}
|
||||
|
||||
function qualify_input(prefix,input,input1) {
|
||||
if (Array.isArray(input)) {
|
||||
a = flatArray(input);
|
||||
if (a.length > 0) {
|
||||
let t = typeof a[0];
|
||||
for (let i = 1; i<a.length; i++) {
|
||||
if (typeof a[i] != t){
|
||||
throw new Error(`Types are not the same in the key ${prefix}`);
|
||||
}
|
||||
}
|
||||
if (t == "object") {
|
||||
qualify_input_list(prefix,input,input1);
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else if (typeof input == "object") {
|
||||
const keys = Object.keys(input);
|
||||
keys.forEach( (k) => {
|
||||
let new_prefix = prefix == ""? k : prefix + "." + k;
|
||||
qualify_input(new_prefix,input[k],input1);
|
||||
});
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
}
|
||||
|
||||
function toArray32(rem,size) {
|
||||
const res = []; //new Uint32Array(size); //has no unshift
|
||||
const radix = BigInt(0x100000000);
|
||||
while (rem) {
|
||||
res.unshift( Number(rem % radix));
|
||||
rem = rem / radix;
|
||||
}
|
||||
if (size) {
|
||||
var i = size - res.length;
|
||||
while (i>0) {
|
||||
res.unshift(0);
|
||||
i--;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function fromArray32(arr) { //returns a BigInt
|
||||
var res = BigInt(0);
|
||||
const radix = BigInt(0x100000000);
|
||||
for (let i = 0; i<arr.length; i++) {
|
||||
res = res*radix + BigInt(arr[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function flatArray(a) {
|
||||
var res = [];
|
||||
fillArray(res, a);
|
||||
return res;
|
||||
|
||||
function fillArray(res, a) {
|
||||
if (Array.isArray(a)) {
|
||||
for (let i=0; i<a.length; i++) {
|
||||
fillArray(res, a[i]);
|
||||
}
|
||||
} else {
|
||||
res.push(a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function normalize(n, prime) {
|
||||
let res = BigInt(n) % prime
|
||||
if (res < 0) res += prime
|
||||
return res
|
||||
}
|
||||
|
||||
function fnvHash(str) {
|
||||
const uint64_max = BigInt(2) ** BigInt(64);
|
||||
let hash = BigInt("0xCBF29CE484222325");
|
||||
for (var i = 0; i < str.length; i++) {
|
||||
hash ^= BigInt(str[i].charCodeAt());
|
||||
hash *= BigInt(0x100000001B3);
|
||||
hash %= uint64_max;
|
||||
}
|
||||
let shash = hash.toString(16);
|
||||
let n = 16 - shash.length;
|
||||
shash = '0'.repeat(n).concat(shash);
|
||||
return shash;
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
BIN
apps/zk-circuits/pot12_simple.ptau
Normal file
BIN
apps/zk-circuits/pot12_simple.ptau
Normal file
Binary file not shown.
BIN
apps/zk-circuits/pot12_simple_1.ptau
Normal file
BIN
apps/zk-circuits/pot12_simple_1.ptau
Normal file
Binary file not shown.
BIN
apps/zk-circuits/pot12_simple_final.ptau
Normal file
BIN
apps/zk-circuits/pot12_simple_final.ptau
Normal file
Binary file not shown.
0
apps/zk-circuits/receipt.sym
Normal file
0
apps/zk-circuits/receipt.sym
Normal file
1172
apps/zk-circuits/receipt_simple.sym
Normal file
1172
apps/zk-circuits/receipt_simple.sym
Normal file
File diff suppressed because it is too large
Load Diff
89
apps/zk-circuits/receipt_simple.vkey
Normal file
89
apps/zk-circuits/receipt_simple.vkey
Normal file
@@ -0,0 +1,89 @@
|
||||
{
|
||||
"protocol": "groth16",
|
||||
"curve": "bn128",
|
||||
"nPublic": 0,
|
||||
"vk_alpha_1": [
|
||||
"17878197547960430839188198659895507284003628546353226099044915418621989763688",
|
||||
"2414401954608202804440744777004803831246497417525080466014468287036253862429",
|
||||
"1"
|
||||
],
|
||||
"vk_beta_2": [
|
||||
[
|
||||
"18272358567695662813397521777636023960648994006030407065408973578488017511142",
|
||||
"9712108885154437847450578891476498392461803797234760197580929785758376650650"
|
||||
],
|
||||
[
|
||||
"18113399933881081841371513445282849558527348349073876801631247450598780960185",
|
||||
"21680758250979848935332437508266260788381562861496889541922176243649072173633"
|
||||
],
|
||||
[
|
||||
"1",
|
||||
"0"
|
||||
]
|
||||
],
|
||||
"vk_gamma_2": [
|
||||
[
|
||||
"10857046999023057135944570762232829481370756359578518086990519993285655852781",
|
||||
"11559732032986387107991004021392285783925812861821192530917403151452391805634"
|
||||
],
|
||||
[
|
||||
"8495653923123431417604973247489272438418190587263600148770280649306958101930",
|
||||
"4082367875863433681332203403145435568316851327593401208105741076214120093531"
|
||||
],
|
||||
[
|
||||
"1",
|
||||
"0"
|
||||
]
|
||||
],
|
||||
"vk_delta_2": [
|
||||
[
|
||||
"12165843319937710460660491044309080580686643140898844199182757276079170588931",
|
||||
"12774548987221780347146542577375964674074290054683884142054120470956957679394"
|
||||
],
|
||||
[
|
||||
"11495780469843451809285048515398120762160136824338528775648991644403497551783",
|
||||
"5902046582690481723876569491209283634644066206041445880136420948730372505228"
|
||||
],
|
||||
[
|
||||
"1",
|
||||
"0"
|
||||
]
|
||||
],
|
||||
"vk_alphabeta_12": [
|
||||
[
|
||||
[
|
||||
"15043385564103330663613654339919240399186271643017395365045553432108770804738",
|
||||
"12714364888329096003970077007548283476095989444275664320665990217429249744102"
|
||||
],
|
||||
[
|
||||
"4280923934094610401199612902709542471670738247683892420346396755109361224194",
|
||||
"5971523870632604777872650089881764809688186504221764776056510270055739855107"
|
||||
],
|
||||
[
|
||||
"14459079939853070802140138225067878054463744988673516330641813466106780423229",
|
||||
"4839711251154406360161812922311023717557179750909045977849842632717981230632"
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
"17169182168985102987328363961265278197034984474501990558050161317058972083308",
|
||||
"8549555053510606289302165143849903925761285779139401276438959553414766561582"
|
||||
],
|
||||
[
|
||||
"21525840049875620673656185364575700574261940775297555537759872607176225382844",
|
||||
"10804170406986327484188973028629688550053758816273117067113206330300963522294"
|
||||
],
|
||||
[
|
||||
"922917354257837537008604464003946574270465496760193887459466960343511330098",
|
||||
"18548885909581399401732271754936250694134330406851366555038143648512851920594"
|
||||
]
|
||||
]
|
||||
],
|
||||
"IC": [
|
||||
[
|
||||
"4148018046519347596812177481784308374584693326254693053110348164627817172095",
|
||||
"20730985524054218557052728073337277395061462810058907329882330843946617288874",
|
||||
"1"
|
||||
]
|
||||
]
|
||||
}
|
||||
BIN
apps/zk-circuits/receipt_simple_0000.zkey
Normal file
BIN
apps/zk-circuits/receipt_simple_0000.zkey
Normal file
Binary file not shown.
BIN
apps/zk-circuits/receipt_simple_0001.zkey
Normal file
BIN
apps/zk-circuits/receipt_simple_0001.zkey
Normal file
Binary file not shown.
BIN
apps/zk-circuits/receipt_simple_0002.zkey
Normal file
BIN
apps/zk-circuits/receipt_simple_0002.zkey
Normal file
Binary file not shown.
20
apps/zk-circuits/receipt_simple_clean.circom
Normal file
20
apps/zk-circuits/receipt_simple_clean.circom
Normal file
@@ -0,0 +1,20 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
include "node_modules/circomlib/circuits/bitify.circom";
|
||||
include "node_modules/circomlib/circuits/poseidon.circom";
|
||||
|
||||
/*
|
||||
* Simple Receipt Attestation Circuit
|
||||
*/
|
||||
|
||||
template SimpleReceipt() {
|
||||
signal input receiptHash;
|
||||
signal input receipt[4];
|
||||
component hasher = Poseidon(4);
|
||||
for (var i = 0; i < 4; i++) {
|
||||
hasher.inputs[i] <== receipt[i];
|
||||
}
|
||||
hasher.out === receiptHash;
|
||||
}
|
||||
|
||||
component main = SimpleReceipt();
|
||||
131
apps/zk-circuits/receipt_simple_fixed.circom
Normal file
131
apps/zk-circuits/receipt_simple_fixed.circom
Normal file
@@ -0,0 +1,131 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
|
||||
include "node_modules/circomlib/circuits/bitify.circom";
|
||||
include "node_modules/circomlib/circuits/poseidon.circom";
|
||||
|
||||
/*
|
||||
* Simple Receipt Attestation Circuit
|
||||
*
|
||||
* This circuit proves that a receipt is valid without revealing sensitive details.
|
||||
*
|
||||
* Public Inputs:
|
||||
* - receiptHash: Hash of the receipt (for public verification)
|
||||
*
|
||||
* Private Inputs:
|
||||
* - receipt: The full receipt data (private)
|
||||
*/
|
||||
|
||||
template SimpleReceipt() {
|
||||
// Public signal
|
||||
signal input receiptHash;
|
||||
|
||||
// Private signals
|
||||
signal input receipt[4];
|
||||
|
||||
// Component for hashing
|
||||
component hasher = Poseidon(4);
|
||||
|
||||
// Connect private inputs to hasher
|
||||
for (var i = 0; i < 4; i++) {
|
||||
hasher.inputs[i] <== receipt[i];
|
||||
}
|
||||
|
||||
// Ensure the computed hash matches the public hash
|
||||
hasher.out === receiptHash;
|
||||
}
|
||||
|
||||
/*
|
||||
* Membership Proof Circuit
|
||||
*
|
||||
* Proves that a value is part of a set without revealing which one
|
||||
*/
|
||||
|
||||
template MembershipProof(n) {
|
||||
// Public signals
|
||||
signal input root;
|
||||
signal input nullifier;
|
||||
signal input pathIndices[n];
|
||||
|
||||
// Private signals
|
||||
signal input leaf;
|
||||
signal input pathElements[n];
|
||||
signal input salt;
|
||||
|
||||
// Component for hashing
|
||||
component hasher[n];
|
||||
|
||||
// Initialize hasher for the leaf
|
||||
hasher[0] = Poseidon(2);
|
||||
hasher[0].inputs[0] <== leaf;
|
||||
hasher[0].inputs[1] <== salt;
|
||||
|
||||
// Hash up the Merkle tree
|
||||
for (var i = 0; i < n - 1; i++) {
|
||||
hasher[i + 1] = Poseidon(2);
|
||||
|
||||
// Choose left or right based on path index
|
||||
hasher[i + 1].inputs[0] <== pathIndices[i] * pathElements[i] + (1 - pathIndices[i]) * hasher[i].out;
|
||||
hasher[i + 1].inputs[1] <== pathIndices[i] * hasher[i].out + (1 - pathIndices[i]) * pathElements[i];
|
||||
}
|
||||
|
||||
// Ensure final hash equals root
|
||||
hasher[n - 1].out === root;
|
||||
|
||||
// Compute nullifier as hash(leaf, salt)
|
||||
component nullifierHasher = Poseidon(2);
|
||||
nullifierHasher.inputs[0] <== leaf;
|
||||
nullifierHasher.inputs[1] <== salt;
|
||||
nullifierHasher.out === nullifier;
|
||||
}
|
||||
|
||||
/*
|
||||
* Bid Range Proof Circuit
|
||||
*
|
||||
* Proves that a bid is within a valid range without revealing the amount
|
||||
*/
|
||||
|
||||
template BidRangeProof() {
|
||||
// Public signals
|
||||
signal input commitment;
|
||||
signal input minAmount;
|
||||
signal input maxAmount;
|
||||
|
||||
// Private signals
|
||||
signal input bid;
|
||||
signal input salt;
|
||||
|
||||
// Component for hashing commitment
|
||||
component commitmentHasher = Poseidon(2);
|
||||
commitmentHasher.inputs[0] <== bid;
|
||||
commitmentHasher.inputs[1] <== salt;
|
||||
commitmentHasher.out === commitment;
|
||||
|
||||
// Components for range checking
|
||||
component minChecker = GreaterEqThan(8);
|
||||
component maxChecker = GreaterEqThan(8);
|
||||
|
||||
// Convert amounts to 8-bit representation
|
||||
component bidBits = Num2Bits(64);
|
||||
component minBits = Num2Bits(64);
|
||||
component maxBits = Num2Bits(64);
|
||||
|
||||
bidBits.in <== bid;
|
||||
minBits.in <== minAmount;
|
||||
maxBits.in <== maxAmount;
|
||||
|
||||
// Check bid >= minAmount
|
||||
for (var i = 0; i < 64; i++) {
|
||||
minChecker.in[i] <== bidBits.out[i] - minBits.out[i];
|
||||
}
|
||||
minChecker.out === 1;
|
||||
|
||||
// Check maxAmount >= bid
|
||||
for (var i = 0; i < 64; i++) {
|
||||
maxChecker.in[i] <== maxBits.out[i] - bidBits.out[i];
|
||||
}
|
||||
maxChecker.out === 1;
|
||||
}
|
||||
|
||||
// Main component instantiation
|
||||
component main = SimpleReceipt();
|
||||
21
apps/zk-circuits/receipt_simple_js/generate_witness.js
Normal file
21
apps/zk-circuits/receipt_simple_js/generate_witness.js
Normal file
@@ -0,0 +1,21 @@
|
||||
const wc = require("./witness_calculator.js");
|
||||
const { readFileSync, writeFile } = require("fs");
|
||||
|
||||
if (process.argv.length != 5) {
|
||||
console.log("Usage: node generate_witness.js <file.wasm> <input.json> <output.wtns>");
|
||||
} else {
|
||||
const input = JSON.parse(readFileSync(process.argv[3], "utf8"));
|
||||
|
||||
const buffer = readFileSync(process.argv[2]);
|
||||
wc(buffer).then(async witnessCalculator => {
|
||||
/*
|
||||
const w= await witnessCalculator.calculateWitness(input,0);
|
||||
for (let i=0; i< w.length; i++){
|
||||
console.log(w[i]);
|
||||
}*/
|
||||
const buff= await witnessCalculator.calculateWTNSBin(input,0);
|
||||
writeFile(process.argv[4], buff, function(err) {
|
||||
if (err) throw err;
|
||||
});
|
||||
});
|
||||
}
|
||||
BIN
apps/zk-circuits/receipt_simple_js/receipt_simple.wasm
Normal file
BIN
apps/zk-circuits/receipt_simple_js/receipt_simple.wasm
Normal file
Binary file not shown.
381
apps/zk-circuits/receipt_simple_js/witness_calculator.js
Normal file
381
apps/zk-circuits/receipt_simple_js/witness_calculator.js
Normal file
@@ -0,0 +1,381 @@
|
||||
module.exports = async function builder(code, options) {
|
||||
|
||||
options = options || {};
|
||||
|
||||
let wasmModule;
|
||||
try {
|
||||
wasmModule = await WebAssembly.compile(code);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
console.log("\nTry to run circom --c in order to generate c++ code instead\n");
|
||||
throw new Error(err);
|
||||
}
|
||||
|
||||
let wc;
|
||||
|
||||
let errStr = "";
|
||||
let msgStr = "";
|
||||
|
||||
const instance = await WebAssembly.instantiate(wasmModule, {
|
||||
runtime: {
|
||||
exceptionHandler : function(code) {
|
||||
let err;
|
||||
if (code == 1) {
|
||||
err = "Signal not found.\n";
|
||||
} else if (code == 2) {
|
||||
err = "Too many signals set.\n";
|
||||
} else if (code == 3) {
|
||||
err = "Signal already set.\n";
|
||||
} else if (code == 4) {
|
||||
err = "Assert Failed.\n";
|
||||
} else if (code == 5) {
|
||||
err = "Not enough memory.\n";
|
||||
} else if (code == 6) {
|
||||
err = "Input signal array access exceeds the size.\n";
|
||||
} else {
|
||||
err = "Unknown error.\n";
|
||||
}
|
||||
throw new Error(err + errStr);
|
||||
},
|
||||
printErrorMessage : function() {
|
||||
errStr += getMessage() + "\n";
|
||||
// console.error(getMessage());
|
||||
},
|
||||
writeBufferMessage : function() {
|
||||
const msg = getMessage();
|
||||
// Any calls to `log()` will always end with a `\n`, so that's when we print and reset
|
||||
if (msg === "\n") {
|
||||
console.log(msgStr);
|
||||
msgStr = "";
|
||||
} else {
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the message to the message we are creating
|
||||
msgStr += msg;
|
||||
}
|
||||
},
|
||||
showSharedRWMemory : function() {
|
||||
printSharedRWMemory ();
|
||||
}
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
const sanityCheck =
|
||||
options
|
||||
// options &&
|
||||
// (
|
||||
// options.sanityCheck ||
|
||||
// options.logGetSignal ||
|
||||
// options.logSetSignal ||
|
||||
// options.logStartComponent ||
|
||||
// options.logFinishComponent
|
||||
// );
|
||||
|
||||
|
||||
wc = new WitnessCalculator(instance, sanityCheck);
|
||||
return wc;
|
||||
|
||||
function getMessage() {
|
||||
var message = "";
|
||||
var c = instance.exports.getMessageChar();
|
||||
while ( c != 0 ) {
|
||||
message += String.fromCharCode(c);
|
||||
c = instance.exports.getMessageChar();
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
function printSharedRWMemory () {
|
||||
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
|
||||
const arr = new Uint32Array(shared_rw_memory_size);
|
||||
for (let j=0; j<shared_rw_memory_size; j++) {
|
||||
arr[shared_rw_memory_size-1-j] = instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the value to the message we are creating
|
||||
msgStr += (fromArray32(arr).toString());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class WitnessCalculator {
|
||||
constructor(instance, sanityCheck) {
|
||||
this.instance = instance;
|
||||
|
||||
this.version = this.instance.exports.getVersion();
|
||||
this.n32 = this.instance.exports.getFieldNumLen32();
|
||||
|
||||
this.instance.exports.getRawPrime();
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let i=0; i<this.n32; i++) {
|
||||
arr[this.n32-1-i] = this.instance.exports.readSharedRWMemory(i);
|
||||
}
|
||||
this.prime = fromArray32(arr);
|
||||
|
||||
this.witnessSize = this.instance.exports.getWitnessSize();
|
||||
|
||||
this.sanityCheck = sanityCheck;
|
||||
}
|
||||
|
||||
circom_version() {
|
||||
return this.instance.exports.getVersion();
|
||||
}
|
||||
|
||||
async _doCalculateWitness(input_orig, sanityCheck) {
|
||||
//input is assumed to be a map from signals to arrays of bigints
|
||||
this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
|
||||
let prefix = "";
|
||||
var input = new Object();
|
||||
//console.log("Input: ", input_orig);
|
||||
qualify_input(prefix,input_orig,input);
|
||||
//console.log("Input after: ",input);
|
||||
const keys = Object.keys(input);
|
||||
var input_counter = 0;
|
||||
keys.forEach( (k) => {
|
||||
const h = fnvHash(k);
|
||||
const hMSB = parseInt(h.slice(0,8), 16);
|
||||
const hLSB = parseInt(h.slice(8,16), 16);
|
||||
const fArr = flatArray(input[k]);
|
||||
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
|
||||
if (signalSize < 0){
|
||||
throw new Error(`Signal ${k} not found\n`);
|
||||
}
|
||||
if (fArr.length < signalSize) {
|
||||
throw new Error(`Not enough values for input signal ${k}\n`);
|
||||
}
|
||||
if (fArr.length > signalSize) {
|
||||
throw new Error(`Too many values for input signal ${k}\n`);
|
||||
}
|
||||
for (let i=0; i<fArr.length; i++) {
|
||||
const arrFr = toArray32(normalize(fArr[i],this.prime),this.n32)
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
this.instance.exports.writeSharedRWMemory(j,arrFr[this.n32-1-j]);
|
||||
}
|
||||
try {
|
||||
this.instance.exports.setInputSignal(hMSB, hLSB,i);
|
||||
input_counter++;
|
||||
} catch (err) {
|
||||
// console.log(`After adding signal ${i} of ${k}`)
|
||||
throw new Error(err);
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
if (input_counter < this.instance.exports.getInputSize()) {
|
||||
throw new Error(`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`);
|
||||
}
|
||||
}
|
||||
|
||||
async calculateWitness(input, sanityCheck) {
|
||||
|
||||
const w = [];
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
arr[this.n32-1-j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
w.push(fromArray32(arr));
|
||||
}
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
|
||||
async calculateBinWitness(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const pos = i*this.n32;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
|
||||
async calculateWTNSBin(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32+this.n32+11);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
//"wtns"
|
||||
buff[0] = "w".charCodeAt(0)
|
||||
buff[1] = "t".charCodeAt(0)
|
||||
buff[2] = "n".charCodeAt(0)
|
||||
buff[3] = "s".charCodeAt(0)
|
||||
|
||||
//version 2
|
||||
buff32[1] = 2;
|
||||
|
||||
//number of sections: 2
|
||||
buff32[2] = 2;
|
||||
|
||||
//id section 1
|
||||
buff32[3] = 1;
|
||||
|
||||
const n8 = this.n32*4;
|
||||
//id section 1 length in 64bytes
|
||||
const idSection1length = 8 + n8;
|
||||
const idSection1lengthHex = idSection1length.toString(16);
|
||||
buff32[4] = parseInt(idSection1lengthHex.slice(0,8), 16);
|
||||
buff32[5] = parseInt(idSection1lengthHex.slice(8,16), 16);
|
||||
|
||||
//this.n32
|
||||
buff32[6] = n8;
|
||||
|
||||
//prime number
|
||||
this.instance.exports.getRawPrime();
|
||||
|
||||
var pos = 7;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
|
||||
// witness size
|
||||
buff32[pos] = this.witnessSize;
|
||||
pos++;
|
||||
|
||||
//id section 2
|
||||
buff32[pos] = 2;
|
||||
pos++;
|
||||
|
||||
// section 2 length
|
||||
const idSection2length = n8*this.witnessSize;
|
||||
const idSection2lengthHex = idSection2length.toString(16);
|
||||
buff32[pos] = parseInt(idSection2lengthHex.slice(0,8), 16);
|
||||
buff32[pos+1] = parseInt(idSection2lengthHex.slice(8,16), 16);
|
||||
|
||||
pos += 2;
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
function qualify_input_list(prefix,input,input1){
|
||||
if (Array.isArray(input)) {
|
||||
for (let i = 0; i<input.length; i++) {
|
||||
let new_prefix = prefix + "[" + i + "]";
|
||||
qualify_input_list(new_prefix,input[i],input1);
|
||||
}
|
||||
} else {
|
||||
qualify_input(prefix,input,input1);
|
||||
}
|
||||
}
|
||||
|
||||
function qualify_input(prefix,input,input1) {
|
||||
if (Array.isArray(input)) {
|
||||
a = flatArray(input);
|
||||
if (a.length > 0) {
|
||||
let t = typeof a[0];
|
||||
for (let i = 1; i<a.length; i++) {
|
||||
if (typeof a[i] != t){
|
||||
throw new Error(`Types are not the same in the key ${prefix}`);
|
||||
}
|
||||
}
|
||||
if (t == "object") {
|
||||
qualify_input_list(prefix,input,input1);
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else if (typeof input == "object") {
|
||||
const keys = Object.keys(input);
|
||||
keys.forEach( (k) => {
|
||||
let new_prefix = prefix == ""? k : prefix + "." + k;
|
||||
qualify_input(new_prefix,input[k],input1);
|
||||
});
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
}
|
||||
|
||||
function toArray32(rem,size) {
|
||||
const res = []; //new Uint32Array(size); //has no unshift
|
||||
const radix = BigInt(0x100000000);
|
||||
while (rem) {
|
||||
res.unshift( Number(rem % radix));
|
||||
rem = rem / radix;
|
||||
}
|
||||
if (size) {
|
||||
var i = size - res.length;
|
||||
while (i>0) {
|
||||
res.unshift(0);
|
||||
i--;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function fromArray32(arr) { //returns a BigInt
|
||||
var res = BigInt(0);
|
||||
const radix = BigInt(0x100000000);
|
||||
for (let i = 0; i<arr.length; i++) {
|
||||
res = res*radix + BigInt(arr[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function flatArray(a) {
|
||||
var res = [];
|
||||
fillArray(res, a);
|
||||
return res;
|
||||
|
||||
function fillArray(res, a) {
|
||||
if (Array.isArray(a)) {
|
||||
for (let i=0; i<a.length; i++) {
|
||||
fillArray(res, a[i]);
|
||||
}
|
||||
} else {
|
||||
res.push(a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function normalize(n, prime) {
|
||||
let res = BigInt(n) % prime
|
||||
if (res < 0) res += prime
|
||||
return res
|
||||
}
|
||||
|
||||
function fnvHash(str) {
|
||||
const uint64_max = BigInt(2) ** BigInt(64);
|
||||
let hash = BigInt("0xCBF29CE484222325");
|
||||
for (var i = 0; i < str.length; i++) {
|
||||
hash ^= BigInt(str[i].charCodeAt());
|
||||
hash *= BigInt(0x100000001B3);
|
||||
hash %= uint64_max;
|
||||
}
|
||||
let shash = hash.toString(16);
|
||||
let n = 16 - shash.length;
|
||||
shash = '0'.repeat(n).concat(shash);
|
||||
return shash;
|
||||
}
|
||||
9
apps/zk-circuits/test.circom
Normal file
9
apps/zk-circuits/test.circom
Normal file
@@ -0,0 +1,9 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
template Test() {
|
||||
signal input in;
|
||||
signal output out;
|
||||
out <== in;
|
||||
}
|
||||
|
||||
component main = Test();
|
||||
9
apps/zk-circuits/test2.circom
Normal file
9
apps/zk-circuits/test2.circom
Normal file
@@ -0,0 +1,9 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
template Test() {
|
||||
signal input in;
|
||||
signal output out;
|
||||
out <== in;
|
||||
}
|
||||
|
||||
component main = Test();
|
||||
9
apps/zk-circuits/test_final.circom
Normal file
9
apps/zk-circuits/test_final.circom
Normal file
@@ -0,0 +1,9 @@
|
||||
pragma circom 0.5.46;
|
||||
|
||||
template Test() {
|
||||
signal input in;
|
||||
signal output out;
|
||||
out <== in;
|
||||
}
|
||||
|
||||
component main = Test();
|
||||
9
apps/zk-circuits/test_final_v2.circom
Normal file
9
apps/zk-circuits/test_final_v2.circom
Normal file
@@ -0,0 +1,9 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
template Test() {
|
||||
signal input in;
|
||||
signal output out;
|
||||
out <== in;
|
||||
}
|
||||
|
||||
component main = Test();
|
||||
BIN
apps/zk-circuits/test_final_v2.r1cs
Normal file
BIN
apps/zk-circuits/test_final_v2.r1cs
Normal file
Binary file not shown.
2
apps/zk-circuits/test_final_v2.sym
Normal file
2
apps/zk-circuits/test_final_v2.sym
Normal file
@@ -0,0 +1,2 @@
|
||||
1,1,0,main.out
|
||||
2,-1,0,main.in
|
||||
94
apps/zk-circuits/test_final_v2.vkey
Normal file
94
apps/zk-circuits/test_final_v2.vkey
Normal file
@@ -0,0 +1,94 @@
|
||||
{
|
||||
"protocol": "groth16",
|
||||
"curve": "bn128",
|
||||
"nPublic": 1,
|
||||
"vk_alpha_1": [
|
||||
"8460216532488165727467564856413555351114670954785488538800357260241591659922",
|
||||
"18445221864308632061488572037047946806659902339700033382142009763125814749748",
|
||||
"1"
|
||||
],
|
||||
"vk_beta_2": [
|
||||
[
|
||||
"6479683735401057464856560780016689003394325158210495956800419236111697402941",
|
||||
"10756899494323454451849886987287990433636781750938311280590204128566742369499"
|
||||
],
|
||||
[
|
||||
"14397376998117601765034877247086905021783475930686205456376147632056422933833",
|
||||
"20413115250143543082989954729570048513153861075230117372641105301032124129876"
|
||||
],
|
||||
[
|
||||
"1",
|
||||
"0"
|
||||
]
|
||||
],
|
||||
"vk_gamma_2": [
|
||||
[
|
||||
"10857046999023057135944570762232829481370756359578518086990519993285655852781",
|
||||
"11559732032986387107991004021392285783925812861821192530917403151452391805634"
|
||||
],
|
||||
[
|
||||
"8495653923123431417604973247489272438418190587263600148770280649306958101930",
|
||||
"4082367875863433681332203403145435568316851327593401208105741076214120093531"
|
||||
],
|
||||
[
|
||||
"1",
|
||||
"0"
|
||||
]
|
||||
],
|
||||
"vk_delta_2": [
|
||||
[
|
||||
"6840503012950456034406412069208230277997775373740741539262294411073505372202",
|
||||
"4187901564856243153173061219345467014727545819082218143172095490940414594424"
|
||||
],
|
||||
[
|
||||
"15354962623567401613422376703326876887451375834046173755940516337285040531401",
|
||||
"16312755549775593509550494456994863905270524213647477910622330564896885944010"
|
||||
],
|
||||
[
|
||||
"1",
|
||||
"0"
|
||||
]
|
||||
],
|
||||
"vk_alphabeta_12": [
|
||||
[
|
||||
[
|
||||
"8995664523327611111940695773435202321527189968635326175993425030330107869209",
|
||||
"10636865864911719472203481854537187286731767382234618665029688610948280447774"
|
||||
],
|
||||
[
|
||||
"2027301146985302003447473427699486288958511647692214852679531814142772884072",
|
||||
"16315179087884712852887019534812875478380885062857601375409989804072625917625"
|
||||
],
|
||||
[
|
||||
"5763629345463320911658464985147138827165295056412698652075257157933349925190",
|
||||
"18007509234277924935356458855535698088409613611430357427754027720054049931159"
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
"12742020694779715461694294344022902700171616940484742698214637693643592478776",
|
||||
"13449812718618008130272786901682245900092785345108866963867787217638117513710"
|
||||
],
|
||||
[
|
||||
"4697328451762890383542458909679544743549594918890775424620183530718745223176",
|
||||
"18283933325645065572175183630291803944633449818122421671865200510652516905389"
|
||||
],
|
||||
[
|
||||
"325914140485583140584324883490363676367108249716427038595477057788929554745",
|
||||
"6765772614216179391904319393793642468016331619939680620407685333447433218960"
|
||||
]
|
||||
]
|
||||
],
|
||||
"IC": [
|
||||
[
|
||||
"7685121570366407724807946503921961619833683410392772870373459476604128011275",
|
||||
"6915443837935167692630810275110398177336960270031115982900890650376967129575",
|
||||
"1"
|
||||
],
|
||||
[
|
||||
"10363999014224824591638032348857401078402637116683579765969796919683926972060",
|
||||
"5716124078230277423780595544607422628270452574948632939527677487979409581469",
|
||||
"1"
|
||||
]
|
||||
]
|
||||
}
|
||||
BIN
apps/zk-circuits/test_final_v2_0000.zkey
Normal file
BIN
apps/zk-circuits/test_final_v2_0000.zkey
Normal file
Binary file not shown.
BIN
apps/zk-circuits/test_final_v2_0001.zkey
Normal file
BIN
apps/zk-circuits/test_final_v2_0001.zkey
Normal file
Binary file not shown.
21
apps/zk-circuits/test_final_v2_js/generate_witness.js
Normal file
21
apps/zk-circuits/test_final_v2_js/generate_witness.js
Normal file
@@ -0,0 +1,21 @@
|
||||
const wc = require("./witness_calculator.js");
|
||||
const { readFileSync, writeFile } = require("fs");
|
||||
|
||||
if (process.argv.length != 5) {
|
||||
console.log("Usage: node generate_witness.js <file.wasm> <input.json> <output.wtns>");
|
||||
} else {
|
||||
const input = JSON.parse(readFileSync(process.argv[3], "utf8"));
|
||||
|
||||
const buffer = readFileSync(process.argv[2]);
|
||||
wc(buffer).then(async witnessCalculator => {
|
||||
/*
|
||||
const w= await witnessCalculator.calculateWitness(input,0);
|
||||
for (let i=0; i< w.length; i++){
|
||||
console.log(w[i]);
|
||||
}*/
|
||||
const buff= await witnessCalculator.calculateWTNSBin(input,0);
|
||||
writeFile(process.argv[4], buff, function(err) {
|
||||
if (err) throw err;
|
||||
});
|
||||
});
|
||||
}
|
||||
BIN
apps/zk-circuits/test_final_v2_js/test_final_v2.wasm
Normal file
BIN
apps/zk-circuits/test_final_v2_js/test_final_v2.wasm
Normal file
Binary file not shown.
381
apps/zk-circuits/test_final_v2_js/witness_calculator.js
Normal file
381
apps/zk-circuits/test_final_v2_js/witness_calculator.js
Normal file
@@ -0,0 +1,381 @@
|
||||
module.exports = async function builder(code, options) {
|
||||
|
||||
options = options || {};
|
||||
|
||||
let wasmModule;
|
||||
try {
|
||||
wasmModule = await WebAssembly.compile(code);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
console.log("\nTry to run circom --c in order to generate c++ code instead\n");
|
||||
throw new Error(err);
|
||||
}
|
||||
|
||||
let wc;
|
||||
|
||||
let errStr = "";
|
||||
let msgStr = "";
|
||||
|
||||
const instance = await WebAssembly.instantiate(wasmModule, {
|
||||
runtime: {
|
||||
exceptionHandler : function(code) {
|
||||
let err;
|
||||
if (code == 1) {
|
||||
err = "Signal not found.\n";
|
||||
} else if (code == 2) {
|
||||
err = "Too many signals set.\n";
|
||||
} else if (code == 3) {
|
||||
err = "Signal already set.\n";
|
||||
} else if (code == 4) {
|
||||
err = "Assert Failed.\n";
|
||||
} else if (code == 5) {
|
||||
err = "Not enough memory.\n";
|
||||
} else if (code == 6) {
|
||||
err = "Input signal array access exceeds the size.\n";
|
||||
} else {
|
||||
err = "Unknown error.\n";
|
||||
}
|
||||
throw new Error(err + errStr);
|
||||
},
|
||||
printErrorMessage : function() {
|
||||
errStr += getMessage() + "\n";
|
||||
// console.error(getMessage());
|
||||
},
|
||||
writeBufferMessage : function() {
|
||||
const msg = getMessage();
|
||||
// Any calls to `log()` will always end with a `\n`, so that's when we print and reset
|
||||
if (msg === "\n") {
|
||||
console.log(msgStr);
|
||||
msgStr = "";
|
||||
} else {
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the message to the message we are creating
|
||||
msgStr += msg;
|
||||
}
|
||||
},
|
||||
showSharedRWMemory : function() {
|
||||
printSharedRWMemory ();
|
||||
}
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
const sanityCheck =
|
||||
options
|
||||
// options &&
|
||||
// (
|
||||
// options.sanityCheck ||
|
||||
// options.logGetSignal ||
|
||||
// options.logSetSignal ||
|
||||
// options.logStartComponent ||
|
||||
// options.logFinishComponent
|
||||
// );
|
||||
|
||||
|
||||
wc = new WitnessCalculator(instance, sanityCheck);
|
||||
return wc;
|
||||
|
||||
function getMessage() {
|
||||
var message = "";
|
||||
var c = instance.exports.getMessageChar();
|
||||
while ( c != 0 ) {
|
||||
message += String.fromCharCode(c);
|
||||
c = instance.exports.getMessageChar();
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
function printSharedRWMemory () {
|
||||
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
|
||||
const arr = new Uint32Array(shared_rw_memory_size);
|
||||
for (let j=0; j<shared_rw_memory_size; j++) {
|
||||
arr[shared_rw_memory_size-1-j] = instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the value to the message we are creating
|
||||
msgStr += (fromArray32(arr).toString());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class WitnessCalculator {
|
||||
constructor(instance, sanityCheck) {
|
||||
this.instance = instance;
|
||||
|
||||
this.version = this.instance.exports.getVersion();
|
||||
this.n32 = this.instance.exports.getFieldNumLen32();
|
||||
|
||||
this.instance.exports.getRawPrime();
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let i=0; i<this.n32; i++) {
|
||||
arr[this.n32-1-i] = this.instance.exports.readSharedRWMemory(i);
|
||||
}
|
||||
this.prime = fromArray32(arr);
|
||||
|
||||
this.witnessSize = this.instance.exports.getWitnessSize();
|
||||
|
||||
this.sanityCheck = sanityCheck;
|
||||
}
|
||||
|
||||
circom_version() {
|
||||
return this.instance.exports.getVersion();
|
||||
}
|
||||
|
||||
async _doCalculateWitness(input_orig, sanityCheck) {
|
||||
//input is assumed to be a map from signals to arrays of bigints
|
||||
this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
|
||||
let prefix = "";
|
||||
var input = new Object();
|
||||
//console.log("Input: ", input_orig);
|
||||
qualify_input(prefix,input_orig,input);
|
||||
//console.log("Input after: ",input);
|
||||
const keys = Object.keys(input);
|
||||
var input_counter = 0;
|
||||
keys.forEach( (k) => {
|
||||
const h = fnvHash(k);
|
||||
const hMSB = parseInt(h.slice(0,8), 16);
|
||||
const hLSB = parseInt(h.slice(8,16), 16);
|
||||
const fArr = flatArray(input[k]);
|
||||
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
|
||||
if (signalSize < 0){
|
||||
throw new Error(`Signal ${k} not found\n`);
|
||||
}
|
||||
if (fArr.length < signalSize) {
|
||||
throw new Error(`Not enough values for input signal ${k}\n`);
|
||||
}
|
||||
if (fArr.length > signalSize) {
|
||||
throw new Error(`Too many values for input signal ${k}\n`);
|
||||
}
|
||||
for (let i=0; i<fArr.length; i++) {
|
||||
const arrFr = toArray32(normalize(fArr[i],this.prime),this.n32)
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
this.instance.exports.writeSharedRWMemory(j,arrFr[this.n32-1-j]);
|
||||
}
|
||||
try {
|
||||
this.instance.exports.setInputSignal(hMSB, hLSB,i);
|
||||
input_counter++;
|
||||
} catch (err) {
|
||||
// console.log(`After adding signal ${i} of ${k}`)
|
||||
throw new Error(err);
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
if (input_counter < this.instance.exports.getInputSize()) {
|
||||
throw new Error(`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`);
|
||||
}
|
||||
}
|
||||
|
||||
async calculateWitness(input, sanityCheck) {
|
||||
|
||||
const w = [];
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
arr[this.n32-1-j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
w.push(fromArray32(arr));
|
||||
}
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
|
||||
async calculateBinWitness(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const pos = i*this.n32;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
|
||||
async calculateWTNSBin(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32+this.n32+11);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
//"wtns"
|
||||
buff[0] = "w".charCodeAt(0)
|
||||
buff[1] = "t".charCodeAt(0)
|
||||
buff[2] = "n".charCodeAt(0)
|
||||
buff[3] = "s".charCodeAt(0)
|
||||
|
||||
//version 2
|
||||
buff32[1] = 2;
|
||||
|
||||
//number of sections: 2
|
||||
buff32[2] = 2;
|
||||
|
||||
//id section 1
|
||||
buff32[3] = 1;
|
||||
|
||||
const n8 = this.n32*4;
|
||||
//id section 1 length in 64bytes
|
||||
const idSection1length = 8 + n8;
|
||||
const idSection1lengthHex = idSection1length.toString(16);
|
||||
buff32[4] = parseInt(idSection1lengthHex.slice(0,8), 16);
|
||||
buff32[5] = parseInt(idSection1lengthHex.slice(8,16), 16);
|
||||
|
||||
//this.n32
|
||||
buff32[6] = n8;
|
||||
|
||||
//prime number
|
||||
this.instance.exports.getRawPrime();
|
||||
|
||||
var pos = 7;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
|
||||
// witness size
|
||||
buff32[pos] = this.witnessSize;
|
||||
pos++;
|
||||
|
||||
//id section 2
|
||||
buff32[pos] = 2;
|
||||
pos++;
|
||||
|
||||
// section 2 length
|
||||
const idSection2length = n8*this.witnessSize;
|
||||
const idSection2lengthHex = idSection2length.toString(16);
|
||||
buff32[pos] = parseInt(idSection2lengthHex.slice(0,8), 16);
|
||||
buff32[pos+1] = parseInt(idSection2lengthHex.slice(8,16), 16);
|
||||
|
||||
pos += 2;
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
function qualify_input_list(prefix,input,input1){
|
||||
if (Array.isArray(input)) {
|
||||
for (let i = 0; i<input.length; i++) {
|
||||
let new_prefix = prefix + "[" + i + "]";
|
||||
qualify_input_list(new_prefix,input[i],input1);
|
||||
}
|
||||
} else {
|
||||
qualify_input(prefix,input,input1);
|
||||
}
|
||||
}
|
||||
|
||||
function qualify_input(prefix,input,input1) {
|
||||
if (Array.isArray(input)) {
|
||||
a = flatArray(input);
|
||||
if (a.length > 0) {
|
||||
let t = typeof a[0];
|
||||
for (let i = 1; i<a.length; i++) {
|
||||
if (typeof a[i] != t){
|
||||
throw new Error(`Types are not the same in the key ${prefix}`);
|
||||
}
|
||||
}
|
||||
if (t == "object") {
|
||||
qualify_input_list(prefix,input,input1);
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else if (typeof input == "object") {
|
||||
const keys = Object.keys(input);
|
||||
keys.forEach( (k) => {
|
||||
let new_prefix = prefix == ""? k : prefix + "." + k;
|
||||
qualify_input(new_prefix,input[k],input1);
|
||||
});
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
}
|
||||
|
||||
function toArray32(rem,size) {
|
||||
const res = []; //new Uint32Array(size); //has no unshift
|
||||
const radix = BigInt(0x100000000);
|
||||
while (rem) {
|
||||
res.unshift( Number(rem % radix));
|
||||
rem = rem / radix;
|
||||
}
|
||||
if (size) {
|
||||
var i = size - res.length;
|
||||
while (i>0) {
|
||||
res.unshift(0);
|
||||
i--;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function fromArray32(arr) { //returns a BigInt
|
||||
var res = BigInt(0);
|
||||
const radix = BigInt(0x100000000);
|
||||
for (let i = 0; i<arr.length; i++) {
|
||||
res = res*radix + BigInt(arr[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function flatArray(a) {
|
||||
var res = [];
|
||||
fillArray(res, a);
|
||||
return res;
|
||||
|
||||
function fillArray(res, a) {
|
||||
if (Array.isArray(a)) {
|
||||
for (let i=0; i<a.length; i++) {
|
||||
fillArray(res, a[i]);
|
||||
}
|
||||
} else {
|
||||
res.push(a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function normalize(n, prime) {
|
||||
let res = BigInt(n) % prime
|
||||
if (res < 0) res += prime
|
||||
return res
|
||||
}
|
||||
|
||||
function fnvHash(str) {
|
||||
const uint64_max = BigInt(2) ** BigInt(64);
|
||||
let hash = BigInt("0xCBF29CE484222325");
|
||||
for (var i = 0; i < str.length; i++) {
|
||||
hash ^= BigInt(str[i].charCodeAt());
|
||||
hash *= BigInt(0x100000001B3);
|
||||
hash %= uint64_max;
|
||||
}
|
||||
let shash = hash.toString(16);
|
||||
let n = 16 - shash.length;
|
||||
shash = '0'.repeat(n).concat(shash);
|
||||
return shash;
|
||||
}
|
||||
168
apps/zk-circuits/test_final_verifier.sol
Normal file
168
apps/zk-circuits/test_final_verifier.sol
Normal file
@@ -0,0 +1,168 @@
|
||||
// SPDX-License-Identifier: GPL-3.0
|
||||
/*
|
||||
Copyright 2021 0KIMS association.
|
||||
|
||||
This file is generated with [snarkJS](https://github.com/iden3/snarkjs).
|
||||
|
||||
snarkJS is a free software: you can redistribute it and/or modify it
|
||||
under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
snarkJS is distributed in the hope that it will be useful, but WITHOUT
|
||||
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
|
||||
or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
|
||||
License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with snarkJS. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
pragma solidity >=0.7.0 <0.9.0;
|
||||
|
||||
contract Groth16Verifier {
|
||||
// Scalar field size
|
||||
uint256 constant r = 21888242871839275222246405745257275088548364400416034343698204186575808495617;
|
||||
// Base field size
|
||||
uint256 constant q = 21888242871839275222246405745257275088696311157297823662689037894645226208583;
|
||||
|
||||
// Verification Key data
|
||||
uint256 constant alphax = 8460216532488165727467564856413555351114670954785488538800357260241591659922;
|
||||
uint256 constant alphay = 18445221864308632061488572037047946806659902339700033382142009763125814749748;
|
||||
uint256 constant betax1 = 10756899494323454451849886987287990433636781750938311280590204128566742369499;
|
||||
uint256 constant betax2 = 6479683735401057464856560780016689003394325158210495956800419236111697402941;
|
||||
uint256 constant betay1 = 20413115250143543082989954729570048513153861075230117372641105301032124129876;
|
||||
uint256 constant betay2 = 14397376998117601765034877247086905021783475930686205456376147632056422933833;
|
||||
uint256 constant gammax1 = 11559732032986387107991004021392285783925812861821192530917403151452391805634;
|
||||
uint256 constant gammax2 = 10857046999023057135944570762232829481370756359578518086990519993285655852781;
|
||||
uint256 constant gammay1 = 4082367875863433681332203403145435568316851327593401208105741076214120093531;
|
||||
uint256 constant gammay2 = 8495653923123431417604973247489272438418190587263600148770280649306958101930;
|
||||
uint256 constant deltax1 = 4187901564856243153173061219345467014727545819082218143172095490940414594424;
|
||||
uint256 constant deltax2 = 6840503012950456034406412069208230277997775373740741539262294411073505372202;
|
||||
uint256 constant deltay1 = 16312755549775593509550494456994863905270524213647477910622330564896885944010;
|
||||
uint256 constant deltay2 = 15354962623567401613422376703326876887451375834046173755940516337285040531401;
|
||||
|
||||
|
||||
uint256 constant IC0x = 7685121570366407724807946503921961619833683410392772870373459476604128011275;
|
||||
uint256 constant IC0y = 6915443837935167692630810275110398177336960270031115982900890650376967129575;
|
||||
|
||||
uint256 constant IC1x = 10363999014224824591638032348857401078402637116683579765969796919683926972060;
|
||||
uint256 constant IC1y = 5716124078230277423780595544607422628270452574948632939527677487979409581469;
|
||||
|
||||
|
||||
// Memory data
|
||||
uint16 constant pVk = 0;
|
||||
uint16 constant pPairing = 128;
|
||||
|
||||
uint16 constant pLastMem = 896;
|
||||
|
||||
function verifyProof(uint[2] calldata _pA, uint[2][2] calldata _pB, uint[2] calldata _pC, uint[1] calldata _pubSignals) public view returns (bool) {
|
||||
assembly {
|
||||
function checkField(v) {
|
||||
if iszero(lt(v, r)) {
|
||||
mstore(0, 0)
|
||||
return(0, 0x20)
|
||||
}
|
||||
}
|
||||
|
||||
// G1 function to multiply a G1 value(x,y) to value in an address
|
||||
function g1_mulAccC(pR, x, y, s) {
|
||||
let success
|
||||
let mIn := mload(0x40)
|
||||
mstore(mIn, x)
|
||||
mstore(add(mIn, 32), y)
|
||||
mstore(add(mIn, 64), s)
|
||||
|
||||
success := staticcall(sub(gas(), 2000), 7, mIn, 96, mIn, 64)
|
||||
|
||||
if iszero(success) {
|
||||
mstore(0, 0)
|
||||
return(0, 0x20)
|
||||
}
|
||||
|
||||
mstore(add(mIn, 64), mload(pR))
|
||||
mstore(add(mIn, 96), mload(add(pR, 32)))
|
||||
|
||||
success := staticcall(sub(gas(), 2000), 6, mIn, 128, pR, 64)
|
||||
|
||||
if iszero(success) {
|
||||
mstore(0, 0)
|
||||
return(0, 0x20)
|
||||
}
|
||||
}
|
||||
|
||||
function checkPairing(pA, pB, pC, pubSignals, pMem) -> isOk {
|
||||
let _pPairing := add(pMem, pPairing)
|
||||
let _pVk := add(pMem, pVk)
|
||||
|
||||
mstore(_pVk, IC0x)
|
||||
mstore(add(_pVk, 32), IC0y)
|
||||
|
||||
// Compute the linear combination vk_x
|
||||
|
||||
g1_mulAccC(_pVk, IC1x, IC1y, calldataload(add(pubSignals, 0)))
|
||||
|
||||
|
||||
// -A
|
||||
mstore(_pPairing, calldataload(pA))
|
||||
mstore(add(_pPairing, 32), mod(sub(q, calldataload(add(pA, 32))), q))
|
||||
|
||||
// B
|
||||
mstore(add(_pPairing, 64), calldataload(pB))
|
||||
mstore(add(_pPairing, 96), calldataload(add(pB, 32)))
|
||||
mstore(add(_pPairing, 128), calldataload(add(pB, 64)))
|
||||
mstore(add(_pPairing, 160), calldataload(add(pB, 96)))
|
||||
|
||||
// alpha1
|
||||
mstore(add(_pPairing, 192), alphax)
|
||||
mstore(add(_pPairing, 224), alphay)
|
||||
|
||||
// beta2
|
||||
mstore(add(_pPairing, 256), betax1)
|
||||
mstore(add(_pPairing, 288), betax2)
|
||||
mstore(add(_pPairing, 320), betay1)
|
||||
mstore(add(_pPairing, 352), betay2)
|
||||
|
||||
// vk_x
|
||||
mstore(add(_pPairing, 384), mload(add(pMem, pVk)))
|
||||
mstore(add(_pPairing, 416), mload(add(pMem, add(pVk, 32))))
|
||||
|
||||
|
||||
// gamma2
|
||||
mstore(add(_pPairing, 448), gammax1)
|
||||
mstore(add(_pPairing, 480), gammax2)
|
||||
mstore(add(_pPairing, 512), gammay1)
|
||||
mstore(add(_pPairing, 544), gammay2)
|
||||
|
||||
// C
|
||||
mstore(add(_pPairing, 576), calldataload(pC))
|
||||
mstore(add(_pPairing, 608), calldataload(add(pC, 32)))
|
||||
|
||||
// delta2
|
||||
mstore(add(_pPairing, 640), deltax1)
|
||||
mstore(add(_pPairing, 672), deltax2)
|
||||
mstore(add(_pPairing, 704), deltay1)
|
||||
mstore(add(_pPairing, 736), deltay2)
|
||||
|
||||
|
||||
let success := staticcall(sub(gas(), 2000), 8, _pPairing, 768, _pPairing, 0x20)
|
||||
|
||||
isOk := and(success, mload(_pPairing))
|
||||
}
|
||||
|
||||
let pMem := mload(0x40)
|
||||
mstore(0x40, add(pMem, pLastMem))
|
||||
|
||||
// Validate that all evaluations ∈ F
|
||||
|
||||
checkField(calldataload(add(_pubSignals, 0)))
|
||||
|
||||
|
||||
// Validate all evaluations
|
||||
let isValid := checkPairing(_pA, _pB, _pC, _pubSignals, pMem)
|
||||
|
||||
mstore(0, isValid)
|
||||
return(0, 0x20)
|
||||
}
|
||||
}
|
||||
}
|
||||
9
apps/zk-circuits/test_legacy.circom
Normal file
9
apps/zk-circuits/test_legacy.circom
Normal file
@@ -0,0 +1,9 @@
|
||||
pragma circom 0.5.46;
|
||||
|
||||
template Test() {
|
||||
signal input in;
|
||||
signal output out;
|
||||
out <== in;
|
||||
}
|
||||
|
||||
component main = Test();
|
||||
10
apps/zk-circuits/test_legacy2.circom
Normal file
10
apps/zk-circuits/test_legacy2.circom
Normal file
@@ -0,0 +1,10 @@
|
||||
pragma circom 0.5.46;
|
||||
|
||||
|
||||
template Test() {
|
||||
signal input in;
|
||||
signal output out;
|
||||
out <== in;
|
||||
}
|
||||
|
||||
component main = Test();
|
||||
0
apps/zk-circuits/wtns.wtns
Normal file
0
apps/zk-circuits/wtns.wtns
Normal file
0
apps/zk-circuits/wtns_simple.wtns
Normal file
0
apps/zk-circuits/wtns_simple.wtns
Normal file
0
apps/zk-circuits/wtns_valid.wtns
Normal file
0
apps/zk-circuits/wtns_valid.wtns
Normal file
Reference in New Issue
Block a user