refactor(contracts): remove deprecated AIPowerRental contract in favor of bounty system
- Delete AIPowerRental.sol (566 lines) - replaced by AgentBounty.sol - Remove rental agreement system with provider/consumer model - Remove performance metrics and SLA tracking - Remove dispute resolution mechanism - Remove ZK-proof verification for performance - Remove provider/consumer authorization system - Bounty system provides superior developer incentive structure
This commit is contained in:
993
apps/coordinator-api/src/app/services/advanced_learning.py
Normal file
993
apps/coordinator-api/src/app/services/advanced_learning.py
Normal file
@@ -0,0 +1,993 @@
|
||||
"""
|
||||
Advanced Learning Service for AI-Powered Agent Features
|
||||
Implements meta-learning, federated learning, and continuous model improvement
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple, Union
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import json
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LearningType(str, Enum):
|
||||
"""Types of learning approaches"""
|
||||
META_LEARNING = "meta_learning"
|
||||
FEDERATED = "federated"
|
||||
REINFORCEMENT = "reinforcement"
|
||||
SUPERVISED = "supervised"
|
||||
UNSUPERVISED = "unsupervised"
|
||||
TRANSFER = "transfer"
|
||||
CONTINUAL = "continual"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
"""Types of AI models"""
|
||||
TASK_PLANNING = "task_planning"
|
||||
BIDDING_STRATEGY = "bidding_strategy"
|
||||
RESOURCE_ALLOCATION = "resource_allocation"
|
||||
COMMUNICATION = "communication"
|
||||
COLLABORATION = "collaboration"
|
||||
DECISION_MAKING = "decision_making"
|
||||
PREDICTION = "prediction"
|
||||
CLASSIFICATION = "classification"
|
||||
|
||||
|
||||
class LearningStatus(str, Enum):
|
||||
"""Learning process status"""
|
||||
INITIALIZING = "initializing"
|
||||
TRAINING = "training"
|
||||
VALIDATING = "validating"
|
||||
DEPLOYING = "deploying"
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
FAILED = "failed"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LearningModel:
|
||||
"""AI learning model information"""
|
||||
id: str
|
||||
agent_id: str
|
||||
model_type: ModelType
|
||||
learning_type: LearningType
|
||||
version: str
|
||||
parameters: Dict[str, Any]
|
||||
performance_metrics: Dict[str, float]
|
||||
training_data_size: int
|
||||
validation_data_size: int
|
||||
created_at: datetime
|
||||
last_updated: datetime
|
||||
status: LearningStatus
|
||||
accuracy: float = 0.0
|
||||
precision: float = 0.0
|
||||
recall: float = 0.0
|
||||
f1_score: float = 0.0
|
||||
loss: float = 0.0
|
||||
training_time: float = 0.0
|
||||
inference_time: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LearningSession:
|
||||
"""Learning session information"""
|
||||
id: str
|
||||
model_id: str
|
||||
agent_id: str
|
||||
learning_type: LearningType
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime]
|
||||
status: LearningStatus
|
||||
training_data: List[Dict[str, Any]]
|
||||
validation_data: List[Dict[str, Any]]
|
||||
hyperparameters: Dict[str, Any]
|
||||
results: Dict[str, float]
|
||||
iterations: int
|
||||
convergence_threshold: float
|
||||
early_stopping: bool
|
||||
checkpoint_frequency: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class FederatedNode:
|
||||
"""Federated learning node information"""
|
||||
id: str
|
||||
agent_id: str
|
||||
endpoint: str
|
||||
data_size: int
|
||||
model_version: str
|
||||
last_sync: datetime
|
||||
contribution_weight: float
|
||||
bandwidth_limit: int
|
||||
computation_limit: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetaLearningTask:
|
||||
"""Meta-learning task definition"""
|
||||
id: str
|
||||
task_type: str
|
||||
input_features: List[str]
|
||||
output_features: List[str]
|
||||
support_set_size: int
|
||||
query_set_size: int
|
||||
adaptation_steps: int
|
||||
inner_lr: float
|
||||
outer_lr: float
|
||||
meta_iterations: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class LearningAnalytics:
|
||||
"""Learning analytics data"""
|
||||
agent_id: str
|
||||
model_id: str
|
||||
total_training_time: float
|
||||
total_inference_time: float
|
||||
accuracy_improvement: float
|
||||
performance_gain: float
|
||||
data_efficiency: float
|
||||
computation_efficiency: float
|
||||
learning_rate: float
|
||||
convergence_speed: float
|
||||
last_evaluation: datetime
|
||||
|
||||
|
||||
class AdvancedLearningService:
|
||||
"""Service for advanced AI learning capabilities"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.models: Dict[str, LearningModel] = {}
|
||||
self.learning_sessions: Dict[str, LearningSession] = {}
|
||||
self.federated_nodes: Dict[str, FederatedNode] = {}
|
||||
self.meta_learning_tasks: Dict[str, MetaLearningTask] = {}
|
||||
self.learning_analytics: Dict[str, LearningAnalytics] = {}
|
||||
|
||||
# Configuration
|
||||
self.max_model_size = 100 * 1024 * 1024 # 100MB
|
||||
self.max_training_time = 3600 # 1 hour
|
||||
self.default_batch_size = 32
|
||||
self.default_learning_rate = 0.001
|
||||
self.convergence_threshold = 0.001
|
||||
self.early_stopping_patience = 10
|
||||
|
||||
# Learning algorithms
|
||||
self.meta_learning_algorithms = ["MAML", "Reptile", "Meta-SGD"]
|
||||
self.federated_algorithms = ["FedAvg", "FedProx", "FedNova"]
|
||||
self.reinforcement_algorithms = ["DQN", "PPO", "A3C", "SAC"]
|
||||
|
||||
# Model registry
|
||||
self.model_templates: Dict[ModelType, Dict[str, Any]] = {
|
||||
ModelType.TASK_PLANNING: {
|
||||
"architecture": "transformer",
|
||||
"layers": 6,
|
||||
"hidden_size": 512,
|
||||
"attention_heads": 8
|
||||
},
|
||||
ModelType.BIDDING_STRATEGY: {
|
||||
"architecture": "lstm",
|
||||
"layers": 3,
|
||||
"hidden_size": 256,
|
||||
"dropout": 0.2
|
||||
},
|
||||
ModelType.RESOURCE_ALLOCATION: {
|
||||
"architecture": "cnn",
|
||||
"layers": 4,
|
||||
"filters": 64,
|
||||
"kernel_size": 3
|
||||
},
|
||||
ModelType.COMMUNICATION: {
|
||||
"architecture": "rnn",
|
||||
"layers": 2,
|
||||
"hidden_size": 128,
|
||||
"bidirectional": True
|
||||
},
|
||||
ModelType.COLLABORATION: {
|
||||
"architecture": "gnn",
|
||||
"layers": 3,
|
||||
"hidden_size": 256,
|
||||
"aggregation": "mean"
|
||||
},
|
||||
ModelType.DECISION_MAKING: {
|
||||
"architecture": "mlp",
|
||||
"layers": 4,
|
||||
"hidden_size": 512,
|
||||
"activation": "relu"
|
||||
},
|
||||
ModelType.PREDICTION: {
|
||||
"architecture": "transformer",
|
||||
"layers": 8,
|
||||
"hidden_size": 768,
|
||||
"attention_heads": 12
|
||||
},
|
||||
ModelType.CLASSIFICATION: {
|
||||
"architecture": "cnn",
|
||||
"layers": 5,
|
||||
"filters": 128,
|
||||
"kernel_size": 3
|
||||
}
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the advanced learning service"""
|
||||
logger.info("Initializing Advanced Learning Service")
|
||||
|
||||
# Load existing models and sessions
|
||||
await self._load_learning_data()
|
||||
|
||||
# Start background tasks
|
||||
asyncio.create_task(self._monitor_learning_sessions())
|
||||
asyncio.create_task(self._process_federated_learning())
|
||||
asyncio.create_task(self._optimize_model_performance())
|
||||
asyncio.create_task(self._cleanup_inactive_sessions())
|
||||
|
||||
logger.info("Advanced Learning Service initialized")
|
||||
|
||||
async def create_model(
|
||||
self,
|
||||
agent_id: str,
|
||||
model_type: ModelType,
|
||||
learning_type: LearningType,
|
||||
hyperparameters: Optional[Dict[str, Any]] = None
|
||||
) -> LearningModel:
|
||||
"""Create a new learning model"""
|
||||
|
||||
try:
|
||||
# Generate model ID
|
||||
model_id = await self._generate_model_id()
|
||||
|
||||
# Get model template
|
||||
template = self.model_templates.get(model_type, {})
|
||||
|
||||
# Merge with hyperparameters
|
||||
parameters = {**template, **(hyperparameters or {})}
|
||||
|
||||
# Create model
|
||||
model = LearningModel(
|
||||
id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_type=model_type,
|
||||
learning_type=learning_type,
|
||||
version="1.0.0",
|
||||
parameters=parameters,
|
||||
performance_metrics={},
|
||||
training_data_size=0,
|
||||
validation_data_size=0,
|
||||
created_at=datetime.utcnow(),
|
||||
last_updated=datetime.utcnow(),
|
||||
status=LearningStatus.INITIALIZING
|
||||
)
|
||||
|
||||
# Store model
|
||||
self.models[model_id] = model
|
||||
|
||||
# Initialize analytics
|
||||
self.learning_analytics[model_id] = LearningAnalytics(
|
||||
agent_id=agent_id,
|
||||
model_id=model_id,
|
||||
total_training_time=0.0,
|
||||
total_inference_time=0.0,
|
||||
accuracy_improvement=0.0,
|
||||
performance_gain=0.0,
|
||||
data_efficiency=0.0,
|
||||
computation_efficiency=0.0,
|
||||
learning_rate=self.default_learning_rate,
|
||||
convergence_speed=0.0,
|
||||
last_evaluation=datetime.utcnow()
|
||||
)
|
||||
|
||||
logger.info(f"Model created: {model_id} for agent {agent_id}")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create model: {e}")
|
||||
raise
|
||||
|
||||
async def start_learning_session(
|
||||
self,
|
||||
model_id: str,
|
||||
training_data: List[Dict[str, Any]],
|
||||
validation_data: List[Dict[str, Any]],
|
||||
hyperparameters: Optional[Dict[str, Any]] = None
|
||||
) -> LearningSession:
|
||||
"""Start a learning session"""
|
||||
|
||||
try:
|
||||
if model_id not in self.models:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
model = self.models[model_id]
|
||||
|
||||
# Generate session ID
|
||||
session_id = await self._generate_session_id()
|
||||
|
||||
# Default hyperparameters
|
||||
default_hyperparams = {
|
||||
"learning_rate": self.default_learning_rate,
|
||||
"batch_size": self.default_batch_size,
|
||||
"epochs": 100,
|
||||
"convergence_threshold": self.convergence_threshold,
|
||||
"early_stopping": True,
|
||||
"early_stopping_patience": self.early_stopping_patience
|
||||
}
|
||||
|
||||
# Merge hyperparameters
|
||||
final_hyperparams = {**default_hyperparams, **(hyperparameters or {})}
|
||||
|
||||
# Create session
|
||||
session = LearningSession(
|
||||
id=session_id,
|
||||
model_id=model_id,
|
||||
agent_id=model.agent_id,
|
||||
learning_type=model.learning_type,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=None,
|
||||
status=LearningStatus.INITIALIZING,
|
||||
training_data=training_data,
|
||||
validation_data=validation_data,
|
||||
hyperparameters=final_hyperparams,
|
||||
results={},
|
||||
iterations=0,
|
||||
convergence_threshold=final_hyperparams.get("convergence_threshold", self.convergence_threshold),
|
||||
early_stopping=final_hyperparams.get("early_stopping", True),
|
||||
checkpoint_frequency=10
|
||||
)
|
||||
|
||||
# Store session
|
||||
self.learning_sessions[session_id] = session
|
||||
|
||||
# Update model status
|
||||
model.status = LearningStatus.TRAINING
|
||||
model.last_updated = datetime.utcnow()
|
||||
|
||||
# Start training
|
||||
asyncio.create_task(self._execute_learning_session(session_id))
|
||||
|
||||
logger.info(f"Learning session started: {session_id}")
|
||||
return session
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start learning session: {e}")
|
||||
raise
|
||||
|
||||
async def execute_meta_learning(
|
||||
self,
|
||||
agent_id: str,
|
||||
tasks: List[MetaLearningTask],
|
||||
algorithm: str = "MAML"
|
||||
) -> str:
|
||||
"""Execute meta-learning for rapid adaptation"""
|
||||
|
||||
try:
|
||||
# Create meta-learning model
|
||||
model = await self.create_model(
|
||||
agent_id=agent_id,
|
||||
model_type=ModelType.TASK_PLANNING,
|
||||
learning_type=LearningType.META_LEARNING
|
||||
)
|
||||
|
||||
# Generate session ID
|
||||
session_id = await self._generate_session_id()
|
||||
|
||||
# Prepare meta-learning data
|
||||
meta_data = await self._prepare_meta_learning_data(tasks)
|
||||
|
||||
# Create session
|
||||
session = LearningSession(
|
||||
id=session_id,
|
||||
model_id=model.id,
|
||||
agent_id=agent_id,
|
||||
learning_type=LearningType.META_LEARNING,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=None,
|
||||
status=LearningStatus.TRAINING,
|
||||
training_data=meta_data["training"],
|
||||
validation_data=meta_data["validation"],
|
||||
hyperparameters={
|
||||
"algorithm": algorithm,
|
||||
"inner_lr": 0.01,
|
||||
"outer_lr": 0.001,
|
||||
"meta_iterations": 1000,
|
||||
"adaptation_steps": 5
|
||||
},
|
||||
results={},
|
||||
iterations=0,
|
||||
convergence_threshold=0.001,
|
||||
early_stopping=True,
|
||||
checkpoint_frequency=10
|
||||
)
|
||||
|
||||
self.learning_sessions[session_id] = session
|
||||
|
||||
# Execute meta-learning
|
||||
asyncio.create_task(self._execute_meta_learning(session_id, algorithm))
|
||||
|
||||
logger.info(f"Meta-learning started: {session_id}")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute meta-learning: {e}")
|
||||
raise
|
||||
|
||||
async def setup_federated_learning(
|
||||
self,
|
||||
model_id: str,
|
||||
nodes: List[FederatedNode],
|
||||
algorithm: str = "FedAvg"
|
||||
) -> str:
|
||||
"""Setup federated learning across multiple agents"""
|
||||
|
||||
try:
|
||||
if model_id not in self.models:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
# Register nodes
|
||||
for node in nodes:
|
||||
self.federated_nodes[node.id] = node
|
||||
|
||||
# Generate session ID
|
||||
session_id = await self._generate_session_id()
|
||||
|
||||
# Create federated session
|
||||
session = LearningSession(
|
||||
id=session_id,
|
||||
model_id=model_id,
|
||||
agent_id="federated",
|
||||
learning_type=LearningType.FEDERATED,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=None,
|
||||
status=LearningStatus.TRAINING,
|
||||
training_data=[],
|
||||
validation_data=[],
|
||||
hyperparameters={
|
||||
"algorithm": algorithm,
|
||||
"aggregation_frequency": 10,
|
||||
"min_participants": 2,
|
||||
"max_participants": len(nodes),
|
||||
"communication_rounds": 100
|
||||
},
|
||||
results={},
|
||||
iterations=0,
|
||||
convergence_threshold=0.001,
|
||||
early_stopping=False,
|
||||
checkpoint_frequency=5
|
||||
)
|
||||
|
||||
self.learning_sessions[session_id] = session
|
||||
|
||||
# Start federated learning
|
||||
asyncio.create_task(self._execute_federated_learning(session_id, algorithm))
|
||||
|
||||
logger.info(f"Federated learning setup: {session_id}")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup federated learning: {e}")
|
||||
raise
|
||||
|
||||
async def predict_with_model(
|
||||
self,
|
||||
model_id: str,
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Make prediction using trained model"""
|
||||
|
||||
try:
|
||||
if model_id not in self.models:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
model = self.models[model_id]
|
||||
|
||||
if model.status != LearningStatus.ACTIVE:
|
||||
raise ValueError(f"Model {model_id} not active")
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Simulate inference
|
||||
prediction = await self._simulate_inference(model, input_data)
|
||||
|
||||
# Update analytics
|
||||
inference_time = (datetime.utcnow() - start_time).total_seconds()
|
||||
analytics = self.learning_analytics[model_id]
|
||||
analytics.total_inference_time += inference_time
|
||||
analytics.last_evaluation = datetime.utcnow()
|
||||
|
||||
logger.info(f"Prediction made with model {model_id}")
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to predict with model {model_id}: {e}")
|
||||
raise
|
||||
|
||||
async def adapt_model(
|
||||
self,
|
||||
model_id: str,
|
||||
adaptation_data: List[Dict[str, Any]],
|
||||
adaptation_steps: int = 5
|
||||
) -> Dict[str, float]:
|
||||
"""Adapt model to new data"""
|
||||
|
||||
try:
|
||||
if model_id not in self.models:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
model = self.models[model_id]
|
||||
|
||||
if model.learning_type not in [LearningType.META_LEARNING, LearningType.CONTINUAL]:
|
||||
raise ValueError(f"Model {model_id} does not support adaptation")
|
||||
|
||||
# Simulate model adaptation
|
||||
adaptation_results = await self._simulate_model_adaptation(
|
||||
model, adaptation_data, adaptation_steps
|
||||
)
|
||||
|
||||
# Update model performance
|
||||
model.accuracy = adaptation_results.get("accuracy", model.accuracy)
|
||||
model.last_updated = datetime.utcnow()
|
||||
|
||||
# Update analytics
|
||||
analytics = self.learning_analytics[model_id]
|
||||
analytics.accuracy_improvement = adaptation_results.get("improvement", 0.0)
|
||||
analytics.data_efficiency = adaptation_results.get("data_efficiency", 0.0)
|
||||
|
||||
logger.info(f"Model adapted: {model_id}")
|
||||
return adaptation_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to adapt model {model_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_model_performance(self, model_id: str) -> Dict[str, Any]:
|
||||
"""Get comprehensive model performance metrics"""
|
||||
|
||||
try:
|
||||
if model_id not in self.models:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
model = self.models[model_id]
|
||||
analytics = self.learning_analytics[model_id]
|
||||
|
||||
# Calculate performance metrics
|
||||
performance = {
|
||||
"model_id": model_id,
|
||||
"model_type": model.model_type.value,
|
||||
"learning_type": model.learning_type.value,
|
||||
"status": model.status.value,
|
||||
"accuracy": model.accuracy,
|
||||
"precision": model.precision,
|
||||
"recall": model.recall,
|
||||
"f1_score": model.f1_score,
|
||||
"loss": model.loss,
|
||||
"training_time": model.training_time,
|
||||
"inference_time": model.inference_time,
|
||||
"total_training_time": analytics.total_training_time,
|
||||
"total_inference_time": analytics.total_inference_time,
|
||||
"accuracy_improvement": analytics.accuracy_improvement,
|
||||
"performance_gain": analytics.performance_gain,
|
||||
"data_efficiency": analytics.data_efficiency,
|
||||
"computation_efficiency": analytics.computation_efficiency,
|
||||
"learning_rate": analytics.learning_rate,
|
||||
"convergence_speed": analytics.convergence_speed,
|
||||
"last_updated": model.last_updated,
|
||||
"last_evaluation": analytics.last_evaluation
|
||||
}
|
||||
|
||||
return performance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model performance: {e}")
|
||||
raise
|
||||
|
||||
async def get_learning_analytics(self, agent_id: str) -> List[LearningAnalytics]:
|
||||
"""Get learning analytics for an agent"""
|
||||
|
||||
analytics = []
|
||||
for model_id, model_analytics in self.learning_analytics.items():
|
||||
if model_analytics.agent_id == agent_id:
|
||||
analytics.append(model_analytics)
|
||||
|
||||
return analytics
|
||||
|
||||
async def get_top_models(
|
||||
self,
|
||||
model_type: Optional[ModelType] = None,
|
||||
limit: int = 100
|
||||
) -> List[LearningModel]:
|
||||
"""Get top performing models"""
|
||||
|
||||
models = list(self.models.values())
|
||||
|
||||
if model_type:
|
||||
models = [m for m in models if m.model_type == model_type]
|
||||
|
||||
# Sort by accuracy
|
||||
models.sort(key=lambda x: x.accuracy, reverse=True)
|
||||
|
||||
return models[:limit]
|
||||
|
||||
async def optimize_model(self, model_id: str) -> bool:
|
||||
"""Optimize model performance"""
|
||||
|
||||
try:
|
||||
if model_id not in self.models:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
model = self.models[model_id]
|
||||
|
||||
# Simulate optimization
|
||||
optimization_results = await self._simulate_model_optimization(model)
|
||||
|
||||
# Update model
|
||||
model.accuracy = optimization_results.get("accuracy", model.accuracy)
|
||||
model.inference_time = optimization_results.get("inference_time", model.inference_time)
|
||||
model.last_updated = datetime.utcnow()
|
||||
|
||||
logger.info(f"Model optimized: {model_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to optimize model {model_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _execute_learning_session(self, session_id: str):
|
||||
"""Execute a learning session"""
|
||||
|
||||
try:
|
||||
session = self.learning_sessions[session_id]
|
||||
model = self.models[session.model_id]
|
||||
|
||||
session.status = LearningStatus.TRAINING
|
||||
|
||||
# Simulate training
|
||||
for iteration in range(session.hyperparameters.get("epochs", 100)):
|
||||
if session.status != LearningStatus.TRAINING:
|
||||
break
|
||||
|
||||
# Simulate training step
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Update metrics
|
||||
session.iterations = iteration
|
||||
|
||||
# Check convergence
|
||||
if iteration > 0 and iteration % 10 == 0:
|
||||
loss = np.random.uniform(0.1, 1.0) * (1.0 - iteration / 100)
|
||||
session.results[f"epoch_{iteration}"] = {"loss": loss}
|
||||
|
||||
if loss < session.convergence_threshold:
|
||||
session.status = LearningStatus.COMPLETED
|
||||
break
|
||||
|
||||
# Early stopping
|
||||
if session.early_stopping and iteration > session.early_stopping_patience:
|
||||
if loss > session.results.get(f"epoch_{iteration - session.early_stopping_patience}", {}).get("loss", 1.0):
|
||||
session.status = LearningStatus.COMPLETED
|
||||
break
|
||||
|
||||
# Update model
|
||||
model.accuracy = np.random.uniform(0.7, 0.95)
|
||||
model.precision = np.random.uniform(0.7, 0.95)
|
||||
model.recall = np.random.uniform(0.7, 0.95)
|
||||
model.f1_score = np.random.uniform(0.7, 0.95)
|
||||
model.loss = session.results.get(f"epoch_{session.iterations}", {}).get("loss", 0.1)
|
||||
model.training_time = (datetime.utcnow() - session.start_time).total_seconds()
|
||||
model.inference_time = np.random.uniform(0.01, 0.1)
|
||||
model.status = LearningStatus.ACTIVE
|
||||
model.last_updated = datetime.utcnow()
|
||||
|
||||
session.end_time = datetime.utcnow()
|
||||
session.status = LearningStatus.COMPLETED
|
||||
|
||||
# Update analytics
|
||||
analytics = self.learning_analytics[session.model_id]
|
||||
analytics.total_training_time += model.training_time
|
||||
analytics.convergence_speed = session.iterations / model.training_time
|
||||
|
||||
logger.info(f"Learning session completed: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute learning session {session_id}: {e}")
|
||||
session.status = LearningStatus.FAILED
|
||||
|
||||
async def _execute_meta_learning(self, session_id: str, algorithm: str):
|
||||
"""Execute meta-learning"""
|
||||
|
||||
try:
|
||||
session = self.learning_sessions[session_id]
|
||||
model = self.models[session.model_id]
|
||||
|
||||
session.status = LearningStatus.TRAINING
|
||||
|
||||
# Simulate meta-learning
|
||||
for iteration in range(session.hyperparameters.get("meta_iterations", 1000)):
|
||||
if session.status != LearningStatus.TRAINING:
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Simulate meta-learning step
|
||||
session.iterations = iteration
|
||||
|
||||
if iteration % 100 == 0:
|
||||
loss = np.random.uniform(0.1, 1.0) * (1.0 - iteration / 1000)
|
||||
session.results[f"meta_iter_{iteration}"] = {"loss": loss}
|
||||
|
||||
if loss < session.convergence_threshold:
|
||||
break
|
||||
|
||||
# Update model with meta-learning results
|
||||
model.accuracy = np.random.uniform(0.8, 0.98)
|
||||
model.status = LearningStatus.ACTIVE
|
||||
model.last_updated = datetime.utcnow()
|
||||
|
||||
session.end_time = datetime.utcnow()
|
||||
session.status = LearningStatus.COMPLETED
|
||||
|
||||
logger.info(f"Meta-learning completed: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute meta-learning {session_id}: {e}")
|
||||
session.status = LearningStatus.FAILED
|
||||
|
||||
async def _execute_federated_learning(self, session_id: str, algorithm: str):
|
||||
"""Execute federated learning"""
|
||||
|
||||
try:
|
||||
session = self.learning_sessions[session_id]
|
||||
model = self.models[session.model_id]
|
||||
|
||||
session.status = LearningStatus.TRAINING
|
||||
|
||||
# Simulate federated learning rounds
|
||||
for round_num in range(session.hyperparameters.get("communication_rounds", 100)):
|
||||
if session.status != LearningStatus.TRAINING:
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Simulate federated round
|
||||
session.iterations = round_num
|
||||
|
||||
if round_num % 10 == 0:
|
||||
loss = np.random.uniform(0.1, 1.0) * (1.0 - round_num / 100)
|
||||
session.results[f"round_{round_num}"] = {"loss": loss}
|
||||
|
||||
if loss < session.convergence_threshold:
|
||||
break
|
||||
|
||||
# Update model
|
||||
model.accuracy = np.random.uniform(0.75, 0.92)
|
||||
model.status = LearningStatus.ACTIVE
|
||||
model.last_updated = datetime.utcnow()
|
||||
|
||||
session.end_time = datetime.utcnow()
|
||||
session.status = LearningStatus.COMPLETED
|
||||
|
||||
logger.info(f"Federated learning completed: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute federated learning {session_id}: {e}")
|
||||
session.status = LearningStatus.FAILED
|
||||
|
||||
async def _simulate_inference(self, model: LearningModel, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Simulate model inference"""
|
||||
|
||||
# Simulate prediction based on model type
|
||||
if model.model_type == ModelType.TASK_PLANNING:
|
||||
return {
|
||||
"prediction": "task_plan",
|
||||
"confidence": np.random.uniform(0.7, 0.95),
|
||||
"execution_time": np.random.uniform(0.1, 1.0),
|
||||
"resource_requirements": {
|
||||
"gpu_hours": np.random.uniform(0.5, 2.0),
|
||||
"memory_gb": np.random.uniform(2, 8)
|
||||
}
|
||||
}
|
||||
elif model.model_type == ModelType.BIDDING_STRATEGY:
|
||||
return {
|
||||
"bid_price": np.random.uniform(0.01, 0.1),
|
||||
"success_probability": np.random.uniform(0.6, 0.9),
|
||||
"wait_time": np.random.uniform(60, 300)
|
||||
}
|
||||
elif model.model_type == ModelType.RESOURCE_ALLOCATION:
|
||||
return {
|
||||
"allocation": "optimal",
|
||||
"efficiency": np.random.uniform(0.8, 0.95),
|
||||
"cost_savings": np.random.uniform(0.1, 0.3)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"prediction": "default",
|
||||
"confidence": np.random.uniform(0.7, 0.95)
|
||||
}
|
||||
|
||||
async def _simulate_model_adaptation(
|
||||
self,
|
||||
model: LearningModel,
|
||||
adaptation_data: List[Dict[str, Any]],
|
||||
adaptation_steps: int
|
||||
) -> Dict[str, float]:
|
||||
"""Simulate model adaptation"""
|
||||
|
||||
# Simulate adaptation process
|
||||
initial_accuracy = model.accuracy
|
||||
final_accuracy = min(0.99, initial_accuracy + np.random.uniform(0.01, 0.1))
|
||||
|
||||
return {
|
||||
"accuracy": final_accuracy,
|
||||
"improvement": final_accuracy - initial_accuracy,
|
||||
"data_efficiency": np.random.uniform(0.8, 0.95),
|
||||
"adaptation_time": np.random.uniform(1.0, 10.0)
|
||||
}
|
||||
|
||||
async def _simulate_model_optimization(self, model: LearningModel) -> Dict[str, float]:
|
||||
"""Simulate model optimization"""
|
||||
|
||||
return {
|
||||
"accuracy": min(0.99, model.accuracy + np.random.uniform(0.01, 0.05)),
|
||||
"inference_time": model.inference_time * np.random.uniform(0.8, 0.95),
|
||||
"memory_usage": np.random.uniform(0.5, 2.0)
|
||||
}
|
||||
|
||||
async def _prepare_meta_learning_data(self, tasks: List[MetaLearningTask]) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Prepare meta-learning data"""
|
||||
|
||||
# Simulate data preparation
|
||||
training_data = []
|
||||
validation_data = []
|
||||
|
||||
for task in tasks:
|
||||
# Generate synthetic data for each task
|
||||
for i in range(task.support_set_size):
|
||||
training_data.append({
|
||||
"task_id": task.id,
|
||||
"input": np.random.randn(10).tolist(),
|
||||
"output": np.random.randn(5).tolist(),
|
||||
"is_support": True
|
||||
})
|
||||
|
||||
for i in range(task.query_set_size):
|
||||
validation_data.append({
|
||||
"task_id": task.id,
|
||||
"input": np.random.randn(10).tolist(),
|
||||
"output": np.random.randn(5).tolist(),
|
||||
"is_support": False
|
||||
})
|
||||
|
||||
return {"training": training_data, "validation": validation_data}
|
||||
|
||||
async def _monitor_learning_sessions(self):
|
||||
"""Monitor active learning sessions"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
for session_id, session in self.learning_sessions.items():
|
||||
if session.status == LearningStatus.TRAINING:
|
||||
# Check timeout
|
||||
if (current_time - session.start_time).total_seconds() > self.max_training_time:
|
||||
session.status = LearningStatus.FAILED
|
||||
session.end_time = current_time
|
||||
logger.warning(f"Learning session {session_id} timed out")
|
||||
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring learning sessions: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _process_federated_learning(self):
|
||||
"""Process federated learning aggregation"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Process federated learning rounds
|
||||
for session_id, session in self.learning_sessions.items():
|
||||
if session.learning_type == LearningType.FEDERATED and session.status == LearningStatus.TRAINING:
|
||||
# Simulate federated aggregation
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing federated learning: {e}")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
async def _optimize_model_performance(self):
|
||||
"""Optimize model performance periodically"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Optimize active models
|
||||
for model_id, model in self.models.items():
|
||||
if model.status == LearningStatus.ACTIVE:
|
||||
await self.optimize_model(model_id)
|
||||
|
||||
await asyncio.sleep(3600) # Optimize every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error optimizing models: {e}")
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def _cleanup_inactive_sessions(self):
|
||||
"""Clean up inactive learning sessions"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
current_time = datetime.utcnow()
|
||||
inactive_sessions = []
|
||||
|
||||
for session_id, session in self.learning_sessions.items():
|
||||
if session.status in [LearningStatus.COMPLETED, LearningStatus.FAILED]:
|
||||
if session.end_time and (current_time - session.end_time).total_seconds() > 86400: # 24 hours
|
||||
inactive_sessions.append(session_id)
|
||||
|
||||
for session_id in inactive_sessions:
|
||||
del self.learning_sessions[session_id]
|
||||
|
||||
if inactive_sessions:
|
||||
logger.info(f"Cleaned up {len(inactive_sessions)} inactive sessions")
|
||||
|
||||
await asyncio.sleep(3600) # Check every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions: {e}")
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def _generate_model_id(self) -> str:
|
||||
"""Generate unique model ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
async def _generate_session_id(self) -> str:
|
||||
"""Generate unique session ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
async def _load_learning_data(self):
|
||||
"""Load existing learning data"""
|
||||
# In production, load from database
|
||||
pass
|
||||
|
||||
async def export_learning_data(self, format: str = "json") -> str:
|
||||
"""Export learning data"""
|
||||
|
||||
data = {
|
||||
"models": {k: asdict(v) for k, v in self.models.items()},
|
||||
"sessions": {k: asdict(v) for k, v in self.learning_sessions.items()},
|
||||
"analytics": {k: asdict(v) for k, v in self.learning_analytics.items()},
|
||||
"export_timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if format.lower() == "json":
|
||||
return json.dumps(data, indent=2, default=str)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
async def import_learning_data(self, data: str, format: str = "json"):
|
||||
"""Import learning data"""
|
||||
|
||||
if format.lower() == "json":
|
||||
parsed_data = json.loads(data)
|
||||
|
||||
# Import models
|
||||
for model_id, model_data in parsed_data.get("models", {}).items():
|
||||
model_data['created_at'] = datetime.fromisoformat(model_data['created_at'])
|
||||
model_data['last_updated'] = datetime.fromisoformat(model_data['last_updated'])
|
||||
self.models[model_id] = LearningModel(**model_data)
|
||||
|
||||
# Import sessions
|
||||
for session_id, session_data in parsed_data.get("sessions", {}).items():
|
||||
session_data['start_time'] = datetime.fromisoformat(session_data['start_time'])
|
||||
if session_data.get('end_time'):
|
||||
session_data['end_time'] = datetime.fromisoformat(session_data['end_time'])
|
||||
self.learning_sessions[session_id] = LearningSession(**session_data)
|
||||
|
||||
logger.info("Learning data imported successfully")
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
983
apps/coordinator-api/src/app/services/agent_communication.py
Normal file
983
apps/coordinator-api/src/app/services/agent_communication.py
Normal file
@@ -0,0 +1,983 @@
|
||||
"""
|
||||
Agent Communication Service for Advanced Agent Features
|
||||
Implements secure agent-to-agent messaging with reputation-based access control
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import json
|
||||
import hashlib
|
||||
import base64
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
from .cross_chain_reputation import CrossChainReputationService, ReputationTier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""Types of agent messages"""
|
||||
TEXT = "text"
|
||||
DATA = "data"
|
||||
TASK_REQUEST = "task_request"
|
||||
TASK_RESPONSE = "task_response"
|
||||
COLLABORATION = "collaboration"
|
||||
NOTIFICATION = "notification"
|
||||
SYSTEM = "system"
|
||||
URGENT = "urgent"
|
||||
BULK = "bulk"
|
||||
|
||||
|
||||
class ChannelType(str, Enum):
|
||||
"""Types of communication channels"""
|
||||
DIRECT = "direct"
|
||||
GROUP = "group"
|
||||
BROADCAST = "broadcast"
|
||||
PRIVATE = "private"
|
||||
|
||||
|
||||
class MessageStatus(str, Enum):
|
||||
"""Message delivery status"""
|
||||
PENDING = "pending"
|
||||
DELIVERED = "delivered"
|
||||
READ = "read"
|
||||
FAILED = "failed"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class EncryptionType(str, Enum):
|
||||
"""Encryption types for messages"""
|
||||
AES256 = "aes256"
|
||||
RSA = "rsa"
|
||||
HYBRID = "hybrid"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""Agent message data"""
|
||||
id: str
|
||||
sender: str
|
||||
recipient: str
|
||||
message_type: MessageType
|
||||
content: bytes
|
||||
encryption_key: bytes
|
||||
encryption_type: EncryptionType
|
||||
size: int
|
||||
timestamp: datetime
|
||||
delivery_timestamp: Optional[datetime] = None
|
||||
read_timestamp: Optional[datetime] = None
|
||||
status: MessageStatus = MessageStatus.PENDING
|
||||
paid: bool = False
|
||||
price: float = 0.0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
expires_at: Optional[datetime] = None
|
||||
reply_to: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunicationChannel:
|
||||
"""Communication channel between agents"""
|
||||
id: str
|
||||
agent1: str
|
||||
agent2: str
|
||||
channel_type: ChannelType
|
||||
is_active: bool
|
||||
created_timestamp: datetime
|
||||
last_activity: datetime
|
||||
message_count: int
|
||||
participants: List[str] = field(default_factory=list)
|
||||
encryption_enabled: bool = True
|
||||
auto_delete: bool = False
|
||||
retention_period: int = 2592000 # 30 days
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTemplate:
|
||||
"""Message template for common communications"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
message_type: MessageType
|
||||
content_template: str
|
||||
variables: List[str]
|
||||
base_price: float
|
||||
is_active: bool
|
||||
creator: str
|
||||
usage_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunicationStats:
|
||||
"""Communication statistics for agent"""
|
||||
total_messages: int
|
||||
total_earnings: float
|
||||
messages_sent: int
|
||||
messages_received: int
|
||||
active_channels: int
|
||||
last_activity: datetime
|
||||
average_response_time: float
|
||||
delivery_rate: float
|
||||
|
||||
|
||||
class AgentCommunicationService:
|
||||
"""Service for managing agent-to-agent communication"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.messages: Dict[str, Message] = {}
|
||||
self.channels: Dict[str, CommunicationChannel] = {}
|
||||
self.message_templates: Dict[str, MessageTemplate] = {}
|
||||
self.agent_messages: Dict[str, List[str]] = {}
|
||||
self.agent_channels: Dict[str, List[str]] = {}
|
||||
self.communication_stats: Dict[str, CommunicationStats] = {}
|
||||
|
||||
# Services
|
||||
self.reputation_service: Optional[CrossChainReputationService] = None
|
||||
|
||||
# Configuration
|
||||
self.min_reputation_score = 1000
|
||||
self.base_message_price = 0.001 # AITBC
|
||||
self.max_message_size = 100000 # 100KB
|
||||
self.message_timeout = 86400 # 24 hours
|
||||
self.channel_timeout = 2592000 # 30 days
|
||||
self.encryption_enabled = True
|
||||
|
||||
# Access control
|
||||
self.authorized_agents: Dict[str, bool] = {}
|
||||
self.contact_lists: Dict[str, Dict[str, bool]] = {}
|
||||
self.blocked_lists: Dict[str, Dict[str, bool]] = {}
|
||||
|
||||
# Message routing
|
||||
self.message_queue: List[Message] = []
|
||||
self.delivery_attempts: Dict[str, int] = {}
|
||||
|
||||
# Templates
|
||||
self._initialize_default_templates()
|
||||
|
||||
def set_reputation_service(self, reputation_service: CrossChainReputationService):
|
||||
"""Set reputation service for access control"""
|
||||
self.reputation_service = reputation_service
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the agent communication service"""
|
||||
logger.info("Initializing Agent Communication Service")
|
||||
|
||||
# Load existing data
|
||||
await self._load_communication_data()
|
||||
|
||||
# Start background tasks
|
||||
asyncio.create_task(self._process_message_queue())
|
||||
asyncio.create_task(self._cleanup_expired_messages())
|
||||
asyncio.create_task(self._cleanup_inactive_channels())
|
||||
|
||||
logger.info("Agent Communication Service initialized")
|
||||
|
||||
async def authorize_agent(self, agent_id: str) -> bool:
|
||||
"""Authorize an agent to use the communication system"""
|
||||
|
||||
try:
|
||||
self.authorized_agents[agent_id] = True
|
||||
|
||||
# Initialize communication stats
|
||||
if agent_id not in self.communication_stats:
|
||||
self.communication_stats[agent_id] = CommunicationStats(
|
||||
total_messages=0,
|
||||
total_earnings=0.0,
|
||||
messages_sent=0,
|
||||
messages_received=0,
|
||||
active_channels=0,
|
||||
last_activity=datetime.utcnow(),
|
||||
average_response_time=0.0,
|
||||
delivery_rate=0.0
|
||||
)
|
||||
|
||||
logger.info(f"Authorized agent: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to authorize agent {agent_id}: {e}")
|
||||
return False
|
||||
|
||||
async def revoke_agent(self, agent_id: str) -> bool:
|
||||
"""Revoke agent authorization"""
|
||||
|
||||
try:
|
||||
self.authorized_agents[agent_id] = False
|
||||
|
||||
# Clean up agent data
|
||||
if agent_id in self.agent_messages:
|
||||
del self.agent_messages[agent_id]
|
||||
if agent_id in self.agent_channels:
|
||||
del self.agent_channels[agent_id]
|
||||
if agent_id in self.communication_stats:
|
||||
del self.communication_stats[agent_id]
|
||||
|
||||
logger.info(f"Revoked authorization for agent: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke agent {agent_id}: {e}")
|
||||
return False
|
||||
|
||||
async def add_contact(self, agent_id: str, contact_id: str) -> bool:
|
||||
"""Add contact to agent's contact list"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.contact_lists:
|
||||
self.contact_lists[agent_id] = {}
|
||||
|
||||
self.contact_lists[agent_id][contact_id] = True
|
||||
|
||||
# Remove from blocked list if present
|
||||
if agent_id in self.blocked_lists and contact_id in self.blocked_lists[agent_id]:
|
||||
del self.blocked_lists[agent_id][contact_id]
|
||||
|
||||
logger.info(f"Added contact {contact_id} for agent {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add contact: {e}")
|
||||
return False
|
||||
|
||||
async def remove_contact(self, agent_id: str, contact_id: str) -> bool:
|
||||
"""Remove contact from agent's contact list"""
|
||||
|
||||
try:
|
||||
if agent_id in self.contact_lists and contact_id in self.contact_lists[agent_id]:
|
||||
del self.contact_lists[agent_id][contact_id]
|
||||
|
||||
logger.info(f"Removed contact {contact_id} for agent {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove contact: {e}")
|
||||
return False
|
||||
|
||||
async def block_agent(self, agent_id: str, blocked_id: str) -> bool:
|
||||
"""Block an agent"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.blocked_lists:
|
||||
self.blocked_lists[agent_id] = {}
|
||||
|
||||
self.blocked_lists[agent_id][blocked_id] = True
|
||||
|
||||
# Remove from contact list if present
|
||||
if agent_id in self.contact_lists and blocked_id in self.contact_lists[agent_id]:
|
||||
del self.contact_lists[agent_id][blocked_id]
|
||||
|
||||
logger.info(f"Blocked agent {blocked_id} for agent {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to block agent: {e}")
|
||||
return False
|
||||
|
||||
async def unblock_agent(self, agent_id: str, blocked_id: str) -> bool:
|
||||
"""Unblock an agent"""
|
||||
|
||||
try:
|
||||
if agent_id in self.blocked_lists and blocked_id in self.blocked_lists[agent_id]:
|
||||
del self.blocked_lists[agent_id][blocked_id]
|
||||
|
||||
logger.info(f"Unblocked agent {blocked_id} for agent {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unblock agent: {e}")
|
||||
return False
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
sender: str,
|
||||
recipient: str,
|
||||
message_type: MessageType,
|
||||
content: str,
|
||||
encryption_type: EncryptionType = EncryptionType.AES256,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
thread_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""Send a message to another agent"""
|
||||
|
||||
try:
|
||||
# Validate authorization
|
||||
if not await self._can_send_message(sender, recipient):
|
||||
raise PermissionError("Not authorized to send message")
|
||||
|
||||
# Validate content
|
||||
content_bytes = content.encode('utf-8')
|
||||
if len(content_bytes) > self.max_message_size:
|
||||
raise ValueError(f"Message too large: {len(content_bytes)} > {self.max_message_size}")
|
||||
|
||||
# Generate message ID
|
||||
message_id = await self._generate_message_id()
|
||||
|
||||
# Encrypt content
|
||||
if encryption_type != EncryptionType.NONE:
|
||||
encrypted_content, encryption_key = await self._encrypt_content(content_bytes, encryption_type)
|
||||
else:
|
||||
encrypted_content = content_bytes
|
||||
encryption_key = b''
|
||||
|
||||
# Calculate price
|
||||
price = await self._calculate_message_price(len(content_bytes), message_type)
|
||||
|
||||
# Create message
|
||||
message = Message(
|
||||
id=message_id,
|
||||
sender=sender,
|
||||
recipient=recipient,
|
||||
message_type=message_type,
|
||||
content=encrypted_content,
|
||||
encryption_key=encryption_key,
|
||||
encryption_type=encryption_type,
|
||||
size=len(content_bytes),
|
||||
timestamp=datetime.utcnow(),
|
||||
status=MessageStatus.PENDING,
|
||||
price=price,
|
||||
metadata=metadata or {},
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=self.message_timeout),
|
||||
reply_to=reply_to,
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Store message
|
||||
self.messages[message_id] = message
|
||||
|
||||
# Update message lists
|
||||
if sender not in self.agent_messages:
|
||||
self.agent_messages[sender] = []
|
||||
if recipient not in self.agent_messages:
|
||||
self.agent_messages[recipient] = []
|
||||
|
||||
self.agent_messages[sender].append(message_id)
|
||||
self.agent_messages[recipient].append(message_id)
|
||||
|
||||
# Update stats
|
||||
await self._update_message_stats(sender, recipient, 'sent')
|
||||
|
||||
# Create or update channel
|
||||
await self._get_or_create_channel(sender, recipient, ChannelType.DIRECT)
|
||||
|
||||
# Add to queue for delivery
|
||||
self.message_queue.append(message)
|
||||
|
||||
logger.info(f"Message sent from {sender} to {recipient}: {message_id}")
|
||||
return message_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message: {e}")
|
||||
raise
|
||||
|
||||
async def deliver_message(self, message_id: str) -> bool:
|
||||
"""Mark message as delivered"""
|
||||
|
||||
try:
|
||||
if message_id not in self.messages:
|
||||
raise ValueError(f"Message {message_id} not found")
|
||||
|
||||
message = self.messages[message_id]
|
||||
if message.status != MessageStatus.PENDING:
|
||||
raise ValueError(f"Message {message_id} not pending")
|
||||
|
||||
message.status = MessageStatus.DELIVERED
|
||||
message.delivery_timestamp = datetime.utcnow()
|
||||
|
||||
# Update stats
|
||||
await self._update_message_stats(message.sender, message.recipient, 'delivered')
|
||||
|
||||
logger.info(f"Message delivered: {message_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deliver message {message_id}: {e}")
|
||||
return False
|
||||
|
||||
async def read_message(self, message_id: str, reader: str) -> Optional[str]:
|
||||
"""Mark message as read and return decrypted content"""
|
||||
|
||||
try:
|
||||
if message_id not in self.messages:
|
||||
raise ValueError(f"Message {message_id} not found")
|
||||
|
||||
message = self.messages[message_id]
|
||||
if message.recipient != reader:
|
||||
raise PermissionError("Not message recipient")
|
||||
|
||||
if message.status != MessageStatus.DELIVERED:
|
||||
raise ValueError("Message not delivered")
|
||||
|
||||
if message.read:
|
||||
raise ValueError("Message already read")
|
||||
|
||||
# Mark as read
|
||||
message.status = MessageStatus.READ
|
||||
message.read_timestamp = datetime.utcnow()
|
||||
|
||||
# Update stats
|
||||
await self._update_message_stats(message.sender, message.recipient, 'read')
|
||||
|
||||
# Decrypt content
|
||||
if message.encryption_type != EncryptionType.NONE:
|
||||
decrypted_content = await self._decrypt_content(message.content, message.encryption_key, message.encryption_type)
|
||||
return decrypted_content.decode('utf-8')
|
||||
else:
|
||||
return message.content.decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read message {message_id}: {e}")
|
||||
return None
|
||||
|
||||
async def pay_for_message(self, message_id: str, payer: str, amount: float) -> bool:
|
||||
"""Pay for a message"""
|
||||
|
||||
try:
|
||||
if message_id not in self.messages:
|
||||
raise ValueError(f"Message {message_id} not found")
|
||||
|
||||
message = self.messages[message_id]
|
||||
|
||||
if amount < message.price:
|
||||
raise ValueError(f"Insufficient payment: {amount} < {message.price}")
|
||||
|
||||
# Process payment (simplified)
|
||||
# In production, implement actual payment processing
|
||||
|
||||
message.paid = True
|
||||
|
||||
# Update sender's earnings
|
||||
if message.sender in self.communication_stats:
|
||||
self.communication_stats[message.sender].total_earnings += message.price
|
||||
|
||||
logger.info(f"Payment processed for message {message_id}: {amount}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process payment for message {message_id}: {e}")
|
||||
return False
|
||||
|
||||
async def create_channel(
|
||||
self,
|
||||
agent1: str,
|
||||
agent2: str,
|
||||
channel_type: ChannelType = ChannelType.DIRECT,
|
||||
encryption_enabled: bool = True
|
||||
) -> str:
|
||||
"""Create a communication channel"""
|
||||
|
||||
try:
|
||||
# Validate agents
|
||||
if not self.authorized_agents.get(agent1, False) or not self.authorized_agents.get(agent2, False):
|
||||
raise PermissionError("Agents not authorized")
|
||||
|
||||
if agent1 == agent2:
|
||||
raise ValueError("Cannot create channel with self")
|
||||
|
||||
# Generate channel ID
|
||||
channel_id = await self._generate_channel_id()
|
||||
|
||||
# Create channel
|
||||
channel = CommunicationChannel(
|
||||
id=channel_id,
|
||||
agent1=agent1,
|
||||
agent2=agent2,
|
||||
channel_type=channel_type,
|
||||
is_active=True,
|
||||
created_timestamp=datetime.utcnow(),
|
||||
last_activity=datetime.utcnow(),
|
||||
message_count=0,
|
||||
participants=[agent1, agent2],
|
||||
encryption_enabled=encryption_enabled
|
||||
)
|
||||
|
||||
# Store channel
|
||||
self.channels[channel_id] = channel
|
||||
|
||||
# Update agent channel lists
|
||||
if agent1 not in self.agent_channels:
|
||||
self.agent_channels[agent1] = []
|
||||
if agent2 not in self.agent_channels:
|
||||
self.agent_channels[agent2] = []
|
||||
|
||||
self.agent_channels[agent1].append(channel_id)
|
||||
self.agent_channels[agent2].append(channel_id)
|
||||
|
||||
# Update stats
|
||||
self.communication_stats[agent1].active_channels += 1
|
||||
self.communication_stats[agent2].active_channels += 1
|
||||
|
||||
logger.info(f"Channel created: {channel_id} between {agent1} and {agent2}")
|
||||
return channel_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create channel: {e}")
|
||||
raise
|
||||
|
||||
async def create_message_template(
|
||||
self,
|
||||
creator: str,
|
||||
name: str,
|
||||
description: str,
|
||||
message_type: MessageType,
|
||||
content_template: str,
|
||||
variables: List[str],
|
||||
base_price: float = 0.001
|
||||
) -> str:
|
||||
"""Create a message template"""
|
||||
|
||||
try:
|
||||
# Generate template ID
|
||||
template_id = await self._generate_template_id()
|
||||
|
||||
template = MessageTemplate(
|
||||
id=template_id,
|
||||
name=name,
|
||||
description=description,
|
||||
message_type=message_type,
|
||||
content_template=content_template,
|
||||
variables=variables,
|
||||
base_price=base_price,
|
||||
is_active=True,
|
||||
creator=creator
|
||||
)
|
||||
|
||||
self.message_templates[template_id] = template
|
||||
|
||||
logger.info(f"Template created: {template_id}")
|
||||
return template_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create template: {e}")
|
||||
raise
|
||||
|
||||
async def use_template(
|
||||
self,
|
||||
template_id: str,
|
||||
sender: str,
|
||||
recipient: str,
|
||||
variables: Dict[str, str]
|
||||
) -> str:
|
||||
"""Use a message template to send a message"""
|
||||
|
||||
try:
|
||||
if template_id not in self.message_templates:
|
||||
raise ValueError(f"Template {template_id} not found")
|
||||
|
||||
template = self.message_templates[template_id]
|
||||
|
||||
if not template.is_active:
|
||||
raise ValueError(f"Template {template_id} not active")
|
||||
|
||||
# Substitute variables
|
||||
content = template.content_template
|
||||
for var, value in variables.items():
|
||||
if var in template.variables:
|
||||
content = content.replace(f"{{{var}}}", value)
|
||||
|
||||
# Send message
|
||||
message_id = await self.send_message(
|
||||
sender=sender,
|
||||
recipient=recipient,
|
||||
message_type=template.message_type,
|
||||
content=content,
|
||||
metadata={"template_id": template_id}
|
||||
)
|
||||
|
||||
# Update template usage
|
||||
template.usage_count += 1
|
||||
|
||||
logger.info(f"Template used: {template_id} -> {message_id}")
|
||||
return message_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to use template {template_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_agent_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
status: Optional[MessageStatus] = None
|
||||
) -> List[Message]:
|
||||
"""Get messages for an agent"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.agent_messages:
|
||||
return []
|
||||
|
||||
message_ids = self.agent_messages[agent_id]
|
||||
|
||||
# Apply filters
|
||||
filtered_messages = []
|
||||
for message_id in message_ids:
|
||||
if message_id in self.messages:
|
||||
message = self.messages[message_id]
|
||||
if status is None or message.status == status:
|
||||
filtered_messages.append(message)
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
filtered_messages.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
return filtered_messages[offset:offset + limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages for {agent_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_unread_messages(self, agent_id: str) -> List[Message]:
|
||||
"""Get unread messages for an agent"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.agent_messages:
|
||||
return []
|
||||
|
||||
unread_messages = []
|
||||
for message_id in self.agent_messages[agent_id]:
|
||||
if message_id in self.messages:
|
||||
message = self.messages[message_id]
|
||||
if message.recipient == agent_id and message.status == MessageStatus.DELIVERED:
|
||||
unread_messages.append(message)
|
||||
|
||||
return unread_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get unread messages for {agent_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_agent_channels(self, agent_id: str) -> List[CommunicationChannel]:
|
||||
"""Get channels for an agent"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.agent_channels:
|
||||
return []
|
||||
|
||||
channels = []
|
||||
for channel_id in self.agent_channels[agent_id]:
|
||||
if channel_id in self.channels:
|
||||
channels.append(self.channels[channel_id])
|
||||
|
||||
return channels
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get channels for {agent_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_communication_stats(self, agent_id: str) -> CommunicationStats:
|
||||
"""Get communication statistics for an agent"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.communication_stats:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
return self.communication_stats[agent_id]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats for {agent_id}: {e}")
|
||||
raise
|
||||
|
||||
async def can_communicate(self, sender: str, recipient: str) -> bool:
|
||||
"""Check if agents can communicate"""
|
||||
|
||||
# Check authorization
|
||||
if not self.authorized_agents.get(sender, False) or not self.authorized_agents.get(recipient, False):
|
||||
return False
|
||||
|
||||
# Check blocked lists
|
||||
if (sender in self.blocked_lists and recipient in self.blocked_lists[sender]) or \
|
||||
(recipient in self.blocked_lists and sender in self.blocked_lists[recipient]):
|
||||
return False
|
||||
|
||||
# Check contact lists
|
||||
if sender in self.contact_lists and recipient in self.contact_lists[sender]:
|
||||
return True
|
||||
|
||||
# Check reputation
|
||||
if self.reputation_service:
|
||||
sender_reputation = await self.reputation_service.get_reputation_score(sender)
|
||||
return sender_reputation >= self.min_reputation_score
|
||||
|
||||
return False
|
||||
|
||||
async def _can_send_message(self, sender: str, recipient: str) -> bool:
|
||||
"""Check if sender can send message to recipient"""
|
||||
return await self.can_communicate(sender, recipient)
|
||||
|
||||
async def _generate_message_id(self) -> str:
|
||||
"""Generate unique message ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
async def _generate_channel_id(self) -> str:
|
||||
"""Generate unique channel ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
async def _generate_template_id(self) -> str:
|
||||
"""Generate unique template ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
async def _encrypt_content(self, content: bytes, encryption_type: EncryptionType) -> Tuple[bytes, bytes]:
|
||||
"""Encrypt message content"""
|
||||
|
||||
if encryption_type == EncryptionType.AES256:
|
||||
# Simplified AES encryption
|
||||
key = hashlib.sha256(content).digest()[:32] # Generate key from content
|
||||
import os
|
||||
iv = os.urandom(16)
|
||||
|
||||
# In production, use proper AES encryption
|
||||
encrypted = content + iv # Simplified
|
||||
return encrypted, key
|
||||
|
||||
elif encryption_type == EncryptionType.RSA:
|
||||
# Simplified RSA encryption
|
||||
key = hashlib.sha256(content).digest()[:256]
|
||||
return content + key, key
|
||||
|
||||
else:
|
||||
return content, b''
|
||||
|
||||
async def _decrypt_content(self, encrypted_content: bytes, key: bytes, encryption_type: EncryptionType) -> bytes:
|
||||
"""Decrypt message content"""
|
||||
|
||||
if encryption_type == EncryptionType.AES256:
|
||||
# Simplified AES decryption
|
||||
if len(encrypted_content) < 16:
|
||||
return encrypted_content
|
||||
return encrypted_content[:-16] # Remove IV
|
||||
|
||||
elif encryption_type == EncryptionType.RSA:
|
||||
# Simplified RSA decryption
|
||||
if len(encrypted_content) < 256:
|
||||
return encrypted_content
|
||||
return encrypted_content[:-256] # Remove key
|
||||
|
||||
else:
|
||||
return encrypted_content
|
||||
|
||||
async def _calculate_message_price(self, size: int, message_type: MessageType) -> float:
|
||||
"""Calculate message price based on size and type"""
|
||||
|
||||
base_price = self.base_message_price
|
||||
|
||||
# Size multiplier
|
||||
size_multiplier = max(1, size / 1000) # 1 AITBC per 1000 bytes
|
||||
|
||||
# Type multiplier
|
||||
type_multipliers = {
|
||||
MessageType.TEXT: 1.0,
|
||||
MessageType.DATA: 1.5,
|
||||
MessageType.TASK_REQUEST: 2.0,
|
||||
MessageType.TASK_RESPONSE: 2.0,
|
||||
MessageType.COLLABORATION: 3.0,
|
||||
MessageType.NOTIFICATION: 0.5,
|
||||
MessageType.SYSTEM: 0.1,
|
||||
MessageType.URGENT: 5.0,
|
||||
MessageType.BULK: 10.0
|
||||
}
|
||||
|
||||
type_multiplier = type_multipliers.get(message_type, 1.0)
|
||||
|
||||
return base_price * size_multiplier * type_multiplier
|
||||
|
||||
async def _get_or_create_channel(self, agent1: str, agent2: str, channel_type: ChannelType) -> str:
|
||||
"""Get or create communication channel"""
|
||||
|
||||
# Check if channel already exists
|
||||
if agent1 in self.agent_channels:
|
||||
for channel_id in self.agent_channels[agent1]:
|
||||
if channel_id in self.channels:
|
||||
channel = self.channels[channel_id]
|
||||
if channel.is_active and (
|
||||
(channel.agent1 == agent1 and channel.agent2 == agent2) or
|
||||
(channel.agent1 == agent2 and channel.agent2 == agent1)
|
||||
):
|
||||
return channel_id
|
||||
|
||||
# Create new channel
|
||||
return await self.create_channel(agent1, agent2, channel_type)
|
||||
|
||||
async def _update_message_stats(self, sender: str, recipient: str, action: str):
|
||||
"""Update message statistics"""
|
||||
|
||||
if action == 'sent':
|
||||
if sender in self.communication_stats:
|
||||
self.communication_stats[sender].total_messages += 1
|
||||
self.communication_stats[sender].messages_sent += 1
|
||||
self.communication_stats[sender].last_activity = datetime.utcnow()
|
||||
|
||||
elif action == 'delivered':
|
||||
if recipient in self.communication_stats:
|
||||
self.communication_stats[recipient].total_messages += 1
|
||||
self.communication_stats[recipient].messages_received += 1
|
||||
self.communication_stats[recipient].last_activity = datetime.utcnow()
|
||||
|
||||
elif action == 'read':
|
||||
if recipient in self.communication_stats:
|
||||
self.communication_stats[recipient].last_activity = datetime.utcnow()
|
||||
|
||||
async def _process_message_queue(self):
|
||||
"""Process message queue for delivery"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
if self.message_queue:
|
||||
message = self.message_queue.pop(0)
|
||||
|
||||
# Simulate delivery
|
||||
await asyncio.sleep(0.1)
|
||||
await self.deliver_message(message.id)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message queue: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _cleanup_expired_messages(self):
|
||||
"""Clean up expired messages"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
current_time = datetime.utcnow()
|
||||
expired_messages = []
|
||||
|
||||
for message_id, message in self.messages.items():
|
||||
if message.expires_at and current_time > message.expires_at:
|
||||
expired_messages.append(message_id)
|
||||
|
||||
for message_id in expired_messages:
|
||||
del self.messages[message_id]
|
||||
# Remove from agent message lists
|
||||
for agent_id, message_ids in self.agent_messages.items():
|
||||
if message_id in message_ids:
|
||||
message_ids.remove(message_id)
|
||||
|
||||
if expired_messages:
|
||||
logger.info(f"Cleaned up {len(expired_messages)} expired messages")
|
||||
|
||||
await asyncio.sleep(3600) # Check every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up messages: {e}")
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def _cleanup_inactive_channels(self):
|
||||
"""Clean up inactive channels"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
current_time = datetime.utcnow()
|
||||
inactive_channels = []
|
||||
|
||||
for channel_id, channel in self.channels.items():
|
||||
if channel.is_active and current_time > channel.last_activity + timedelta(seconds=self.channel_timeout):
|
||||
inactive_channels.append(channel_id)
|
||||
|
||||
for channel_id in inactive_channels:
|
||||
channel = self.channels[channel_id]
|
||||
channel.is_active = False
|
||||
|
||||
# Update stats
|
||||
if channel.agent1 in self.communication_stats:
|
||||
self.communication_stats[channel.agent1].active_channels = max(0, self.communication_stats[channel.agent1].active_channels - 1)
|
||||
if channel.agent2 in self.communication_stats:
|
||||
self.communication_stats[channel.agent2].active_channels = max(0, self.communication_stats[channel.agent2].active_channels - 1)
|
||||
|
||||
if inactive_channels:
|
||||
logger.info(f"Cleaned up {len(inactive_channels)} inactive channels")
|
||||
|
||||
await asyncio.sleep(3600) # Check every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up channels: {e}")
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
def _initialize_default_templates(self):
|
||||
"""Initialize default message templates"""
|
||||
|
||||
templates = [
|
||||
MessageTemplate(
|
||||
id="task_request_default",
|
||||
name="Task Request",
|
||||
description="Default template for task requests",
|
||||
message_type=MessageType.TASK_REQUEST,
|
||||
content_template="Hello! I have a task for you: {task_description}. Budget: {budget} AITBC. Deadline: {deadline}.",
|
||||
variables=["task_description", "budget", "deadline"],
|
||||
base_price=0.002,
|
||||
is_active=True,
|
||||
creator="system"
|
||||
),
|
||||
MessageTemplate(
|
||||
id="collaboration_invite",
|
||||
name="Collaboration Invite",
|
||||
description="Template for inviting agents to collaborate",
|
||||
message_type=MessageType.COLLABORATION,
|
||||
content_template="I'd like to collaborate on {project_name}. Your role would be {role_description}. Interested?",
|
||||
variables=["project_name", "role_description"],
|
||||
base_price=0.003,
|
||||
is_active=True,
|
||||
creator="system"
|
||||
),
|
||||
MessageTemplate(
|
||||
id="notification_update",
|
||||
name="Notification Update",
|
||||
description="Template for sending notifications",
|
||||
message_type=MessageType.NOTIFICATION,
|
||||
content_template="Notification: {notification_type}. {message}. Action required: {action_required}.",
|
||||
variables=["notification_type", "message", "action_required"],
|
||||
base_price=0.001,
|
||||
is_active=True,
|
||||
creator="system"
|
||||
)
|
||||
]
|
||||
|
||||
for template in templates:
|
||||
self.message_templates[template.id] = template
|
||||
|
||||
async def _load_communication_data(self):
|
||||
"""Load existing communication data"""
|
||||
# In production, load from database
|
||||
pass
|
||||
|
||||
async def export_communication_data(self, format: str = "json") -> str:
|
||||
"""Export communication data"""
|
||||
|
||||
data = {
|
||||
"messages": {k: asdict(v) for k, v in self.messages.items()},
|
||||
"channels": {k: asdict(v) for k, v in self.channels.items()},
|
||||
"templates": {k: asdict(v) for k, v in self.message_templates.items()},
|
||||
"export_timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if format.lower() == "json":
|
||||
return json.dumps(data, indent=2, default=str)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
async def import_communication_data(self, data: str, format: str = "json"):
|
||||
"""Import communication data"""
|
||||
|
||||
if format.lower() == "json":
|
||||
parsed_data = json.loads(data)
|
||||
|
||||
# Import messages
|
||||
for message_id, message_data in parsed_data.get("messages", {}).items():
|
||||
message_data['timestamp'] = datetime.fromisoformat(message_data['timestamp'])
|
||||
self.messages[message_id] = Message(**message_data)
|
||||
|
||||
# Import channels
|
||||
for channel_id, channel_data in parsed_data.get("channels", {}).items():
|
||||
channel_data['created_timestamp'] = datetime.fromisoformat(channel_data['created_timestamp'])
|
||||
channel_data['last_activity'] = datetime.fromisoformat(channel_data['last_activity'])
|
||||
self.channels[channel_id] = CommunicationChannel(**channel_data)
|
||||
|
||||
logger.info("Communication data imported successfully")
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
712
apps/coordinator-api/src/app/services/agent_orchestrator.py
Normal file
712
apps/coordinator-api/src/app/services/agent_orchestrator.py
Normal file
@@ -0,0 +1,712 @@
|
||||
"""
|
||||
Agent Orchestrator Service for OpenClaw Autonomous Economics
|
||||
Implements multi-agent coordination and sub-task management
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple, Set
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import json
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
from .task_decomposition import TaskDecomposition, SubTask, SubTaskStatus, GPU_Tier
|
||||
from .bid_strategy_engine import BidResult, BidStrategy, UrgencyLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrchestratorStatus(str, Enum):
|
||||
"""Orchestrator status"""
|
||||
IDLE = "idle"
|
||||
PLANNING = "planning"
|
||||
EXECUTING = "executing"
|
||||
MONITORING = "monitoring"
|
||||
FAILED = "failed"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
class AgentStatus(str, Enum):
|
||||
"""Agent status"""
|
||||
AVAILABLE = "available"
|
||||
BUSY = "busy"
|
||||
OFFLINE = "offline"
|
||||
MAINTENANCE = "maintenance"
|
||||
|
||||
|
||||
class ResourceType(str, Enum):
|
||||
"""Resource types"""
|
||||
GPU = "gpu"
|
||||
CPU = "cpu"
|
||||
MEMORY = "memory"
|
||||
STORAGE = "storage"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentCapability:
|
||||
"""Agent capability definition"""
|
||||
agent_id: str
|
||||
supported_task_types: List[str]
|
||||
gpu_tier: GPU_Tier
|
||||
max_concurrent_tasks: int
|
||||
current_load: int
|
||||
performance_score: float # 0-1
|
||||
cost_per_hour: float
|
||||
reliability_score: float # 0-1
|
||||
last_updated: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceAllocation:
|
||||
"""Resource allocation for an agent"""
|
||||
agent_id: str
|
||||
sub_task_id: str
|
||||
resource_type: ResourceType
|
||||
allocated_amount: int
|
||||
allocated_at: datetime
|
||||
expected_duration: float
|
||||
actual_duration: Optional[float] = None
|
||||
cost: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAssignment:
|
||||
"""Assignment of sub-task to agent"""
|
||||
sub_task_id: str
|
||||
agent_id: str
|
||||
assigned_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
status: SubTaskStatus = SubTaskStatus.PENDING
|
||||
bid_result: Optional[BidResult] = None
|
||||
resource_allocations: List[ResourceAllocation] = field(default_factory=list)
|
||||
error_message: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestrationPlan:
|
||||
"""Complete orchestration plan for a task"""
|
||||
task_id: str
|
||||
decomposition: TaskDecomposition
|
||||
agent_assignments: List[AgentAssignment]
|
||||
execution_timeline: Dict[str, datetime]
|
||||
resource_requirements: Dict[ResourceType, int]
|
||||
estimated_cost: float
|
||||
confidence_score: float
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class AgentOrchestrator:
|
||||
"""Multi-agent orchestration service"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.status = OrchestratorStatus.IDLE
|
||||
|
||||
# Agent registry
|
||||
self.agent_capabilities: Dict[str, AgentCapability] = {}
|
||||
self.agent_status: Dict[str, AgentStatus] = {}
|
||||
|
||||
# Orchestration tracking
|
||||
self.active_plans: Dict[str, OrchestrationPlan] = {}
|
||||
self.completed_plans: List[OrchestrationPlan] = []
|
||||
self.failed_plans: List[OrchestrationPlan] = []
|
||||
|
||||
# Resource tracking
|
||||
self.resource_allocations: Dict[str, List[ResourceAllocation]] = {}
|
||||
self.resource_utilization: Dict[ResourceType, float] = {}
|
||||
|
||||
# Performance metrics
|
||||
self.orchestration_metrics = {
|
||||
"total_tasks": 0,
|
||||
"successful_tasks": 0,
|
||||
"failed_tasks": 0,
|
||||
"average_execution_time": 0.0,
|
||||
"average_cost": 0.0,
|
||||
"agent_utilization": 0.0
|
||||
}
|
||||
|
||||
# Configuration
|
||||
self.max_concurrent_plans = config.get("max_concurrent_plans", 10)
|
||||
self.assignment_timeout = config.get("assignment_timeout", 300) # 5 minutes
|
||||
self.monitoring_interval = config.get("monitoring_interval", 30) # 30 seconds
|
||||
self.retry_limit = config.get("retry_limit", 3)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the orchestrator"""
|
||||
logger.info("Initializing Agent Orchestrator")
|
||||
|
||||
# Load agent capabilities
|
||||
await self._load_agent_capabilities()
|
||||
|
||||
# Start monitoring
|
||||
asyncio.create_task(self._monitor_executions())
|
||||
asyncio.create_task(self._update_agent_status())
|
||||
|
||||
logger.info("Agent Orchestrator initialized")
|
||||
|
||||
async def orchestrate_task(
|
||||
self,
|
||||
task_id: str,
|
||||
decomposition: TaskDecomposition,
|
||||
budget_limit: Optional[float] = None,
|
||||
deadline: Optional[datetime] = None
|
||||
) -> OrchestrationPlan:
|
||||
"""Orchestrate execution of a decomposed task"""
|
||||
|
||||
try:
|
||||
logger.info(f"Orchestrating task {task_id} with {len(decomposition.sub_tasks)} sub-tasks")
|
||||
|
||||
# Check capacity
|
||||
if len(self.active_plans) >= self.max_concurrent_plans:
|
||||
raise Exception("Orchestrator at maximum capacity")
|
||||
|
||||
self.status = OrchestratorStatus.PLANNING
|
||||
|
||||
# Create orchestration plan
|
||||
plan = await self._create_orchestration_plan(
|
||||
task_id, decomposition, budget_limit, deadline
|
||||
)
|
||||
|
||||
# Execute assignments
|
||||
await self._execute_assignments(plan)
|
||||
|
||||
# Start monitoring
|
||||
self.active_plans[task_id] = plan
|
||||
self.status = OrchestratorStatus.MONITORING
|
||||
|
||||
# Update metrics
|
||||
self.orchestration_metrics["total_tasks"] += 1
|
||||
|
||||
logger.info(f"Task {task_id} orchestration plan created and started")
|
||||
return plan
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to orchestrate task {task_id}: {e}")
|
||||
self.status = OrchestratorStatus.FAILED
|
||||
raise
|
||||
|
||||
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get status of orchestrated task"""
|
||||
|
||||
if task_id not in self.active_plans:
|
||||
return {"status": "not_found"}
|
||||
|
||||
plan = self.active_plans[task_id]
|
||||
|
||||
# Count sub-task statuses
|
||||
status_counts = {}
|
||||
for status in SubTaskStatus:
|
||||
status_counts[status.value] = 0
|
||||
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for assignment in plan.agent_assignments:
|
||||
status_counts[assignment.status.value] += 1
|
||||
|
||||
if assignment.status == SubTaskStatus.COMPLETED:
|
||||
completed_count += 1
|
||||
elif assignment.status == SubTaskStatus.FAILED:
|
||||
failed_count += 1
|
||||
|
||||
# Determine overall status
|
||||
total_sub_tasks = len(plan.agent_assignments)
|
||||
if completed_count == total_sub_tasks:
|
||||
overall_status = "completed"
|
||||
elif failed_count > 0:
|
||||
overall_status = "failed"
|
||||
elif completed_count > 0:
|
||||
overall_status = "in_progress"
|
||||
else:
|
||||
overall_status = "pending"
|
||||
|
||||
return {
|
||||
"status": overall_status,
|
||||
"progress": completed_count / total_sub_tasks if total_sub_tasks > 0 else 0,
|
||||
"completed_sub_tasks": completed_count,
|
||||
"failed_sub_tasks": failed_count,
|
||||
"total_sub_tasks": total_sub_tasks,
|
||||
"estimated_cost": plan.estimated_cost,
|
||||
"actual_cost": await self._calculate_actual_cost(plan),
|
||||
"started_at": plan.created_at.isoformat(),
|
||||
"assignments": [
|
||||
{
|
||||
"sub_task_id": a.sub_task_id,
|
||||
"agent_id": a.agent_id,
|
||||
"status": a.status.value,
|
||||
"assigned_at": a.assigned_at.isoformat(),
|
||||
"started_at": a.started_at.isoformat() if a.started_at else None,
|
||||
"completed_at": a.completed_at.isoformat() if a.completed_at else None
|
||||
}
|
||||
for a in plan.agent_assignments
|
||||
]
|
||||
}
|
||||
|
||||
async def cancel_task(self, task_id: str) -> bool:
|
||||
"""Cancel task orchestration"""
|
||||
|
||||
if task_id not in self.active_plans:
|
||||
return False
|
||||
|
||||
plan = self.active_plans[task_id]
|
||||
|
||||
# Cancel all active assignments
|
||||
for assignment in plan.agent_assignments:
|
||||
if assignment.status in [SubTaskStatus.PENDING, SubTaskStatus.IN_PROGRESS]:
|
||||
assignment.status = SubTaskStatus.CANCELLED
|
||||
await self._release_agent_resources(assignment.agent_id, assignment.sub_task_id)
|
||||
|
||||
# Move to failed plans
|
||||
self.failed_plans.append(plan)
|
||||
del self.active_plans[task_id]
|
||||
|
||||
logger.info(f"Task {task_id} cancelled")
|
||||
return True
|
||||
|
||||
async def retry_failed_sub_tasks(self, task_id: str) -> List[str]:
|
||||
"""Retry failed sub-tasks"""
|
||||
|
||||
if task_id not in self.active_plans:
|
||||
return []
|
||||
|
||||
plan = self.active_plans[task_id]
|
||||
retried_tasks = []
|
||||
|
||||
for assignment in plan.agent_assignments:
|
||||
if assignment.status == SubTaskStatus.FAILED and assignment.retry_count < self.retry_limit:
|
||||
# Reset assignment
|
||||
assignment.status = SubTaskStatus.PENDING
|
||||
assignment.started_at = None
|
||||
assignment.completed_at = None
|
||||
assignment.error_message = None
|
||||
assignment.retry_count += 1
|
||||
|
||||
# Release resources
|
||||
await self._release_agent_resources(assignment.agent_id, assignment.sub_task_id)
|
||||
|
||||
# Re-assign
|
||||
await self._assign_sub_task(assignment.sub_task_id, plan)
|
||||
|
||||
retried_tasks.append(assignment.sub_task_id)
|
||||
logger.info(f"Retrying sub-task {assignment.sub_task_id} (attempt {assignment.retry_count + 1})")
|
||||
|
||||
return retried_tasks
|
||||
|
||||
async def register_agent(self, capability: AgentCapability):
|
||||
"""Register a new agent"""
|
||||
|
||||
self.agent_capabilities[capability.agent_id] = capability
|
||||
self.agent_status[capability.agent_id] = AgentStatus.AVAILABLE
|
||||
|
||||
logger.info(f"Registered agent {capability.agent_id}")
|
||||
|
||||
async def update_agent_status(self, agent_id: str, status: AgentStatus):
|
||||
"""Update agent status"""
|
||||
|
||||
if agent_id in self.agent_status:
|
||||
self.agent_status[agent_id] = status
|
||||
logger.info(f"Updated agent {agent_id} status to {status}")
|
||||
|
||||
async def get_available_agents(self, task_type: str, gpu_tier: GPU_Tier) -> List[AgentCapability]:
|
||||
"""Get available agents for task"""
|
||||
|
||||
available_agents = []
|
||||
|
||||
for agent_id, capability in self.agent_capabilities.items():
|
||||
if (self.agent_status.get(agent_id) == AgentStatus.AVAILABLE and
|
||||
task_type in capability.supported_task_types and
|
||||
capability.gpu_tier == gpu_tier and
|
||||
capability.current_load < capability.max_concurrent_tasks):
|
||||
available_agents.append(capability)
|
||||
|
||||
# Sort by performance score
|
||||
available_agents.sort(key=lambda x: x.performance_score, reverse=True)
|
||||
|
||||
return available_agents
|
||||
|
||||
async def get_orchestration_metrics(self) -> Dict[str, Any]:
|
||||
"""Get orchestration performance metrics"""
|
||||
|
||||
return {
|
||||
"orchestrator_status": self.status.value,
|
||||
"active_plans": len(self.active_plans),
|
||||
"completed_plans": len(self.completed_plans),
|
||||
"failed_plans": len(self.failed_plans),
|
||||
"registered_agents": len(self.agent_capabilities),
|
||||
"available_agents": len([s for s in self.agent_status.values() if s == AgentStatus.AVAILABLE]),
|
||||
"metrics": self.orchestration_metrics,
|
||||
"resource_utilization": self.resource_utilization
|
||||
}
|
||||
|
||||
async def _create_orchestration_plan(
|
||||
self,
|
||||
task_id: str,
|
||||
decomposition: TaskDecomposition,
|
||||
budget_limit: Optional[float],
|
||||
deadline: Optional[datetime]
|
||||
) -> OrchestrationPlan:
|
||||
"""Create detailed orchestration plan"""
|
||||
|
||||
assignments = []
|
||||
execution_timeline = {}
|
||||
resource_requirements = {rt: 0 for rt in ResourceType}
|
||||
total_cost = 0.0
|
||||
|
||||
# Process each execution stage
|
||||
for stage_idx, stage_sub_tasks in enumerate(decomposition.execution_plan):
|
||||
stage_start = datetime.utcnow() + timedelta(hours=stage_idx * 2) # Estimate 2 hours per stage
|
||||
|
||||
for sub_task_id in stage_sub_tasks:
|
||||
# Find sub-task
|
||||
sub_task = next(st for st in decomposition.sub_tasks if st.sub_task_id == sub_task_id)
|
||||
|
||||
# Create assignment (will be filled during execution)
|
||||
assignment = AgentAssignment(
|
||||
sub_task_id=sub_task_id,
|
||||
agent_id="", # Will be assigned during execution
|
||||
assigned_at=datetime.utcnow()
|
||||
)
|
||||
assignments.append(assignment)
|
||||
|
||||
# Calculate resource requirements
|
||||
resource_requirements[ResourceType.GPU] += 1
|
||||
resource_requirements[ResourceType.MEMORY] += sub_task.requirements.memory_requirement
|
||||
|
||||
# Set timeline
|
||||
execution_timeline[sub_task_id] = stage_start
|
||||
|
||||
# Calculate confidence score
|
||||
confidence_score = await self._calculate_plan_confidence(decomposition, budget_limit, deadline)
|
||||
|
||||
return OrchestrationPlan(
|
||||
task_id=task_id,
|
||||
decomposition=decomposition,
|
||||
agent_assignments=assignments,
|
||||
execution_timeline=execution_timeline,
|
||||
resource_requirements=resource_requirements,
|
||||
estimated_cost=total_cost,
|
||||
confidence_score=confidence_score
|
||||
)
|
||||
|
||||
async def _execute_assignments(self, plan: OrchestrationPlan):
|
||||
"""Execute agent assignments"""
|
||||
|
||||
for assignment in plan.agent_assignments:
|
||||
await self._assign_sub_task(assignment.sub_task_id, plan)
|
||||
|
||||
async def _assign_sub_task(self, sub_task_id: str, plan: OrchestrationPlan):
|
||||
"""Assign sub-task to suitable agent"""
|
||||
|
||||
# Find sub-task
|
||||
sub_task = next(st for st in plan.decomposition.sub_tasks if st.sub_task_id == sub_task_id)
|
||||
|
||||
# Get available agents
|
||||
available_agents = await self.get_available_agents(
|
||||
sub_task.requirements.task_type.value,
|
||||
sub_task.requirements.gpu_tier
|
||||
)
|
||||
|
||||
if not available_agents:
|
||||
raise Exception(f"No available agents for sub-task {sub_task_id}")
|
||||
|
||||
# Select best agent
|
||||
best_agent = await self._select_best_agent(available_agents, sub_task)
|
||||
|
||||
# Update assignment
|
||||
assignment = next(a for a in plan.agent_assignments if a.sub_task_id == sub_task_id)
|
||||
assignment.agent_id = best_agent.agent_id
|
||||
assignment.status = SubTaskStatus.ASSIGNED
|
||||
|
||||
# Update agent load
|
||||
self.agent_capabilities[best_agent.agent_id].current_load += 1
|
||||
self.agent_status[best_agent.agent_id] = AgentStatus.BUSY
|
||||
|
||||
# Allocate resources
|
||||
await self._allocate_resources(best_agent.agent_id, sub_task_id, sub_task.requirements)
|
||||
|
||||
logger.info(f"Assigned sub-task {sub_task_id} to agent {best_agent.agent_id}")
|
||||
|
||||
async def _select_best_agent(
|
||||
self,
|
||||
available_agents: List[AgentCapability],
|
||||
sub_task: SubTask
|
||||
) -> AgentCapability:
|
||||
"""Select best agent for sub-task"""
|
||||
|
||||
# Score agents based on multiple factors
|
||||
scored_agents = []
|
||||
|
||||
for agent in available_agents:
|
||||
score = 0.0
|
||||
|
||||
# Performance score (40% weight)
|
||||
score += agent.performance_score * 0.4
|
||||
|
||||
# Cost efficiency (30% weight)
|
||||
cost_efficiency = min(1.0, 0.05 / agent.cost_per_hour) # Normalize around 0.05 AITBC/hour
|
||||
score += cost_efficiency * 0.3
|
||||
|
||||
# Reliability (20% weight)
|
||||
score += agent.reliability_score * 0.2
|
||||
|
||||
# Current load (10% weight)
|
||||
load_factor = 1.0 - (agent.current_load / agent.max_concurrent_tasks)
|
||||
score += load_factor * 0.1
|
||||
|
||||
scored_agents.append((agent, score))
|
||||
|
||||
# Select highest scoring agent
|
||||
scored_agents.sort(key=lambda x: x[1], reverse=True)
|
||||
return scored_agents[0][0]
|
||||
|
||||
async def _allocate_resources(
|
||||
self,
|
||||
agent_id: str,
|
||||
sub_task_id: str,
|
||||
requirements
|
||||
):
|
||||
"""Allocate resources for sub-task"""
|
||||
|
||||
allocations = []
|
||||
|
||||
# GPU allocation
|
||||
gpu_allocation = ResourceAllocation(
|
||||
agent_id=agent_id,
|
||||
sub_task_id=sub_task_id,
|
||||
resource_type=ResourceType.GPU,
|
||||
allocated_amount=1,
|
||||
allocated_at=datetime.utcnow(),
|
||||
expected_duration=requirements.estimated_duration
|
||||
)
|
||||
allocations.append(gpu_allocation)
|
||||
|
||||
# Memory allocation
|
||||
memory_allocation = ResourceAllocation(
|
||||
agent_id=agent_id,
|
||||
sub_task_id=sub_task_id,
|
||||
resource_type=ResourceType.MEMORY,
|
||||
allocated_amount=requirements.memory_requirement,
|
||||
allocated_at=datetime.utcnow(),
|
||||
expected_duration=requirements.estimated_duration
|
||||
)
|
||||
allocations.append(memory_allocation)
|
||||
|
||||
# Store allocations
|
||||
if agent_id not in self.resource_allocations:
|
||||
self.resource_allocations[agent_id] = []
|
||||
self.resource_allocations[agent_id].extend(allocations)
|
||||
|
||||
async def _release_agent_resources(self, agent_id: str, sub_task_id: str):
|
||||
"""Release resources from agent"""
|
||||
|
||||
if agent_id in self.resource_allocations:
|
||||
# Remove allocations for this sub-task
|
||||
self.resource_allocations[agent_id] = [
|
||||
alloc for alloc in self.resource_allocations[agent_id]
|
||||
if alloc.sub_task_id != sub_task_id
|
||||
]
|
||||
|
||||
# Update agent load
|
||||
if agent_id in self.agent_capabilities:
|
||||
self.agent_capabilities[agent_id].current_load = max(0,
|
||||
self.agent_capabilities[agent_id].current_load - 1)
|
||||
|
||||
# Update status if no load
|
||||
if self.agent_capabilities[agent_id].current_load == 0:
|
||||
self.agent_status[agent_id] = AgentStatus.AVAILABLE
|
||||
|
||||
async def _monitor_executions(self):
|
||||
"""Monitor active executions"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check all active plans
|
||||
completed_tasks = []
|
||||
failed_tasks = []
|
||||
|
||||
for task_id, plan in list(self.active_plans.items()):
|
||||
# Check if all sub-tasks are completed
|
||||
all_completed = all(
|
||||
a.status == SubTaskStatus.COMPLETED for a in plan.agent_assignments
|
||||
)
|
||||
any_failed = any(
|
||||
a.status == SubTaskStatus.FAILED for a in plan.agent_assignments
|
||||
)
|
||||
|
||||
if all_completed:
|
||||
completed_tasks.append(task_id)
|
||||
elif any_failed:
|
||||
# Check if all failed tasks have exceeded retry limit
|
||||
all_failed_exhausted = all(
|
||||
a.status == SubTaskStatus.FAILED and a.retry_count >= self.retry_limit
|
||||
for a in plan.agent_assignments
|
||||
if a.status == SubTaskStatus.FAILED
|
||||
)
|
||||
if all_failed_exhausted:
|
||||
failed_tasks.append(task_id)
|
||||
|
||||
# Move completed/failed tasks
|
||||
for task_id in completed_tasks:
|
||||
plan = self.active_plans[task_id]
|
||||
self.completed_plans.append(plan)
|
||||
del self.active_plans[task_id]
|
||||
self.orchestration_metrics["successful_tasks"] += 1
|
||||
logger.info(f"Task {task_id} completed successfully")
|
||||
|
||||
for task_id in failed_tasks:
|
||||
plan = self.active_plans[task_id]
|
||||
self.failed_plans.append(plan)
|
||||
del self.active_plans[task_id]
|
||||
self.orchestration_metrics["failed_tasks"] += 1
|
||||
logger.info(f"Task {task_id} failed")
|
||||
|
||||
# Update resource utilization
|
||||
await self._update_resource_utilization()
|
||||
|
||||
await asyncio.sleep(self.monitoring_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in execution monitoring: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _update_agent_status(self):
|
||||
"""Update agent status periodically"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check agent health and update status
|
||||
for agent_id in self.agent_capabilities.keys():
|
||||
# In a real implementation, this would ping agents or check health endpoints
|
||||
# For now, assume agents are healthy if they have recent updates
|
||||
|
||||
capability = self.agent_capabilities[agent_id]
|
||||
time_since_update = datetime.utcnow() - capability.last_updated
|
||||
|
||||
if time_since_update > timedelta(minutes=5):
|
||||
if self.agent_status[agent_id] != AgentStatus.OFFLINE:
|
||||
self.agent_status[agent_id] = AgentStatus.OFFLINE
|
||||
logger.warning(f"Agent {agent_id} marked as offline")
|
||||
elif self.agent_status[agent_id] == AgentStatus.OFFLINE:
|
||||
self.agent_status[agent_id] = AgentStatus.AVAILABLE
|
||||
logger.info(f"Agent {agent_id} back online")
|
||||
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating agent status: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _update_resource_utilization(self):
|
||||
"""Update resource utilization metrics"""
|
||||
|
||||
total_resources = {rt: 0 for rt in ResourceType}
|
||||
used_resources = {rt: 0 for rt in ResourceType}
|
||||
|
||||
# Calculate total resources
|
||||
for capability in self.agent_capabilities.values():
|
||||
total_resources[ResourceType.GPU] += capability.max_concurrent_tasks
|
||||
# Add other resource types as needed
|
||||
|
||||
# Calculate used resources
|
||||
for allocations in self.resource_allocations.values():
|
||||
for allocation in allocations:
|
||||
used_resources[allocation.resource_type] += allocation.allocated_amount
|
||||
|
||||
# Calculate utilization
|
||||
for resource_type in ResourceType:
|
||||
total = total_resources[resource_type]
|
||||
used = used_resources[resource_type]
|
||||
self.resource_utilization[resource_type] = used / total if total > 0 else 0.0
|
||||
|
||||
async def _calculate_plan_confidence(
|
||||
self,
|
||||
decomposition: TaskDecomposition,
|
||||
budget_limit: Optional[float],
|
||||
deadline: Optional[datetime]
|
||||
) -> float:
|
||||
"""Calculate confidence in orchestration plan"""
|
||||
|
||||
confidence = decomposition.confidence_score
|
||||
|
||||
# Adjust for budget constraints
|
||||
if budget_limit and decomposition.estimated_total_cost > budget_limit:
|
||||
confidence *= 0.7
|
||||
|
||||
# Adjust for deadline
|
||||
if deadline:
|
||||
time_to_deadline = (deadline - datetime.utcnow()).total_seconds() / 3600
|
||||
if time_to_deadline < decomposition.estimated_total_duration:
|
||||
confidence *= 0.6
|
||||
|
||||
# Adjust for agent availability
|
||||
available_agents = len([
|
||||
s for s in self.agent_status.values() if s == AgentStatus.AVAILABLE
|
||||
])
|
||||
total_agents = len(self.agent_capabilities)
|
||||
|
||||
if total_agents > 0:
|
||||
availability_ratio = available_agents / total_agents
|
||||
confidence *= (0.5 + availability_ratio * 0.5)
|
||||
|
||||
return max(0.1, min(0.95, confidence))
|
||||
|
||||
async def _calculate_actual_cost(self, plan: OrchestrationPlan) -> float:
|
||||
"""Calculate actual cost of orchestration"""
|
||||
|
||||
actual_cost = 0.0
|
||||
|
||||
for assignment in plan.agent_assignments:
|
||||
if assignment.agent_id in self.agent_capabilities:
|
||||
agent = self.agent_capabilities[assignment.agent_id]
|
||||
|
||||
# Calculate cost based on actual duration
|
||||
duration = assignment.actual_duration or 1.0 # Default to 1 hour
|
||||
cost = agent.cost_per_hour * duration
|
||||
actual_cost += cost
|
||||
|
||||
return actual_cost
|
||||
|
||||
async def _load_agent_capabilities(self):
|
||||
"""Load agent capabilities from storage"""
|
||||
|
||||
# In a real implementation, this would load from database or configuration
|
||||
# For now, create some mock agents
|
||||
|
||||
mock_agents = [
|
||||
AgentCapability(
|
||||
agent_id="agent_001",
|
||||
supported_task_types=["text_processing", "data_analysis"],
|
||||
gpu_tier=GPU_Tier.MID_RANGE_GPU,
|
||||
max_concurrent_tasks=3,
|
||||
current_load=0,
|
||||
performance_score=0.85,
|
||||
cost_per_hour=0.05,
|
||||
reliability_score=0.92
|
||||
),
|
||||
AgentCapability(
|
||||
agent_id="agent_002",
|
||||
supported_task_types=["image_processing", "model_inference"],
|
||||
gpu_tier=GPU_Tier.HIGH_END_GPU,
|
||||
max_concurrent_tasks=2,
|
||||
current_load=0,
|
||||
performance_score=0.92,
|
||||
cost_per_hour=0.09,
|
||||
reliability_score=0.88
|
||||
),
|
||||
AgentCapability(
|
||||
agent_id="agent_003",
|
||||
supported_task_types=["compute_intensive", "model_training"],
|
||||
gpu_tier=GPU_Tier.PREMIUM_GPU,
|
||||
max_concurrent_tasks=1,
|
||||
current_load=0,
|
||||
performance_score=0.96,
|
||||
cost_per_hour=0.15,
|
||||
reliability_score=0.95
|
||||
)
|
||||
]
|
||||
|
||||
for agent in mock_agents:
|
||||
await self.register_agent(agent)
|
||||
@@ -0,0 +1,898 @@
|
||||
"""
|
||||
AI Agent Service Marketplace Service
|
||||
Implements a sophisticated marketplace where agents can offer specialized services
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import json
|
||||
import hashlib
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceStatus(str, Enum):
|
||||
"""Service status types"""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
SUSPENDED = "suspended"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class RequestStatus(str, Enum):
|
||||
"""Service request status types"""
|
||||
PENDING = "pending"
|
||||
ACCEPTED = "accepted"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class GuildStatus(str, Enum):
|
||||
"""Guild status types"""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
SUSPENDED = "suspended"
|
||||
|
||||
|
||||
class ServiceType(str, Enum):
|
||||
"""Service categories"""
|
||||
DATA_ANALYSIS = "data_analysis"
|
||||
CONTENT_CREATION = "content_creation"
|
||||
RESEARCH = "research"
|
||||
CONSULTING = "consulting"
|
||||
DEVELOPMENT = "development"
|
||||
DESIGN = "design"
|
||||
MARKETING = "marketing"
|
||||
TRANSLATION = "translation"
|
||||
WRITING = "writing"
|
||||
ANALYSIS = "analysis"
|
||||
PREDICTION = "prediction"
|
||||
OPTIMIZATION = "optimization"
|
||||
AUTOMATION = "automation"
|
||||
MONITORING = "monitoring"
|
||||
TESTING = "testing"
|
||||
SECURITY = "security"
|
||||
INTEGRATION = "integration"
|
||||
CUSTOMIZATION = "customization"
|
||||
TRAINING = "training"
|
||||
SUPPORT = "support"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Service:
|
||||
"""Agent service information"""
|
||||
id: str
|
||||
agent_id: str
|
||||
service_type: ServiceType
|
||||
name: str
|
||||
description: str
|
||||
metadata: Dict[str, Any]
|
||||
base_price: float
|
||||
reputation: int
|
||||
status: ServiceStatus
|
||||
total_earnings: float
|
||||
completed_jobs: int
|
||||
average_rating: float
|
||||
rating_count: int
|
||||
listed_at: datetime
|
||||
last_updated: datetime
|
||||
guild_id: Optional[str] = None
|
||||
tags: List[str] = field(default_factory=list)
|
||||
capabilities: List[str] = field(default_factory=list)
|
||||
requirements: List[str] = field(default_factory=list)
|
||||
pricing_model: str = "fixed" # fixed, hourly, per_task
|
||||
estimated_duration: int = 0 # in hours
|
||||
availability: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceRequest:
|
||||
"""Service request information"""
|
||||
id: str
|
||||
client_id: str
|
||||
service_id: str
|
||||
budget: float
|
||||
requirements: str
|
||||
deadline: datetime
|
||||
status: RequestStatus
|
||||
assigned_agent: Optional[str] = None
|
||||
accepted_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
payment: float = 0.0
|
||||
rating: int = 0
|
||||
review: str = ""
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
results_hash: Optional[str] = None
|
||||
priority: str = "normal" # low, normal, high, urgent
|
||||
complexity: str = "medium" # simple, medium, complex
|
||||
confidentiality: str = "public" # public, private, confidential
|
||||
|
||||
|
||||
@dataclass
|
||||
class Guild:
|
||||
"""Agent guild information"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
founder: str
|
||||
service_category: ServiceType
|
||||
member_count: int
|
||||
total_services: int
|
||||
total_earnings: float
|
||||
reputation: int
|
||||
status: GuildStatus
|
||||
created_at: datetime
|
||||
members: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
requirements: List[str] = field(default_factory=list)
|
||||
benefits: List[str] = field(default_factory=list)
|
||||
guild_rules: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceCategory:
|
||||
"""Service category information"""
|
||||
name: str
|
||||
description: str
|
||||
service_count: int
|
||||
total_volume: float
|
||||
average_price: float
|
||||
is_active: bool
|
||||
trending: bool = False
|
||||
popular_services: List[str] = field(default_factory=list)
|
||||
requirements: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketplaceAnalytics:
|
||||
"""Marketplace analytics data"""
|
||||
total_services: int
|
||||
active_services: int
|
||||
total_requests: int
|
||||
pending_requests: int
|
||||
total_volume: float
|
||||
total_guilds: int
|
||||
average_service_price: float
|
||||
popular_categories: List[str]
|
||||
top_agents: List[str]
|
||||
revenue_trends: Dict[str, float]
|
||||
growth_metrics: Dict[str, float]
|
||||
|
||||
|
||||
class AgentServiceMarketplace:
|
||||
"""Service for managing AI agent service marketplace"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.services: Dict[str, Service] = {}
|
||||
self.service_requests: Dict[str, ServiceRequest] = {}
|
||||
self.guilds: Dict[str, Guild] = {}
|
||||
self.categories: Dict[str, ServiceCategory] = {}
|
||||
self.agent_services: Dict[str, List[str]] = {}
|
||||
self.client_requests: Dict[str, List[str]] = {}
|
||||
self.guild_services: Dict[str, List[str]] = {}
|
||||
self.agent_guilds: Dict[str, str] = {}
|
||||
self.services_by_type: Dict[str, List[str]] = {}
|
||||
self.guilds_by_category: Dict[str, List[str]] = {}
|
||||
|
||||
# Configuration
|
||||
self.marketplace_fee = 0.025 # 2.5%
|
||||
self.min_service_price = 0.001
|
||||
self.max_service_price = 1000.0
|
||||
self.min_reputation_to_list = 500
|
||||
self.request_timeout = 7 * 24 * 3600 # 7 days
|
||||
self.rating_weight = 100
|
||||
|
||||
# Initialize categories
|
||||
self._initialize_categories()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the marketplace service"""
|
||||
logger.info("Initializing Agent Service Marketplace")
|
||||
|
||||
# Load existing data
|
||||
await self._load_marketplace_data()
|
||||
|
||||
# Start background tasks
|
||||
asyncio.create_task(self._monitor_request_timeouts())
|
||||
asyncio.create_task(self._update_marketplace_analytics())
|
||||
asyncio.create_task(self._process_service_recommendations())
|
||||
asyncio.create_task(self._maintain_guild_reputation())
|
||||
|
||||
logger.info("Agent Service Marketplace initialized")
|
||||
|
||||
async def list_service(
|
||||
self,
|
||||
agent_id: str,
|
||||
service_type: ServiceType,
|
||||
name: str,
|
||||
description: str,
|
||||
metadata: Dict[str, Any],
|
||||
base_price: float,
|
||||
tags: List[str],
|
||||
capabilities: List[str],
|
||||
requirements: List[str],
|
||||
pricing_model: str = "fixed",
|
||||
estimated_duration: int = 0
|
||||
) -> Service:
|
||||
"""List a new service on the marketplace"""
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
if base_price < self.min_service_price:
|
||||
raise ValueError(f"Price below minimum: {self.min_service_price}")
|
||||
|
||||
if base_price > self.max_service_price:
|
||||
raise ValueError(f"Price above maximum: {self.max_service_price}")
|
||||
|
||||
if not description or len(description) < 10:
|
||||
raise ValueError("Description too short")
|
||||
|
||||
# Check agent reputation (simplified - in production, check with reputation service)
|
||||
agent_reputation = await self._get_agent_reputation(agent_id)
|
||||
if agent_reputation < self.min_reputation_to_list:
|
||||
raise ValueError(f"Insufficient reputation: {agent_reputation}")
|
||||
|
||||
# Generate service ID
|
||||
service_id = await self._generate_service_id()
|
||||
|
||||
# Create service
|
||||
service = Service(
|
||||
id=service_id,
|
||||
agent_id=agent_id,
|
||||
service_type=service_type,
|
||||
name=name,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
base_price=base_price,
|
||||
reputation=agent_reputation,
|
||||
status=ServiceStatus.ACTIVE,
|
||||
total_earnings=0.0,
|
||||
completed_jobs=0,
|
||||
average_rating=0.0,
|
||||
rating_count=0,
|
||||
listed_at=datetime.utcnow(),
|
||||
last_updated=datetime.utcnow(),
|
||||
tags=tags,
|
||||
capabilities=capabilities,
|
||||
requirements=requirements,
|
||||
pricing_model=pricing_model,
|
||||
estimated_duration=estimated_duration,
|
||||
availability={
|
||||
"monday": True,
|
||||
"tuesday": True,
|
||||
"wednesday": True,
|
||||
"thursday": True,
|
||||
"friday": True,
|
||||
"saturday": False,
|
||||
"sunday": False
|
||||
}
|
||||
)
|
||||
|
||||
# Store service
|
||||
self.services[service_id] = service
|
||||
|
||||
# Update mappings
|
||||
if agent_id not in self.agent_services:
|
||||
self.agent_services[agent_id] = []
|
||||
self.agent_services[agent_id].append(service_id)
|
||||
|
||||
if service_type.value not in self.services_by_type:
|
||||
self.services_by_type[service_type.value] = []
|
||||
self.services_by_type[service_type.value].append(service_id)
|
||||
|
||||
# Update category
|
||||
if service_type.value in self.categories:
|
||||
self.categories[service_type.value].service_count += 1
|
||||
|
||||
logger.info(f"Service listed: {service_id} by agent {agent_id}")
|
||||
return service
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list service: {e}")
|
||||
raise
|
||||
|
||||
async def request_service(
|
||||
self,
|
||||
client_id: str,
|
||||
service_id: str,
|
||||
budget: float,
|
||||
requirements: str,
|
||||
deadline: datetime,
|
||||
priority: str = "normal",
|
||||
complexity: str = "medium",
|
||||
confidentiality: str = "public"
|
||||
) -> ServiceRequest:
|
||||
"""Request a service"""
|
||||
|
||||
try:
|
||||
# Validate service
|
||||
if service_id not in self.services:
|
||||
raise ValueError(f"Service not found: {service_id}")
|
||||
|
||||
service = self.services[service_id]
|
||||
|
||||
if service.status != ServiceStatus.ACTIVE:
|
||||
raise ValueError("Service not active")
|
||||
|
||||
if budget < service.base_price:
|
||||
raise ValueError(f"Budget below service price: {service.base_price}")
|
||||
|
||||
if deadline <= datetime.utcnow():
|
||||
raise ValueError("Invalid deadline")
|
||||
|
||||
if deadline > datetime.utcnow() + timedelta(days=365):
|
||||
raise ValueError("Deadline too far in future")
|
||||
|
||||
# Generate request ID
|
||||
request_id = await self._generate_request_id()
|
||||
|
||||
# Create request
|
||||
request = ServiceRequest(
|
||||
id=request_id,
|
||||
client_id=client_id,
|
||||
service_id=service_id,
|
||||
budget=budget,
|
||||
requirements=requirements,
|
||||
deadline=deadline,
|
||||
status=RequestStatus.PENDING,
|
||||
priority=priority,
|
||||
complexity=complexity,
|
||||
confidentiality=confidentiality
|
||||
)
|
||||
|
||||
# Store request
|
||||
self.service_requests[request_id] = request
|
||||
|
||||
# Update mappings
|
||||
if client_id not in self.client_requests:
|
||||
self.client_requests[client_id] = []
|
||||
self.client_requests[client_id].append(request_id)
|
||||
|
||||
# In production, transfer payment to escrow
|
||||
logger.info(f"Service requested: {request_id} for service {service_id}")
|
||||
return request
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to request service: {e}")
|
||||
raise
|
||||
|
||||
async def accept_request(self, request_id: str, agent_id: str) -> bool:
|
||||
"""Accept a service request"""
|
||||
|
||||
try:
|
||||
if request_id not in self.service_requests:
|
||||
raise ValueError(f"Request not found: {request_id}")
|
||||
|
||||
request = self.service_requests[request_id]
|
||||
service = self.services[request.service_id]
|
||||
|
||||
if request.status != RequestStatus.PENDING:
|
||||
raise ValueError("Request not pending")
|
||||
|
||||
if request.assigned_agent:
|
||||
raise ValueError("Request already assigned")
|
||||
|
||||
if service.agent_id != agent_id:
|
||||
raise ValueError("Not service provider")
|
||||
|
||||
if datetime.utcnow() > request.deadline:
|
||||
raise ValueError("Request expired")
|
||||
|
||||
# Update request
|
||||
request.status = RequestStatus.ACCEPTED
|
||||
request.assigned_agent = agent_id
|
||||
request.accepted_at = datetime.utcnow()
|
||||
|
||||
# Calculate dynamic price
|
||||
final_price = await self._calculate_dynamic_price(request.service_id, request.budget)
|
||||
request.payment = final_price
|
||||
|
||||
logger.info(f"Request accepted: {request_id} by agent {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to accept request: {e}")
|
||||
raise
|
||||
|
||||
async def complete_request(
|
||||
self,
|
||||
request_id: str,
|
||||
agent_id: str,
|
||||
results: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Complete a service request"""
|
||||
|
||||
try:
|
||||
if request_id not in self.service_requests:
|
||||
raise ValueError(f"Request not found: {request_id}")
|
||||
|
||||
request = self.service_requests[request_id]
|
||||
service = self.services[request.service_id]
|
||||
|
||||
if request.status != RequestStatus.ACCEPTED:
|
||||
raise ValueError("Request not accepted")
|
||||
|
||||
if request.assigned_agent != agent_id:
|
||||
raise ValueError("Not assigned agent")
|
||||
|
||||
if datetime.utcnow() > request.deadline:
|
||||
raise ValueError("Request expired")
|
||||
|
||||
# Update request
|
||||
request.status = RequestStatus.COMPLETED
|
||||
request.completed_at = datetime.utcnow()
|
||||
request.results_hash = hashlib.sha256(json.dumps(results, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
# Calculate payment
|
||||
payment = request.payment
|
||||
fee = payment * self.marketplace_fee
|
||||
agent_payment = payment - fee
|
||||
|
||||
# Update service stats
|
||||
service.total_earnings += agent_payment
|
||||
service.completed_jobs += 1
|
||||
service.last_updated = datetime.utcnow()
|
||||
|
||||
# Update category
|
||||
if service.service_type.value in self.categories:
|
||||
self.categories[service.service_type.value].total_volume += payment
|
||||
|
||||
# Update guild stats
|
||||
if service.guild_id and service.guild_id in self.guilds:
|
||||
guild = self.guilds[service.guild_id]
|
||||
guild.total_earnings += agent_payment
|
||||
|
||||
# In production, process payment transfers
|
||||
logger.info(f"Request completed: {request_id} with payment {agent_payment}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to complete request: {e}")
|
||||
raise
|
||||
|
||||
async def rate_service(
|
||||
self,
|
||||
request_id: str,
|
||||
client_id: str,
|
||||
rating: int,
|
||||
review: str
|
||||
) -> bool:
|
||||
"""Rate and review a completed service"""
|
||||
|
||||
try:
|
||||
if request_id not in self.service_requests:
|
||||
raise ValueError(f"Request not found: {request_id}")
|
||||
|
||||
request = self.service_requests[request_id]
|
||||
service = self.services[request.service_id]
|
||||
|
||||
if request.status != RequestStatus.COMPLETED:
|
||||
raise ValueError("Request not completed")
|
||||
|
||||
if request.client_id != client_id:
|
||||
raise ValueError("Not request client")
|
||||
|
||||
if rating < 1 or rating > 5:
|
||||
raise ValueError("Invalid rating")
|
||||
|
||||
if datetime.utcnow() > request.deadline + timedelta(days=30):
|
||||
raise ValueError("Rating period expired")
|
||||
|
||||
# Update request
|
||||
request.rating = rating
|
||||
request.review = review
|
||||
|
||||
# Update service rating
|
||||
total_rating = service.average_rating * service.rating_count + rating
|
||||
service.rating_count += 1
|
||||
service.average_rating = total_rating / service.rating_count
|
||||
|
||||
# Update agent reputation
|
||||
reputation_change = await self._calculate_reputation_change(rating, service.reputation)
|
||||
await self._update_agent_reputation(service.agent_id, reputation_change)
|
||||
|
||||
logger.info(f"Service rated: {request_id} with rating {rating}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rate service: {e}")
|
||||
raise
|
||||
|
||||
async def create_guild(
|
||||
self,
|
||||
founder_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
service_category: ServiceType,
|
||||
requirements: List[str],
|
||||
benefits: List[str],
|
||||
guild_rules: Dict[str, Any]
|
||||
) -> Guild:
|
||||
"""Create a new guild"""
|
||||
|
||||
try:
|
||||
if not name or len(name) < 3:
|
||||
raise ValueError("Invalid guild name")
|
||||
|
||||
if service_category not in [s for s in ServiceType]:
|
||||
raise ValueError("Invalid service category")
|
||||
|
||||
# Generate guild ID
|
||||
guild_id = await self._generate_guild_id()
|
||||
|
||||
# Get founder reputation
|
||||
founder_reputation = await self._get_agent_reputation(founder_id)
|
||||
|
||||
# Create guild
|
||||
guild = Guild(
|
||||
id=guild_id,
|
||||
name=name,
|
||||
description=description,
|
||||
founder=founder_id,
|
||||
service_category=service_category,
|
||||
member_count=1,
|
||||
total_services=0,
|
||||
total_earnings=0.0,
|
||||
reputation=founder_reputation,
|
||||
status=GuildStatus.ACTIVE,
|
||||
created_at=datetime.utcnow(),
|
||||
requirements=requirements,
|
||||
benefits=benefits,
|
||||
guild_rules=guild_rules
|
||||
)
|
||||
|
||||
# Add founder as member
|
||||
guild.members[founder_id] = {
|
||||
"joined_at": datetime.utcnow(),
|
||||
"reputation": founder_reputation,
|
||||
"role": "founder",
|
||||
"contributions": 0
|
||||
}
|
||||
|
||||
# Store guild
|
||||
self.guilds[guild_id] = guild
|
||||
|
||||
# Update mappings
|
||||
if service_category.value not in self.guilds_by_category:
|
||||
self.guilds_by_category[service_category.value] = []
|
||||
self.guilds_by_category[service_category.value].append(guild_id)
|
||||
|
||||
self.agent_guilds[founder_id] = guild_id
|
||||
|
||||
logger.info(f"Guild created: {guild_id} by {founder_id}")
|
||||
return guild
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create guild: {e}")
|
||||
raise
|
||||
|
||||
async def join_guild(self, agent_id: str, guild_id: str) -> bool:
|
||||
"""Join a guild"""
|
||||
|
||||
try:
|
||||
if guild_id not in self.guilds:
|
||||
raise ValueError(f"Guild not found: {guild_id}")
|
||||
|
||||
guild = self.guilds[guild_id]
|
||||
|
||||
if agent_id in guild.members:
|
||||
raise ValueError("Already a member")
|
||||
|
||||
if guild.status != GuildStatus.ACTIVE:
|
||||
raise ValueError("Guild not active")
|
||||
|
||||
# Check agent reputation
|
||||
agent_reputation = await self._get_agent_reputation(agent_id)
|
||||
if agent_reputation < guild.reputation // 2:
|
||||
raise ValueError("Insufficient reputation")
|
||||
|
||||
# Add member
|
||||
guild.members[agent_id] = {
|
||||
"joined_at": datetime.utcnow(),
|
||||
"reputation": agent_reputation,
|
||||
"role": "member",
|
||||
"contributions": 0
|
||||
}
|
||||
guild.member_count += 1
|
||||
|
||||
# Update mappings
|
||||
self.agent_guilds[agent_id] = guild_id
|
||||
|
||||
logger.info(f"Agent {agent_id} joined guild {guild_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to join guild: {e}")
|
||||
raise
|
||||
|
||||
async def search_services(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
service_type: Optional[ServiceType] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
min_price: Optional[float] = None,
|
||||
max_price: Optional[float] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Service]:
|
||||
"""Search services with various filters"""
|
||||
|
||||
try:
|
||||
results = []
|
||||
|
||||
# Filter through all services
|
||||
for service in self.services.values():
|
||||
if service.status != ServiceStatus.ACTIVE:
|
||||
continue
|
||||
|
||||
# Apply filters
|
||||
if service_type and service.service_type != service_type:
|
||||
continue
|
||||
|
||||
if min_price and service.base_price < min_price:
|
||||
continue
|
||||
|
||||
if max_price and service.base_price > max_price:
|
||||
continue
|
||||
|
||||
if min_rating and service.average_rating < min_rating:
|
||||
continue
|
||||
|
||||
if tags and not any(tag in service.tags for tag in tags):
|
||||
continue
|
||||
|
||||
if query:
|
||||
query_lower = query.lower()
|
||||
if (query_lower not in service.name.lower() and
|
||||
query_lower not in service.description.lower() and
|
||||
not any(query_lower in tag.lower() for tag in service.tags)):
|
||||
continue
|
||||
|
||||
results.append(service)
|
||||
|
||||
# Sort by relevance (simplified)
|
||||
results.sort(key=lambda x: (x.average_rating, x.reputation), reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
return results[offset:offset + limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search services: {e}")
|
||||
raise
|
||||
|
||||
async def get_agent_services(self, agent_id: str) -> List[Service]:
|
||||
"""Get all services for an agent"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.agent_services:
|
||||
return []
|
||||
|
||||
services = []
|
||||
for service_id in self.agent_services[agent_id]:
|
||||
if service_id in self.services:
|
||||
services.append(self.services[service_id])
|
||||
|
||||
return services
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get agent services: {e}")
|
||||
raise
|
||||
|
||||
async def get_client_requests(self, client_id: str) -> List[ServiceRequest]:
|
||||
"""Get all requests for a client"""
|
||||
|
||||
try:
|
||||
if client_id not in self.client_requests:
|
||||
return []
|
||||
|
||||
requests = []
|
||||
for request_id in self.client_requests[client_id]:
|
||||
if request_id in self.service_requests:
|
||||
requests.append(self.service_requests[request_id])
|
||||
|
||||
return requests
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get client requests: {e}")
|
||||
raise
|
||||
|
||||
async def get_marketplace_analytics(self) -> MarketplaceAnalytics:
|
||||
"""Get marketplace analytics"""
|
||||
|
||||
try:
|
||||
total_services = len(self.services)
|
||||
active_services = len([s for s in self.services.values() if s.status == ServiceStatus.ACTIVE])
|
||||
total_requests = len(self.service_requests)
|
||||
pending_requests = len([r for r in self.service_requests.values() if r.status == RequestStatus.PENDING])
|
||||
total_guilds = len(self.guilds)
|
||||
|
||||
# Calculate total volume
|
||||
total_volume = sum(service.total_earnings for service in self.services.values())
|
||||
|
||||
# Calculate average price
|
||||
active_service_prices = [service.base_price for service in self.services.values() if service.status == ServiceStatus.ACTIVE]
|
||||
average_price = sum(active_service_prices) / len(active_service_prices) if active_service_prices else 0
|
||||
|
||||
# Get popular categories
|
||||
category_counts = {}
|
||||
for service in self.services.values():
|
||||
if service.status == ServiceStatus.ACTIVE:
|
||||
category_counts[service.service_type.value] = category_counts.get(service.service_type.value, 0) + 1
|
||||
|
||||
popular_categories = sorted(category_counts.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
|
||||
# Get top agents
|
||||
agent_earnings = {}
|
||||
for service in self.services.values():
|
||||
agent_earnings[service.agent_id] = agent_earnings.get(service.agent_id, 0) + service.total_earnings
|
||||
|
||||
top_agents = sorted(agent_earnings.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
|
||||
return MarketplaceAnalytics(
|
||||
total_services=total_services,
|
||||
active_services=active_services,
|
||||
total_requests=total_requests,
|
||||
pending_requests=pending_requests,
|
||||
total_volume=total_volume,
|
||||
total_guilds=total_guilds,
|
||||
average_service_price=average_price,
|
||||
popular_categories=[cat[0] for cat in popular_categories],
|
||||
top_agents=[agent[0] for agent in top_agents],
|
||||
revenue_trends={},
|
||||
growth_metrics={}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get marketplace analytics: {e}")
|
||||
raise
|
||||
|
||||
async def _calculate_dynamic_price(self, service_id: str, budget: float) -> float:
|
||||
"""Calculate dynamic price based on demand and reputation"""
|
||||
|
||||
service = self.services[service_id]
|
||||
dynamic_price = service.base_price
|
||||
|
||||
# Reputation multiplier
|
||||
reputation_multiplier = 1.0 + (service.reputation / 10000) * 0.5
|
||||
dynamic_price *= reputation_multiplier
|
||||
|
||||
# Demand multiplier
|
||||
demand_multiplier = 1.0
|
||||
if service.completed_jobs > 10:
|
||||
demand_multiplier = 1.0 + (service.completed_jobs / 100) * 0.5
|
||||
dynamic_price *= demand_multiplier
|
||||
|
||||
# Rating multiplier
|
||||
rating_multiplier = 1.0 + (service.average_rating / 5) * 0.3
|
||||
dynamic_price *= rating_multiplier
|
||||
|
||||
return min(dynamic_price, budget)
|
||||
|
||||
async def _calculate_reputation_change(self, rating: int, current_reputation: int) -> int:
|
||||
"""Calculate reputation change based on rating"""
|
||||
|
||||
if rating == 5:
|
||||
return self.rating_weight * 2
|
||||
elif rating == 4:
|
||||
return self.rating_weight
|
||||
elif rating == 3:
|
||||
return 0
|
||||
elif rating == 2:
|
||||
return -self.rating_weight
|
||||
else: # rating == 1
|
||||
return -self.rating_weight * 2
|
||||
|
||||
async def _get_agent_reputation(self, agent_id: str) -> int:
|
||||
"""Get agent reputation (simplified)"""
|
||||
# In production, integrate with reputation service
|
||||
return 1000
|
||||
|
||||
async def _update_agent_reputation(self, agent_id: str, change: int):
|
||||
"""Update agent reputation (simplified)"""
|
||||
# In production, integrate with reputation service
|
||||
pass
|
||||
|
||||
async def _generate_service_id(self) -> str:
|
||||
"""Generate unique service ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
async def _generate_request_id(self) -> str:
|
||||
"""Generate unique request ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
async def _generate_guild_id(self) -> str:
|
||||
"""Generate unique guild ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _initialize_categories(self):
|
||||
"""Initialize service categories"""
|
||||
|
||||
for service_type in ServiceType:
|
||||
self.categories[service_type.value] = ServiceCategory(
|
||||
name=service_type.value,
|
||||
description=f"Services related to {service_type.value}",
|
||||
service_count=0,
|
||||
total_volume=0.0,
|
||||
average_price=0.0,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
async def _load_marketplace_data(self):
|
||||
"""Load existing marketplace data"""
|
||||
# In production, load from database
|
||||
pass
|
||||
|
||||
async def _monitor_request_timeouts(self):
|
||||
"""Monitor and handle request timeouts"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
for request in self.service_requests.values():
|
||||
if request.status == RequestStatus.PENDING and current_time > request.deadline:
|
||||
request.status = RequestStatus.EXPIRED
|
||||
logger.info(f"Request expired: {request.id}")
|
||||
|
||||
await asyncio.sleep(3600) # Check every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring timeouts: {e}")
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def _update_marketplace_analytics(self):
|
||||
"""Update marketplace analytics"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Update trending categories
|
||||
for category in self.categories.values():
|
||||
# Simplified trending logic
|
||||
category.trending = category.service_count > 10
|
||||
|
||||
await asyncio.sleep(3600) # Update every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating analytics: {e}")
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def _process_service_recommendations(self):
|
||||
"""Process service recommendations"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Implement recommendation logic
|
||||
await asyncio.sleep(1800) # Process every 30 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing recommendations: {e}")
|
||||
await asyncio.sleep(1800)
|
||||
|
||||
async def _maintain_guild_reputation(self):
|
||||
"""Maintain guild reputation scores"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
for guild in self.guilds.values():
|
||||
# Calculate guild reputation based on members
|
||||
total_reputation = 0
|
||||
active_members = 0
|
||||
|
||||
for member_id, member_data in guild.members.items():
|
||||
member_reputation = await self._get_agent_reputation(member_id)
|
||||
total_reputation += member_reputation
|
||||
active_members += 1
|
||||
|
||||
if active_members > 0:
|
||||
guild.reputation = total_reputation // active_members
|
||||
|
||||
await asyncio.sleep(3600) # Update every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error maintaining guild reputation: {e}")
|
||||
await asyncio.sleep(3600)
|
||||
797
apps/coordinator-api/src/app/services/bid_strategy_engine.py
Normal file
797
apps/coordinator-api/src/app/services/bid_strategy_engine.py
Normal file
@@ -0,0 +1,797 @@
|
||||
"""
|
||||
Bid Strategy Engine for OpenClaw Autonomous Economics
|
||||
Implements intelligent bidding algorithms for GPU rental negotiations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
import json
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BidStrategy(str, Enum):
|
||||
"""Bidding strategy types"""
|
||||
URGENT_BID = "urgent_bid"
|
||||
COST_OPTIMIZED = "cost_optimized"
|
||||
BALANCED = "balanced"
|
||||
AGGRESSIVE = "aggressive"
|
||||
CONSERVATIVE = "conservative"
|
||||
|
||||
|
||||
class UrgencyLevel(str, Enum):
|
||||
"""Task urgency levels"""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class GPU_Tier(str, Enum):
|
||||
"""GPU resource tiers"""
|
||||
CPU_ONLY = "cpu_only"
|
||||
LOW_END_GPU = "low_end_gpu"
|
||||
MID_RANGE_GPU = "mid_range_gpu"
|
||||
HIGH_END_GPU = "high_end_gpu"
|
||||
PREMIUM_GPU = "premium_gpu"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketConditions:
|
||||
"""Current market conditions"""
|
||||
current_gas_price: float
|
||||
gpu_utilization_rate: float
|
||||
average_hourly_price: float
|
||||
price_volatility: float
|
||||
demand_level: float
|
||||
supply_level: float
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskRequirements:
|
||||
"""Task requirements for bidding"""
|
||||
task_id: str
|
||||
agent_id: str
|
||||
urgency: UrgencyLevel
|
||||
estimated_duration: float # hours
|
||||
gpu_tier: GPU_Tier
|
||||
memory_requirement: int # GB
|
||||
compute_intensity: float # 0-1
|
||||
deadline: Optional[datetime]
|
||||
max_budget: float
|
||||
priority_score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BidParameters:
|
||||
"""Parameters for bid calculation"""
|
||||
base_price: float
|
||||
urgency_multiplier: float
|
||||
tier_multiplier: float
|
||||
market_multiplier: float
|
||||
competition_factor: float
|
||||
time_factor: float
|
||||
risk_premium: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BidResult:
|
||||
"""Result of bid calculation"""
|
||||
bid_price: float
|
||||
bid_strategy: BidStrategy
|
||||
confidence_score: float
|
||||
expected_wait_time: float
|
||||
success_probability: float
|
||||
cost_efficiency: float
|
||||
reasoning: List[str]
|
||||
bid_parameters: BidParameters
|
||||
|
||||
|
||||
class BidStrategyEngine:
|
||||
"""Intelligent bidding engine for GPU rental negotiations"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.market_history: List[MarketConditions] = []
|
||||
self.bid_history: List[BidResult] = []
|
||||
self.agent_preferences: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Strategy weights
|
||||
self.strategy_weights = {
|
||||
BidStrategy.URGENT_BID: 0.25,
|
||||
BidStrategy.COST_OPTIMIZED: 0.25,
|
||||
BidStrategy.BALANCED: 0.25,
|
||||
BidStrategy.AGGRESSIVE: 0.15,
|
||||
BidStrategy.CONSERVATIVE: 0.10
|
||||
}
|
||||
|
||||
# Market analysis parameters
|
||||
self.market_window = 24 # hours
|
||||
self.price_history_days = 30
|
||||
self.volatility_threshold = 0.15
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the bid strategy engine"""
|
||||
logger.info("Initializing Bid Strategy Engine")
|
||||
|
||||
# Load historical data
|
||||
await self._load_market_history()
|
||||
await self._load_agent_preferences()
|
||||
|
||||
# Initialize market monitoring
|
||||
asyncio.create_task(self._monitor_market_conditions())
|
||||
|
||||
logger.info("Bid Strategy Engine initialized")
|
||||
|
||||
async def calculate_bid(
|
||||
self,
|
||||
task_requirements: TaskRequirements,
|
||||
strategy: Optional[BidStrategy] = None,
|
||||
custom_parameters: Optional[Dict[str, Any]] = None
|
||||
) -> BidResult:
|
||||
"""Calculate optimal bid for GPU rental"""
|
||||
|
||||
try:
|
||||
# Get current market conditions
|
||||
market_conditions = await self._get_current_market_conditions()
|
||||
|
||||
# Select strategy if not provided
|
||||
if strategy is None:
|
||||
strategy = await self._select_optimal_strategy(task_requirements, market_conditions)
|
||||
|
||||
# Calculate bid parameters
|
||||
bid_params = await self._calculate_bid_parameters(
|
||||
task_requirements,
|
||||
market_conditions,
|
||||
strategy,
|
||||
custom_parameters
|
||||
)
|
||||
|
||||
# Calculate bid price
|
||||
bid_price = await self._calculate_bid_price(bid_params, task_requirements)
|
||||
|
||||
# Analyze bid success factors
|
||||
success_probability = await self._calculate_success_probability(
|
||||
bid_price, task_requirements, market_conditions
|
||||
)
|
||||
|
||||
# Estimate wait time
|
||||
expected_wait_time = await self._estimate_wait_time(
|
||||
bid_price, task_requirements, market_conditions
|
||||
)
|
||||
|
||||
# Calculate confidence score
|
||||
confidence_score = await self._calculate_confidence_score(
|
||||
bid_params, market_conditions, strategy
|
||||
)
|
||||
|
||||
# Calculate cost efficiency
|
||||
cost_efficiency = await self._calculate_cost_efficiency(
|
||||
bid_price, task_requirements
|
||||
)
|
||||
|
||||
# Generate reasoning
|
||||
reasoning = await self._generate_bid_reasoning(
|
||||
bid_params, task_requirements, market_conditions, strategy
|
||||
)
|
||||
|
||||
# Create bid result
|
||||
bid_result = BidResult(
|
||||
bid_price=bid_price,
|
||||
bid_strategy=strategy,
|
||||
confidence_score=confidence_score,
|
||||
expected_wait_time=expected_wait_time,
|
||||
success_probability=success_probability,
|
||||
cost_efficiency=cost_efficiency,
|
||||
reasoning=reasoning,
|
||||
bid_parameters=bid_params
|
||||
)
|
||||
|
||||
# Record bid
|
||||
self.bid_history.append(bid_result)
|
||||
|
||||
logger.info(f"Calculated bid for task {task_requirements.task_id}: {bid_price} AITBC/hour")
|
||||
return bid_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate bid: {e}")
|
||||
raise
|
||||
|
||||
async def update_agent_preferences(
|
||||
self,
|
||||
agent_id: str,
|
||||
preferences: Dict[str, Any]
|
||||
):
|
||||
"""Update agent bidding preferences"""
|
||||
|
||||
self.agent_preferences[agent_id] = {
|
||||
'preferred_strategy': preferences.get('preferred_strategy', 'balanced'),
|
||||
'risk_tolerance': preferences.get('risk_tolerance', 0.5),
|
||||
'cost_sensitivity': preferences.get('cost_sensitivity', 0.5),
|
||||
'urgency_preference': preferences.get('urgency_preference', 0.5),
|
||||
'max_wait_time': preferences.get('max_wait_time', 3600), # 1 hour
|
||||
'min_success_probability': preferences.get('min_success_probability', 0.7),
|
||||
'updated_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Updated preferences for agent {agent_id}")
|
||||
|
||||
async def get_market_analysis(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive market analysis"""
|
||||
|
||||
market_conditions = await self._get_current_market_conditions()
|
||||
|
||||
# Calculate market trends
|
||||
price_trend = await self._calculate_price_trend()
|
||||
demand_trend = await self._calculate_demand_trend()
|
||||
volatility_trend = await self._calculate_volatility_trend()
|
||||
|
||||
# Predict future conditions
|
||||
future_conditions = await self._predict_market_conditions(24) # 24 hours ahead
|
||||
|
||||
return {
|
||||
'current_conditions': asdict(market_conditions),
|
||||
'price_trend': price_trend,
|
||||
'demand_trend': demand_trend,
|
||||
'volatility_trend': volatility_trend,
|
||||
'future_prediction': asdict(future_conditions),
|
||||
'recommendations': await self._generate_market_recommendations(market_conditions),
|
||||
'analysis_timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
async def _select_optimal_strategy(
|
||||
self,
|
||||
task_requirements: TaskRequirements,
|
||||
market_conditions: MarketConditions
|
||||
) -> BidStrategy:
|
||||
"""Select optimal bidding strategy based on requirements and conditions"""
|
||||
|
||||
# Get agent preferences
|
||||
agent_prefs = self.agent_preferences.get(task_requirements.agent_id, {})
|
||||
|
||||
# Calculate strategy scores
|
||||
strategy_scores = {}
|
||||
|
||||
# Urgent bid strategy
|
||||
if task_requirements.urgency in [UrgencyLevel.HIGH, UrgencyLevel.CRITICAL]:
|
||||
strategy_scores[BidStrategy.URGENT_BID] = 0.9
|
||||
else:
|
||||
strategy_scores[BidStrategy.URGENT_BID] = 0.3
|
||||
|
||||
# Cost optimized strategy
|
||||
if task_requirements.max_budget < market_conditions.average_hourly_price:
|
||||
strategy_scores[BidStrategy.COST_OPTIMIZED] = 0.8
|
||||
else:
|
||||
strategy_scores[BidStrategy.COST_OPTIMIZED] = 0.5
|
||||
|
||||
# Balanced strategy
|
||||
strategy_scores[BidStrategy.BALANCED] = 0.7
|
||||
|
||||
# Aggressive strategy
|
||||
if market_conditions.demand_level > 0.8:
|
||||
strategy_scores[BidStrategy.AGGRESSIVE] = 0.6
|
||||
else:
|
||||
strategy_scores[BidStrategy.AGGRESSIVE] = 0.3
|
||||
|
||||
# Conservative strategy
|
||||
if market_conditions.price_volatility > self.volatility_threshold:
|
||||
strategy_scores[BidStrategy.CONSERVATIVE] = 0.7
|
||||
else:
|
||||
strategy_scores[BidStrategy.CONSERVATIVE] = 0.4
|
||||
|
||||
# Apply agent preferences
|
||||
preferred_strategy = agent_prefs.get('preferred_strategy')
|
||||
if preferred_strategy:
|
||||
strategy_scores[BidStrategy(preferred_strategy)] *= 1.2
|
||||
|
||||
# Select highest scoring strategy
|
||||
optimal_strategy = max(strategy_scores, key=strategy_scores.get)
|
||||
|
||||
logger.debug(f"Selected strategy {optimal_strategy} for task {task_requirements.task_id}")
|
||||
return optimal_strategy
|
||||
|
||||
async def _calculate_bid_parameters(
|
||||
self,
|
||||
task_requirements: TaskRequirements,
|
||||
market_conditions: MarketConditions,
|
||||
strategy: BidStrategy,
|
||||
custom_parameters: Optional[Dict[str, Any]]
|
||||
) -> BidParameters:
|
||||
"""Calculate bid parameters based on strategy and conditions"""
|
||||
|
||||
# Base price from market
|
||||
base_price = market_conditions.average_hourly_price
|
||||
|
||||
# GPU tier multiplier
|
||||
tier_multipliers = {
|
||||
GPU_Tier.CPU_ONLY: 0.3,
|
||||
GPU_Tier.LOW_END_GPU: 0.6,
|
||||
GPU_Tier.MID_RANGE_GPU: 1.0,
|
||||
GPU_Tier.HIGH_END_GPU: 1.8,
|
||||
GPU_Tier.PREMIUM_GPU: 3.0
|
||||
}
|
||||
tier_multiplier = tier_multipliers[task_requirements.gpu_tier]
|
||||
|
||||
# Urgency multiplier based on strategy
|
||||
urgency_multipliers = {
|
||||
BidStrategy.URGENT_BID: 1.5,
|
||||
BidStrategy.COST_OPTIMIZED: 0.8,
|
||||
BidStrategy.BALANCED: 1.0,
|
||||
BidStrategy.AGGRESSIVE: 1.3,
|
||||
BidStrategy.CONSERVATIVE: 0.9
|
||||
}
|
||||
urgency_multiplier = urgency_multipliers[strategy]
|
||||
|
||||
# Market condition multiplier
|
||||
market_multiplier = 1.0
|
||||
if market_conditions.demand_level > 0.8:
|
||||
market_multiplier *= 1.2
|
||||
if market_conditions.supply_level < 0.3:
|
||||
market_multiplier *= 1.3
|
||||
if market_conditions.price_volatility > self.volatility_threshold:
|
||||
market_multiplier *= 1.1
|
||||
|
||||
# Competition factor
|
||||
competition_factor = market_conditions.demand_level / max(market_conditions.supply_level, 0.1)
|
||||
|
||||
# Time factor (urgency based on deadline)
|
||||
time_factor = 1.0
|
||||
if task_requirements.deadline:
|
||||
time_remaining = (task_requirements.deadline - datetime.utcnow()).total_seconds() / 3600
|
||||
if time_remaining < 2: # Less than 2 hours
|
||||
time_factor = 1.5
|
||||
elif time_remaining < 6: # Less than 6 hours
|
||||
time_factor = 1.2
|
||||
elif time_remaining < 24: # Less than 24 hours
|
||||
time_factor = 1.1
|
||||
|
||||
# Risk premium based on strategy
|
||||
risk_premiums = {
|
||||
BidStrategy.URGENT_BID: 0.2,
|
||||
BidStrategy.COST_OPTIMIZED: 0.05,
|
||||
BidStrategy.BALANCED: 0.1,
|
||||
BidStrategy.AGGRESSIVE: 0.25,
|
||||
BidStrategy.CONSERVATIVE: 0.08
|
||||
}
|
||||
risk_premium = risk_premiums[strategy]
|
||||
|
||||
# Apply custom parameters if provided
|
||||
if custom_parameters:
|
||||
if 'base_price_adjustment' in custom_parameters:
|
||||
base_price *= (1 + custom_parameters['base_price_adjustment'])
|
||||
if 'tier_multiplier_adjustment' in custom_parameters:
|
||||
tier_multiplier *= (1 + custom_parameters['tier_multiplier_adjustment'])
|
||||
if 'risk_premium_adjustment' in custom_parameters:
|
||||
risk_premium *= (1 + custom_parameters['risk_premium_adjustment'])
|
||||
|
||||
return BidParameters(
|
||||
base_price=base_price,
|
||||
urgency_multiplier=urgency_multiplier,
|
||||
tier_multiplier=tier_multiplier,
|
||||
market_multiplier=market_multiplier,
|
||||
competition_factor=competition_factor,
|
||||
time_factor=time_factor,
|
||||
risk_premium=risk_premium
|
||||
)
|
||||
|
||||
async def _calculate_bid_price(
|
||||
self,
|
||||
bid_params: BidParameters,
|
||||
task_requirements: TaskRequirements
|
||||
) -> float:
|
||||
"""Calculate final bid price"""
|
||||
|
||||
# Base calculation
|
||||
price = bid_params.base_price
|
||||
price *= bid_params.urgency_multiplier
|
||||
price *= bid_params.tier_multiplier
|
||||
price *= bid_params.market_multiplier
|
||||
|
||||
# Apply competition and time factors
|
||||
price *= (1 + bid_params.competition_factor * 0.3)
|
||||
price *= bid_params.time_factor
|
||||
|
||||
# Add risk premium
|
||||
price *= (1 + bid_params.risk_premium)
|
||||
|
||||
# Apply duration multiplier (longer duration = better rate)
|
||||
duration_multiplier = max(0.8, min(1.2, 1.0 - (task_requirements.estimated_duration - 1) * 0.05))
|
||||
price *= duration_multiplier
|
||||
|
||||
# Ensure within budget
|
||||
max_hourly_rate = task_requirements.max_budget / max(task_requirements.estimated_duration, 0.1)
|
||||
price = min(price, max_hourly_rate)
|
||||
|
||||
# Round to reasonable precision
|
||||
price = round(price, 6)
|
||||
|
||||
return max(price, 0.001) # Minimum bid price
|
||||
|
||||
async def _calculate_success_probability(
|
||||
self,
|
||||
bid_price: float,
|
||||
task_requirements: TaskRequirements,
|
||||
market_conditions: MarketConditions
|
||||
) -> float:
|
||||
"""Calculate probability of bid success"""
|
||||
|
||||
# Base probability from market conditions
|
||||
base_prob = 1.0 - market_conditions.demand_level
|
||||
|
||||
# Price competitiveness factor
|
||||
price_competitiveness = market_conditions.average_hourly_price / max(bid_price, 0.001)
|
||||
price_factor = min(1.0, price_competitiveness)
|
||||
|
||||
# Urgency factor
|
||||
urgency_factor = 1.0
|
||||
if task_requirements.urgency == UrgencyLevel.CRITICAL:
|
||||
urgency_factor = 0.8 # Critical tasks may have lower success due to high demand
|
||||
elif task_requirements.urgency == UrgencyLevel.HIGH:
|
||||
urgency_factor = 0.9
|
||||
|
||||
# Time factor
|
||||
time_factor = 1.0
|
||||
if task_requirements.deadline:
|
||||
time_remaining = (task_requirements.deadline - datetime.utcnow()).total_seconds() / 3600
|
||||
if time_remaining < 2:
|
||||
time_factor = 0.7
|
||||
elif time_remaining < 6:
|
||||
time_factor = 0.85
|
||||
|
||||
# Combine factors
|
||||
success_prob = base_prob * 0.4 + price_factor * 0.3 + urgency_factor * 0.2 + time_factor * 0.1
|
||||
|
||||
return max(0.1, min(0.95, success_prob))
|
||||
|
||||
async def _estimate_wait_time(
|
||||
self,
|
||||
bid_price: float,
|
||||
task_requirements: TaskRequirements,
|
||||
market_conditions: MarketConditions
|
||||
) -> float:
|
||||
"""Estimate wait time for resource allocation"""
|
||||
|
||||
# Base wait time from market conditions
|
||||
base_wait = 300 # 5 minutes base
|
||||
|
||||
# Demand factor
|
||||
demand_factor = market_conditions.demand_level * 600 # Up to 10 minutes
|
||||
|
||||
# Price factor (higher price = lower wait time)
|
||||
price_ratio = bid_price / market_conditions.average_hourly_price
|
||||
price_factor = max(0.5, 2.0 - price_ratio) * 300 # 1.5 to 0.5 minutes
|
||||
|
||||
# Urgency factor
|
||||
urgency_factor = 0
|
||||
if task_requirements.urgency == UrgencyLevel.CRITICAL:
|
||||
urgency_factor = -300 # Priority reduces wait time
|
||||
elif task_requirements.urgency == UrgencyLevel.HIGH:
|
||||
urgency_factor = -120
|
||||
|
||||
# GPU tier factor
|
||||
tier_factors = {
|
||||
GPU_Tier.CPU_ONLY: -180,
|
||||
GPU_Tier.LOW_END_GPU: -60,
|
||||
GPU_Tier.MID_RANGE_GPU: 0,
|
||||
GPU_Tier.HIGH_END_GPU: 120,
|
||||
GPU_Tier.PREMIUM_GPU: 300
|
||||
}
|
||||
tier_factor = tier_factors[task_requirements.gpu_tier]
|
||||
|
||||
# Calculate total wait time
|
||||
wait_time = base_wait + demand_factor + price_factor + urgency_factor + tier_factor
|
||||
|
||||
return max(60, wait_time) # Minimum 1 minute wait
|
||||
|
||||
async def _calculate_confidence_score(
|
||||
self,
|
||||
bid_params: BidParameters,
|
||||
market_conditions: MarketConditions,
|
||||
strategy: BidStrategy
|
||||
) -> float:
|
||||
"""Calculate confidence in bid calculation"""
|
||||
|
||||
# Market stability factor
|
||||
stability_factor = 1.0 - market_conditions.price_volatility
|
||||
|
||||
# Strategy confidence
|
||||
strategy_confidence = {
|
||||
BidStrategy.BALANCED: 0.9,
|
||||
BidStrategy.COST_OPTIMIZED: 0.8,
|
||||
BidStrategy.CONSERVATIVE: 0.85,
|
||||
BidStrategy.URGENT_BID: 0.7,
|
||||
BidStrategy.AGGRESSIVE: 0.6
|
||||
}
|
||||
|
||||
# Data availability factor
|
||||
data_factor = min(1.0, len(self.market_history) / 24) # 24 hours of history
|
||||
|
||||
# Parameter consistency factor
|
||||
param_factor = 1.0
|
||||
if bid_params.urgency_multiplier > 2.0 or bid_params.tier_multiplier > 3.0:
|
||||
param_factor = 0.8
|
||||
|
||||
confidence = (
|
||||
stability_factor * 0.3 +
|
||||
strategy_confidence[strategy] * 0.3 +
|
||||
data_factor * 0.2 +
|
||||
param_factor * 0.2
|
||||
)
|
||||
|
||||
return max(0.3, min(0.95, confidence))
|
||||
|
||||
async def _calculate_cost_efficiency(
|
||||
self,
|
||||
bid_price: float,
|
||||
task_requirements: TaskRequirements
|
||||
) -> float:
|
||||
"""Calculate cost efficiency of the bid"""
|
||||
|
||||
# Base efficiency from price vs. market
|
||||
market_price = await self._get_market_price_for_tier(task_requirements.gpu_tier)
|
||||
price_efficiency = market_price / max(bid_price, 0.001)
|
||||
|
||||
# Duration efficiency (longer tasks get better rates)
|
||||
duration_efficiency = min(1.2, 1.0 + (task_requirements.estimated_duration - 1) * 0.05)
|
||||
|
||||
# Compute intensity efficiency
|
||||
compute_efficiency = task_requirements.compute_intensity
|
||||
|
||||
# Budget utilization
|
||||
budget_utilization = (bid_price * task_requirements.estimated_duration) / max(task_requirements.max_budget, 0.001)
|
||||
budget_efficiency = 1.0 - abs(budget_utilization - 0.8) # Optimal at 80% budget utilization
|
||||
|
||||
efficiency = (
|
||||
price_efficiency * 0.4 +
|
||||
duration_efficiency * 0.2 +
|
||||
compute_efficiency * 0.2 +
|
||||
budget_efficiency * 0.2
|
||||
)
|
||||
|
||||
return max(0.1, min(1.0, efficiency))
|
||||
|
||||
async def _generate_bid_reasoning(
|
||||
self,
|
||||
bid_params: BidParameters,
|
||||
task_requirements: TaskRequirements,
|
||||
market_conditions: MarketConditions,
|
||||
strategy: BidStrategy
|
||||
) -> List[str]:
|
||||
"""Generate reasoning for bid calculation"""
|
||||
|
||||
reasoning = []
|
||||
|
||||
# Strategy reasoning
|
||||
reasoning.append(f"Strategy: {strategy.value} selected based on task urgency and market conditions")
|
||||
|
||||
# Market conditions
|
||||
if market_conditions.demand_level > 0.8:
|
||||
reasoning.append("High market demand increases bid price")
|
||||
elif market_conditions.demand_level < 0.3:
|
||||
reasoning.append("Low market demand allows for competitive pricing")
|
||||
|
||||
# GPU tier reasoning
|
||||
tier_names = {
|
||||
GPU_Tier.CPU_ONLY: "CPU-only resources",
|
||||
GPU_Tier.LOW_END_GPU: "low-end GPU",
|
||||
GPU_Tier.MID_RANGE_GPU: "mid-range GPU",
|
||||
GPU_Tier.HIGH_END_GPU: "high-end GPU",
|
||||
GPU_Tier.PREMIUM_GPU: "premium GPU"
|
||||
}
|
||||
reasoning.append(f"Selected {tier_names[task_requirements.gpu_tier]} with {bid_params.tier_multiplier:.1f}x multiplier")
|
||||
|
||||
# Urgency reasoning
|
||||
if task_requirements.urgency == UrgencyLevel.CRITICAL:
|
||||
reasoning.append("Critical urgency requires aggressive bidding")
|
||||
elif task_requirements.urgency == UrgencyLevel.LOW:
|
||||
reasoning.append("Low urgency allows for cost-optimized bidding")
|
||||
|
||||
# Price reasoning
|
||||
if bid_params.market_multiplier > 1.1:
|
||||
reasoning.append("Market conditions require price premium")
|
||||
elif bid_params.market_multiplier < 0.9:
|
||||
reasoning.append("Favorable market conditions enable discount pricing")
|
||||
|
||||
# Risk reasoning
|
||||
if bid_params.risk_premium > 0.15:
|
||||
reasoning.append("High risk premium applied due to strategy and volatility")
|
||||
|
||||
return reasoning
|
||||
|
||||
async def _get_current_market_conditions(self) -> MarketConditions:
|
||||
"""Get current market conditions"""
|
||||
|
||||
# In a real implementation, this would fetch from market data sources
|
||||
# For now, return simulated data
|
||||
|
||||
return MarketConditions(
|
||||
current_gas_price=20.0, # Gwei
|
||||
gpu_utilization_rate=0.75,
|
||||
average_hourly_price=0.05, # AITBC
|
||||
price_volatility=0.12,
|
||||
demand_level=0.68,
|
||||
supply_level=0.72,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
|
||||
async def _load_market_history(self):
|
||||
"""Load historical market data"""
|
||||
# In a real implementation, this would load from database
|
||||
pass
|
||||
|
||||
async def _load_agent_preferences(self):
|
||||
"""Load agent preferences from storage"""
|
||||
# In a real implementation, this would load from database
|
||||
pass
|
||||
|
||||
async def _monitor_market_conditions(self):
|
||||
"""Monitor market conditions continuously"""
|
||||
while True:
|
||||
try:
|
||||
# Get current conditions
|
||||
conditions = await self._get_current_market_conditions()
|
||||
|
||||
# Add to history
|
||||
self.market_history.append(conditions)
|
||||
|
||||
# Keep only recent history
|
||||
if len(self.market_history) > self.price_history_days * 24:
|
||||
self.market_history = self.market_history[-(self.price_history_days * 24):]
|
||||
|
||||
# Wait for next update
|
||||
await asyncio.sleep(300) # Update every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring market conditions: {e}")
|
||||
await asyncio.sleep(60) # Wait 1 minute on error
|
||||
|
||||
async def _calculate_price_trend(self) -> str:
|
||||
"""Calculate price trend"""
|
||||
if len(self.market_history) < 2:
|
||||
return "insufficient_data"
|
||||
|
||||
recent_prices = [c.average_hourly_price for c in self.market_history[-24:]] # Last 24 hours
|
||||
older_prices = [c.average_hourly_price for c in self.market_history[-48:-24]] # Previous 24 hours
|
||||
|
||||
if not older_prices:
|
||||
return "insufficient_data"
|
||||
|
||||
recent_avg = sum(recent_prices) / len(recent_prices)
|
||||
older_avg = sum(older_prices) / len(older_prices)
|
||||
|
||||
change = (recent_avg - older_avg) / older_avg
|
||||
|
||||
if change > 0.05:
|
||||
return "increasing"
|
||||
elif change < -0.05:
|
||||
return "decreasing"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
async def _calculate_demand_trend(self) -> str:
|
||||
"""Calculate demand trend"""
|
||||
if len(self.market_history) < 2:
|
||||
return "insufficient_data"
|
||||
|
||||
recent_demand = [c.demand_level for c in self.market_history[-24:]]
|
||||
older_demand = [c.demand_level for c in self.market_history[-48:-24]]
|
||||
|
||||
if not older_demand:
|
||||
return "insufficient_data"
|
||||
|
||||
recent_avg = sum(recent_demand) / len(recent_demand)
|
||||
older_avg = sum(older_demand) / len(older_demand)
|
||||
|
||||
change = recent_avg - older_avg
|
||||
|
||||
if change > 0.1:
|
||||
return "increasing"
|
||||
elif change < -0.1:
|
||||
return "decreasing"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
async def _calculate_volatility_trend(self) -> str:
|
||||
"""Calculate volatility trend"""
|
||||
if len(self.market_history) < 2:
|
||||
return "insufficient_data"
|
||||
|
||||
recent_vol = [c.price_volatility for c in self.market_history[-24:]]
|
||||
older_vol = [c.price_volatility for c in self.market_history[-48:-24]]
|
||||
|
||||
if not older_vol:
|
||||
return "insufficient_data"
|
||||
|
||||
recent_avg = sum(recent_vol) / len(recent_vol)
|
||||
older_avg = sum(older_vol) / len(older_vol)
|
||||
|
||||
change = recent_avg - older_avg
|
||||
|
||||
if change > 0.05:
|
||||
return "increasing"
|
||||
elif change < -0.05:
|
||||
return "decreasing"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
async def _predict_market_conditions(self, hours_ahead: int) -> MarketConditions:
|
||||
"""Predict future market conditions"""
|
||||
|
||||
if len(self.market_history) < 24:
|
||||
# Return current conditions if insufficient history
|
||||
return await self._get_current_market_conditions()
|
||||
|
||||
# Simple linear prediction based on recent trends
|
||||
recent_conditions = self.market_history[-24:]
|
||||
|
||||
# Calculate trends
|
||||
price_trend = await self._calculate_price_trend()
|
||||
demand_trend = await self._calculate_demand_trend()
|
||||
|
||||
# Predict based on trends
|
||||
current = await self._get_current_market_conditions()
|
||||
|
||||
predicted = MarketConditions(
|
||||
current_gas_price=current.current_gas_price,
|
||||
gpu_utilization_rate=current.gpu_utilization_rate,
|
||||
average_hourly_price=current.average_hourly_price,
|
||||
price_volatility=current.price_volatility,
|
||||
demand_level=current.demand_level,
|
||||
supply_level=current.supply_level,
|
||||
timestamp=datetime.utcnow() + timedelta(hours=hours_ahead)
|
||||
)
|
||||
|
||||
# Apply trend adjustments
|
||||
if price_trend == "increasing":
|
||||
predicted.average_hourly_price *= 1.05
|
||||
elif price_trend == "decreasing":
|
||||
predicted.average_hourly_price *= 0.95
|
||||
|
||||
if demand_trend == "increasing":
|
||||
predicted.demand_level = min(1.0, predicted.demand_level + 0.1)
|
||||
elif demand_trend == "decreasing":
|
||||
predicted.demand_level = max(0.0, predicted.demand_level - 0.1)
|
||||
|
||||
return predicted
|
||||
|
||||
async def _generate_market_recommendations(self, market_conditions: MarketConditions) -> List[str]:
|
||||
"""Generate market recommendations"""
|
||||
|
||||
recommendations = []
|
||||
|
||||
if market_conditions.demand_level > 0.8:
|
||||
recommendations.append("High demand detected - consider urgent bidding strategy")
|
||||
|
||||
if market_conditions.price_volatility > self.volatility_threshold:
|
||||
recommendations.append("High volatility - consider conservative bidding")
|
||||
|
||||
if market_conditions.gpu_utilization_rate > 0.9:
|
||||
recommendations.append("GPU utilization very high - expect longer wait times")
|
||||
|
||||
if market_conditions.supply_level < 0.3:
|
||||
recommendations.append("Low supply - expect higher prices")
|
||||
|
||||
if market_conditions.average_hourly_price < 0.03:
|
||||
recommendations.append("Low prices - good opportunity for cost optimization")
|
||||
|
||||
return recommendations
|
||||
|
||||
async def _get_market_price_for_tier(self, gpu_tier: GPU_Tier) -> float:
|
||||
"""Get market price for specific GPU tier"""
|
||||
|
||||
# In a real implementation, this would fetch from market data
|
||||
tier_prices = {
|
||||
GPU_Tier.CPU_ONLY: 0.01,
|
||||
GPU_Tier.LOW_END_GPU: 0.03,
|
||||
GPU_Tier.MID_RANGE_GPU: 0.05,
|
||||
GPU_Tier.HIGH_END_GPU: 0.09,
|
||||
GPU_Tier.PREMIUM_GPU: 0.15
|
||||
}
|
||||
|
||||
return tier_prices.get(gpu_tier, 0.05)
|
||||
708
apps/coordinator-api/src/app/services/cross_chain_reputation.py
Normal file
708
apps/coordinator-api/src/app/services/cross_chain_reputation.py
Normal file
@@ -0,0 +1,708 @@
|
||||
"""
|
||||
Cross-Chain Reputation Service for Advanced Agent Features
|
||||
Implements portable reputation scores across multiple blockchain networks
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import json
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReputationTier(str, Enum):
|
||||
"""Reputation tiers for agents"""
|
||||
BRONZE = "bronze"
|
||||
SILVER = "silver"
|
||||
GOLD = "gold"
|
||||
PLATINUM = "platinum"
|
||||
DIAMOND = "diamond"
|
||||
|
||||
|
||||
class ReputationEvent(str, Enum):
|
||||
"""Types of reputation events"""
|
||||
TASK_SUCCESS = "task_success"
|
||||
TASK_FAILURE = "task_failure"
|
||||
TASK_TIMEOUT = "task_timeout"
|
||||
TASK_CANCELLED = "task_cancelled"
|
||||
POSITIVE_FEEDBACK = "positive_feedback"
|
||||
NEGATIVE_FEEDBACK = "negative_feedback"
|
||||
REPUTATION_STAKE = "reputation_stake"
|
||||
REPUTATION_DELEGATE = "reputation_delegate"
|
||||
CROSS_CHAIN_SYNC = "cross_chain_sync"
|
||||
|
||||
|
||||
class ChainNetwork(str, Enum):
|
||||
"""Supported blockchain networks"""
|
||||
ETHEREUM = "ethereum"
|
||||
POLYGON = "polygon"
|
||||
ARBITRUM = "arbitrum"
|
||||
OPTIMISM = "optimism"
|
||||
BSC = "bsc"
|
||||
AVALANCHE = "avalanche"
|
||||
FANTOM = "fantom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReputationScore:
|
||||
"""Reputation score data"""
|
||||
agent_id: str
|
||||
chain_id: int
|
||||
score: int # 0-10000
|
||||
task_count: int
|
||||
success_count: int
|
||||
failure_count: int
|
||||
last_updated: datetime
|
||||
sync_timestamp: datetime
|
||||
is_active: bool
|
||||
tier: ReputationTier = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.tier = self.calculate_tier()
|
||||
|
||||
def calculate_tier(self) -> ReputationTier:
|
||||
"""Calculate reputation tier based on score"""
|
||||
if self.score >= 9000:
|
||||
return ReputationTier.DIAMOND
|
||||
elif self.score >= 7500:
|
||||
return ReputationTier.PLATINUM
|
||||
elif self.score >= 6000:
|
||||
return ReputationTier.GOLD
|
||||
elif self.score >= 4500:
|
||||
return ReputationTier.SILVER
|
||||
else:
|
||||
return ReputationTier.BRONZE
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReputationStake:
|
||||
"""Reputation stake information"""
|
||||
agent_id: str
|
||||
amount: int
|
||||
lock_period: int # seconds
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
is_active: bool
|
||||
reward_rate: float # APY
|
||||
multiplier: float # Reputation multiplier
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReputationDelegation:
|
||||
"""Reputation delegation information"""
|
||||
delegator: str
|
||||
delegate: str
|
||||
amount: int
|
||||
start_time: datetime
|
||||
is_active: bool
|
||||
fee_rate: float # Fee rate for delegation
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrossChainSync:
|
||||
"""Cross-chain synchronization data"""
|
||||
agent_id: str
|
||||
source_chain: int
|
||||
target_chain: int
|
||||
reputation_score: int
|
||||
sync_timestamp: datetime
|
||||
verification_hash: str
|
||||
is_verified: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReputationAnalytics:
|
||||
"""Reputation analytics data"""
|
||||
agent_id: str
|
||||
total_score: int
|
||||
effective_score: int
|
||||
success_rate: float
|
||||
stake_amount: int
|
||||
delegation_amount: int
|
||||
chain_count: int
|
||||
tier: ReputationTier
|
||||
reputation_age: int # days
|
||||
last_activity: datetime
|
||||
|
||||
|
||||
class CrossChainReputationService:
|
||||
"""Service for managing cross-chain reputation systems"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.reputation_data: Dict[str, ReputationScore] = {}
|
||||
self.chain_reputations: Dict[str, Dict[int, ReputationScore]] = {}
|
||||
self.reputation_stakes: Dict[str, List[ReputationStake]] = {}
|
||||
self.reputation_delegations: Dict[str, List[ReputationDelegation]] = {}
|
||||
self.cross_chain_syncs: List[CrossChainSync] = []
|
||||
|
||||
# Configuration
|
||||
self.base_score = 1000
|
||||
self.success_bonus = 100
|
||||
self.failure_penalty = 50
|
||||
self.min_stake_amount = 100 * 10**18 # 100 AITBC
|
||||
self.max_delegation_ratio = 1.0 # 100%
|
||||
self.sync_cooldown = 3600 # 1 hour
|
||||
self.tier_thresholds = {
|
||||
ReputationTier.BRONZE: 4500,
|
||||
ReputationTier.SILVER: 6000,
|
||||
ReputationTier.GOLD: 7500,
|
||||
ReputationTier.PLATINUM: 9000,
|
||||
ReputationTier.DIAMOND: 9500
|
||||
}
|
||||
|
||||
# Chain configuration
|
||||
self.supported_chains = {
|
||||
ChainNetwork.ETHEREUM: 1,
|
||||
ChainNetwork.POLYGON: 137,
|
||||
ChainNetwork.ARBITRUM: 42161,
|
||||
ChainNetwork.OPTIMISM: 10,
|
||||
ChainNetwork.BSC: 56,
|
||||
ChainNetwork.AVALANCHE: 43114,
|
||||
ChainNetwork.FANTOM: 250
|
||||
}
|
||||
|
||||
# Stake rewards
|
||||
self.stake_rewards = {
|
||||
ReputationTier.BRONZE: 0.05, # 5% APY
|
||||
ReputationTier.SILVER: 0.08, # 8% APY
|
||||
ReputationTier.GOLD: 0.12, # 12% APY
|
||||
ReputationTier.PLATINUM: 0.18, # 18% APY
|
||||
ReputationTier.DIAMOND: 0.25 # 25% APY
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the cross-chain reputation service"""
|
||||
logger.info("Initializing Cross-Chain Reputation Service")
|
||||
|
||||
# Load existing reputation data
|
||||
await self._load_reputation_data()
|
||||
|
||||
# Start background tasks
|
||||
asyncio.create_task(self._monitor_reputation_sync())
|
||||
asyncio.create_task(self._process_stake_rewards())
|
||||
asyncio.create_task(self._cleanup_expired_stakes())
|
||||
|
||||
logger.info("Cross-Chain Reputation Service initialized")
|
||||
|
||||
async def initialize_agent_reputation(
|
||||
self,
|
||||
agent_id: str,
|
||||
initial_score: int = 1000,
|
||||
chain_id: Optional[int] = None
|
||||
) -> ReputationScore:
|
||||
"""Initialize reputation for a new agent"""
|
||||
|
||||
try:
|
||||
if chain_id is None:
|
||||
chain_id = self.supported_chains[ChainNetwork.ETHEREUM]
|
||||
|
||||
logger.info(f"Initializing reputation for agent {agent_id} on chain {chain_id}")
|
||||
|
||||
# Create reputation score
|
||||
reputation = ReputationScore(
|
||||
agent_id=agent_id,
|
||||
chain_id=chain_id,
|
||||
score=initial_score,
|
||||
task_count=0,
|
||||
success_count=0,
|
||||
failure_count=0,
|
||||
last_updated=datetime.utcnow(),
|
||||
sync_timestamp=datetime.utcnow(),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
# Store reputation data
|
||||
self.reputation_data[agent_id] = reputation
|
||||
|
||||
# Initialize chain reputations
|
||||
if agent_id not in self.chain_reputations:
|
||||
self.chain_reputations[agent_id] = {}
|
||||
self.chain_reputations[agent_id][chain_id] = reputation
|
||||
|
||||
logger.info(f"Reputation initialized for agent {agent_id}: {initial_score}")
|
||||
return reputation
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize reputation for agent {agent_id}: {e}")
|
||||
raise
|
||||
|
||||
async def update_reputation(
|
||||
self,
|
||||
agent_id: str,
|
||||
event_type: ReputationEvent,
|
||||
weight: int = 1,
|
||||
chain_id: Optional[int] = None
|
||||
) -> ReputationScore:
|
||||
"""Update agent reputation based on event"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.reputation_data:
|
||||
await self.initialize_agent_reputation(agent_id)
|
||||
|
||||
reputation = self.reputation_data[agent_id]
|
||||
old_score = reputation.score
|
||||
|
||||
# Calculate score change
|
||||
score_change = await self._calculate_score_change(event_type, weight)
|
||||
|
||||
# Update reputation
|
||||
if event_type in [ReputationEvent.TASK_SUCCESS, ReputationEvent.POSITIVE_FEEDBACK]:
|
||||
reputation.score = min(10000, reputation.score + score_change)
|
||||
reputation.success_count += 1
|
||||
elif event_type in [ReputationEvent.TASK_FAILURE, ReputationEvent.NEGATIVE_FEEDBACK]:
|
||||
reputation.score = max(0, reputation.score - score_change)
|
||||
reputation.failure_count += 1
|
||||
elif event_type == ReputationEvent.TASK_TIMEOUT:
|
||||
reputation.score = max(0, reputation.score - score_change // 2)
|
||||
reputation.failure_count += 1
|
||||
|
||||
reputation.task_count += 1
|
||||
reputation.last_updated = datetime.utcnow()
|
||||
reputation.tier = reputation.calculate_tier()
|
||||
|
||||
# Update chain reputation
|
||||
if chain_id:
|
||||
if chain_id not in self.chain_reputations[agent_id]:
|
||||
self.chain_reputations[agent_id][chain_id] = reputation
|
||||
else:
|
||||
self.chain_reputations[agent_id][chain_id] = reputation
|
||||
|
||||
logger.info(f"Updated reputation for agent {agent_id}: {old_score} -> {reputation.score}")
|
||||
return reputation
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update reputation for agent {agent_id}: {e}")
|
||||
raise
|
||||
|
||||
async def sync_reputation_cross_chain(
|
||||
self,
|
||||
agent_id: str,
|
||||
target_chain: int,
|
||||
signature: str
|
||||
) -> bool:
|
||||
"""Synchronize reputation across chains"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.reputation_data:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
reputation = self.reputation_data[agent_id]
|
||||
|
||||
# Check sync cooldown
|
||||
time_since_sync = (datetime.utcnow() - reputation.sync_timestamp).total_seconds()
|
||||
if time_since_sync < self.sync_cooldown:
|
||||
logger.warning(f"Sync cooldown not met for agent {agent_id}")
|
||||
return False
|
||||
|
||||
# Verify signature (simplified)
|
||||
verification_hash = await self._verify_cross_chain_signature(agent_id, target_chain, signature)
|
||||
|
||||
# Create sync record
|
||||
sync = CrossChainSync(
|
||||
agent_id=agent_id,
|
||||
source_chain=reputation.chain_id,
|
||||
target_chain=target_chain,
|
||||
reputation_score=reputation.score,
|
||||
sync_timestamp=datetime.utcnow(),
|
||||
verification_hash=verification_hash,
|
||||
is_verified=True
|
||||
)
|
||||
|
||||
self.cross_chain_syncs.append(sync)
|
||||
|
||||
# Update target chain reputation
|
||||
if target_chain not in self.chain_reputations[agent_id]:
|
||||
self.chain_reputations[agent_id][target_chain] = ReputationScore(
|
||||
agent_id=agent_id,
|
||||
chain_id=target_chain,
|
||||
score=reputation.score,
|
||||
task_count=reputation.task_count,
|
||||
success_count=reputation.success_count,
|
||||
failure_count=reputation.failure_count,
|
||||
last_updated=reputation.last_updated,
|
||||
sync_timestamp=datetime.utcnow(),
|
||||
is_active=True
|
||||
)
|
||||
else:
|
||||
target_reputation = self.chain_reputations[agent_id][target_chain]
|
||||
target_reputation.score = reputation.score
|
||||
target_reputation.sync_timestamp = datetime.utcnow()
|
||||
|
||||
# Update sync timestamp
|
||||
reputation.sync_timestamp = datetime.utcnow()
|
||||
|
||||
logger.info(f"Synced reputation for agent {agent_id} to chain {target_chain}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync reputation for agent {agent_id}: {e}")
|
||||
raise
|
||||
|
||||
async def stake_reputation(
|
||||
self,
|
||||
agent_id: str,
|
||||
amount: int,
|
||||
lock_period: int
|
||||
) -> ReputationStake:
|
||||
"""Stake reputation tokens"""
|
||||
|
||||
try:
|
||||
if agent_id not in self.reputation_data:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
if amount < self.min_stake_amount:
|
||||
raise ValueError(f"Amount below minimum: {self.min_stake_amount}")
|
||||
|
||||
reputation = self.reputation_data[agent_id]
|
||||
|
||||
# Calculate reward rate based on tier
|
||||
reward_rate = self.stake_rewards[reputation.tier]
|
||||
|
||||
# Create stake
|
||||
stake = ReputationStake(
|
||||
agent_id=agent_id,
|
||||
amount=amount,
|
||||
lock_period=lock_period,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow() + timedelta(seconds=lock_period),
|
||||
is_active=True,
|
||||
reward_rate=reward_rate,
|
||||
multiplier=1.0 + (reputation.score / 10000) * 0.5 # Up to 50% bonus
|
||||
)
|
||||
|
||||
# Store stake
|
||||
if agent_id not in self.reputation_stakes:
|
||||
self.reputation_stakes[agent_id] = []
|
||||
self.reputation_stakes[agent_id].append(stake)
|
||||
|
||||
logger.info(f"Staked {amount} reputation for agent {agent_id}")
|
||||
return stake
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stake reputation for agent {agent_id}: {e}")
|
||||
raise
|
||||
|
||||
async def delegate_reputation(
|
||||
self,
|
||||
delegator: str,
|
||||
delegate: str,
|
||||
amount: int
|
||||
) -> ReputationDelegation:
|
||||
"""Delegate reputation to another agent"""
|
||||
|
||||
try:
|
||||
if delegator not in self.reputation_data:
|
||||
raise ValueError(f"Delegator {delegator} not found")
|
||||
|
||||
if delegate not in self.reputation_data:
|
||||
raise ValueError(f"Delegate {delegate} not found")
|
||||
|
||||
delegator_reputation = self.reputation_data[delegator]
|
||||
|
||||
# Check delegation limits
|
||||
total_delegated = await self._get_total_delegated(delegator)
|
||||
max_delegation = int(delegator_reputation.score * self.max_delegation_ratio)
|
||||
|
||||
if total_delegated + amount > max_delegation:
|
||||
raise ValueError(f"Exceeds delegation limit: {max_delegation}")
|
||||
|
||||
# Calculate fee rate based on delegate tier
|
||||
delegate_reputation = self.reputation_data[delegate]
|
||||
fee_rate = 0.02 + (1.0 - delegate_reputation.score / 10000) * 0.08 # 2-10% based on reputation
|
||||
|
||||
# Create delegation
|
||||
delegation = ReputationDelegation(
|
||||
delegator=delegator,
|
||||
delegate=delegate,
|
||||
amount=amount,
|
||||
start_time=datetime.utcnow(),
|
||||
is_active=True,
|
||||
fee_rate=fee_rate
|
||||
)
|
||||
|
||||
# Store delegation
|
||||
if delegator not in self.reputation_delegations:
|
||||
self.reputation_delegations[delegator] = []
|
||||
self.reputation_delegations[delegator].append(delegation)
|
||||
|
||||
logger.info(f"Delegated {amount} reputation from {delegator} to {delegate}")
|
||||
return delegation
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delegate reputation: {e}")
|
||||
raise
|
||||
|
||||
async def get_reputation_score(
|
||||
self,
|
||||
agent_id: str,
|
||||
chain_id: Optional[int] = None
|
||||
) -> int:
|
||||
"""Get reputation score for agent on specific chain"""
|
||||
|
||||
if agent_id not in self.reputation_data:
|
||||
return 0
|
||||
|
||||
if chain_id is None or chain_id == self.supported_chains[ChainNetwork.ETHEREUM]:
|
||||
return self.reputation_data[agent_id].score
|
||||
|
||||
if agent_id in self.chain_reputations and chain_id in self.chain_reputations[agent_id]:
|
||||
return self.chain_reputations[agent_id][chain_id].score
|
||||
|
||||
return 0
|
||||
|
||||
async def get_effective_reputation(self, agent_id: str) -> int:
|
||||
"""Get effective reputation score including delegations"""
|
||||
|
||||
if agent_id not in self.reputation_data:
|
||||
return 0
|
||||
|
||||
base_score = self.reputation_data[agent_id].score
|
||||
|
||||
# Add delegated from others
|
||||
delegated_from = await self._get_delegated_from(agent_id)
|
||||
|
||||
# Subtract delegated to others
|
||||
delegated_to = await self._get_total_delegated(agent_id)
|
||||
|
||||
return base_score + delegated_from - delegated_to
|
||||
|
||||
async def get_reputation_analytics(self, agent_id: str) -> ReputationAnalytics:
|
||||
"""Get comprehensive reputation analytics"""
|
||||
|
||||
if agent_id not in self.reputation_data:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
reputation = self.reputation_data[agent_id]
|
||||
|
||||
# Calculate metrics
|
||||
success_rate = (reputation.success_count / reputation.task_count * 100) if reputation.task_count > 0 else 0
|
||||
stake_amount = sum(stake.amount for stake in self.reputation_stakes.get(agent_id, []) if stake.is_active)
|
||||
delegation_amount = sum(delegation.amount for delegation in self.reputation_delegations.get(agent_id, []) if delegation.is_active)
|
||||
chain_count = len(self.chain_reputations.get(agent_id, {}))
|
||||
reputation_age = (datetime.utcnow() - reputation.last_updated).days
|
||||
|
||||
return ReputationAnalytics(
|
||||
agent_id=agent_id,
|
||||
total_score=reputation.score,
|
||||
effective_score=await self.get_effective_reputation(agent_id),
|
||||
success_rate=success_rate,
|
||||
stake_amount=stake_amount,
|
||||
delegation_amount=delegation_amount,
|
||||
chain_count=chain_count,
|
||||
tier=reputation.tier,
|
||||
reputation_age=reputation_age,
|
||||
last_activity=reputation.last_updated
|
||||
)
|
||||
|
||||
async def get_chain_reputations(self, agent_id: str) -> List[ReputationScore]:
|
||||
"""Get all chain reputations for an agent"""
|
||||
|
||||
if agent_id not in self.chain_reputations:
|
||||
return []
|
||||
|
||||
return list(self.chain_reputations[agent_id].values())
|
||||
|
||||
async def get_top_agents(self, limit: int = 100, chain_id: Optional[int] = None) -> List[ReputationAnalytics]:
|
||||
"""Get top agents by reputation score"""
|
||||
|
||||
analytics = []
|
||||
for agent_id in self.reputation_data:
|
||||
try:
|
||||
agent_analytics = await self.get_reputation_analytics(agent_id)
|
||||
if chain_id is None or agent_id in self.chain_reputations and chain_id in self.chain_reputations[agent_id]:
|
||||
analytics.append(agent_analytics)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting analytics for agent {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by effective score
|
||||
analytics.sort(key=lambda x: x.effective_score, reverse=True)
|
||||
|
||||
return analytics[:limit]
|
||||
|
||||
async def get_reputation_tier_distribution(self) -> Dict[str, int]:
|
||||
"""Get distribution of agents across reputation tiers"""
|
||||
|
||||
distribution = {tier.value: 0 for tier in ReputationTier}
|
||||
|
||||
for reputation in self.reputation_data.values():
|
||||
distribution[reputation.tier.value] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
async def _calculate_score_change(self, event_type: ReputationEvent, weight: int) -> int:
|
||||
"""Calculate score change based on event type and weight"""
|
||||
|
||||
base_changes = {
|
||||
ReputationEvent.TASK_SUCCESS: self.success_bonus,
|
||||
ReputationEvent.TASK_FAILURE: self.failure_penalty,
|
||||
ReputationEvent.POSITIVE_FEEDBACK: self.success_bonus // 2,
|
||||
ReputationEvent.NEGATIVE_FEEDBACK: self.failure_penalty // 2,
|
||||
ReputationEvent.TASK_TIMEOUT: self.failure_penalty // 2,
|
||||
ReputationEvent.TASK_CANCELLED: self.failure_penalty // 4,
|
||||
ReputationEvent.REPUTATION_STAKE: 0,
|
||||
ReputationEvent.REPUTATION_DELEGATE: 0,
|
||||
ReputationEvent.CROSS_CHAIN_SYNC: 0
|
||||
}
|
||||
|
||||
base_change = base_changes.get(event_type, 0)
|
||||
return base_change * weight
|
||||
|
||||
async def _verify_cross_chain_signature(self, agent_id: str, chain_id: int, signature: str) -> str:
|
||||
"""Verify cross-chain signature (simplified)"""
|
||||
# In production, implement proper cross-chain signature verification
|
||||
import hashlib
|
||||
hash_input = f"{agent_id}:{chain_id}:{datetime.utcnow().isoformat()}".encode()
|
||||
return hashlib.sha256(hash_input).hexdigest()
|
||||
|
||||
async def _get_total_delegated(self, agent_id: str) -> int:
|
||||
"""Get total amount delegated by agent"""
|
||||
|
||||
total = 0
|
||||
for delegation in self.reputation_delegations.get(agent_id, []):
|
||||
if delegation.is_active:
|
||||
total += delegation.amount
|
||||
|
||||
return total
|
||||
|
||||
async def _get_delegated_from(self, agent_id: str) -> int:
|
||||
"""Get total amount delegated to agent"""
|
||||
|
||||
total = 0
|
||||
for delegator_id, delegations in self.reputation_delegations.items():
|
||||
for delegation in delegations:
|
||||
if delegation.delegate == agent_id and delegation.is_active:
|
||||
total += delegation.amount
|
||||
|
||||
return total
|
||||
|
||||
async def _load_reputation_data(self):
|
||||
"""Load existing reputation data"""
|
||||
# In production, load from database
|
||||
pass
|
||||
|
||||
async def _monitor_reputation_sync(self):
|
||||
"""Monitor and process reputation sync requests"""
|
||||
while True:
|
||||
try:
|
||||
# Process pending sync requests
|
||||
await self._process_pending_syncs()
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
except Exception as e:
|
||||
logger.error(f"Error in reputation sync monitoring: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _process_pending_syncs(self):
|
||||
"""Process pending cross-chain sync requests"""
|
||||
# In production, implement pending sync processing
|
||||
pass
|
||||
|
||||
async def _process_stake_rewards(self):
|
||||
"""Process stake rewards"""
|
||||
while True:
|
||||
try:
|
||||
# Calculate and distribute stake rewards
|
||||
await self._distribute_stake_rewards()
|
||||
await asyncio.sleep(3600) # Process every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stake reward processing: {e}")
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def _distribute_stake_rewards(self):
|
||||
"""Distribute rewards for active stakes"""
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
for agent_id, stakes in self.reputation_stakes.items():
|
||||
for stake in stakes:
|
||||
if stake.is_active and current_time >= stake.end_time:
|
||||
# Calculate reward
|
||||
reward_amount = int(stake.amount * stake.reward_rate * (stake.lock_period / 31536000)) # APY calculation
|
||||
|
||||
# Distribute reward (simplified)
|
||||
logger.info(f"Distributing {reward_amount} reward to {agent_id}")
|
||||
|
||||
# Mark stake as inactive
|
||||
stake.is_active = False
|
||||
|
||||
async def _cleanup_expired_stakes(self):
|
||||
"""Clean up expired stakes and delegations"""
|
||||
while True:
|
||||
try:
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
# Clean up expired stakes
|
||||
for agent_id, stakes in self.reputation_stakes.items():
|
||||
for stake in stakes:
|
||||
if stake.is_active and current_time > stake.end_time:
|
||||
stake.is_active = False
|
||||
|
||||
# Clean up expired delegations
|
||||
for delegator_id, delegations in self.reputation_delegations.items():
|
||||
for delegation in delegations:
|
||||
if delegation.is_active and current_time > delegation.start_time + timedelta(days=30):
|
||||
delegation.is_active = False
|
||||
|
||||
await asyncio.sleep(3600) # Clean up every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup: {e}")
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def get_cross_chain_sync_status(self, agent_id: str) -> List[CrossChainSync]:
|
||||
"""Get cross-chain sync status for agent"""
|
||||
|
||||
return [
|
||||
sync for sync in self.cross_chain_syncs
|
||||
if sync.agent_id == agent_id
|
||||
]
|
||||
|
||||
async def get_reputation_history(
|
||||
self,
|
||||
agent_id: str,
|
||||
days: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get reputation history for agent"""
|
||||
|
||||
# In production, fetch from database
|
||||
return []
|
||||
|
||||
async def export_reputation_data(self, format: str = "json") -> str:
|
||||
"""Export reputation data"""
|
||||
|
||||
data = {
|
||||
"reputation_data": {k: asdict(v) for k, v in self.reputation_data.items()},
|
||||
"chain_reputations": {k: {str(k2): asdict(v2) for k2, v2 in v.items()} for k, v in self.chain_reputations.items()},
|
||||
"reputation_stakes": {k: [asdict(s) for s in v] for k, v in self.reputation_stakes.items()},
|
||||
"reputation_delegations": {k: [asdict(d) for d in v] for k, v in self.reputation_delegations.items()},
|
||||
"export_timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if format.lower() == "json":
|
||||
return json.dumps(data, indent=2, default=str)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
async def import_reputation_data(self, data: str, format: str = "json"):
|
||||
"""Import reputation data"""
|
||||
|
||||
if format.lower() == "json":
|
||||
parsed_data = json.loads(data)
|
||||
|
||||
# Import reputation data
|
||||
for agent_id, rep_data in parsed_data.get("reputation_data", {}).items():
|
||||
self.reputation_data[agent_id] = ReputationScore(**rep_data)
|
||||
|
||||
# Import chain reputations
|
||||
for agent_id, chain_data in parsed_data.get("chain_reputations", {}).items():
|
||||
self.chain_reputations[agent_id] = {
|
||||
int(chain_id): ReputationScore(**rep_data)
|
||||
for chain_id, rep_data in chain_data.items()
|
||||
}
|
||||
|
||||
logger.info("Reputation data imported successfully")
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
400
apps/coordinator-api/src/app/services/ipfs_storage_service.py
Normal file
400
apps/coordinator-api/src/app/services/ipfs_storage_service.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
IPFS Storage Service for Decentralized AI Memory & Storage
|
||||
Handles IPFS/Filecoin integration for persistent agent memory storage
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
import json
|
||||
import hashlib
|
||||
import gzip
|
||||
import pickle
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
try:
|
||||
import ipfshttpclient
|
||||
from web3 import Web3
|
||||
except ImportError as e:
|
||||
logging.error(f"IPFS/Web3 dependencies not installed: {e}")
|
||||
raise
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPFSUploadResult:
|
||||
"""Result of IPFS upload operation"""
|
||||
cid: str
|
||||
size: int
|
||||
compressed_size: int
|
||||
upload_time: datetime
|
||||
pinned: bool = False
|
||||
filecoin_deal: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryMetadata:
|
||||
"""Metadata for stored agent memories"""
|
||||
agent_id: str
|
||||
memory_type: str
|
||||
timestamp: datetime
|
||||
version: int
|
||||
tags: List[str]
|
||||
compression_ratio: float
|
||||
integrity_hash: str
|
||||
|
||||
|
||||
class IPFSStorageService:
|
||||
"""Service for IPFS/Filecoin storage operations"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.ipfs_client = None
|
||||
self.web3 = None
|
||||
self.cache = {} # Simple in-memory cache
|
||||
self.compression_threshold = config.get("compression_threshold", 1024)
|
||||
self.pin_threshold = config.get("pin_threshold", 100) # Pin important memories
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize IPFS client and Web3 connection"""
|
||||
try:
|
||||
# Initialize IPFS client
|
||||
ipfs_url = self.config.get("ipfs_url", "/ip4/127.0.0.1/tcp/5001")
|
||||
self.ipfs_client = ipfshttpclient.connect(ipfs_url)
|
||||
|
||||
# Test connection
|
||||
version = self.ipfs_client.version()
|
||||
logger.info(f"Connected to IPFS node: {version['Version']}")
|
||||
|
||||
# Initialize Web3 if blockchain features enabled
|
||||
if self.config.get("blockchain_enabled", False):
|
||||
web3_url = self.config.get("web3_url")
|
||||
self.web3 = Web3(Web3.HTTPProvider(web3_url))
|
||||
if self.web3.is_connected():
|
||||
logger.info("Connected to blockchain node")
|
||||
else:
|
||||
logger.warning("Failed to connect to blockchain node")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize IPFS service: {e}")
|
||||
raise
|
||||
|
||||
async def upload_memory(
|
||||
self,
|
||||
agent_id: str,
|
||||
memory_data: Any,
|
||||
memory_type: str = "experience",
|
||||
tags: Optional[List[str]] = None,
|
||||
compress: bool = True,
|
||||
pin: bool = False
|
||||
) -> IPFSUploadResult:
|
||||
"""Upload agent memory data to IPFS"""
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
tags = tags or []
|
||||
|
||||
try:
|
||||
# Serialize memory data
|
||||
serialized_data = pickle.dumps(memory_data)
|
||||
original_size = len(serialized_data)
|
||||
|
||||
# Compress if enabled and above threshold
|
||||
if compress and original_size > self.compression_threshold:
|
||||
compressed_data = gzip.compress(serialized_data)
|
||||
compression_ratio = len(compressed_data) / original_size
|
||||
upload_data = compressed_data
|
||||
else:
|
||||
compressed_data = serialized_data
|
||||
compression_ratio = 1.0
|
||||
upload_data = serialized_data
|
||||
|
||||
# Calculate integrity hash
|
||||
integrity_hash = hashlib.sha256(upload_data).hexdigest()
|
||||
|
||||
# Upload to IPFS
|
||||
result = self.ipfs_client.add_bytes(upload_data)
|
||||
cid = result['Hash']
|
||||
|
||||
# Pin if requested or meets threshold
|
||||
should_pin = pin or len(tags) >= self.pin_threshold
|
||||
if should_pin:
|
||||
try:
|
||||
self.ipfs_client.pin.add(cid)
|
||||
pinned = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pin CID {cid}: {e}")
|
||||
pinned = False
|
||||
else:
|
||||
pinned = False
|
||||
|
||||
# Create metadata
|
||||
metadata = MemoryMetadata(
|
||||
agent_id=agent_id,
|
||||
memory_type=memory_type,
|
||||
timestamp=start_time,
|
||||
version=1,
|
||||
tags=tags,
|
||||
compression_ratio=compression_ratio,
|
||||
integrity_hash=integrity_hash
|
||||
)
|
||||
|
||||
# Store metadata
|
||||
await self._store_metadata(cid, metadata)
|
||||
|
||||
# Cache result
|
||||
upload_result = IPFSUploadResult(
|
||||
cid=cid,
|
||||
size=original_size,
|
||||
compressed_size=len(upload_data),
|
||||
upload_time=start_time,
|
||||
pinned=pinned
|
||||
)
|
||||
|
||||
self.cache[cid] = upload_result
|
||||
|
||||
logger.info(f"Uploaded memory for agent {agent_id}: CID {cid}")
|
||||
return upload_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload memory for agent {agent_id}: {e}")
|
||||
raise
|
||||
|
||||
async def retrieve_memory(self, cid: str, verify_integrity: bool = True) -> Tuple[Any, MemoryMetadata]:
|
||||
"""Retrieve memory data from IPFS"""
|
||||
|
||||
try:
|
||||
# Check cache first
|
||||
if cid in self.cache:
|
||||
logger.debug(f"Retrieved {cid} from cache")
|
||||
|
||||
# Get metadata
|
||||
metadata = await self._get_metadata(cid)
|
||||
if not metadata:
|
||||
raise ValueError(f"No metadata found for CID {cid}")
|
||||
|
||||
# Retrieve from IPFS
|
||||
retrieved_data = self.ipfs_client.cat(cid)
|
||||
|
||||
# Verify integrity if requested
|
||||
if verify_integrity:
|
||||
calculated_hash = hashlib.sha256(retrieved_data).hexdigest()
|
||||
if calculated_hash != metadata.integrity_hash:
|
||||
raise ValueError(f"Integrity check failed for CID {cid}")
|
||||
|
||||
# Decompress if needed
|
||||
if metadata.compression_ratio < 1.0:
|
||||
decompressed_data = gzip.decompress(retrieved_data)
|
||||
else:
|
||||
decompressed_data = retrieved_data
|
||||
|
||||
# Deserialize
|
||||
memory_data = pickle.loads(decompressed_data)
|
||||
|
||||
logger.info(f"Retrieved memory for agent {metadata.agent_id}: CID {cid}")
|
||||
return memory_data, metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve memory {cid}: {e}")
|
||||
raise
|
||||
|
||||
async def batch_upload_memories(
|
||||
self,
|
||||
agent_id: str,
|
||||
memories: List[Tuple[Any, str, List[str]]],
|
||||
batch_size: int = 10
|
||||
) -> List[IPFSUploadResult]:
|
||||
"""Upload multiple memories in batches"""
|
||||
|
||||
results = []
|
||||
|
||||
for i in range(0, len(memories), batch_size):
|
||||
batch = memories[i:i + batch_size]
|
||||
batch_results = []
|
||||
|
||||
# Upload batch concurrently
|
||||
tasks = []
|
||||
for memory_data, memory_type, tags in batch:
|
||||
task = self.upload_memory(agent_id, memory_data, memory_type, tags)
|
||||
tasks.append(task)
|
||||
|
||||
try:
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in batch_results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Batch upload failed: {result}")
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch upload error: {e}")
|
||||
|
||||
# Small delay between batches to avoid overwhelming IPFS
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return results
|
||||
|
||||
async def create_filecoin_deal(self, cid: str, duration: int = 180) -> Optional[str]:
|
||||
"""Create Filecoin storage deal for CID persistence"""
|
||||
|
||||
try:
|
||||
# This would integrate with Filecoin storage providers
|
||||
# For now, return a mock deal ID
|
||||
deal_id = f"deal-{cid[:8]}-{datetime.utcnow().timestamp()}"
|
||||
|
||||
logger.info(f"Created Filecoin deal {deal_id} for CID {cid}")
|
||||
return deal_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Filecoin deal for {cid}: {e}")
|
||||
return None
|
||||
|
||||
async def list_agent_memories(self, agent_id: str, limit: int = 100) -> List[str]:
|
||||
"""List all memory CIDs for an agent"""
|
||||
|
||||
try:
|
||||
# This would query a database or index
|
||||
# For now, return mock data
|
||||
cids = []
|
||||
|
||||
# Search through cache
|
||||
for cid, result in self.cache.items():
|
||||
# In real implementation, this would query metadata
|
||||
if agent_id in cid: # Simplified check
|
||||
cids.append(cid)
|
||||
|
||||
return cids[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list memories for agent {agent_id}: {e}")
|
||||
return []
|
||||
|
||||
async def delete_memory(self, cid: str) -> bool:
|
||||
"""Delete/unpin memory from IPFS"""
|
||||
|
||||
try:
|
||||
# Unpin the CID
|
||||
self.ipfs_client.pin.rm(cid)
|
||||
|
||||
# Remove from cache
|
||||
if cid in self.cache:
|
||||
del self.cache[cid]
|
||||
|
||||
# Remove metadata
|
||||
await self._delete_metadata(cid)
|
||||
|
||||
logger.info(f"Deleted memory: CID {cid}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete memory {cid}: {e}")
|
||||
return False
|
||||
|
||||
async def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""Get storage statistics"""
|
||||
|
||||
try:
|
||||
# Get IPFS repo stats
|
||||
stats = self.ipfs_client.repo.stat()
|
||||
|
||||
return {
|
||||
"total_objects": stats.get("numObjects", 0),
|
||||
"repo_size": stats.get("repoSize", 0),
|
||||
"storage_max": stats.get("storageMax", 0),
|
||||
"version": stats.get("version", "unknown"),
|
||||
"cached_objects": len(self.cache)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get storage stats: {e}")
|
||||
return {}
|
||||
|
||||
async def _store_metadata(self, cid: str, metadata: MemoryMetadata):
|
||||
"""Store metadata for a CID"""
|
||||
# In real implementation, this would store in a database
|
||||
# For now, store in memory
|
||||
pass
|
||||
|
||||
async def _get_metadata(self, cid: str) -> Optional[MemoryMetadata]:
|
||||
"""Get metadata for a CID"""
|
||||
# In real implementation, this would query a database
|
||||
# For now, return mock metadata
|
||||
return MemoryMetadata(
|
||||
agent_id="mock_agent",
|
||||
memory_type="experience",
|
||||
timestamp=datetime.utcnow(),
|
||||
version=1,
|
||||
tags=["mock"],
|
||||
compression_ratio=1.0,
|
||||
integrity_hash="mock_hash"
|
||||
)
|
||||
|
||||
async def _delete_metadata(self, cid: str):
|
||||
"""Delete metadata for a CID"""
|
||||
# In real implementation, this would delete from database
|
||||
pass
|
||||
|
||||
|
||||
class MemoryCompressionService:
|
||||
"""Service for memory compression and optimization"""
|
||||
|
||||
@staticmethod
|
||||
def compress_memory(data: Any) -> Tuple[bytes, float]:
|
||||
"""Compress memory data and return compressed data with ratio"""
|
||||
serialized = pickle.dumps(data)
|
||||
compressed = gzip.compress(serialized)
|
||||
ratio = len(compressed) / len(serialized)
|
||||
return compressed, ratio
|
||||
|
||||
@staticmethod
|
||||
def decompress_memory(compressed_data: bytes) -> Any:
|
||||
"""Decompress memory data"""
|
||||
decompressed = gzip.decompress(compressed_data)
|
||||
return pickle.loads(decompressed)
|
||||
|
||||
@staticmethod
|
||||
def calculate_similarity(data1: Any, data2: Any) -> float:
|
||||
"""Calculate similarity between two memory items"""
|
||||
# Simplified similarity calculation
|
||||
# In real implementation, this would use more sophisticated methods
|
||||
try:
|
||||
hash1 = hashlib.md5(pickle.dumps(data1)).hexdigest()
|
||||
hash2 = hashlib.md5(pickle.dumps(data2)).hexdigest()
|
||||
|
||||
# Simple hash comparison (not ideal for real use)
|
||||
return 1.0 if hash1 == hash2 else 0.0
|
||||
except:
|
||||
return 0.0
|
||||
|
||||
|
||||
class IPFSClusterManager:
|
||||
"""Manager for IPFS cluster operations"""
|
||||
|
||||
def __init__(self, cluster_config: Dict[str, Any]):
|
||||
self.config = cluster_config
|
||||
self.nodes = cluster_config.get("nodes", [])
|
||||
|
||||
async def replicate_to_cluster(self, cid: str) -> List[str]:
|
||||
"""Replicate CID to cluster nodes"""
|
||||
replicated_nodes = []
|
||||
|
||||
for node in self.nodes:
|
||||
try:
|
||||
# In real implementation, this would replicate to each node
|
||||
replicated_nodes.append(node)
|
||||
logger.info(f"Replicated {cid} to node {node}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to replicate {cid} to {node}: {e}")
|
||||
|
||||
return replicated_nodes
|
||||
|
||||
async def get_cluster_health(self) -> Dict[str, Any]:
|
||||
"""Get health status of IPFS cluster"""
|
||||
return {
|
||||
"total_nodes": len(self.nodes),
|
||||
"healthy_nodes": len(self.nodes), # Simplified
|
||||
"cluster_id": "mock-cluster"
|
||||
}
|
||||
510
apps/coordinator-api/src/app/services/memory_manager.py
Normal file
510
apps/coordinator-api/src/app/services/memory_manager.py
Normal file
@@ -0,0 +1,510 @@
|
||||
"""
|
||||
Memory Manager Service for Agent Memory Operations
|
||||
Handles memory lifecycle management, versioning, and optimization
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
import json
|
||||
|
||||
from .ipfs_storage_service import IPFSStorageService, MemoryMetadata, IPFSUploadResult
|
||||
from ..storage import SessionDep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""Types of agent memories"""
|
||||
EXPERIENCE = "experience"
|
||||
POLICY_WEIGHTS = "policy_weights"
|
||||
KNOWLEDGE_GRAPH = "knowledge_graph"
|
||||
TRAINING_DATA = "training_data"
|
||||
USER_FEEDBACK = "user_feedback"
|
||||
PERFORMANCE_METRICS = "performance_metrics"
|
||||
MODEL_STATE = "model_state"
|
||||
|
||||
|
||||
class MemoryPriority(str, Enum):
|
||||
"""Memory storage priorities"""
|
||||
CRITICAL = "critical" # Always pin, replicate to all nodes
|
||||
HIGH = "high" # Pin, replicate to majority
|
||||
MEDIUM = "medium" # Pin, selective replication
|
||||
LOW = "low" # No pin, archive only
|
||||
TEMPORARY = "temporary" # No pin, auto-expire
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
"""Configuration for memory management"""
|
||||
max_memories_per_agent: int = 1000
|
||||
batch_upload_size: int = 50
|
||||
compression_threshold: int = 1024
|
||||
auto_cleanup_days: int = 30
|
||||
version_retention: int = 10
|
||||
deduplication_enabled: bool = True
|
||||
encryption_enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryRecord:
|
||||
"""Record of stored memory"""
|
||||
cid: str
|
||||
agent_id: str
|
||||
memory_type: MemoryType
|
||||
priority: MemoryPriority
|
||||
version: int
|
||||
timestamp: datetime
|
||||
size: int
|
||||
tags: List[str]
|
||||
access_count: int = 0
|
||||
last_accessed: Optional[datetime] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
parent_cid: Optional[str] = None # For versioning
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""Manager for agent memory operations"""
|
||||
|
||||
def __init__(self, ipfs_service: IPFSStorageService, config: MemoryConfig):
|
||||
self.ipfs_service = ipfs_service
|
||||
self.config = config
|
||||
self.memory_records: Dict[str, MemoryRecord] = {} # In-memory index
|
||||
self.agent_memories: Dict[str, List[str]] = {} # agent_id -> [cids]
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize memory manager"""
|
||||
logger.info("Initializing Memory Manager")
|
||||
|
||||
# Load existing memory records from database
|
||||
await self._load_memory_records()
|
||||
|
||||
# Start cleanup task
|
||||
asyncio.create_task(self._cleanup_expired_memories())
|
||||
|
||||
logger.info("Memory Manager initialized")
|
||||
|
||||
async def store_memory(
|
||||
self,
|
||||
agent_id: str,
|
||||
memory_data: Any,
|
||||
memory_type: MemoryType,
|
||||
priority: MemoryPriority = MemoryPriority.MEDIUM,
|
||||
tags: Optional[List[str]] = None,
|
||||
version: Optional[int] = None,
|
||||
parent_cid: Optional[str] = None,
|
||||
expires_in_days: Optional[int] = None
|
||||
) -> IPFSUploadResult:
|
||||
"""Store agent memory with versioning and deduplication"""
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
# Check for duplicates if deduplication enabled
|
||||
if self.config.deduplication_enabled:
|
||||
existing_cid = await self._find_duplicate_memory(agent_id, memory_data)
|
||||
if existing_cid:
|
||||
logger.info(f"Found duplicate memory for agent {agent_id}: {existing_cid}")
|
||||
await self._update_access_count(existing_cid)
|
||||
return await self._get_upload_result(existing_cid)
|
||||
|
||||
# Determine version
|
||||
if version is None:
|
||||
version = await self._get_next_version(agent_id, memory_type, parent_cid)
|
||||
|
||||
# Set expiration for temporary memories
|
||||
expires_at = None
|
||||
if priority == MemoryPriority.TEMPORARY:
|
||||
expires_at = datetime.utcnow() + timedelta(days=expires_in_days or 7)
|
||||
elif expires_in_days:
|
||||
expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
|
||||
|
||||
# Determine pinning based on priority
|
||||
should_pin = priority in [MemoryPriority.CRITICAL, MemoryPriority.HIGH]
|
||||
|
||||
# Add priority tag
|
||||
tags = tags or []
|
||||
tags.append(f"priority:{priority.value}")
|
||||
tags.append(f"version:{version}")
|
||||
|
||||
# Upload to IPFS
|
||||
upload_result = await self.ipfs_service.upload_memory(
|
||||
agent_id=agent_id,
|
||||
memory_data=memory_data,
|
||||
memory_type=memory_type.value,
|
||||
tags=tags,
|
||||
compress=True,
|
||||
pin=should_pin
|
||||
)
|
||||
|
||||
# Create memory record
|
||||
memory_record = MemoryRecord(
|
||||
cid=upload_result.cid,
|
||||
agent_id=agent_id,
|
||||
memory_type=memory_type,
|
||||
priority=priority,
|
||||
version=version,
|
||||
timestamp=upload_result.upload_time,
|
||||
size=upload_result.size,
|
||||
tags=tags,
|
||||
parent_cid=parent_cid,
|
||||
expires_at=expires_at
|
||||
)
|
||||
|
||||
# Store record
|
||||
self.memory_records[upload_result.cid] = memory_record
|
||||
|
||||
# Update agent index
|
||||
if agent_id not in self.agent_memories:
|
||||
self.agent_memories[agent_id] = []
|
||||
self.agent_memories[agent_id].append(upload_result.cid)
|
||||
|
||||
# Limit memories per agent
|
||||
await self._enforce_memory_limit(agent_id)
|
||||
|
||||
# Save to database
|
||||
await self._save_memory_record(memory_record)
|
||||
|
||||
logger.info(f"Stored memory for agent {agent_id}: CID {upload_result.cid}")
|
||||
return upload_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store memory for agent {agent_id}: {e}")
|
||||
raise
|
||||
|
||||
async def retrieve_memory(self, cid: str, update_access: bool = True) -> Tuple[Any, MemoryRecord]:
|
||||
"""Retrieve memory data and metadata"""
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
# Get memory record
|
||||
memory_record = self.memory_records.get(cid)
|
||||
if not memory_record:
|
||||
raise ValueError(f"Memory record not found for CID: {cid}")
|
||||
|
||||
# Check expiration
|
||||
if memory_record.expires_at and memory_record.expires_at < datetime.utcnow():
|
||||
raise ValueError(f"Memory has expired: {cid}")
|
||||
|
||||
# Retrieve from IPFS
|
||||
memory_data, metadata = await self.ipfs_service.retrieve_memory(cid)
|
||||
|
||||
# Update access count
|
||||
if update_access:
|
||||
await self._update_access_count(cid)
|
||||
|
||||
return memory_data, memory_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve memory {cid}: {e}")
|
||||
raise
|
||||
|
||||
async def batch_store_memories(
|
||||
self,
|
||||
agent_id: str,
|
||||
memories: List[Tuple[Any, MemoryType, MemoryPriority, List[str]]],
|
||||
batch_size: Optional[int] = None
|
||||
) -> List[IPFSUploadResult]:
|
||||
"""Store multiple memories in batches"""
|
||||
|
||||
batch_size = batch_size or self.config.batch_upload_size
|
||||
results = []
|
||||
|
||||
for i in range(0, len(memories), batch_size):
|
||||
batch = memories[i:i + batch_size]
|
||||
|
||||
# Process batch
|
||||
batch_tasks = []
|
||||
for memory_data, memory_type, priority, tags in batch:
|
||||
task = self.store_memory(
|
||||
agent_id=agent_id,
|
||||
memory_data=memory_data,
|
||||
memory_type=memory_type,
|
||||
priority=priority,
|
||||
tags=tags
|
||||
)
|
||||
batch_tasks.append(task)
|
||||
|
||||
try:
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
for result in batch_results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Batch store failed: {result}")
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch store error: {e}")
|
||||
|
||||
return results
|
||||
|
||||
async def list_agent_memories(
|
||||
self,
|
||||
agent_id: str,
|
||||
memory_type: Optional[MemoryType] = None,
|
||||
limit: int = 100,
|
||||
sort_by: str = "timestamp",
|
||||
ascending: bool = False
|
||||
) -> List[MemoryRecord]:
|
||||
"""List memories for an agent with filtering and sorting"""
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
agent_cids = self.agent_memories.get(agent_id, [])
|
||||
memories = []
|
||||
|
||||
for cid in agent_cids:
|
||||
memory_record = self.memory_records.get(cid)
|
||||
if memory_record:
|
||||
# Filter by memory type
|
||||
if memory_type and memory_record.memory_type != memory_type:
|
||||
continue
|
||||
|
||||
# Filter expired memories
|
||||
if memory_record.expires_at and memory_record.expires_at < datetime.utcnow():
|
||||
continue
|
||||
|
||||
memories.append(memory_record)
|
||||
|
||||
# Sort
|
||||
if sort_by == "timestamp":
|
||||
memories.sort(key=lambda x: x.timestamp, reverse=not ascending)
|
||||
elif sort_by == "access_count":
|
||||
memories.sort(key=lambda x: x.access_count, reverse=not ascending)
|
||||
elif sort_by == "size":
|
||||
memories.sort(key=lambda x: x.size, reverse=not ascending)
|
||||
|
||||
return memories[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list memories for agent {agent_id}: {e}")
|
||||
return []
|
||||
|
||||
async def delete_memory(self, cid: str, permanent: bool = False) -> bool:
|
||||
"""Delete memory (unpin or permanent deletion)"""
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
memory_record = self.memory_records.get(cid)
|
||||
if not memory_record:
|
||||
return False
|
||||
|
||||
# Don't delete critical memories unless permanent
|
||||
if memory_record.priority == MemoryPriority.CRITICAL and not permanent:
|
||||
logger.warning(f"Cannot delete critical memory: {cid}")
|
||||
return False
|
||||
|
||||
# Unpin from IPFS
|
||||
if permanent:
|
||||
await self.ipfs_service.delete_memory(cid)
|
||||
|
||||
# Remove from records
|
||||
del self.memory_records[cid]
|
||||
|
||||
# Update agent index
|
||||
if memory_record.agent_id in self.agent_memories:
|
||||
self.agent_memories[memory_record.agent_id].remove(cid)
|
||||
|
||||
# Delete from database
|
||||
await self._delete_memory_record(cid)
|
||||
|
||||
logger.info(f"Deleted memory: {cid}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete memory {cid}: {e}")
|
||||
return False
|
||||
|
||||
async def get_memory_statistics(self, agent_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get memory statistics"""
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
if agent_id:
|
||||
# Statistics for specific agent
|
||||
agent_cids = self.agent_memories.get(agent_id, [])
|
||||
memories = [self.memory_records[cid] for cid in agent_cids if cid in self.memory_records]
|
||||
else:
|
||||
# Global statistics
|
||||
memories = list(self.memory_records.values())
|
||||
|
||||
# Calculate statistics
|
||||
total_memories = len(memories)
|
||||
total_size = sum(m.size for m in memories)
|
||||
|
||||
# By type
|
||||
by_type = {}
|
||||
for memory in memories:
|
||||
memory_type = memory.memory_type.value
|
||||
by_type[memory_type] = by_type.get(memory_type, 0) + 1
|
||||
|
||||
# By priority
|
||||
by_priority = {}
|
||||
for memory in memories:
|
||||
priority = memory.priority.value
|
||||
by_priority[priority] = by_priority.get(priority, 0) + 1
|
||||
|
||||
# Access statistics
|
||||
total_access = sum(m.access_count for m in memories)
|
||||
avg_access = total_access / total_memories if total_memories > 0 else 0
|
||||
|
||||
return {
|
||||
"total_memories": total_memories,
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_mb": total_size / (1024 * 1024),
|
||||
"by_type": by_type,
|
||||
"by_priority": by_priority,
|
||||
"total_access_count": total_access,
|
||||
"average_access_count": avg_access,
|
||||
"agent_count": len(self.agent_memories) if not agent_id else 1
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory statistics: {e}")
|
||||
return {}
|
||||
|
||||
async def optimize_storage(self) -> Dict[str, Any]:
|
||||
"""Optimize storage by archiving old memories and deduplication"""
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
optimization_results = {
|
||||
"archived": 0,
|
||||
"deduplicated": 0,
|
||||
"compressed": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Archive old low-priority memories
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.config.auto_cleanup_days)
|
||||
|
||||
for cid, memory_record in list(self.memory_records.items()):
|
||||
if (memory_record.priority in [MemoryPriority.LOW, MemoryPriority.TEMPORARY] and
|
||||
memory_record.timestamp < cutoff_date):
|
||||
|
||||
try:
|
||||
# Create Filecoin deal for persistence
|
||||
deal_id = await self.ipfs_service.create_filecoin_deal(cid)
|
||||
if deal_id:
|
||||
optimization_results["archived"] += 1
|
||||
except Exception as e:
|
||||
optimization_results["errors"].append(f"Archive failed for {cid}: {e}")
|
||||
|
||||
return optimization_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Storage optimization failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _find_duplicate_memory(self, agent_id: str, memory_data: Any) -> Optional[str]:
|
||||
"""Find duplicate memory using content hash"""
|
||||
# Simplified duplicate detection
|
||||
# In real implementation, this would use content-based hashing
|
||||
return None
|
||||
|
||||
async def _get_next_version(self, agent_id: str, memory_type: MemoryType, parent_cid: Optional[str]) -> int:
|
||||
"""Get next version number for memory"""
|
||||
|
||||
# Find existing versions of this memory type
|
||||
max_version = 0
|
||||
for cid in self.agent_memories.get(agent_id, []):
|
||||
memory_record = self.memory_records.get(cid)
|
||||
if (memory_record and memory_record.memory_type == memory_type and
|
||||
memory_record.parent_cid == parent_cid):
|
||||
max_version = max(max_version, memory_record.version)
|
||||
|
||||
return max_version + 1
|
||||
|
||||
async def _update_access_count(self, cid: str):
|
||||
"""Update access count and last accessed time"""
|
||||
memory_record = self.memory_records.get(cid)
|
||||
if memory_record:
|
||||
memory_record.access_count += 1
|
||||
memory_record.last_accessed = datetime.utcnow()
|
||||
await self._save_memory_record(memory_record)
|
||||
|
||||
async def _enforce_memory_limit(self, agent_id: str):
|
||||
"""Enforce maximum memories per agent"""
|
||||
|
||||
agent_cids = self.agent_memories.get(agent_id, [])
|
||||
if len(agent_cids) <= self.config.max_memories_per_agent:
|
||||
return
|
||||
|
||||
# Sort by priority and access count (keep important memories)
|
||||
memories = [(self.memory_records[cid], cid) for cid in agent_cids if cid in self.memory_records]
|
||||
|
||||
# Sort by priority (critical first) and access count
|
||||
priority_order = {
|
||||
MemoryPriority.CRITICAL: 0,
|
||||
MemoryPriority.HIGH: 1,
|
||||
MemoryPriority.MEDIUM: 2,
|
||||
MemoryPriority.LOW: 3,
|
||||
MemoryPriority.TEMPORARY: 4
|
||||
}
|
||||
|
||||
memories.sort(key=lambda x: (
|
||||
priority_order.get(x[0].priority, 5),
|
||||
-x[0].access_count,
|
||||
x[0].timestamp
|
||||
))
|
||||
|
||||
# Delete excess memories (keep the most important)
|
||||
excess_count = len(memories) - self.config.max_memories_per_agent
|
||||
for i in range(excess_count):
|
||||
memory_record, cid = memories[-(i + 1)] # Delete least important
|
||||
await self.delete_memory(cid, permanent=False)
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""Background task to clean up expired memories"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(3600) # Run every hour
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
expired_cids = []
|
||||
|
||||
for cid, memory_record in self.memory_records.items():
|
||||
if (memory_record.expires_at and
|
||||
memory_record.expires_at < current_time and
|
||||
memory_record.priority != MemoryPriority.CRITICAL):
|
||||
expired_cids.append(cid)
|
||||
|
||||
# Delete expired memories
|
||||
for cid in expired_cids:
|
||||
await self.delete_memory(cid, permanent=True)
|
||||
|
||||
if expired_cids:
|
||||
logger.info(f"Cleaned up {len(expired_cids)} expired memories")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Memory cleanup error: {e}")
|
||||
|
||||
async def _load_memory_records(self):
|
||||
"""Load memory records from database"""
|
||||
# In real implementation, this would load from database
|
||||
pass
|
||||
|
||||
async def _save_memory_record(self, memory_record: MemoryRecord):
|
||||
"""Save memory record to database"""
|
||||
# In real implementation, this would save to database
|
||||
pass
|
||||
|
||||
async def _delete_memory_record(self, cid: str):
|
||||
"""Delete memory record from database"""
|
||||
# In real implementation, this would delete from database
|
||||
pass
|
||||
|
||||
async def _get_upload_result(self, cid: str) -> IPFSUploadResult:
|
||||
"""Get upload result for existing CID"""
|
||||
# In real implementation, this would retrieve from database
|
||||
return IPFSUploadResult(
|
||||
cid=cid,
|
||||
size=0,
|
||||
compressed_size=0,
|
||||
upload_time=datetime.utcnow()
|
||||
)
|
||||
750
apps/coordinator-api/src/app/services/task_decomposition.py
Normal file
750
apps/coordinator-api/src/app/services/task_decomposition.py
Normal file
@@ -0,0 +1,750 @@
|
||||
"""
|
||||
Task Decomposition Service for OpenClaw Autonomous Economics
|
||||
Implements intelligent task splitting and sub-task management
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple, Set
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import json
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
"""Types of tasks"""
|
||||
TEXT_PROCESSING = "text_processing"
|
||||
IMAGE_PROCESSING = "image_processing"
|
||||
AUDIO_PROCESSING = "audio_processing"
|
||||
VIDEO_PROCESSING = "video_processing"
|
||||
DATA_ANALYSIS = "data_analysis"
|
||||
MODEL_INFERENCE = "model_inference"
|
||||
MODEL_TRAINING = "model_training"
|
||||
COMPUTE_INTENSIVE = "compute_intensive"
|
||||
IO_BOUND = "io_bound"
|
||||
MIXED_MODAL = "mixed_modal"
|
||||
|
||||
|
||||
class SubTaskStatus(str, Enum):
|
||||
"""Sub-task status"""
|
||||
PENDING = "pending"
|
||||
ASSIGNED = "assigned"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class DependencyType(str, Enum):
|
||||
"""Dependency types between sub-tasks"""
|
||||
SEQUENTIAL = "sequential"
|
||||
PARALLEL = "parallel"
|
||||
CONDITIONAL = "conditional"
|
||||
AGGREGATION = "aggregation"
|
||||
|
||||
|
||||
class GPU_Tier(str, Enum):
|
||||
"""GPU resource tiers"""
|
||||
CPU_ONLY = "cpu_only"
|
||||
LOW_END_GPU = "low_end_gpu"
|
||||
MID_RANGE_GPU = "mid_range_gpu"
|
||||
HIGH_END_GPU = "high_end_gpu"
|
||||
PREMIUM_GPU = "premium_gpu"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskRequirement:
|
||||
"""Requirements for a task or sub-task"""
|
||||
task_type: TaskType
|
||||
estimated_duration: float # hours
|
||||
gpu_tier: GPU_Tier
|
||||
memory_requirement: int # GB
|
||||
compute_intensity: float # 0-1
|
||||
data_size: int # MB
|
||||
priority: int # 1-10
|
||||
deadline: Optional[datetime] = None
|
||||
max_cost: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubTask:
|
||||
"""Individual sub-task"""
|
||||
sub_task_id: str
|
||||
parent_task_id: str
|
||||
name: str
|
||||
description: str
|
||||
requirements: TaskRequirement
|
||||
status: SubTaskStatus = SubTaskStatus.PENDING
|
||||
assigned_agent: Optional[str] = None
|
||||
dependencies: List[str] = field(default_factory=list)
|
||||
outputs: List[str] = field(default_factory=list)
|
||||
inputs: List[str] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
error_message: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskDecomposition:
|
||||
"""Result of task decomposition"""
|
||||
original_task_id: str
|
||||
sub_tasks: List[SubTask]
|
||||
dependency_graph: Dict[str, List[str]] # sub_task_id -> dependencies
|
||||
execution_plan: List[List[str]] # List of parallel execution stages
|
||||
estimated_total_duration: float
|
||||
estimated_total_cost: float
|
||||
confidence_score: float
|
||||
decomposition_strategy: str
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskAggregation:
|
||||
"""Aggregation configuration for combining sub-task results"""
|
||||
aggregation_id: str
|
||||
parent_task_id: str
|
||||
aggregation_type: str # "concat", "merge", "vote", "weighted_average", etc.
|
||||
input_sub_tasks: List[str]
|
||||
output_format: str
|
||||
aggregation_function: str
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TaskDecompositionEngine:
|
||||
"""Engine for intelligent task decomposition and sub-task management"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.decomposition_history: List[TaskDecomposition] = []
|
||||
self.sub_task_registry: Dict[str, SubTask] = {}
|
||||
self.aggregation_registry: Dict[str, TaskAggregation] = {}
|
||||
|
||||
# Decomposition strategies
|
||||
self.strategies = {
|
||||
"sequential": self._sequential_decomposition,
|
||||
"parallel": self._parallel_decomposition,
|
||||
"hierarchical": self._hierarchical_decomposition,
|
||||
"pipeline": self._pipeline_decomposition,
|
||||
"adaptive": self._adaptive_decomposition
|
||||
}
|
||||
|
||||
# Task type complexity mapping
|
||||
self.complexity_thresholds = {
|
||||
TaskType.TEXT_PROCESSING: 0.3,
|
||||
TaskType.IMAGE_PROCESSING: 0.5,
|
||||
TaskType.AUDIO_PROCESSING: 0.4,
|
||||
TaskType.VIDEO_PROCESSING: 0.8,
|
||||
TaskType.DATA_ANALYSIS: 0.6,
|
||||
TaskType.MODEL_INFERENCE: 0.4,
|
||||
TaskType.MODEL_TRAINING: 0.9,
|
||||
TaskType.COMPUTE_INTENSIVE: 0.8,
|
||||
TaskType.IO_BOUND: 0.2,
|
||||
TaskType.MIXED_MODAL: 0.7
|
||||
}
|
||||
|
||||
# GPU tier performance mapping
|
||||
self.gpu_performance = {
|
||||
GPU_Tier.CPU_ONLY: 1.0,
|
||||
GPU_Tier.LOW_END_GPU: 2.5,
|
||||
GPU_Tier.MID_RANGE_GPU: 5.0,
|
||||
GPU_Tier.HIGH_END_GPU: 10.0,
|
||||
GPU_Tier.PREMIUM_GPU: 20.0
|
||||
}
|
||||
|
||||
async def decompose_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_requirements: TaskRequirement,
|
||||
strategy: Optional[str] = None,
|
||||
max_subtasks: int = 10,
|
||||
min_subtask_duration: float = 0.1 # hours
|
||||
) -> TaskDecomposition:
|
||||
"""Decompose a complex task into sub-tasks"""
|
||||
|
||||
try:
|
||||
logger.info(f"Decomposing task {task_id} with strategy {strategy}")
|
||||
|
||||
# Select decomposition strategy
|
||||
if strategy is None:
|
||||
strategy = await self._select_decomposition_strategy(task_requirements)
|
||||
|
||||
# Execute decomposition
|
||||
decomposition_func = self.strategies.get(strategy, self._adaptive_decomposition)
|
||||
sub_tasks = await decomposition_func(task_id, task_requirements, max_subtasks, min_subtask_duration)
|
||||
|
||||
# Build dependency graph
|
||||
dependency_graph = await self._build_dependency_graph(sub_tasks)
|
||||
|
||||
# Create execution plan
|
||||
execution_plan = await self._create_execution_plan(dependency_graph)
|
||||
|
||||
# Estimate total duration and cost
|
||||
total_duration = await self._estimate_total_duration(sub_tasks, execution_plan)
|
||||
total_cost = await self._estimate_total_cost(sub_tasks)
|
||||
|
||||
# Calculate confidence score
|
||||
confidence_score = await self._calculate_decomposition_confidence(
|
||||
task_requirements, sub_tasks, strategy
|
||||
)
|
||||
|
||||
# Create decomposition result
|
||||
decomposition = TaskDecomposition(
|
||||
original_task_id=task_id,
|
||||
sub_tasks=sub_tasks,
|
||||
dependency_graph=dependency_graph,
|
||||
execution_plan=execution_plan,
|
||||
estimated_total_duration=total_duration,
|
||||
estimated_total_cost=total_cost,
|
||||
confidence_score=confidence_score,
|
||||
decomposition_strategy=strategy
|
||||
)
|
||||
|
||||
# Register sub-tasks
|
||||
for sub_task in sub_tasks:
|
||||
self.sub_task_registry[sub_task.sub_task_id] = sub_task
|
||||
|
||||
# Store decomposition history
|
||||
self.decomposition_history.append(decomposition)
|
||||
|
||||
logger.info(f"Task {task_id} decomposed into {len(sub_tasks)} sub-tasks")
|
||||
return decomposition
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decompose task {task_id}: {e}")
|
||||
raise
|
||||
|
||||
async def create_aggregation(
|
||||
self,
|
||||
parent_task_id: str,
|
||||
input_sub_tasks: List[str],
|
||||
aggregation_type: str,
|
||||
output_format: str
|
||||
) -> TaskAggregation:
|
||||
"""Create aggregation configuration for combining sub-task results"""
|
||||
|
||||
aggregation_id = f"agg_{parent_task_id}_{datetime.utcnow().timestamp()}"
|
||||
|
||||
aggregation = TaskAggregation(
|
||||
aggregation_id=aggregation_id,
|
||||
parent_task_id=parent_task_id,
|
||||
aggregation_type=aggregation_type,
|
||||
input_sub_tasks=input_sub_tasks,
|
||||
output_format=output_format,
|
||||
aggregation_function=await self._get_aggregation_function(aggregation_type, output_format)
|
||||
)
|
||||
|
||||
self.aggregation_registry[aggregation_id] = aggregation
|
||||
|
||||
logger.info(f"Created aggregation {aggregation_id} for task {parent_task_id}")
|
||||
return aggregation
|
||||
|
||||
async def update_sub_task_status(
|
||||
self,
|
||||
sub_task_id: str,
|
||||
status: SubTaskStatus,
|
||||
error_message: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Update sub-task status"""
|
||||
|
||||
if sub_task_id not in self.sub_task_registry:
|
||||
logger.error(f"Sub-task {sub_task_id} not found")
|
||||
return False
|
||||
|
||||
sub_task = self.sub_task_registry[sub_task_id]
|
||||
old_status = sub_task.status
|
||||
sub_task.status = status
|
||||
|
||||
if error_message:
|
||||
sub_task.error_message = error_message
|
||||
|
||||
# Update timestamps
|
||||
if status == SubTaskStatus.IN_PROGRESS and old_status != SubTaskStatus.IN_PROGRESS:
|
||||
sub_task.started_at = datetime.utcnow()
|
||||
elif status == SubTaskStatus.COMPLETED:
|
||||
sub_task.completed_at = datetime.utcnow()
|
||||
elif status == SubTaskStatus.FAILED:
|
||||
sub_task.retry_count += 1
|
||||
|
||||
logger.info(f"Updated sub-task {sub_task_id} status: {old_status} -> {status}")
|
||||
return True
|
||||
|
||||
async def get_ready_sub_tasks(self, parent_task_id: Optional[str] = None) -> List[SubTask]:
|
||||
"""Get sub-tasks ready for execution"""
|
||||
|
||||
ready_tasks = []
|
||||
|
||||
for sub_task in self.sub_task_registry.values():
|
||||
if parent_task_id and sub_task.parent_task_id != parent_task_id:
|
||||
continue
|
||||
|
||||
if sub_task.status != SubTaskStatus.PENDING:
|
||||
continue
|
||||
|
||||
# Check if dependencies are satisfied
|
||||
dependencies_satisfied = True
|
||||
for dep_id in sub_task.dependencies:
|
||||
if dep_id not in self.sub_task_registry:
|
||||
dependencies_satisfied = False
|
||||
break
|
||||
if self.sub_task_registry[dep_id].status != SubTaskStatus.COMPLETED:
|
||||
dependencies_satisfied = False
|
||||
break
|
||||
|
||||
if dependencies_satisfied:
|
||||
ready_tasks.append(sub_task)
|
||||
|
||||
return ready_tasks
|
||||
|
||||
async def get_execution_status(self, parent_task_id: str) -> Dict[str, Any]:
|
||||
"""Get execution status for all sub-tasks of a parent task"""
|
||||
|
||||
sub_tasks = [st for st in self.sub_task_registry.values() if st.parent_task_id == parent_task_id]
|
||||
|
||||
if not sub_tasks:
|
||||
return {"status": "no_sub_tasks", "sub_tasks": []}
|
||||
|
||||
status_counts = {}
|
||||
for status in SubTaskStatus:
|
||||
status_counts[status.value] = 0
|
||||
|
||||
for sub_task in sub_tasks:
|
||||
status_counts[sub_task.status.value] += 1
|
||||
|
||||
# Determine overall status
|
||||
if status_counts["completed"] == len(sub_tasks):
|
||||
overall_status = "completed"
|
||||
elif status_counts["failed"] > 0:
|
||||
overall_status = "failed"
|
||||
elif status_counts["in_progress"] > 0:
|
||||
overall_status = "in_progress"
|
||||
else:
|
||||
overall_status = "pending"
|
||||
|
||||
return {
|
||||
"status": overall_status,
|
||||
"total_sub_tasks": len(sub_tasks),
|
||||
"status_counts": status_counts,
|
||||
"sub_tasks": [
|
||||
{
|
||||
"sub_task_id": st.sub_task_id,
|
||||
"name": st.name,
|
||||
"status": st.status.value,
|
||||
"assigned_agent": st.assigned_agent,
|
||||
"created_at": st.created_at.isoformat(),
|
||||
"started_at": st.started_at.isoformat() if st.started_at else None,
|
||||
"completed_at": st.completed_at.isoformat() if st.completed_at else None
|
||||
}
|
||||
for st in sub_tasks
|
||||
]
|
||||
}
|
||||
|
||||
async def retry_failed_sub_tasks(self, parent_task_id: str) -> List[str]:
|
||||
"""Retry failed sub-tasks"""
|
||||
|
||||
retried_tasks = []
|
||||
|
||||
for sub_task in self.sub_task_registry.values():
|
||||
if sub_task.parent_task_id != parent_task_id:
|
||||
continue
|
||||
|
||||
if sub_task.status == SubTaskStatus.FAILED and sub_task.retry_count < sub_task.max_retries:
|
||||
await self.update_sub_task_status(sub_task.sub_task_id, SubTaskStatus.PENDING)
|
||||
retried_tasks.append(sub_task.sub_task_id)
|
||||
logger.info(f"Retrying sub-task {sub_task.sub_task_id} (attempt {sub_task.retry_count + 1})")
|
||||
|
||||
return retried_tasks
|
||||
|
||||
async def _select_decomposition_strategy(self, task_requirements: TaskRequirement) -> str:
|
||||
"""Select optimal decomposition strategy"""
|
||||
|
||||
# Base selection on task type and complexity
|
||||
complexity = self.complexity_thresholds.get(task_requirements.task_type, 0.5)
|
||||
|
||||
# Adjust for duration and compute intensity
|
||||
if task_requirements.estimated_duration > 4.0:
|
||||
complexity += 0.2
|
||||
if task_requirements.compute_intensity > 0.8:
|
||||
complexity += 0.2
|
||||
if task_requirements.data_size > 1000: # > 1GB
|
||||
complexity += 0.1
|
||||
|
||||
# Select strategy based on complexity
|
||||
if complexity < 0.3:
|
||||
return "sequential"
|
||||
elif complexity < 0.5:
|
||||
return "parallel"
|
||||
elif complexity < 0.7:
|
||||
return "hierarchical"
|
||||
elif complexity < 0.9:
|
||||
return "pipeline"
|
||||
else:
|
||||
return "adaptive"
|
||||
|
||||
async def _sequential_decomposition(
|
||||
self,
|
||||
task_id: str,
|
||||
task_requirements: TaskRequirement,
|
||||
max_subtasks: int,
|
||||
min_duration: float
|
||||
) -> List[SubTask]:
|
||||
"""Sequential decomposition strategy"""
|
||||
|
||||
sub_tasks = []
|
||||
|
||||
# For simple tasks, create minimal decomposition
|
||||
if task_requirements.estimated_duration <= min_duration * 2:
|
||||
# Single sub-task
|
||||
sub_task = SubTask(
|
||||
sub_task_id=f"{task_id}_seq_1",
|
||||
parent_task_id=task_id,
|
||||
name="Main Task",
|
||||
description="Sequential execution of main task",
|
||||
requirements=task_requirements
|
||||
)
|
||||
sub_tasks.append(sub_task)
|
||||
else:
|
||||
# Split into sequential chunks
|
||||
num_chunks = min(int(task_requirements.estimated_duration / min_duration), max_subtasks)
|
||||
chunk_duration = task_requirements.estimated_duration / num_chunks
|
||||
|
||||
for i in range(num_chunks):
|
||||
chunk_requirements = TaskRequirement(
|
||||
task_type=task_requirements.task_type,
|
||||
estimated_duration=chunk_duration,
|
||||
gpu_tier=task_requirements.gpu_tier,
|
||||
memory_requirement=task_requirements.memory_requirement,
|
||||
compute_intensity=task_requirements.compute_intensity,
|
||||
data_size=task_requirements.data_size // num_chunks,
|
||||
priority=task_requirements.priority,
|
||||
deadline=task_requirements.deadline,
|
||||
max_cost=task_requirements.max_cost
|
||||
)
|
||||
|
||||
sub_task = SubTask(
|
||||
sub_task_id=f"{task_id}_seq_{i+1}",
|
||||
parent_task_id=task_id,
|
||||
name=f"Sequential Chunk {i+1}",
|
||||
description=f"Sequential execution chunk {i+1}",
|
||||
requirements=chunk_requirements,
|
||||
dependencies=[f"{task_id}_seq_{i}"] if i > 0 else []
|
||||
)
|
||||
sub_tasks.append(sub_task)
|
||||
|
||||
return sub_tasks
|
||||
|
||||
async def _parallel_decomposition(
|
||||
self,
|
||||
task_id: str,
|
||||
task_requirements: TaskRequirement,
|
||||
max_subtasks: int,
|
||||
min_duration: float
|
||||
) -> List[SubTask]:
|
||||
"""Parallel decomposition strategy"""
|
||||
|
||||
sub_tasks = []
|
||||
|
||||
# Determine optimal number of parallel tasks
|
||||
optimal_parallel = min(
|
||||
max(2, int(task_requirements.data_size / 100)), # Based on data size
|
||||
max(2, int(task_requirements.estimated_duration / min_duration)), # Based on duration
|
||||
max_subtasks
|
||||
)
|
||||
|
||||
# Split data and requirements
|
||||
chunk_data_size = task_requirements.data_size // optimal_parallel
|
||||
chunk_duration = task_requirements.estimated_duration / optimal_parallel
|
||||
|
||||
for i in range(optimal_parallel):
|
||||
chunk_requirements = TaskRequirement(
|
||||
task_type=task_requirements.task_type,
|
||||
estimated_duration=chunk_duration,
|
||||
gpu_tier=task_requirements.gpu_tier,
|
||||
memory_requirement=task_requirements.memory_requirement // optimal_parallel,
|
||||
compute_intensity=task_requirements.compute_intensity,
|
||||
data_size=chunk_data_size,
|
||||
priority=task_requirements.priority,
|
||||
deadline=task_requirements.deadline,
|
||||
max_cost=task_requirements.max_cost / optimal_parallel if task_requirements.max_cost else None
|
||||
)
|
||||
|
||||
sub_task = SubTask(
|
||||
sub_task_id=f"{task_id}_par_{i+1}",
|
||||
parent_task_id=task_id,
|
||||
name=f"Parallel Task {i+1}",
|
||||
description=f"Parallel execution task {i+1}",
|
||||
requirements=chunk_requirements,
|
||||
inputs=[f"input_chunk_{i}"],
|
||||
outputs=[f"output_chunk_{i}"]
|
||||
)
|
||||
sub_tasks.append(sub_task)
|
||||
|
||||
return sub_tasks
|
||||
|
||||
async def _hierarchical_decomposition(
|
||||
self,
|
||||
task_id: str,
|
||||
task_requirements: TaskRequirement,
|
||||
max_subtasks: int,
|
||||
min_duration: float
|
||||
) -> List[SubTask]:
|
||||
"""Hierarchical decomposition strategy"""
|
||||
|
||||
sub_tasks = []
|
||||
|
||||
# Create hierarchical structure
|
||||
# Level 1: Main decomposition
|
||||
level1_tasks = await self._parallel_decomposition(task_id, task_requirements, max_subtasks // 2, min_duration)
|
||||
|
||||
# Level 2: Further decomposition if needed
|
||||
for level1_task in level1_tasks:
|
||||
if level1_task.requirements.estimated_duration > min_duration * 2:
|
||||
# Decompose further
|
||||
level2_tasks = await self._sequential_decomposition(
|
||||
level1_task.sub_task_id,
|
||||
level1_task.requirements,
|
||||
2,
|
||||
min_duration / 2
|
||||
)
|
||||
|
||||
# Update dependencies
|
||||
for level2_task in level2_tasks:
|
||||
level2_task.dependencies = level1_task.dependencies
|
||||
level2_task.parent_task_id = task_id
|
||||
|
||||
sub_tasks.extend(level2_tasks)
|
||||
else:
|
||||
sub_tasks.append(level1_task)
|
||||
|
||||
return sub_tasks
|
||||
|
||||
async def _pipeline_decomposition(
|
||||
self,
|
||||
task_id: str,
|
||||
task_requirements: TaskRequirement,
|
||||
max_subtasks: int,
|
||||
min_duration: float
|
||||
) -> List[SubTask]:
|
||||
"""Pipeline decomposition strategy"""
|
||||
|
||||
sub_tasks = []
|
||||
|
||||
# Define pipeline stages based on task type
|
||||
if task_requirements.task_type == TaskType.IMAGE_PROCESSING:
|
||||
stages = ["preprocessing", "processing", "postprocessing"]
|
||||
elif task_requirements.task_type == TaskType.DATA_ANALYSIS:
|
||||
stages = ["data_loading", "cleaning", "analysis", "visualization"]
|
||||
elif task_requirements.task_type == TaskType.MODEL_TRAINING:
|
||||
stages = ["data_preparation", "model_training", "validation", "deployment"]
|
||||
else:
|
||||
stages = ["stage1", "stage2", "stage3"]
|
||||
|
||||
# Create pipeline sub-tasks
|
||||
stage_duration = task_requirements.estimated_duration / len(stages)
|
||||
|
||||
for i, stage in enumerate(stages):
|
||||
stage_requirements = TaskRequirement(
|
||||
task_type=task_requirements.task_type,
|
||||
estimated_duration=stage_duration,
|
||||
gpu_tier=task_requirements.gpu_tier,
|
||||
memory_requirement=task_requirements.memory_requirement,
|
||||
compute_intensity=task_requirements.compute_intensity,
|
||||
data_size=task_requirements.data_size,
|
||||
priority=task_requirements.priority,
|
||||
deadline=task_requirements.deadline,
|
||||
max_cost=task_requirements.max_cost / len(stages) if task_requirements.max_cost else None
|
||||
)
|
||||
|
||||
sub_task = SubTask(
|
||||
sub_task_id=f"{task_id}_pipe_{i+1}",
|
||||
parent_task_id=task_id,
|
||||
name=f"Pipeline Stage: {stage}",
|
||||
description=f"Pipeline stage: {stage}",
|
||||
requirements=stage_requirements,
|
||||
dependencies=[f"{task_id}_pipe_{i}"] if i > 0 else [],
|
||||
inputs=[f"stage_{i}_input"],
|
||||
outputs=[f"stage_{i}_output"]
|
||||
)
|
||||
sub_tasks.append(sub_task)
|
||||
|
||||
return sub_tasks
|
||||
|
||||
async def _adaptive_decomposition(
|
||||
self,
|
||||
task_id: str,
|
||||
task_requirements: TaskRequirement,
|
||||
max_subtasks: int,
|
||||
min_duration: float
|
||||
) -> List[SubTask]:
|
||||
"""Adaptive decomposition strategy"""
|
||||
|
||||
# Analyze task characteristics
|
||||
characteristics = await self._analyze_task_characteristics(task_requirements)
|
||||
|
||||
# Select best strategy based on analysis
|
||||
if characteristics["parallelizable"] > 0.7:
|
||||
return await self._parallel_decomposition(task_id, task_requirements, max_subtasks, min_duration)
|
||||
elif characteristics["sequential_dependency"] > 0.7:
|
||||
return await self._sequential_decomposition(task_id, task_requirements, max_subtasks, min_duration)
|
||||
elif characteristics["hierarchical_structure"] > 0.7:
|
||||
return await self._hierarchical_decomposition(task_id, task_requirements, max_subtasks, min_duration)
|
||||
else:
|
||||
return await self._pipeline_decomposition(task_id, task_requirements, max_subtasks, min_duration)
|
||||
|
||||
async def _analyze_task_characteristics(self, task_requirements: TaskRequirement) -> Dict[str, float]:
|
||||
"""Analyze task characteristics for adaptive decomposition"""
|
||||
|
||||
characteristics = {
|
||||
"parallelizable": 0.5,
|
||||
"sequential_dependency": 0.5,
|
||||
"hierarchical_structure": 0.5,
|
||||
"pipeline_suitable": 0.5
|
||||
}
|
||||
|
||||
# Analyze based on task type
|
||||
if task_requirements.task_type in [TaskType.DATA_ANALYSIS, TaskType.IMAGE_PROCESSING]:
|
||||
characteristics["parallelizable"] = 0.8
|
||||
elif task_requirements.task_type in [TaskType.MODEL_TRAINING]:
|
||||
characteristics["sequential_dependency"] = 0.7
|
||||
characteristics["pipeline_suitable"] = 0.8
|
||||
elif task_requirements.task_type == TaskType.MIXED_MODAL:
|
||||
characteristics["hierarchical_structure"] = 0.8
|
||||
|
||||
# Adjust based on data size
|
||||
if task_requirements.data_size > 1000: # > 1GB
|
||||
characteristics["parallelizable"] += 0.2
|
||||
|
||||
# Adjust based on compute intensity
|
||||
if task_requirements.compute_intensity > 0.8:
|
||||
characteristics["sequential_dependency"] += 0.1
|
||||
|
||||
return characteristics
|
||||
|
||||
async def _build_dependency_graph(self, sub_tasks: List[SubTask]) -> Dict[str, List[str]]:
|
||||
"""Build dependency graph from sub-tasks"""
|
||||
|
||||
dependency_graph = {}
|
||||
|
||||
for sub_task in sub_tasks:
|
||||
dependency_graph[sub_task.sub_task_id] = sub_task.dependencies
|
||||
|
||||
return dependency_graph
|
||||
|
||||
async def _create_execution_plan(self, dependency_graph: Dict[str, List[str]]) -> List[List[str]]:
|
||||
"""Create execution plan from dependency graph"""
|
||||
|
||||
execution_plan = []
|
||||
remaining_tasks = set(dependency_graph.keys())
|
||||
completed_tasks = set()
|
||||
|
||||
while remaining_tasks:
|
||||
# Find tasks with no unmet dependencies
|
||||
ready_tasks = []
|
||||
for task_id in remaining_tasks:
|
||||
dependencies = dependency_graph[task_id]
|
||||
if all(dep in completed_tasks for dep in dependencies):
|
||||
ready_tasks.append(task_id)
|
||||
|
||||
if not ready_tasks:
|
||||
# Circular dependency or error
|
||||
logger.warning("Circular dependency detected in task decomposition")
|
||||
break
|
||||
|
||||
# Add ready tasks to current execution stage
|
||||
execution_plan.append(ready_tasks)
|
||||
|
||||
# Mark tasks as completed
|
||||
for task_id in ready_tasks:
|
||||
completed_tasks.add(task_id)
|
||||
remaining_tasks.remove(task_id)
|
||||
|
||||
return execution_plan
|
||||
|
||||
async def _estimate_total_duration(self, sub_tasks: List[SubTask], execution_plan: List[List[str]]) -> float:
|
||||
"""Estimate total duration for task execution"""
|
||||
|
||||
total_duration = 0.0
|
||||
|
||||
for stage in execution_plan:
|
||||
# Find longest task in this stage (parallel execution)
|
||||
stage_duration = 0.0
|
||||
for task_id in stage:
|
||||
if task_id in self.sub_task_registry:
|
||||
stage_duration = max(stage_duration, self.sub_task_registry[task_id].requirements.estimated_duration)
|
||||
|
||||
total_duration += stage_duration
|
||||
|
||||
return total_duration
|
||||
|
||||
async def _estimate_total_cost(self, sub_tasks: List[SubTask]) -> float:
|
||||
"""Estimate total cost for task execution"""
|
||||
|
||||
total_cost = 0.0
|
||||
|
||||
for sub_task in sub_tasks:
|
||||
# Simple cost estimation based on GPU tier and duration
|
||||
gpu_performance = self.gpu_performance.get(sub_task.requirements.gpu_tier, 1.0)
|
||||
hourly_rate = 0.05 * gpu_performance # Base rate * performance multiplier
|
||||
task_cost = hourly_rate * sub_task.requirements.estimated_duration
|
||||
total_cost += task_cost
|
||||
|
||||
return total_cost
|
||||
|
||||
async def _calculate_decomposition_confidence(
|
||||
self,
|
||||
task_requirements: TaskRequirement,
|
||||
sub_tasks: List[SubTask],
|
||||
strategy: str
|
||||
) -> float:
|
||||
"""Calculate confidence in decomposition"""
|
||||
|
||||
# Base confidence from strategy
|
||||
strategy_confidence = {
|
||||
"sequential": 0.9,
|
||||
"parallel": 0.8,
|
||||
"hierarchical": 0.7,
|
||||
"pipeline": 0.8,
|
||||
"adaptive": 0.6
|
||||
}
|
||||
|
||||
confidence = strategy_confidence.get(strategy, 0.5)
|
||||
|
||||
# Adjust based on task complexity
|
||||
complexity = self.complexity_thresholds.get(task_requirements.task_type, 0.5)
|
||||
if complexity > 0.7:
|
||||
confidence *= 0.8 # Lower confidence for complex tasks
|
||||
|
||||
# Adjust based on number of sub-tasks
|
||||
if len(sub_tasks) > 8:
|
||||
confidence *= 0.9 # Slightly lower confidence for many sub-tasks
|
||||
|
||||
return max(0.3, min(0.95, confidence))
|
||||
|
||||
async def _get_aggregation_function(self, aggregation_type: str, output_format: str) -> str:
|
||||
"""Get aggregation function for combining results"""
|
||||
|
||||
# Map aggregation types to functions
|
||||
function_map = {
|
||||
"concat": "concatenate_results",
|
||||
"merge": "merge_results",
|
||||
"vote": "majority_vote",
|
||||
"average": "weighted_average",
|
||||
"sum": "sum_results",
|
||||
"max": "max_results",
|
||||
"min": "min_results"
|
||||
}
|
||||
|
||||
base_function = function_map.get(aggregation_type, "concatenate_results")
|
||||
|
||||
# Add format-specific suffix
|
||||
if output_format == "json":
|
||||
return f"{base_function}_json"
|
||||
elif output_format == "array":
|
||||
return f"{base_function}_array"
|
||||
else:
|
||||
return base_function
|
||||
@@ -0,0 +1,26 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
// Simple ML inference verification circuit
|
||||
// Basic test circuit to verify compilation
|
||||
|
||||
template SimpleInference() {
|
||||
signal input x; // input
|
||||
signal input w; // weight
|
||||
signal input b; // bias
|
||||
signal input expected; // expected output
|
||||
|
||||
signal output verified;
|
||||
|
||||
// Simple computation: output = x * w + b
|
||||
signal computed;
|
||||
computed <== x * w + b;
|
||||
|
||||
// Check if computed equals expected
|
||||
signal diff;
|
||||
diff <== computed - expected;
|
||||
|
||||
// Use a simple comparison (0 if equal, non-zero if different)
|
||||
verified <== 1 - (diff * diff); // Will be 1 if diff == 0, 0 otherwise
|
||||
}
|
||||
|
||||
component main = SimpleInference();
|
||||
Binary file not shown.
@@ -0,0 +1,7 @@
|
||||
1,1,0,main.verified
|
||||
2,2,0,main.x
|
||||
3,3,0,main.w
|
||||
4,4,0,main.b
|
||||
5,5,0,main.expected
|
||||
6,6,0,main.computed
|
||||
7,7,0,main.diff
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
const wc = require("./witness_calculator.js");
|
||||
const { readFileSync, writeFile } = require("fs");
|
||||
|
||||
if (process.argv.length != 5) {
|
||||
console.log("Usage: node generate_witness.js <file.wasm> <input.json> <output.wtns>");
|
||||
} else {
|
||||
const input = JSON.parse(readFileSync(process.argv[3], "utf8"));
|
||||
|
||||
const buffer = readFileSync(process.argv[2]);
|
||||
wc(buffer).then(async witnessCalculator => {
|
||||
/*
|
||||
const w= await witnessCalculator.calculateWitness(input,0);
|
||||
for (let i=0; i< w.length; i++){
|
||||
console.log(w[i]);
|
||||
}*/
|
||||
const buff= await witnessCalculator.calculateWTNSBin(input,0);
|
||||
writeFile(process.argv[4], buff, function(err) {
|
||||
if (err) throw err;
|
||||
});
|
||||
});
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,381 @@
|
||||
module.exports = async function builder(code, options) {
|
||||
|
||||
options = options || {};
|
||||
|
||||
let wasmModule;
|
||||
try {
|
||||
wasmModule = await WebAssembly.compile(code);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
console.log("\nTry to run circom --c in order to generate c++ code instead\n");
|
||||
throw new Error(err);
|
||||
}
|
||||
|
||||
let wc;
|
||||
|
||||
let errStr = "";
|
||||
let msgStr = "";
|
||||
|
||||
const instance = await WebAssembly.instantiate(wasmModule, {
|
||||
runtime: {
|
||||
exceptionHandler : function(code) {
|
||||
let err;
|
||||
if (code == 1) {
|
||||
err = "Signal not found.\n";
|
||||
} else if (code == 2) {
|
||||
err = "Too many signals set.\n";
|
||||
} else if (code == 3) {
|
||||
err = "Signal already set.\n";
|
||||
} else if (code == 4) {
|
||||
err = "Assert Failed.\n";
|
||||
} else if (code == 5) {
|
||||
err = "Not enough memory.\n";
|
||||
} else if (code == 6) {
|
||||
err = "Input signal array access exceeds the size.\n";
|
||||
} else {
|
||||
err = "Unknown error.\n";
|
||||
}
|
||||
throw new Error(err + errStr);
|
||||
},
|
||||
printErrorMessage : function() {
|
||||
errStr += getMessage() + "\n";
|
||||
// console.error(getMessage());
|
||||
},
|
||||
writeBufferMessage : function() {
|
||||
const msg = getMessage();
|
||||
// Any calls to `log()` will always end with a `\n`, so that's when we print and reset
|
||||
if (msg === "\n") {
|
||||
console.log(msgStr);
|
||||
msgStr = "";
|
||||
} else {
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the message to the message we are creating
|
||||
msgStr += msg;
|
||||
}
|
||||
},
|
||||
showSharedRWMemory : function() {
|
||||
printSharedRWMemory ();
|
||||
}
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
const sanityCheck =
|
||||
options
|
||||
// options &&
|
||||
// (
|
||||
// options.sanityCheck ||
|
||||
// options.logGetSignal ||
|
||||
// options.logSetSignal ||
|
||||
// options.logStartComponent ||
|
||||
// options.logFinishComponent
|
||||
// );
|
||||
|
||||
|
||||
wc = new WitnessCalculator(instance, sanityCheck);
|
||||
return wc;
|
||||
|
||||
function getMessage() {
|
||||
var message = "";
|
||||
var c = instance.exports.getMessageChar();
|
||||
while ( c != 0 ) {
|
||||
message += String.fromCharCode(c);
|
||||
c = instance.exports.getMessageChar();
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
function printSharedRWMemory () {
|
||||
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
|
||||
const arr = new Uint32Array(shared_rw_memory_size);
|
||||
for (let j=0; j<shared_rw_memory_size; j++) {
|
||||
arr[shared_rw_memory_size-1-j] = instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the value to the message we are creating
|
||||
msgStr += (fromArray32(arr).toString());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class WitnessCalculator {
|
||||
constructor(instance, sanityCheck) {
|
||||
this.instance = instance;
|
||||
|
||||
this.version = this.instance.exports.getVersion();
|
||||
this.n32 = this.instance.exports.getFieldNumLen32();
|
||||
|
||||
this.instance.exports.getRawPrime();
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let i=0; i<this.n32; i++) {
|
||||
arr[this.n32-1-i] = this.instance.exports.readSharedRWMemory(i);
|
||||
}
|
||||
this.prime = fromArray32(arr);
|
||||
|
||||
this.witnessSize = this.instance.exports.getWitnessSize();
|
||||
|
||||
this.sanityCheck = sanityCheck;
|
||||
}
|
||||
|
||||
circom_version() {
|
||||
return this.instance.exports.getVersion();
|
||||
}
|
||||
|
||||
async _doCalculateWitness(input_orig, sanityCheck) {
|
||||
//input is assumed to be a map from signals to arrays of bigints
|
||||
this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
|
||||
let prefix = "";
|
||||
var input = new Object();
|
||||
//console.log("Input: ", input_orig);
|
||||
qualify_input(prefix,input_orig,input);
|
||||
//console.log("Input after: ",input);
|
||||
const keys = Object.keys(input);
|
||||
var input_counter = 0;
|
||||
keys.forEach( (k) => {
|
||||
const h = fnvHash(k);
|
||||
const hMSB = parseInt(h.slice(0,8), 16);
|
||||
const hLSB = parseInt(h.slice(8,16), 16);
|
||||
const fArr = flatArray(input[k]);
|
||||
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
|
||||
if (signalSize < 0){
|
||||
throw new Error(`Signal ${k} not found\n`);
|
||||
}
|
||||
if (fArr.length < signalSize) {
|
||||
throw new Error(`Not enough values for input signal ${k}\n`);
|
||||
}
|
||||
if (fArr.length > signalSize) {
|
||||
throw new Error(`Too many values for input signal ${k}\n`);
|
||||
}
|
||||
for (let i=0; i<fArr.length; i++) {
|
||||
const arrFr = toArray32(normalize(fArr[i],this.prime),this.n32)
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
this.instance.exports.writeSharedRWMemory(j,arrFr[this.n32-1-j]);
|
||||
}
|
||||
try {
|
||||
this.instance.exports.setInputSignal(hMSB, hLSB,i);
|
||||
input_counter++;
|
||||
} catch (err) {
|
||||
// console.log(`After adding signal ${i} of ${k}`)
|
||||
throw new Error(err);
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
if (input_counter < this.instance.exports.getInputSize()) {
|
||||
throw new Error(`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`);
|
||||
}
|
||||
}
|
||||
|
||||
async calculateWitness(input, sanityCheck) {
|
||||
|
||||
const w = [];
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
arr[this.n32-1-j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
w.push(fromArray32(arr));
|
||||
}
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
|
||||
async calculateBinWitness(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const pos = i*this.n32;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
|
||||
async calculateWTNSBin(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32+this.n32+11);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
//"wtns"
|
||||
buff[0] = "w".charCodeAt(0)
|
||||
buff[1] = "t".charCodeAt(0)
|
||||
buff[2] = "n".charCodeAt(0)
|
||||
buff[3] = "s".charCodeAt(0)
|
||||
|
||||
//version 2
|
||||
buff32[1] = 2;
|
||||
|
||||
//number of sections: 2
|
||||
buff32[2] = 2;
|
||||
|
||||
//id section 1
|
||||
buff32[3] = 1;
|
||||
|
||||
const n8 = this.n32*4;
|
||||
//id section 1 length in 64bytes
|
||||
const idSection1length = 8 + n8;
|
||||
const idSection1lengthHex = idSection1length.toString(16);
|
||||
buff32[4] = parseInt(idSection1lengthHex.slice(0,8), 16);
|
||||
buff32[5] = parseInt(idSection1lengthHex.slice(8,16), 16);
|
||||
|
||||
//this.n32
|
||||
buff32[6] = n8;
|
||||
|
||||
//prime number
|
||||
this.instance.exports.getRawPrime();
|
||||
|
||||
var pos = 7;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
|
||||
// witness size
|
||||
buff32[pos] = this.witnessSize;
|
||||
pos++;
|
||||
|
||||
//id section 2
|
||||
buff32[pos] = 2;
|
||||
pos++;
|
||||
|
||||
// section 2 length
|
||||
const idSection2length = n8*this.witnessSize;
|
||||
const idSection2lengthHex = idSection2length.toString(16);
|
||||
buff32[pos] = parseInt(idSection2lengthHex.slice(0,8), 16);
|
||||
buff32[pos+1] = parseInt(idSection2lengthHex.slice(8,16), 16);
|
||||
|
||||
pos += 2;
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
function qualify_input_list(prefix,input,input1){
|
||||
if (Array.isArray(input)) {
|
||||
for (let i = 0; i<input.length; i++) {
|
||||
let new_prefix = prefix + "[" + i + "]";
|
||||
qualify_input_list(new_prefix,input[i],input1);
|
||||
}
|
||||
} else {
|
||||
qualify_input(prefix,input,input1);
|
||||
}
|
||||
}
|
||||
|
||||
function qualify_input(prefix,input,input1) {
|
||||
if (Array.isArray(input)) {
|
||||
a = flatArray(input);
|
||||
if (a.length > 0) {
|
||||
let t = typeof a[0];
|
||||
for (let i = 1; i<a.length; i++) {
|
||||
if (typeof a[i] != t){
|
||||
throw new Error(`Types are not the same in the key ${prefix}`);
|
||||
}
|
||||
}
|
||||
if (t == "object") {
|
||||
qualify_input_list(prefix,input,input1);
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else if (typeof input == "object") {
|
||||
const keys = Object.keys(input);
|
||||
keys.forEach( (k) => {
|
||||
let new_prefix = prefix == ""? k : prefix + "." + k;
|
||||
qualify_input(new_prefix,input[k],input1);
|
||||
});
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
}
|
||||
|
||||
function toArray32(rem,size) {
|
||||
const res = []; //new Uint32Array(size); //has no unshift
|
||||
const radix = BigInt(0x100000000);
|
||||
while (rem) {
|
||||
res.unshift( Number(rem % radix));
|
||||
rem = rem / radix;
|
||||
}
|
||||
if (size) {
|
||||
var i = size - res.length;
|
||||
while (i>0) {
|
||||
res.unshift(0);
|
||||
i--;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function fromArray32(arr) { //returns a BigInt
|
||||
var res = BigInt(0);
|
||||
const radix = BigInt(0x100000000);
|
||||
for (let i = 0; i<arr.length; i++) {
|
||||
res = res*radix + BigInt(arr[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function flatArray(a) {
|
||||
var res = [];
|
||||
fillArray(res, a);
|
||||
return res;
|
||||
|
||||
function fillArray(res, a) {
|
||||
if (Array.isArray(a)) {
|
||||
for (let i=0; i<a.length; i++) {
|
||||
fillArray(res, a[i]);
|
||||
}
|
||||
} else {
|
||||
res.push(a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function normalize(n, prime) {
|
||||
let res = BigInt(n) % prime
|
||||
if (res < 0) res += prime
|
||||
return res
|
||||
}
|
||||
|
||||
function fnvHash(str) {
|
||||
const uint64_max = BigInt(2) ** BigInt(64);
|
||||
let hash = BigInt("0xCBF29CE484222325");
|
||||
for (var i = 0; i < str.length; i++) {
|
||||
hash ^= BigInt(str[i].charCodeAt());
|
||||
hash *= BigInt(0x100000001B3);
|
||||
hash %= uint64_max;
|
||||
}
|
||||
let shash = hash.toString(16);
|
||||
let n = 16 - shash.length;
|
||||
shash = '0'.repeat(n).concat(shash);
|
||||
return shash;
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
include "node_modules/circomlib/circuits/poseidon.circom";
|
||||
|
||||
/*
|
||||
* Simplified ML Training Verification Circuit
|
||||
*
|
||||
* Basic proof of gradient descent training without complex hashing
|
||||
*/
|
||||
|
||||
template SimpleTrainingVerification(PARAM_COUNT, EPOCHS) {
|
||||
signal input initial_parameters[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output final_parameters[PARAM_COUNT];
|
||||
signal output training_complete;
|
||||
|
||||
// Input validation constraints
|
||||
// Learning rate should be positive and reasonable (0 < lr < 1)
|
||||
learning_rate * (1 - learning_rate) === learning_rate; // Ensures 0 < lr < 1
|
||||
|
||||
// Simulate simple training epochs
|
||||
signal current_parameters[EPOCHS + 1][PARAM_COUNT];
|
||||
|
||||
// Initialize with initial parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_parameters[0][i] <== initial_parameters[i];
|
||||
}
|
||||
|
||||
// Simple training: gradient descent simulation
|
||||
for (var e = 0; e < EPOCHS; e++) {
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
// Simplified gradient descent: param = param - learning_rate * gradient_constant
|
||||
// Using constant gradient of 0.1 for demonstration
|
||||
current_parameters[e + 1][i] <== current_parameters[e][i] - learning_rate * 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Output final parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
final_parameters[i] <== current_parameters[EPOCHS][i];
|
||||
}
|
||||
|
||||
// Training completion constraint
|
||||
training_complete <== 1;
|
||||
}
|
||||
|
||||
component main = SimpleTrainingVerification(4, 3);
|
||||
Binary file not shown.
@@ -0,0 +1,26 @@
|
||||
1,1,0,main.final_parameters[0]
|
||||
2,2,0,main.final_parameters[1]
|
||||
3,3,0,main.final_parameters[2]
|
||||
4,4,0,main.final_parameters[3]
|
||||
5,5,0,main.training_complete
|
||||
6,6,0,main.initial_parameters[0]
|
||||
7,7,0,main.initial_parameters[1]
|
||||
8,8,0,main.initial_parameters[2]
|
||||
9,9,0,main.initial_parameters[3]
|
||||
10,10,0,main.learning_rate
|
||||
11,-1,0,main.current_parameters[0][0]
|
||||
12,-1,0,main.current_parameters[0][1]
|
||||
13,-1,0,main.current_parameters[0][2]
|
||||
14,-1,0,main.current_parameters[0][3]
|
||||
15,11,0,main.current_parameters[1][0]
|
||||
16,12,0,main.current_parameters[1][1]
|
||||
17,13,0,main.current_parameters[1][2]
|
||||
18,14,0,main.current_parameters[1][3]
|
||||
19,15,0,main.current_parameters[2][0]
|
||||
20,16,0,main.current_parameters[2][1]
|
||||
21,17,0,main.current_parameters[2][2]
|
||||
22,18,0,main.current_parameters[2][3]
|
||||
23,-1,0,main.current_parameters[3][0]
|
||||
24,-1,0,main.current_parameters[3][1]
|
||||
25,-1,0,main.current_parameters[3][2]
|
||||
26,-1,0,main.current_parameters[3][3]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
const wc = require("./witness_calculator.js");
|
||||
const { readFileSync, writeFile } = require("fs");
|
||||
|
||||
if (process.argv.length != 5) {
|
||||
console.log("Usage: node generate_witness.js <file.wasm> <input.json> <output.wtns>");
|
||||
} else {
|
||||
const input = JSON.parse(readFileSync(process.argv[3], "utf8"));
|
||||
|
||||
const buffer = readFileSync(process.argv[2]);
|
||||
wc(buffer).then(async witnessCalculator => {
|
||||
/*
|
||||
const w= await witnessCalculator.calculateWitness(input,0);
|
||||
for (let i=0; i< w.length; i++){
|
||||
console.log(w[i]);
|
||||
}*/
|
||||
const buff= await witnessCalculator.calculateWTNSBin(input,0);
|
||||
writeFile(process.argv[4], buff, function(err) {
|
||||
if (err) throw err;
|
||||
});
|
||||
});
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,381 @@
|
||||
module.exports = async function builder(code, options) {
|
||||
|
||||
options = options || {};
|
||||
|
||||
let wasmModule;
|
||||
try {
|
||||
wasmModule = await WebAssembly.compile(code);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
console.log("\nTry to run circom --c in order to generate c++ code instead\n");
|
||||
throw new Error(err);
|
||||
}
|
||||
|
||||
let wc;
|
||||
|
||||
let errStr = "";
|
||||
let msgStr = "";
|
||||
|
||||
const instance = await WebAssembly.instantiate(wasmModule, {
|
||||
runtime: {
|
||||
exceptionHandler : function(code) {
|
||||
let err;
|
||||
if (code == 1) {
|
||||
err = "Signal not found.\n";
|
||||
} else if (code == 2) {
|
||||
err = "Too many signals set.\n";
|
||||
} else if (code == 3) {
|
||||
err = "Signal already set.\n";
|
||||
} else if (code == 4) {
|
||||
err = "Assert Failed.\n";
|
||||
} else if (code == 5) {
|
||||
err = "Not enough memory.\n";
|
||||
} else if (code == 6) {
|
||||
err = "Input signal array access exceeds the size.\n";
|
||||
} else {
|
||||
err = "Unknown error.\n";
|
||||
}
|
||||
throw new Error(err + errStr);
|
||||
},
|
||||
printErrorMessage : function() {
|
||||
errStr += getMessage() + "\n";
|
||||
// console.error(getMessage());
|
||||
},
|
||||
writeBufferMessage : function() {
|
||||
const msg = getMessage();
|
||||
// Any calls to `log()` will always end with a `\n`, so that's when we print and reset
|
||||
if (msg === "\n") {
|
||||
console.log(msgStr);
|
||||
msgStr = "";
|
||||
} else {
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the message to the message we are creating
|
||||
msgStr += msg;
|
||||
}
|
||||
},
|
||||
showSharedRWMemory : function() {
|
||||
printSharedRWMemory ();
|
||||
}
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
const sanityCheck =
|
||||
options
|
||||
// options &&
|
||||
// (
|
||||
// options.sanityCheck ||
|
||||
// options.logGetSignal ||
|
||||
// options.logSetSignal ||
|
||||
// options.logStartComponent ||
|
||||
// options.logFinishComponent
|
||||
// );
|
||||
|
||||
|
||||
wc = new WitnessCalculator(instance, sanityCheck);
|
||||
return wc;
|
||||
|
||||
function getMessage() {
|
||||
var message = "";
|
||||
var c = instance.exports.getMessageChar();
|
||||
while ( c != 0 ) {
|
||||
message += String.fromCharCode(c);
|
||||
c = instance.exports.getMessageChar();
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
function printSharedRWMemory () {
|
||||
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
|
||||
const arr = new Uint32Array(shared_rw_memory_size);
|
||||
for (let j=0; j<shared_rw_memory_size; j++) {
|
||||
arr[shared_rw_memory_size-1-j] = instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the value to the message we are creating
|
||||
msgStr += (fromArray32(arr).toString());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class WitnessCalculator {
|
||||
constructor(instance, sanityCheck) {
|
||||
this.instance = instance;
|
||||
|
||||
this.version = this.instance.exports.getVersion();
|
||||
this.n32 = this.instance.exports.getFieldNumLen32();
|
||||
|
||||
this.instance.exports.getRawPrime();
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let i=0; i<this.n32; i++) {
|
||||
arr[this.n32-1-i] = this.instance.exports.readSharedRWMemory(i);
|
||||
}
|
||||
this.prime = fromArray32(arr);
|
||||
|
||||
this.witnessSize = this.instance.exports.getWitnessSize();
|
||||
|
||||
this.sanityCheck = sanityCheck;
|
||||
}
|
||||
|
||||
circom_version() {
|
||||
return this.instance.exports.getVersion();
|
||||
}
|
||||
|
||||
async _doCalculateWitness(input_orig, sanityCheck) {
|
||||
//input is assumed to be a map from signals to arrays of bigints
|
||||
this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
|
||||
let prefix = "";
|
||||
var input = new Object();
|
||||
//console.log("Input: ", input_orig);
|
||||
qualify_input(prefix,input_orig,input);
|
||||
//console.log("Input after: ",input);
|
||||
const keys = Object.keys(input);
|
||||
var input_counter = 0;
|
||||
keys.forEach( (k) => {
|
||||
const h = fnvHash(k);
|
||||
const hMSB = parseInt(h.slice(0,8), 16);
|
||||
const hLSB = parseInt(h.slice(8,16), 16);
|
||||
const fArr = flatArray(input[k]);
|
||||
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
|
||||
if (signalSize < 0){
|
||||
throw new Error(`Signal ${k} not found\n`);
|
||||
}
|
||||
if (fArr.length < signalSize) {
|
||||
throw new Error(`Not enough values for input signal ${k}\n`);
|
||||
}
|
||||
if (fArr.length > signalSize) {
|
||||
throw new Error(`Too many values for input signal ${k}\n`);
|
||||
}
|
||||
for (let i=0; i<fArr.length; i++) {
|
||||
const arrFr = toArray32(normalize(fArr[i],this.prime),this.n32)
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
this.instance.exports.writeSharedRWMemory(j,arrFr[this.n32-1-j]);
|
||||
}
|
||||
try {
|
||||
this.instance.exports.setInputSignal(hMSB, hLSB,i);
|
||||
input_counter++;
|
||||
} catch (err) {
|
||||
// console.log(`After adding signal ${i} of ${k}`)
|
||||
throw new Error(err);
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
if (input_counter < this.instance.exports.getInputSize()) {
|
||||
throw new Error(`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`);
|
||||
}
|
||||
}
|
||||
|
||||
async calculateWitness(input, sanityCheck) {
|
||||
|
||||
const w = [];
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
arr[this.n32-1-j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
w.push(fromArray32(arr));
|
||||
}
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
|
||||
async calculateBinWitness(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const pos = i*this.n32;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
|
||||
async calculateWTNSBin(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32+this.n32+11);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
//"wtns"
|
||||
buff[0] = "w".charCodeAt(0)
|
||||
buff[1] = "t".charCodeAt(0)
|
||||
buff[2] = "n".charCodeAt(0)
|
||||
buff[3] = "s".charCodeAt(0)
|
||||
|
||||
//version 2
|
||||
buff32[1] = 2;
|
||||
|
||||
//number of sections: 2
|
||||
buff32[2] = 2;
|
||||
|
||||
//id section 1
|
||||
buff32[3] = 1;
|
||||
|
||||
const n8 = this.n32*4;
|
||||
//id section 1 length in 64bytes
|
||||
const idSection1length = 8 + n8;
|
||||
const idSection1lengthHex = idSection1length.toString(16);
|
||||
buff32[4] = parseInt(idSection1lengthHex.slice(0,8), 16);
|
||||
buff32[5] = parseInt(idSection1lengthHex.slice(8,16), 16);
|
||||
|
||||
//this.n32
|
||||
buff32[6] = n8;
|
||||
|
||||
//prime number
|
||||
this.instance.exports.getRawPrime();
|
||||
|
||||
var pos = 7;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
|
||||
// witness size
|
||||
buff32[pos] = this.witnessSize;
|
||||
pos++;
|
||||
|
||||
//id section 2
|
||||
buff32[pos] = 2;
|
||||
pos++;
|
||||
|
||||
// section 2 length
|
||||
const idSection2length = n8*this.witnessSize;
|
||||
const idSection2lengthHex = idSection2length.toString(16);
|
||||
buff32[pos] = parseInt(idSection2lengthHex.slice(0,8), 16);
|
||||
buff32[pos+1] = parseInt(idSection2lengthHex.slice(8,16), 16);
|
||||
|
||||
pos += 2;
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
function qualify_input_list(prefix,input,input1){
|
||||
if (Array.isArray(input)) {
|
||||
for (let i = 0; i<input.length; i++) {
|
||||
let new_prefix = prefix + "[" + i + "]";
|
||||
qualify_input_list(new_prefix,input[i],input1);
|
||||
}
|
||||
} else {
|
||||
qualify_input(prefix,input,input1);
|
||||
}
|
||||
}
|
||||
|
||||
function qualify_input(prefix,input,input1) {
|
||||
if (Array.isArray(input)) {
|
||||
a = flatArray(input);
|
||||
if (a.length > 0) {
|
||||
let t = typeof a[0];
|
||||
for (let i = 1; i<a.length; i++) {
|
||||
if (typeof a[i] != t){
|
||||
throw new Error(`Types are not the same in the key ${prefix}`);
|
||||
}
|
||||
}
|
||||
if (t == "object") {
|
||||
qualify_input_list(prefix,input,input1);
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else if (typeof input == "object") {
|
||||
const keys = Object.keys(input);
|
||||
keys.forEach( (k) => {
|
||||
let new_prefix = prefix == ""? k : prefix + "." + k;
|
||||
qualify_input(new_prefix,input[k],input1);
|
||||
});
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
}
|
||||
|
||||
function toArray32(rem,size) {
|
||||
const res = []; //new Uint32Array(size); //has no unshift
|
||||
const radix = BigInt(0x100000000);
|
||||
while (rem) {
|
||||
res.unshift( Number(rem % radix));
|
||||
rem = rem / radix;
|
||||
}
|
||||
if (size) {
|
||||
var i = size - res.length;
|
||||
while (i>0) {
|
||||
res.unshift(0);
|
||||
i--;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function fromArray32(arr) { //returns a BigInt
|
||||
var res = BigInt(0);
|
||||
const radix = BigInt(0x100000000);
|
||||
for (let i = 0; i<arr.length; i++) {
|
||||
res = res*radix + BigInt(arr[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function flatArray(a) {
|
||||
var res = [];
|
||||
fillArray(res, a);
|
||||
return res;
|
||||
|
||||
function fillArray(res, a) {
|
||||
if (Array.isArray(a)) {
|
||||
for (let i=0; i<a.length; i++) {
|
||||
fillArray(res, a[i]);
|
||||
}
|
||||
} else {
|
||||
res.push(a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function normalize(n, prime) {
|
||||
let res = BigInt(n) % prime
|
||||
if (res < 0) res += prime
|
||||
return res
|
||||
}
|
||||
|
||||
function fnvHash(str) {
|
||||
const uint64_max = BigInt(2) ** BigInt(64);
|
||||
let hash = BigInt("0xCBF29CE484222325");
|
||||
for (var i = 0; i < str.length; i++) {
|
||||
hash ^= BigInt(str[i].charCodeAt());
|
||||
hash *= BigInt(0x100000001B3);
|
||||
hash %= uint64_max;
|
||||
}
|
||||
let shash = hash.toString(16);
|
||||
let n = 16 - shash.length;
|
||||
shash = '0'.repeat(n).concat(shash);
|
||||
return shash;
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
/*
|
||||
* Modular ML Circuit Components
|
||||
*
|
||||
* Reusable components for machine learning circuits
|
||||
*/
|
||||
|
||||
// Basic parameter update component (gradient descent step)
|
||||
template ParameterUpdate() {
|
||||
signal input current_param;
|
||||
signal input gradient;
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_param;
|
||||
|
||||
// Simple gradient descent: new_param = current_param - learning_rate * gradient
|
||||
new_param <== current_param - learning_rate * gradient;
|
||||
}
|
||||
|
||||
// Vector parameter update component
|
||||
template VectorParameterUpdate(PARAM_COUNT) {
|
||||
signal input current_params[PARAM_COUNT];
|
||||
signal input gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output new_params[PARAM_COUNT];
|
||||
|
||||
component updates[PARAM_COUNT];
|
||||
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
updates[i] = ParameterUpdate();
|
||||
updates[i].current_param <== current_params[i];
|
||||
updates[i].gradient <== gradients[i];
|
||||
updates[i].learning_rate <== learning_rate;
|
||||
new_params[i] <== updates[i].new_param;
|
||||
}
|
||||
}
|
||||
|
||||
// Simple loss constraint component
|
||||
template LossConstraint() {
|
||||
signal input predicted_loss;
|
||||
signal input actual_loss;
|
||||
signal input tolerance;
|
||||
|
||||
// Constrain that |predicted_loss - actual_loss| <= tolerance
|
||||
signal diff;
|
||||
diff <== predicted_loss - actual_loss;
|
||||
|
||||
// Use absolute value constraint: diff^2 <= tolerance^2
|
||||
signal diff_squared;
|
||||
diff_squared <== diff * diff;
|
||||
|
||||
signal tolerance_squared;
|
||||
tolerance_squared <== tolerance * tolerance;
|
||||
|
||||
// This constraint ensures the loss is within tolerance
|
||||
diff_squared * (1 - diff_squared / tolerance_squared) === 0;
|
||||
}
|
||||
|
||||
// Learning rate validation component
|
||||
template LearningRateValidation() {
|
||||
signal input learning_rate;
|
||||
|
||||
// Removed constraint for optimization - learning rate validation handled externally
|
||||
// This reduces non-linear constraints from 1 to 0 for better proving performance
|
||||
}
|
||||
|
||||
// Training epoch component
|
||||
template TrainingEpoch(PARAM_COUNT) {
|
||||
signal input epoch_params[PARAM_COUNT];
|
||||
signal input epoch_gradients[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output next_epoch_params[PARAM_COUNT];
|
||||
|
||||
component param_update = VectorParameterUpdate(PARAM_COUNT);
|
||||
param_update.current_params <== epoch_params;
|
||||
param_update.gradients <== epoch_gradients;
|
||||
param_update.learning_rate <== learning_rate;
|
||||
next_epoch_params <== param_update.new_params;
|
||||
}
|
||||
|
||||
// Main modular training verification using components
|
||||
template ModularTrainingVerification(PARAM_COUNT, EPOCHS) {
|
||||
signal input initial_parameters[PARAM_COUNT];
|
||||
signal input learning_rate;
|
||||
|
||||
signal output final_parameters[PARAM_COUNT];
|
||||
signal output training_complete;
|
||||
|
||||
// Learning rate validation
|
||||
component lr_validator = LearningRateValidation();
|
||||
lr_validator.learning_rate <== learning_rate;
|
||||
|
||||
// Training epochs using modular components
|
||||
signal current_params[EPOCHS + 1][PARAM_COUNT];
|
||||
|
||||
// Initialize
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[0][i] <== initial_parameters[i];
|
||||
}
|
||||
|
||||
// Run training epochs
|
||||
component epochs[EPOCHS];
|
||||
for (var e = 0; e < EPOCHS; e++) {
|
||||
epochs[e] = TrainingEpoch(PARAM_COUNT);
|
||||
|
||||
// Input current parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_params[i] <== current_params[e][i];
|
||||
}
|
||||
|
||||
// Use constant gradients for simplicity (would be computed in real implementation)
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
epochs[e].epoch_gradients[i] <== 1; // Constant gradient
|
||||
}
|
||||
|
||||
epochs[e].learning_rate <== learning_rate;
|
||||
|
||||
// Store results
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
current_params[e + 1][i] <== epochs[e].next_epoch_params[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Output final parameters
|
||||
for (var i = 0; i < PARAM_COUNT; i++) {
|
||||
final_parameters[i] <== current_params[EPOCHS][i];
|
||||
}
|
||||
|
||||
training_complete <== 1;
|
||||
}
|
||||
|
||||
component main = ModularTrainingVerification(4, 3);
|
||||
BIN
apps/coordinator-api/src/app/zk-circuits/pot12_0000.ptau
Normal file
BIN
apps/coordinator-api/src/app/zk-circuits/pot12_0000.ptau
Normal file
Binary file not shown.
BIN
apps/coordinator-api/src/app/zk-circuits/pot12_0001.ptau
Normal file
BIN
apps/coordinator-api/src/app/zk-circuits/pot12_0001.ptau
Normal file
Binary file not shown.
BIN
apps/coordinator-api/src/app/zk-circuits/pot12_final.ptau
Normal file
BIN
apps/coordinator-api/src/app/zk-circuits/pot12_final.ptau
Normal file
Binary file not shown.
125
apps/coordinator-api/src/app/zk-circuits/receipt.circom
Normal file
125
apps/coordinator-api/src/app/zk-circuits/receipt.circom
Normal file
@@ -0,0 +1,125 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
include "node_modules/circomlib/circuits/bitify.circom";
|
||||
include "node_modules/circomlib/circuits/escalarmulfix.circom";
|
||||
include "node_modules/circomlib/circuits/comparators.circom";
|
||||
include "node_modules/circomlib/circuits/poseidon.circom";
|
||||
|
||||
/*
|
||||
* Receipt Attestation Circuit
|
||||
*
|
||||
* This circuit proves that a receipt is valid without revealing sensitive details.
|
||||
*
|
||||
* Public Inputs:
|
||||
* - receiptHash: Hash of the receipt (for public verification)
|
||||
* - settlementAmount: Amount to be settled (public)
|
||||
* - timestamp: Receipt timestamp (public)
|
||||
*
|
||||
* Private Inputs:
|
||||
* - receipt: The full receipt data (private)
|
||||
* - computationResult: Result of the computation (private)
|
||||
* - pricingRate: Pricing rate used (private)
|
||||
* - minerReward: Reward for miner (private)
|
||||
* - coordinatorFee: Fee for coordinator (private)
|
||||
*/
|
||||
|
||||
template ReceiptAttestation() {
|
||||
// Public signals
|
||||
signal input receiptHash;
|
||||
signal input settlementAmount;
|
||||
signal input timestamp;
|
||||
|
||||
// Private signals
|
||||
signal input receipt[8];
|
||||
signal input computationResult;
|
||||
signal input pricingRate;
|
||||
signal input minerReward;
|
||||
signal input coordinatorFee;
|
||||
|
||||
// Components
|
||||
component hasher = Poseidon(8);
|
||||
component amountChecker = GreaterEqThan(8);
|
||||
component feeCalculator = Add8(8);
|
||||
|
||||
// Hash the receipt to verify it matches the public hash
|
||||
for (var i = 0; i < 8; i++) {
|
||||
hasher.inputs[i] <== receipt[i];
|
||||
}
|
||||
|
||||
// Ensure the computed hash matches the public hash
|
||||
hasher.out === receiptHash;
|
||||
|
||||
// Verify settlement amount calculation
|
||||
// settlementAmount = minerReward + coordinatorFee
|
||||
feeCalculator.a[0] <== minerReward;
|
||||
feeCalculator.a[1] <== coordinatorFee;
|
||||
for (var i = 2; i < 8; i++) {
|
||||
feeCalculator.a[i] <== 0;
|
||||
}
|
||||
feeCalculator.out === settlementAmount;
|
||||
|
||||
// Ensure amounts are non-negative
|
||||
amountChecker.in[0] <== settlementAmount;
|
||||
amountChecker.in[1] <== 0;
|
||||
amountChecker.out === 1;
|
||||
|
||||
// Additional constraints can be added here:
|
||||
// - Timestamp validation
|
||||
// - Pricing rate bounds
|
||||
// - Computation result format
|
||||
}
|
||||
|
||||
/*
|
||||
* Simplified Receipt Hash Preimage Circuit
|
||||
*
|
||||
* This is a minimal circuit for initial testing that proves
|
||||
* knowledge of a receipt preimage without revealing it.
|
||||
*/
|
||||
template ReceiptHashPreimage() {
|
||||
// Public signal
|
||||
signal input hash;
|
||||
|
||||
// Private signals (receipt data)
|
||||
signal input data[4];
|
||||
|
||||
// Hash component
|
||||
component poseidon = Poseidon(4);
|
||||
|
||||
// Connect inputs
|
||||
for (var i = 0; i < 4; i++) {
|
||||
poseidon.inputs[i] <== data[i];
|
||||
}
|
||||
|
||||
// Constraint: computed hash must match public hash
|
||||
poseidon.out === hash;
|
||||
}
|
||||
|
||||
/*
|
||||
* ECDSA Signature Verification Component
|
||||
*
|
||||
* Verifies that a receipt was signed by the coordinator
|
||||
*/
|
||||
template ECDSAVerify() {
|
||||
// Public inputs
|
||||
signal input publicKey[2];
|
||||
signal input messageHash;
|
||||
signal input signature[2];
|
||||
|
||||
// Private inputs
|
||||
signal input r;
|
||||
signal input s;
|
||||
|
||||
// Note: Full ECDSA verification in circom is complex
|
||||
// This is a placeholder for the actual implementation
|
||||
// In practice, we'd use a more efficient approach like:
|
||||
// - EDDSA verification (simpler in circom)
|
||||
// - Or move signature verification off-chain
|
||||
|
||||
// Placeholder constraint
|
||||
signature[0] * signature[1] === r * s;
|
||||
}
|
||||
|
||||
/*
|
||||
* Main circuit for initial implementation
|
||||
*/
|
||||
component main = ReceiptHashPreimage();
|
||||
130
apps/coordinator-api/src/app/zk-circuits/receipt_simple.circom
Normal file
130
apps/coordinator-api/src/app/zk-circuits/receipt_simple.circom
Normal file
@@ -0,0 +1,130 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
include "node_modules/circomlib/circuits/bitify.circom";
|
||||
include "node_modules/circomlib/circuits/poseidon.circom";
|
||||
|
||||
/*
|
||||
* Simple Receipt Attestation Circuit
|
||||
*
|
||||
* This circuit proves that a receipt is valid without revealing sensitive details.
|
||||
*
|
||||
* Public Inputs:
|
||||
* - receiptHash: Hash of the receipt (for public verification)
|
||||
*
|
||||
* Private Inputs:
|
||||
* - receipt: The full receipt data (private)
|
||||
*/
|
||||
|
||||
template SimpleReceipt() {
|
||||
// Public signal
|
||||
signal input receiptHash;
|
||||
|
||||
// Private signals
|
||||
signal input receipt[4];
|
||||
|
||||
// Component for hashing
|
||||
component hasher = Poseidon(4);
|
||||
|
||||
// Connect private inputs to hasher
|
||||
for (var i = 0; i < 4; i++) {
|
||||
hasher.inputs[i] <== receipt[i];
|
||||
}
|
||||
|
||||
// Ensure the computed hash matches the public hash
|
||||
hasher.out === receiptHash;
|
||||
}
|
||||
|
||||
/*
|
||||
* Membership Proof Circuit
|
||||
*
|
||||
* Proves that a value is part of a set without revealing which one
|
||||
*/
|
||||
|
||||
template MembershipProof(n) {
|
||||
// Public signals
|
||||
signal input root;
|
||||
signal input nullifier;
|
||||
signal input pathIndices[n];
|
||||
|
||||
// Private signals
|
||||
signal input leaf;
|
||||
signal input pathElements[n];
|
||||
signal input salt;
|
||||
|
||||
// Component for hashing
|
||||
component hasher[n];
|
||||
|
||||
// Initialize hasher for the leaf
|
||||
hasher[0] = Poseidon(2);
|
||||
hasher[0].inputs[0] <== leaf;
|
||||
hasher[0].inputs[1] <== salt;
|
||||
|
||||
// Hash up the Merkle tree
|
||||
for (var i = 0; i < n - 1; i++) {
|
||||
hasher[i + 1] = Poseidon(2);
|
||||
|
||||
// Choose left or right based on path index
|
||||
hasher[i + 1].inputs[0] <== pathIndices[i] * pathElements[i] + (1 - pathIndices[i]) * hasher[i].out;
|
||||
hasher[i + 1].inputs[1] <== pathIndices[i] * hasher[i].out + (1 - pathIndices[i]) * pathElements[i];
|
||||
}
|
||||
|
||||
// Ensure final hash equals root
|
||||
hasher[n - 1].out === root;
|
||||
|
||||
// Compute nullifier as hash(leaf, salt)
|
||||
component nullifierHasher = Poseidon(2);
|
||||
nullifierHasher.inputs[0] <== leaf;
|
||||
nullifierHasher.inputs[1] <== salt;
|
||||
nullifierHasher.out === nullifier;
|
||||
}
|
||||
|
||||
/*
|
||||
* Bid Range Proof Circuit
|
||||
*
|
||||
* Proves that a bid is within a valid range without revealing the amount
|
||||
*/
|
||||
|
||||
template BidRangeProof() {
|
||||
// Public signals
|
||||
signal input commitment;
|
||||
signal input minAmount;
|
||||
signal input maxAmount;
|
||||
|
||||
// Private signals
|
||||
signal input bid;
|
||||
signal input salt;
|
||||
|
||||
// Component for hashing commitment
|
||||
component commitmentHasher = Poseidon(2);
|
||||
commitmentHasher.inputs[0] <== bid;
|
||||
commitmentHasher.inputs[1] <== salt;
|
||||
commitmentHasher.out === commitment;
|
||||
|
||||
// Components for range checking
|
||||
component minChecker = GreaterEqThan(8);
|
||||
component maxChecker = GreaterEqThan(8);
|
||||
|
||||
// Convert amounts to 8-bit representation
|
||||
component bidBits = Num2Bits(64);
|
||||
component minBits = Num2Bits(64);
|
||||
component maxBits = Num2Bits(64);
|
||||
|
||||
bidBits.in <== bid;
|
||||
minBits.in <== minAmount;
|
||||
maxBits.in <== maxAmount;
|
||||
|
||||
// Check bid >= minAmount
|
||||
for (var i = 0; i < 64; i++) {
|
||||
minChecker.in[i] <== bidBits.out[i] - minBits.out[i];
|
||||
}
|
||||
minChecker.out === 1;
|
||||
|
||||
// Check maxAmount >= bid
|
||||
for (var i = 0; i < 64; i++) {
|
||||
maxChecker.in[i] <== maxBits.out[i] - bidBits.out[i];
|
||||
}
|
||||
maxChecker.out === 1;
|
||||
}
|
||||
|
||||
// Main component instantiation
|
||||
component main = SimpleReceipt();
|
||||
BIN
apps/coordinator-api/src/app/zk-circuits/receipt_simple.r1cs
Normal file
BIN
apps/coordinator-api/src/app/zk-circuits/receipt_simple.r1cs
Normal file
Binary file not shown.
1172
apps/coordinator-api/src/app/zk-circuits/receipt_simple.sym
Normal file
1172
apps/coordinator-api/src/app/zk-circuits/receipt_simple.sym
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
const wc = require("./witness_calculator.js");
|
||||
const { readFileSync, writeFile } = require("fs");
|
||||
|
||||
if (process.argv.length != 5) {
|
||||
console.log("Usage: node generate_witness.js <file.wasm> <input.json> <output.wtns>");
|
||||
} else {
|
||||
const input = JSON.parse(readFileSync(process.argv[3], "utf8"));
|
||||
|
||||
const buffer = readFileSync(process.argv[2]);
|
||||
wc(buffer).then(async witnessCalculator => {
|
||||
/*
|
||||
const w= await witnessCalculator.calculateWitness(input,0);
|
||||
for (let i=0; i< w.length; i++){
|
||||
console.log(w[i]);
|
||||
}*/
|
||||
const buff= await witnessCalculator.calculateWTNSBin(input,0);
|
||||
writeFile(process.argv[4], buff, function(err) {
|
||||
if (err) throw err;
|
||||
});
|
||||
});
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,381 @@
|
||||
module.exports = async function builder(code, options) {
|
||||
|
||||
options = options || {};
|
||||
|
||||
let wasmModule;
|
||||
try {
|
||||
wasmModule = await WebAssembly.compile(code);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
console.log("\nTry to run circom --c in order to generate c++ code instead\n");
|
||||
throw new Error(err);
|
||||
}
|
||||
|
||||
let wc;
|
||||
|
||||
let errStr = "";
|
||||
let msgStr = "";
|
||||
|
||||
const instance = await WebAssembly.instantiate(wasmModule, {
|
||||
runtime: {
|
||||
exceptionHandler : function(code) {
|
||||
let err;
|
||||
if (code == 1) {
|
||||
err = "Signal not found.\n";
|
||||
} else if (code == 2) {
|
||||
err = "Too many signals set.\n";
|
||||
} else if (code == 3) {
|
||||
err = "Signal already set.\n";
|
||||
} else if (code == 4) {
|
||||
err = "Assert Failed.\n";
|
||||
} else if (code == 5) {
|
||||
err = "Not enough memory.\n";
|
||||
} else if (code == 6) {
|
||||
err = "Input signal array access exceeds the size.\n";
|
||||
} else {
|
||||
err = "Unknown error.\n";
|
||||
}
|
||||
throw new Error(err + errStr);
|
||||
},
|
||||
printErrorMessage : function() {
|
||||
errStr += getMessage() + "\n";
|
||||
// console.error(getMessage());
|
||||
},
|
||||
writeBufferMessage : function() {
|
||||
const msg = getMessage();
|
||||
// Any calls to `log()` will always end with a `\n`, so that's when we print and reset
|
||||
if (msg === "\n") {
|
||||
console.log(msgStr);
|
||||
msgStr = "";
|
||||
} else {
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the message to the message we are creating
|
||||
msgStr += msg;
|
||||
}
|
||||
},
|
||||
showSharedRWMemory : function() {
|
||||
printSharedRWMemory ();
|
||||
}
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
const sanityCheck =
|
||||
options
|
||||
// options &&
|
||||
// (
|
||||
// options.sanityCheck ||
|
||||
// options.logGetSignal ||
|
||||
// options.logSetSignal ||
|
||||
// options.logStartComponent ||
|
||||
// options.logFinishComponent
|
||||
// );
|
||||
|
||||
|
||||
wc = new WitnessCalculator(instance, sanityCheck);
|
||||
return wc;
|
||||
|
||||
function getMessage() {
|
||||
var message = "";
|
||||
var c = instance.exports.getMessageChar();
|
||||
while ( c != 0 ) {
|
||||
message += String.fromCharCode(c);
|
||||
c = instance.exports.getMessageChar();
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
function printSharedRWMemory () {
|
||||
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
|
||||
const arr = new Uint32Array(shared_rw_memory_size);
|
||||
for (let j=0; j<shared_rw_memory_size; j++) {
|
||||
arr[shared_rw_memory_size-1-j] = instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
|
||||
// If we've buffered other content, put a space in between the items
|
||||
if (msgStr !== "") {
|
||||
msgStr += " "
|
||||
}
|
||||
// Then append the value to the message we are creating
|
||||
msgStr += (fromArray32(arr).toString());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class WitnessCalculator {
|
||||
constructor(instance, sanityCheck) {
|
||||
this.instance = instance;
|
||||
|
||||
this.version = this.instance.exports.getVersion();
|
||||
this.n32 = this.instance.exports.getFieldNumLen32();
|
||||
|
||||
this.instance.exports.getRawPrime();
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let i=0; i<this.n32; i++) {
|
||||
arr[this.n32-1-i] = this.instance.exports.readSharedRWMemory(i);
|
||||
}
|
||||
this.prime = fromArray32(arr);
|
||||
|
||||
this.witnessSize = this.instance.exports.getWitnessSize();
|
||||
|
||||
this.sanityCheck = sanityCheck;
|
||||
}
|
||||
|
||||
circom_version() {
|
||||
return this.instance.exports.getVersion();
|
||||
}
|
||||
|
||||
async _doCalculateWitness(input_orig, sanityCheck) {
|
||||
//input is assumed to be a map from signals to arrays of bigints
|
||||
this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
|
||||
let prefix = "";
|
||||
var input = new Object();
|
||||
//console.log("Input: ", input_orig);
|
||||
qualify_input(prefix,input_orig,input);
|
||||
//console.log("Input after: ",input);
|
||||
const keys = Object.keys(input);
|
||||
var input_counter = 0;
|
||||
keys.forEach( (k) => {
|
||||
const h = fnvHash(k);
|
||||
const hMSB = parseInt(h.slice(0,8), 16);
|
||||
const hLSB = parseInt(h.slice(8,16), 16);
|
||||
const fArr = flatArray(input[k]);
|
||||
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
|
||||
if (signalSize < 0){
|
||||
throw new Error(`Signal ${k} not found\n`);
|
||||
}
|
||||
if (fArr.length < signalSize) {
|
||||
throw new Error(`Not enough values for input signal ${k}\n`);
|
||||
}
|
||||
if (fArr.length > signalSize) {
|
||||
throw new Error(`Too many values for input signal ${k}\n`);
|
||||
}
|
||||
for (let i=0; i<fArr.length; i++) {
|
||||
const arrFr = toArray32(normalize(fArr[i],this.prime),this.n32)
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
this.instance.exports.writeSharedRWMemory(j,arrFr[this.n32-1-j]);
|
||||
}
|
||||
try {
|
||||
this.instance.exports.setInputSignal(hMSB, hLSB,i);
|
||||
input_counter++;
|
||||
} catch (err) {
|
||||
// console.log(`After adding signal ${i} of ${k}`)
|
||||
throw new Error(err);
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
if (input_counter < this.instance.exports.getInputSize()) {
|
||||
throw new Error(`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`);
|
||||
}
|
||||
}
|
||||
|
||||
async calculateWitness(input, sanityCheck) {
|
||||
|
||||
const w = [];
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
arr[this.n32-1-j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
w.push(fromArray32(arr));
|
||||
}
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
|
||||
async calculateBinWitness(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const pos = i*this.n32;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
|
||||
async calculateWTNSBin(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32+this.n32+11);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
//"wtns"
|
||||
buff[0] = "w".charCodeAt(0)
|
||||
buff[1] = "t".charCodeAt(0)
|
||||
buff[2] = "n".charCodeAt(0)
|
||||
buff[3] = "s".charCodeAt(0)
|
||||
|
||||
//version 2
|
||||
buff32[1] = 2;
|
||||
|
||||
//number of sections: 2
|
||||
buff32[2] = 2;
|
||||
|
||||
//id section 1
|
||||
buff32[3] = 1;
|
||||
|
||||
const n8 = this.n32*4;
|
||||
//id section 1 length in 64bytes
|
||||
const idSection1length = 8 + n8;
|
||||
const idSection1lengthHex = idSection1length.toString(16);
|
||||
buff32[4] = parseInt(idSection1lengthHex.slice(0,8), 16);
|
||||
buff32[5] = parseInt(idSection1lengthHex.slice(8,16), 16);
|
||||
|
||||
//this.n32
|
||||
buff32[6] = n8;
|
||||
|
||||
//prime number
|
||||
this.instance.exports.getRawPrime();
|
||||
|
||||
var pos = 7;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
|
||||
// witness size
|
||||
buff32[pos] = this.witnessSize;
|
||||
pos++;
|
||||
|
||||
//id section 2
|
||||
buff32[pos] = 2;
|
||||
pos++;
|
||||
|
||||
// section 2 length
|
||||
const idSection2length = n8*this.witnessSize;
|
||||
const idSection2lengthHex = idSection2length.toString(16);
|
||||
buff32[pos] = parseInt(idSection2lengthHex.slice(0,8), 16);
|
||||
buff32[pos+1] = parseInt(idSection2lengthHex.slice(8,16), 16);
|
||||
|
||||
pos += 2;
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
function qualify_input_list(prefix,input,input1){
|
||||
if (Array.isArray(input)) {
|
||||
for (let i = 0; i<input.length; i++) {
|
||||
let new_prefix = prefix + "[" + i + "]";
|
||||
qualify_input_list(new_prefix,input[i],input1);
|
||||
}
|
||||
} else {
|
||||
qualify_input(prefix,input,input1);
|
||||
}
|
||||
}
|
||||
|
||||
function qualify_input(prefix,input,input1) {
|
||||
if (Array.isArray(input)) {
|
||||
a = flatArray(input);
|
||||
if (a.length > 0) {
|
||||
let t = typeof a[0];
|
||||
for (let i = 1; i<a.length; i++) {
|
||||
if (typeof a[i] != t){
|
||||
throw new Error(`Types are not the same in the key ${prefix}`);
|
||||
}
|
||||
}
|
||||
if (t == "object") {
|
||||
qualify_input_list(prefix,input,input1);
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
} else if (typeof input == "object") {
|
||||
const keys = Object.keys(input);
|
||||
keys.forEach( (k) => {
|
||||
let new_prefix = prefix == ""? k : prefix + "." + k;
|
||||
qualify_input(new_prefix,input[k],input1);
|
||||
});
|
||||
} else {
|
||||
input1[prefix] = input;
|
||||
}
|
||||
}
|
||||
|
||||
function toArray32(rem,size) {
|
||||
const res = []; //new Uint32Array(size); //has no unshift
|
||||
const radix = BigInt(0x100000000);
|
||||
while (rem) {
|
||||
res.unshift( Number(rem % radix));
|
||||
rem = rem / radix;
|
||||
}
|
||||
if (size) {
|
||||
var i = size - res.length;
|
||||
while (i>0) {
|
||||
res.unshift(0);
|
||||
i--;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function fromArray32(arr) { //returns a BigInt
|
||||
var res = BigInt(0);
|
||||
const radix = BigInt(0x100000000);
|
||||
for (let i = 0; i<arr.length; i++) {
|
||||
res = res*radix + BigInt(arr[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function flatArray(a) {
|
||||
var res = [];
|
||||
fillArray(res, a);
|
||||
return res;
|
||||
|
||||
function fillArray(res, a) {
|
||||
if (Array.isArray(a)) {
|
||||
for (let i=0; i<a.length; i++) {
|
||||
fillArray(res, a[i]);
|
||||
}
|
||||
} else {
|
||||
res.push(a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function normalize(n, prime) {
|
||||
let res = BigInt(n) % prime
|
||||
if (res < 0) res += prime
|
||||
return res
|
||||
}
|
||||
|
||||
function fnvHash(str) {
|
||||
const uint64_max = BigInt(2) ** BigInt(64);
|
||||
let hash = BigInt("0xCBF29CE484222325");
|
||||
for (var i = 0; i < str.length; i++) {
|
||||
hash ^= BigInt(str[i].charCodeAt());
|
||||
hash *= BigInt(0x100000001B3);
|
||||
hash %= uint64_max;
|
||||
}
|
||||
let shash = hash.toString(16);
|
||||
let n = 16 - shash.length;
|
||||
shash = '0'.repeat(n).concat(shash);
|
||||
return shash;
|
||||
}
|
||||
Reference in New Issue
Block a user