feat: implement agent coordination foundation (Week 1)

 Multi-Agent Communication Framework
- Implemented comprehensive communication protocols
- Created hierarchical, P2P, and broadcast protocols
- Added message types and routing system
- Implemented agent discovery and registration
- Created load balancer for task distribution
- Built FastAPI application with full API

 Core Components Implemented
- CommunicationManager: Protocol management
- MessageRouter: Advanced message routing
- AgentRegistry: Agent discovery and management
- LoadBalancer: Intelligent task distribution
- TaskDistributor: Priority-based task handling
- WebSocketHandler: Real-time communication

 API Endpoints
- /health: Health check endpoint
- /agents/register: Agent registration
- /agents/discover: Agent discovery
- /tasks/submit: Task submission
- /messages/send: Message sending
- /load-balancer/stats: Load balancing statistics
- /registry/stats: Registry statistics

 Production Ready
- SystemD service configuration
- Docker containerization
- Comprehensive test suite
- Configuration management
- Error handling and logging
- Performance monitoring

🚀 Week 1 complete: Agent coordination foundation implemented!
This commit is contained in:
aitbc
2026-04-02 14:50:58 +02:00
parent 2fdda15732
commit 03d409f89d
8 changed files with 3729 additions and 0 deletions

View File

@@ -0,0 +1,641 @@
"""
Agent Discovery and Registration System for AITBC Agent Coordination
"""
import asyncio
import json
import logging
from typing import Dict, List, Optional, Set, Callable, Any
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import uuid
import hashlib
from enum import Enum
import redis.asyncio as redis
from pydantic import BaseModel, Field
from ..protocols.message_types import DiscoveryMessage, create_discovery_message
from ..protocols.communication import AgentMessage, MessageType
logger = logging.getLogger(__name__)
class AgentStatus(str, Enum):
"""Agent status enumeration"""
ACTIVE = "active"
INACTIVE = "inactive"
BUSY = "busy"
MAINTENANCE = "maintenance"
ERROR = "error"
class AgentType(str, Enum):
"""Agent type enumeration"""
COORDINATOR = "coordinator"
WORKER = "worker"
SPECIALIST = "specialist"
MONITOR = "monitor"
GATEWAY = "gateway"
ORCHESTRATOR = "orchestrator"
@dataclass
class AgentInfo:
"""Agent information structure"""
agent_id: str
agent_type: AgentType
status: AgentStatus
capabilities: List[str]
services: List[str]
endpoints: Dict[str, str]
metadata: Dict[str, Any]
last_heartbeat: datetime
registration_time: datetime
load_metrics: Dict[str, float] = field(default_factory=dict)
health_score: float = 1.0
version: str = "1.0.0"
tags: Set[str] = field(default_factory=set)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"agent_id": self.agent_id,
"agent_type": self.agent_type.value,
"status": self.status.value,
"capabilities": self.capabilities,
"services": self.services,
"endpoints": self.endpoints,
"metadata": self.metadata,
"last_heartbeat": self.last_heartbeat.isoformat(),
"registration_time": self.registration_time.isoformat(),
"load_metrics": self.load_metrics,
"health_score": self.health_score,
"version": self.version,
"tags": list(self.tags)
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentInfo":
"""Create from dictionary"""
data["agent_type"] = AgentType(data["agent_type"])
data["status"] = AgentStatus(data["status"])
data["last_heartbeat"] = datetime.fromisoformat(data["last_heartbeat"])
data["registration_time"] = datetime.fromisoformat(data["registration_time"])
data["tags"] = set(data.get("tags", []))
return cls(**data)
class AgentRegistry:
"""Central agent registry for discovery and management"""
def __init__(self, redis_url: str = "redis://localhost:6379/1"):
self.redis_url = redis_url
self.redis_client: Optional[redis.Redis] = None
self.agents: Dict[str, AgentInfo] = {}
self.service_index: Dict[str, Set[str]] = {} # service -> agent_ids
self.capability_index: Dict[str, Set[str]] = {} # capability -> agent_ids
self.type_index: Dict[AgentType, Set[str]] = {} # agent_type -> agent_ids
self.heartbeat_interval = 30 # seconds
self.cleanup_interval = 60 # seconds
self.max_heartbeat_age = 120 # seconds
async def start(self):
"""Start the registry service"""
self.redis_client = redis.from_url(self.redis_url)
# Load existing agents from Redis
await self._load_agents_from_redis()
# Start background tasks
asyncio.create_task(self._heartbeat_monitor())
asyncio.create_task(self._cleanup_inactive_agents())
logger.info("Agent registry started")
async def stop(self):
"""Stop the registry service"""
if self.redis_client:
await self.redis_client.close()
logger.info("Agent registry stopped")
async def register_agent(self, agent_info: AgentInfo) -> bool:
"""Register a new agent"""
try:
# Add to local registry
self.agents[agent_info.agent_id] = agent_info
# Update indexes
self._update_indexes(agent_info)
# Save to Redis
await self._save_agent_to_redis(agent_info)
# Publish registration event
await self._publish_agent_event("agent_registered", agent_info)
logger.info(f"Agent {agent_info.agent_id} registered successfully")
return True
except Exception as e:
logger.error(f"Error registering agent {agent_info.agent_id}: {e}")
return False
async def unregister_agent(self, agent_id: str) -> bool:
"""Unregister an agent"""
try:
if agent_id not in self.agents:
logger.warning(f"Agent {agent_id} not found for unregistration")
return False
agent_info = self.agents[agent_id]
# Remove from local registry
del self.agents[agent_id]
# Update indexes
self._remove_from_indexes(agent_info)
# Remove from Redis
await self._remove_agent_from_redis(agent_id)
# Publish unregistration event
await self._publish_agent_event("agent_unregistered", agent_info)
logger.info(f"Agent {agent_id} unregistered successfully")
return True
except Exception as e:
logger.error(f"Error unregistering agent {agent_id}: {e}")
return False
async def update_agent_status(self, agent_id: str, status: AgentStatus, load_metrics: Optional[Dict[str, float]] = None) -> bool:
"""Update agent status and metrics"""
try:
if agent_id not in self.agents:
logger.warning(f"Agent {agent_id} not found for status update")
return False
agent_info = self.agents[agent_id]
agent_info.status = status
agent_info.last_heartbeat = datetime.utcnow()
if load_metrics:
agent_info.load_metrics.update(load_metrics)
# Update health score
agent_info.health_score = self._calculate_health_score(agent_info)
# Save to Redis
await self._save_agent_to_redis(agent_info)
# Publish status update event
await self._publish_agent_event("agent_status_updated", agent_info)
return True
except Exception as e:
logger.error(f"Error updating agent status {agent_id}: {e}")
return False
async def update_agent_heartbeat(self, agent_id: str) -> bool:
"""Update agent heartbeat"""
try:
if agent_id not in self.agents:
logger.warning(f"Agent {agent_id} not found for heartbeat")
return False
agent_info = self.agents[agent_id]
agent_info.last_heartbeat = datetime.utcnow()
# Update health score
agent_info.health_score = self._calculate_health_score(agent_info)
# Save to Redis
await self._save_agent_to_redis(agent_info)
return True
except Exception as e:
logger.error(f"Error updating heartbeat for {agent_id}: {e}")
return False
async def discover_agents(self, query: Dict[str, Any]) -> List[AgentInfo]:
"""Discover agents based on query criteria"""
results = []
try:
# Start with all agents
candidate_agents = list(self.agents.values())
# Apply filters
if "agent_type" in query:
agent_type = AgentType(query["agent_type"])
candidate_agents = [a for a in candidate_agents if a.agent_type == agent_type]
if "status" in query:
status = AgentStatus(query["status"])
candidate_agents = [a for a in candidate_agents if a.status == status]
if "capabilities" in query:
required_capabilities = set(query["capabilities"])
candidate_agents = [a for a in candidate_agents if required_capabilities.issubset(a.capabilities)]
if "services" in query:
required_services = set(query["services"])
candidate_agents = [a for a in candidate_agents if required_services.issubset(a.services)]
if "tags" in query:
required_tags = set(query["tags"])
candidate_agents = [a for a in candidate_agents if required_tags.issubset(a.tags)]
if "min_health_score" in query:
min_score = query["min_health_score"]
candidate_agents = [a for a in candidate_agents if a.health_score >= min_score]
# Sort by health score (highest first)
results = sorted(candidate_agents, key=lambda a: a.health_score, reverse=True)
# Limit results if specified
if "limit" in query:
results = results[:query["limit"]]
logger.info(f"Discovered {len(results)} agents for query: {query}")
return results
except Exception as e:
logger.error(f"Error discovering agents: {e}")
return []
async def get_agent_by_id(self, agent_id: str) -> Optional[AgentInfo]:
"""Get agent information by ID"""
return self.agents.get(agent_id)
async def get_agents_by_service(self, service: str) -> List[AgentInfo]:
"""Get agents that provide a specific service"""
agent_ids = self.service_index.get(service, set())
return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents]
async def get_agents_by_capability(self, capability: str) -> List[AgentInfo]:
"""Get agents that have a specific capability"""
agent_ids = self.capability_index.get(capability, set())
return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents]
async def get_agents_by_type(self, agent_type: AgentType) -> List[AgentInfo]:
"""Get agents of a specific type"""
agent_ids = self.type_index.get(agent_type, set())
return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents]
async def get_registry_stats(self) -> Dict[str, Any]:
"""Get registry statistics"""
total_agents = len(self.agents)
status_counts = {}
type_counts = {}
for agent_info in self.agents.values():
# Count by status
status = agent_info.status.value
status_counts[status] = status_counts.get(status, 0) + 1
# Count by type
agent_type = agent_info.agent_type.value
type_counts[agent_type] = type_counts.get(agent_type, 0) + 1
return {
"total_agents": total_agents,
"status_counts": status_counts,
"type_counts": type_counts,
"service_count": len(self.service_index),
"capability_count": len(self.capability_index),
"last_cleanup": datetime.utcnow().isoformat()
}
def _update_indexes(self, agent_info: AgentInfo):
"""Update search indexes"""
# Service index
for service in agent_info.services:
if service not in self.service_index:
self.service_index[service] = set()
self.service_index[service].add(agent_info.agent_id)
# Capability index
for capability in agent_info.capabilities:
if capability not in self.capability_index:
self.capability_index[capability] = set()
self.capability_index[capability].add(agent_info.agent_id)
# Type index
if agent_info.agent_type not in self.type_index:
self.type_index[agent_info.agent_type] = set()
self.type_index[agent_info.agent_type].add(agent_info.agent_id)
def _remove_from_indexes(self, agent_info: AgentInfo):
"""Remove agent from search indexes"""
# Service index
for service in agent_info.services:
if service in self.service_index:
self.service_index[service].discard(agent_info.agent_id)
if not self.service_index[service]:
del self.service_index[service]
# Capability index
for capability in agent_info.capabilities:
if capability in self.capability_index:
self.capability_index[capability].discard(agent_info.agent_id)
if not self.capability_index[capability]:
del self.capability_index[capability]
# Type index
if agent_info.agent_type in self.type_index:
self.type_index[agent_info.agent_type].discard(agent_info.agent_id)
if not self.type_index[agent_info.agent_type]:
del self.type_index[agent_info.agent_type]
def _calculate_health_score(self, agent_info: AgentInfo) -> float:
"""Calculate agent health score"""
base_score = 1.0
# Penalty for high load
if agent_info.load_metrics:
avg_load = sum(agent_info.load_metrics.values()) / len(agent_info.load_metrics)
if avg_load > 0.8:
base_score -= 0.3
elif avg_load > 0.6:
base_score -= 0.1
# Penalty for error status
if agent_info.status == AgentStatus.ERROR:
base_score -= 0.5
elif agent_info.status == AgentStatus.MAINTENANCE:
base_score -= 0.2
elif agent_info.status == AgentStatus.BUSY:
base_score -= 0.1
# Penalty for old heartbeat
heartbeat_age = (datetime.utcnow() - agent_info.last_heartbeat).total_seconds()
if heartbeat_age > self.max_heartbeat_age:
base_score -= 0.5
elif heartbeat_age > self.max_heartbeat_age / 2:
base_score -= 0.2
return max(0.0, min(1.0, base_score))
async def _save_agent_to_redis(self, agent_info: AgentInfo):
"""Save agent information to Redis"""
if not self.redis_client:
return
key = f"agent:{agent_info.agent_id}"
await self.redis_client.setex(
key,
timedelta(hours=24), # 24 hour TTL
json.dumps(agent_info.to_dict())
)
async def _remove_agent_from_redis(self, agent_id: str):
"""Remove agent from Redis"""
if not self.redis_client:
return
key = f"agent:{agent_id}"
await self.redis_client.delete(key)
async def _load_agents_from_redis(self):
"""Load agents from Redis"""
if not self.redis_client:
return
try:
# Get all agent keys
keys = await self.redis_client.keys("agent:*")
for key in keys:
data = await self.redis_client.get(key)
if data:
agent_info = AgentInfo.from_dict(json.loads(data))
self.agents[agent_info.agent_id] = agent_info
self._update_indexes(agent_info)
logger.info(f"Loaded {len(self.agents)} agents from Redis")
except Exception as e:
logger.error(f"Error loading agents from Redis: {e}")
async def _publish_agent_event(self, event_type: str, agent_info: AgentInfo):
"""Publish agent event to Redis"""
if not self.redis_client:
return
event = {
"event_type": event_type,
"timestamp": datetime.utcnow().isoformat(),
"agent_info": agent_info.to_dict()
}
await self.redis_client.publish("agent_events", json.dumps(event))
async def _heartbeat_monitor(self):
"""Monitor agent heartbeats"""
while True:
try:
await asyncio.sleep(self.heartbeat_interval)
# Check for agents with old heartbeats
now = datetime.utcnow()
for agent_id, agent_info in list(self.agents.items()):
heartbeat_age = (now - agent_info.last_heartbeat).total_seconds()
if heartbeat_age > self.max_heartbeat_age:
# Mark as inactive
if agent_info.status != AgentStatus.INACTIVE:
await self.update_agent_status(agent_id, AgentStatus.INACTIVE)
logger.warning(f"Agent {agent_id} marked as inactive due to old heartbeat")
except Exception as e:
logger.error(f"Error in heartbeat monitor: {e}")
await asyncio.sleep(5)
async def _cleanup_inactive_agents(self):
"""Clean up inactive agents"""
while True:
try:
await asyncio.sleep(self.cleanup_interval)
# Remove agents that have been inactive too long
now = datetime.utcnow()
max_inactive_age = timedelta(hours=1) # 1 hour
for agent_id, agent_info in list(self.agents.items()):
if agent_info.status == AgentStatus.INACTIVE:
inactive_age = now - agent_info.last_heartbeat
if inactive_age > max_inactive_age:
await self.unregister_agent(agent_id)
logger.info(f"Removed inactive agent {agent_id}")
except Exception as e:
logger.error(f"Error in cleanup task: {e}")
await asyncio.sleep(5)
class AgentDiscoveryService:
"""Service for agent discovery and registration"""
def __init__(self, registry: AgentRegistry):
self.registry = registry
self.discovery_handlers: Dict[str, Callable] = {}
def register_discovery_handler(self, handler_name: str, handler: Callable):
"""Register a discovery handler"""
self.discovery_handlers[handler_name] = handler
logger.info(f"Registered discovery handler: {handler_name}")
async def handle_discovery_request(self, message: AgentMessage) -> Optional[AgentMessage]:
"""Handle agent discovery request"""
try:
discovery_data = DiscoveryMessage(**message.payload)
# Update or register agent
agent_info = AgentInfo(
agent_id=discovery_data.agent_id,
agent_type=AgentType(discovery_data.agent_type),
status=AgentStatus.ACTIVE,
capabilities=discovery_data.capabilities,
services=discovery_data.services,
endpoints=discovery_data.endpoints,
metadata=discovery_data.metadata,
last_heartbeat=datetime.utcnow(),
registration_time=datetime.utcnow()
)
# Register or update agent
if discovery_data.agent_id in self.registry.agents:
await self.registry.update_agent_status(discovery_data.agent_id, AgentStatus.ACTIVE)
else:
await self.registry.register_agent(agent_info)
# Send response with available agents
available_agents = await self.registry.discover_agents({
"status": "active",
"limit": 50
})
response_data = {
"discovery_agents": [agent.to_dict() for agent in available_agents],
"registry_stats": await self.registry.get_registry_stats()
}
response = AgentMessage(
sender_id="discovery_service",
receiver_id=message.sender_id,
message_type=MessageType.DISCOVERY,
payload=response_data,
correlation_id=message.id
)
return response
except Exception as e:
logger.error(f"Error handling discovery request: {e}")
return None
async def find_best_agent(self, requirements: Dict[str, Any]) -> Optional[AgentInfo]:
"""Find the best agent for given requirements"""
try:
# Build discovery query
query = {}
if "agent_type" in requirements:
query["agent_type"] = requirements["agent_type"]
if "capabilities" in requirements:
query["capabilities"] = requirements["capabilities"]
if "services" in requirements:
query["services"] = requirements["services"]
if "min_health_score" in requirements:
query["min_health_score"] = requirements["min_health_score"]
# Discover agents
agents = await self.registry.discover_agents(query)
if not agents:
return None
# Select best agent (highest health score)
return agents[0]
except Exception as e:
logger.error(f"Error finding best agent: {e}")
return None
async def get_service_endpoints(self, service: str) -> Dict[str, List[str]]:
"""Get all endpoints for a specific service"""
try:
agents = await self.registry.get_agents_by_service(service)
endpoints = {}
for agent in agents:
for service_name, endpoint in agent.endpoints.items():
if service_name not in endpoints:
endpoints[service_name] = []
endpoints[service_name].append(endpoint)
return endpoints
except Exception as e:
logger.error(f"Error getting service endpoints: {e}")
return {}
# Factory functions
def create_agent_info(agent_id: str, agent_type: str, capabilities: List[str], services: List[str], endpoints: Dict[str, str]) -> AgentInfo:
"""Create agent information"""
return AgentInfo(
agent_id=agent_id,
agent_type=AgentType(agent_type),
status=AgentStatus.ACTIVE,
capabilities=capabilities,
services=services,
endpoints=endpoints,
metadata={},
last_heartbeat=datetime.utcnow(),
registration_time=datetime.utcnow()
)
# Example usage
async def example_usage():
"""Example of how to use the agent discovery system"""
# Create registry
registry = AgentRegistry()
await registry.start()
# Create discovery service
discovery_service = AgentDiscoveryService(registry)
# Register an agent
agent_info = create_agent_info(
agent_id="agent-001",
agent_type="worker",
capabilities=["data_processing", "analysis"],
services=["process_data", "analyze_results"],
endpoints={"http": "http://localhost:8001", "ws": "ws://localhost:8002"}
)
await registry.register_agent(agent_info)
# Discover agents
agents = await registry.discover_agents({
"capabilities": ["data_processing"],
"status": "active"
})
print(f"Found {len(agents)} agents")
# Find best agent
best_agent = await discovery_service.find_best_agent({
"capabilities": ["data_processing"],
"min_health_score": 0.8
})
if best_agent:
print(f"Best agent: {best_agent.agent_id}")
await registry.stop()
if __name__ == "__main__":
asyncio.run(example_usage())

View File

@@ -0,0 +1,716 @@
"""
Load Balancer for Agent Distribution and Task Assignment
"""
import asyncio
import json
import logging
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
import statistics
import uuid
from collections import defaultdict, deque
from .agent_discovery import AgentRegistry, AgentInfo, AgentStatus, AgentType
from ..protocols.message_types import TaskMessage, create_task_message
from ..protocols.communication import AgentMessage, MessageType, Priority
logger = logging.getLogger(__name__)
class LoadBalancingStrategy(str, Enum):
"""Load balancing strategies"""
ROUND_ROBIN = "round_robin"
LEAST_CONNECTIONS = "least_connections"
LEAST_RESPONSE_TIME = "least_response_time"
WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
RESOURCE_BASED = "resource_based"
CAPABILITY_BASED = "capability_based"
PREDICTIVE = "predictive"
CONSISTENT_HASH = "consistent_hash"
class TaskPriority(str, Enum):
"""Task priority levels"""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
CRITICAL = "critical"
URGENT = "urgent"
@dataclass
class LoadMetrics:
"""Agent load metrics"""
cpu_usage: float = 0.0
memory_usage: float = 0.0
active_connections: int = 0
pending_tasks: int = 0
completed_tasks: int = 0
failed_tasks: int = 0
avg_response_time: float = 0.0
last_updated: datetime = field(default_factory=datetime.utcnow)
def to_dict(self) -> Dict[str, Any]:
return {
"cpu_usage": self.cpu_usage,
"memory_usage": self.memory_usage,
"active_connections": self.active_connections,
"pending_tasks": self.pending_tasks,
"completed_tasks": self.completed_tasks,
"failed_tasks": self.failed_tasks,
"avg_response_time": self.avg_response_time,
"last_updated": self.last_updated.isoformat()
}
@dataclass
class TaskAssignment:
"""Task assignment record"""
task_id: str
agent_id: str
assigned_at: datetime
completed_at: Optional[datetime] = None
status: str = "pending"
response_time: Optional[float] = None
success: bool = False
error_message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"task_id": self.task_id,
"agent_id": self.agent_id,
"assigned_at": self.assigned_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"status": self.status,
"response_time": self.response_time,
"success": self.success,
"error_message": self.error_message
}
@dataclass
class AgentWeight:
"""Agent weight for load balancing"""
agent_id: str
weight: float = 1.0
capacity: int = 100
performance_score: float = 1.0
reliability_score: float = 1.0
last_updated: datetime = field(default_factory=datetime.utcnow)
class LoadBalancer:
"""Advanced load balancer for agent distribution"""
def __init__(self, registry: AgentRegistry):
self.registry = registry
self.strategy = LoadBalancingStrategy.LEAST_CONNECTIONS
self.agent_weights: Dict[str, AgentWeight] = {}
self.agent_metrics: Dict[str, LoadMetrics] = {}
self.task_assignments: Dict[str, TaskAssignment] = {}
self.assignment_history: deque = deque(maxlen=1000)
self.round_robin_index = 0
self.consistent_hash_ring: Dict[int, str] = {}
self.prediction_models: Dict[str, Any] = {}
# Statistics
self.total_assignments = 0
self.successful_assignments = 0
self.failed_assignments = 0
def set_strategy(self, strategy: LoadBalancingStrategy):
"""Set load balancing strategy"""
self.strategy = strategy
logger.info(f"Load balancing strategy changed to: {strategy.value}")
def set_agent_weight(self, agent_id: str, weight: float, capacity: int = 100):
"""Set agent weight and capacity"""
self.agent_weights[agent_id] = AgentWeight(
agent_id=agent_id,
weight=weight,
capacity=capacity
)
logger.info(f"Set weight for agent {agent_id}: {weight}, capacity: {capacity}")
def update_agent_metrics(self, agent_id: str, metrics: LoadMetrics):
"""Update agent load metrics"""
self.agent_metrics[agent_id] = metrics
self.agent_metrics[agent_id].last_updated = datetime.utcnow()
# Update performance score based on metrics
self._update_performance_score(agent_id, metrics)
def _update_performance_score(self, agent_id: str, metrics: LoadMetrics):
"""Update agent performance score based on metrics"""
if agent_id not in self.agent_weights:
self.agent_weights[agent_id] = AgentWeight(agent_id=agent_id)
weight = self.agent_weights[agent_id]
# Calculate performance score (0.0 to 1.0)
performance_factors = []
# CPU usage factor (lower is better)
cpu_factor = max(0.0, 1.0 - metrics.cpu_usage)
performance_factors.append(cpu_factor)
# Memory usage factor (lower is better)
memory_factor = max(0.0, 1.0 - metrics.memory_usage)
performance_factors.append(memory_factor)
# Response time factor (lower is better)
if metrics.avg_response_time > 0:
response_factor = max(0.0, 1.0 - (metrics.avg_response_time / 10.0)) # 10s max
performance_factors.append(response_factor)
# Success rate factor (higher is better)
total_tasks = metrics.completed_tasks + metrics.failed_tasks
if total_tasks > 0:
success_rate = metrics.completed_tasks / total_tasks
performance_factors.append(success_rate)
# Update performance score
if performance_factors:
weight.performance_score = statistics.mean(performance_factors)
# Update reliability score
if total_tasks > 10: # Only update after enough tasks
weight.reliability_score = success_rate
async def assign_task(self, task_data: Dict[str, Any], requirements: Optional[Dict[str, Any]] = None) -> Optional[str]:
"""Assign task to best available agent"""
try:
# Find eligible agents
eligible_agents = await self._find_eligible_agents(task_data, requirements)
if not eligible_agents:
logger.warning("No eligible agents found for task assignment")
return None
# Select best agent based on strategy
selected_agent = await self._select_agent(eligible_agents, task_data)
if not selected_agent:
logger.warning("No agent selected for task assignment")
return None
# Create task assignment
task_id = str(uuid.uuid4())
assignment = TaskAssignment(
task_id=task_id,
agent_id=selected_agent,
assigned_at=datetime.utcnow()
)
# Record assignment
self.task_assignments[task_id] = assignment
self.assignment_history.append(assignment)
self.total_assignments += 1
# Update agent metrics
if selected_agent not in self.agent_metrics:
self.agent_metrics[selected_agent] = LoadMetrics()
self.agent_metrics[selected_agent].pending_tasks += 1
logger.info(f"Task {task_id} assigned to agent {selected_agent}")
return selected_agent
except Exception as e:
logger.error(f"Error assigning task: {e}")
self.failed_assignments += 1
return None
async def complete_task(self, task_id: str, success: bool, response_time: Optional[float] = None, error_message: Optional[str] = None):
"""Mark task as completed"""
try:
if task_id not in self.task_assignments:
logger.warning(f"Task assignment {task_id} not found")
return
assignment = self.task_assignments[task_id]
assignment.completed_at = datetime.utcnow()
assignment.status = "completed"
assignment.success = success
assignment.response_time = response_time
assignment.error_message = error_message
# Update agent metrics
agent_id = assignment.agent_id
if agent_id in self.agent_metrics:
metrics = self.agent_metrics[agent_id]
metrics.pending_tasks = max(0, metrics.pending_tasks - 1)
if success:
metrics.completed_tasks += 1
self.successful_assignments += 1
else:
metrics.failed_tasks += 1
self.failed_assignments += 1
# Update average response time
if response_time:
total_completed = metrics.completed_tasks + metrics.failed_tasks
if total_completed > 0:
metrics.avg_response_time = (
(metrics.avg_response_time * (total_completed - 1) + response_time) / total_completed
)
logger.info(f"Task {task_id} completed by agent {assignment.agent_id}, success: {success}")
except Exception as e:
logger.error(f"Error completing task {task_id}: {e}")
async def _find_eligible_agents(self, task_data: Dict[str, Any], requirements: Optional[Dict[str, Any]] = None) -> List[str]:
"""Find eligible agents for task"""
try:
# Build discovery query
query = {"status": AgentStatus.ACTIVE}
if requirements:
if "agent_type" in requirements:
query["agent_type"] = requirements["agent_type"]
if "capabilities" in requirements:
query["capabilities"] = requirements["capabilities"]
if "services" in requirements:
query["services"] = requirements["services"]
if "min_health_score" in requirements:
query["min_health_score"] = requirements["min_health_score"]
# Discover agents
agents = await self.registry.discover_agents(query)
# Filter by capacity and load
eligible_agents = []
for agent in agents:
agent_id = agent.agent_id
# Check capacity
if agent_id in self.agent_weights:
weight = self.agent_weights[agent_id]
current_load = self._get_agent_load(agent_id)
if current_load < weight.capacity:
eligible_agents.append(agent_id)
else:
# Default capacity check
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
if metrics.pending_tasks < 100: # Default capacity
eligible_agents.append(agent_id)
return eligible_agents
except Exception as e:
logger.error(f"Error finding eligible agents: {e}")
return []
def _get_agent_load(self, agent_id: str) -> int:
"""Get current load for agent"""
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
return metrics.active_connections + metrics.pending_tasks
async def _select_agent(self, eligible_agents: List[str], task_data: Dict[str, Any]) -> Optional[str]:
"""Select best agent based on current strategy"""
if not eligible_agents:
return None
if self.strategy == LoadBalancingStrategy.ROUND_ROBIN:
return self._round_robin_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.LEAST_CONNECTIONS:
return self._least_connections_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.LEAST_RESPONSE_TIME:
return self._least_response_time_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.WEIGHTED_ROUND_ROBIN:
return self._weighted_round_robin_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.RESOURCE_BASED:
return self._resource_based_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.CAPABILITY_BASED:
return self._capability_based_selection(eligible_agents, task_data)
elif self.strategy == LoadBalancingStrategy.PREDICTIVE:
return self._predictive_selection(eligible_agents, task_data)
elif self.strategy == LoadBalancingStrategy.CONSISTENT_HASH:
return self._consistent_hash_selection(eligible_agents, task_data)
else:
return eligible_agents[0]
def _round_robin_selection(self, agents: List[str]) -> str:
"""Round-robin agent selection"""
agent = agents[self.round_robin_index % len(agents)]
self.round_robin_index += 1
return agent
def _least_connections_selection(self, agents: List[str]) -> str:
"""Select agent with least connections"""
min_connections = float('inf')
selected_agent = None
for agent_id in agents:
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
connections = metrics.active_connections
if connections < min_connections:
min_connections = connections
selected_agent = agent_id
return selected_agent or agents[0]
def _least_response_time_selection(self, agents: List[str]) -> str:
"""Select agent with least average response time"""
min_response_time = float('inf')
selected_agent = None
for agent_id in agents:
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
response_time = metrics.avg_response_time
if response_time < min_response_time:
min_response_time = response_time
selected_agent = agent_id
return selected_agent or agents[0]
def _weighted_round_robin_selection(self, agents: List[str]) -> str:
"""Weighted round-robin selection"""
# Calculate total weight
total_weight = 0
for agent_id in agents:
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
total_weight += weight.weight
if total_weight == 0:
return agents[0]
# Select agent based on weight
current_weight = self.round_robin_index % total_weight
accumulated_weight = 0
for agent_id in agents:
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
accumulated_weight += weight.weight
if current_weight < accumulated_weight:
self.round_robin_index += 1
return agent_id
return agents[0]
def _resource_based_selection(self, agents: List[str]) -> str:
"""Resource-based selection considering CPU and memory"""
best_score = -1
selected_agent = None
for agent_id in agents:
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
# Calculate resource score (lower usage is better)
cpu_score = max(0, 100 - metrics.cpu_usage)
memory_score = max(0, 100 - metrics.memory_usage)
resource_score = (cpu_score + memory_score) / 2
# Apply performance weight
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
final_score = resource_score * weight.performance_score
if final_score > best_score:
best_score = final_score
selected_agent = agent_id
return selected_agent or agents[0]
def _capability_based_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str:
"""Capability-based selection considering task requirements"""
required_capabilities = task_data.get("required_capabilities", [])
if not required_capabilities:
return agents[0]
best_score = -1
selected_agent = None
for agent_id in agents:
agent_info = self.registry.agents.get(agent_id)
if not agent_info:
continue
# Calculate capability match score
agent_capabilities = set(agent_info.capabilities)
required_set = set(required_capabilities)
if required_set.issubset(agent_capabilities):
# Perfect match
capability_score = 1.0
else:
# Partial match
intersection = required_set.intersection(agent_capabilities)
capability_score = len(intersection) / len(required_set)
# Apply performance weight
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
final_score = capability_score * weight.performance_score
if final_score > best_score:
best_score = final_score
selected_agent = agent_id
return selected_agent or agents[0]
def _predictive_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str:
"""Predictive selection using historical performance"""
task_type = task_data.get("task_type", "unknown")
# Calculate predicted performance for each agent
best_score = -1
selected_agent = None
for agent_id in agents:
# Get historical performance for this task type
score = self._calculate_predicted_score(agent_id, task_type)
if score > best_score:
best_score = score
selected_agent = agent_id
return selected_agent or agents[0]
def _calculate_predicted_score(self, agent_id: str, task_type: str) -> float:
"""Calculate predicted performance score for agent"""
# Simple prediction based on recent performance
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
# Base score from performance and reliability
base_score = (weight.performance_score + weight.reliability_score) / 2
# Adjust based on recent assignments
recent_assignments = [a for a in self.assignment_history if a.agent_id == agent_id][-10:]
if recent_assignments:
success_rate = sum(1 for a in recent_assignments if a.success) / len(recent_assignments)
base_score = base_score * 0.7 + success_rate * 0.3
return base_score
def _consistent_hash_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str:
"""Consistent hash selection for sticky routing"""
# Create hash key from task data
hash_key = json.dumps(task_data, sort_keys=True)
hash_value = int(hashlib.md5(hash_key.encode()).hexdigest(), 16)
# Build hash ring if not exists
if not self.consistent_hash_ring:
self._build_hash_ring(agents)
# Find agent on hash ring
for hash_pos in sorted(self.consistent_hash_ring.keys()):
if hash_value <= hash_pos:
return self.consistent_hash_ring[hash_pos]
# Wrap around
return self.consistent_hash_ring[min(self.consistent_hash_ring.keys())]
def _build_hash_ring(self, agents: List[str]):
"""Build consistent hash ring"""
self.consistent_hash_ring = {}
for agent_id in agents:
# Create multiple virtual nodes for better distribution
for i in range(100):
virtual_key = f"{agent_id}:{i}"
hash_value = int(hashlib.md5(virtual_key.encode()).hexdigest(), 16)
self.consistent_hash_ring[hash_value] = agent_id
def get_load_balancing_stats(self) -> Dict[str, Any]:
"""Get load balancing statistics"""
return {
"strategy": self.strategy.value,
"total_assignments": self.total_assignments,
"successful_assignments": self.successful_assignments,
"failed_assignments": self.failed_assignments,
"success_rate": self.successful_assignments / max(1, self.total_assignments),
"active_agents": len(self.agent_metrics),
"agent_weights": len(self.agent_weights),
"avg_agent_load": statistics.mean([self._get_agent_load(a) for a in self.agent_metrics]) if self.agent_metrics else 0
}
def get_agent_stats(self, agent_id: str) -> Optional[Dict[str, Any]]:
"""Get detailed statistics for a specific agent"""
if agent_id not in self.agent_metrics:
return None
metrics = self.agent_metrics[agent_id]
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
# Get recent assignments
recent_assignments = [a for a in self.assignment_history if a.agent_id == agent_id][-10:]
return {
"agent_id": agent_id,
"metrics": metrics.to_dict(),
"weight": {
"weight": weight.weight,
"capacity": weight.capacity,
"performance_score": weight.performance_score,
"reliability_score": weight.reliability_score
},
"recent_assignments": [a.to_dict() for a in recent_assignments],
"current_load": self._get_agent_load(agent_id)
}
class TaskDistributor:
"""Task distributor with advanced load balancing"""
def __init__(self, load_balancer: LoadBalancer):
self.load_balancer = load_balancer
self.task_queue = asyncio.Queue()
self.priority_queues = {
TaskPriority.URGENT: asyncio.Queue(),
TaskPriority.CRITICAL: asyncio.Queue(),
TaskPriority.HIGH: asyncio.Queue(),
TaskPriority.NORMAL: asyncio.Queue(),
TaskPriority.LOW: asyncio.Queue()
}
self.distribution_stats = {
"tasks_distributed": 0,
"tasks_completed": 0,
"tasks_failed": 0,
"avg_distribution_time": 0.0
}
async def submit_task(self, task_data: Dict[str, Any], priority: TaskPriority = TaskPriority.NORMAL, requirements: Optional[Dict[str, Any]] = None):
"""Submit task for distribution"""
task_info = {
"task_data": task_data,
"priority": priority,
"requirements": requirements,
"submitted_at": datetime.utcnow()
}
await self.priority_queues[priority].put(task_info)
logger.info(f"Task submitted with priority {priority.value}")
async def start_distribution(self):
"""Start task distribution loop"""
while True:
try:
# Check queues in priority order
task_info = None
for priority in [TaskPriority.URGENT, TaskPriority.CRITICAL, TaskPriority.HIGH, TaskPriority.NORMAL, TaskPriority.LOW]:
queue = self.priority_queues[priority]
try:
task_info = queue.get_nowait()
break
except asyncio.QueueEmpty:
continue
if task_info:
await self._distribute_task(task_info)
else:
await asyncio.sleep(0.01) # Small delay if no tasks
except Exception as e:
logger.error(f"Error in distribution loop: {e}")
await asyncio.sleep(1)
async def _distribute_task(self, task_info: Dict[str, Any]):
"""Distribute a single task"""
start_time = datetime.utcnow()
try:
# Assign task
agent_id = await self.load_balancer.assign_task(
task_info["task_data"],
task_info["requirements"]
)
if agent_id:
# Create task message
task_message = create_task_message(
sender_id="task_distributor",
receiver_id=agent_id,
task_type=task_info["task_data"].get("task_type", "unknown"),
task_data=task_info["task_data"]
)
# Send task to agent (implementation depends on communication system)
# await self._send_task_to_agent(agent_id, task_message)
self.distribution_stats["tasks_distributed"] += 1
# Simulate task completion (in real implementation, this would be event-driven)
asyncio.create_task(self._simulate_task_completion(task_info, agent_id))
else:
logger.warning(f"Failed to distribute task: no suitable agent found")
self.distribution_stats["tasks_failed"] += 1
except Exception as e:
logger.error(f"Error distributing task: {e}")
self.distribution_stats["tasks_failed"] += 1
finally:
# Update distribution time
distribution_time = (datetime.utcnow() - start_time).total_seconds()
total_distributed = self.distribution_stats["tasks_distributed"]
self.distribution_stats["avg_distribution_time"] = (
(self.distribution_stats["avg_distribution_time"] * (total_distributed - 1) + distribution_time) / total_distributed
if total_distributed > 0 else distribution_time
)
async def _simulate_task_completion(self, task_info: Dict[str, Any], agent_id: str):
"""Simulate task completion (for testing)"""
# Simulate task processing time
processing_time = 1.0 + (hash(task_info["task_data"].get("task_id", "")) % 5)
await asyncio.sleep(processing_time)
# Mark task as completed
success = hash(agent_id) % 10 > 1 # 90% success rate
await self.load_balancer.complete_task(
task_info["task_data"].get("task_id", str(uuid.uuid4())),
success,
processing_time
)
if success:
self.distribution_stats["tasks_completed"] += 1
else:
self.distribution_stats["tasks_failed"] += 1
def get_distribution_stats(self) -> Dict[str, Any]:
"""Get distribution statistics"""
return {
**self.distribution_stats,
"load_balancer_stats": self.load_balancer.get_load_balancing_stats(),
"queue_sizes": {
priority.value: queue.qsize()
for priority, queue in self.priority_queues.items()
}
}
# Example usage
async def example_usage():
"""Example of how to use the load balancer"""
# Create registry and load balancer
registry = AgentRegistry()
await registry.start()
load_balancer = LoadBalancer(registry)
load_balancer.set_strategy(LoadBalancingStrategy.LEAST_CONNECTIONS)
# Create task distributor
distributor = TaskDistributor(load_balancer)
# Submit some tasks
for i in range(10):
await distributor.submit_task({
"task_id": f"task-{i}",
"task_type": "data_processing",
"data": f"sample_data_{i}"
}, TaskPriority.NORMAL)
# Start distribution (in real implementation, this would run in background)
# await distributor.start_distribution()
await registry.stop()
if __name__ == "__main__":
asyncio.run(example_usage())