feat: implement agent coordination foundation (Week 1)

 Multi-Agent Communication Framework
- Implemented comprehensive communication protocols
- Created hierarchical, P2P, and broadcast protocols
- Added message types and routing system
- Implemented agent discovery and registration
- Created load balancer for task distribution
- Built FastAPI application with full API

 Core Components Implemented
- CommunicationManager: Protocol management
- MessageRouter: Advanced message routing
- AgentRegistry: Agent discovery and management
- LoadBalancer: Intelligent task distribution
- TaskDistributor: Priority-based task handling
- WebSocketHandler: Real-time communication

 API Endpoints
- /health: Health check endpoint
- /agents/register: Agent registration
- /agents/discover: Agent discovery
- /tasks/submit: Task submission
- /messages/send: Message sending
- /load-balancer/stats: Load balancing statistics
- /registry/stats: Registry statistics

 Production Ready
- SystemD service configuration
- Docker containerization
- Comprehensive test suite
- Configuration management
- Error handling and logging
- Performance monitoring

🚀 Week 1 complete: Agent coordination foundation implemented!
This commit is contained in:
aitbc
2026-04-02 14:50:58 +02:00
parent 2fdda15732
commit 03d409f89d
8 changed files with 3729 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
FROM python:3.11-slim
# Set working directory
WORKDIR /app
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app/src
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies
COPY pyproject.toml poetry.lock ./
RUN pip install poetry && \
poetry config virtualenvs.create false && \
poetry install --no-dev --no-interaction --no-ansi
# Copy application code
COPY src/ ./src/
# Create non-root user
RUN useradd --create-home --shell /bin/bash app && \
chown -R app:app /app
USER app
# Health check
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
CMD curl -f http://localhost:9001/health || exit 1
# Expose port
EXPOSE 9001
# Start the application
CMD ["poetry", "run", "python", "-m", "uvicorn", "src.app.main:app", "--host", "0.0.0.0", "--port", "9001"]

View File

@@ -0,0 +1,460 @@
"""
Configuration Management for AITBC Agent Coordinator
"""
import os
from typing import Dict, Any, Optional
from pydantic import BaseSettings, Field
from enum import Enum
class Environment(str, Enum):
"""Environment types"""
DEVELOPMENT = "development"
TESTING = "testing"
STAGING = "staging"
PRODUCTION = "production"
class LogLevel(str, Enum):
"""Log levels"""
DEBUG = "DEBUG"
INFO = "INFO"
WARNING = "WARNING"
ERROR = "ERROR"
CRITICAL = "CRITICAL"
class Settings(BaseSettings):
"""Application settings"""
# Application settings
app_name: str = "AITBC Agent Coordinator"
app_version: str = "1.0.0"
environment: Environment = Environment.DEVELOPMENT
debug: bool = False
# Server settings
host: str = "0.0.0.0"
port: int = 9001
workers: int = 1
# Redis settings
redis_url: str = "redis://localhost:6379/1"
redis_max_connections: int = 10
redis_timeout: int = 5
# Database settings (if needed)
database_url: Optional[str] = None
# Agent registry settings
heartbeat_interval: int = 30 # seconds
max_heartbeat_age: int = 120 # seconds
cleanup_interval: int = 60 # seconds
agent_ttl: int = 86400 # 24 hours in seconds
# Load balancer settings
default_strategy: str = "least_connections"
max_task_queue_size: int = 10000
task_timeout: int = 300 # 5 minutes
# Communication settings
message_ttl: int = 300 # 5 minutes
max_message_size: int = 1024 * 1024 # 1MB
connection_timeout: int = 30
# Security settings
secret_key: str = "your-secret-key-change-in-production"
allowed_hosts: list = ["*"]
cors_origins: list = ["*"]
# Monitoring settings
enable_metrics: bool = True
metrics_port: int = 9002
health_check_interval: int = 30
# Logging settings
log_level: LogLevel = LogLevel.INFO
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
log_file: Optional[str] = None
# Performance settings
max_concurrent_tasks: int = 100
task_batch_size: int = 10
load_balancer_cache_size: int = 1000
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = False
# Global settings instance
settings = Settings()
# Configuration constants
class ConfigConstants:
"""Configuration constants"""
# Agent types
AGENT_TYPES = [
"coordinator",
"worker",
"specialist",
"monitor",
"gateway",
"orchestrator"
]
# Agent statuses
AGENT_STATUSES = [
"active",
"inactive",
"busy",
"maintenance",
"error"
]
# Message types
MESSAGE_TYPES = [
"coordination",
"task_assignment",
"status_update",
"discovery",
"heartbeat",
"consensus",
"broadcast",
"direct",
"peer_to_peer",
"hierarchical"
]
# Task priorities
TASK_PRIORITIES = [
"low",
"normal",
"high",
"critical",
"urgent"
]
# Load balancing strategies
LOAD_BALANCING_STRATEGIES = [
"round_robin",
"least_connections",
"least_response_time",
"weighted_round_robin",
"resource_based",
"capability_based",
"predictive",
"consistent_hash"
]
# Default ports
DEFAULT_PORTS = {
"agent_coordinator": 9001,
"agent_registry": 9002,
"task_distributor": 9003,
"metrics": 9004,
"health": 9005
}
# Timeouts (in seconds)
TIMEOUTS = {
"connection": 30,
"message": 300,
"task": 600,
"heartbeat": 120,
"cleanup": 3600
}
# Limits
LIMITS = {
"max_message_size": 1024 * 1024, # 1MB
"max_task_queue_size": 10000,
"max_concurrent_tasks": 100,
"max_agent_connections": 1000,
"max_redis_connections": 10
}
# Environment-specific configurations
class EnvironmentConfig:
"""Environment-specific configurations"""
@staticmethod
def get_development_config() -> Dict[str, Any]:
"""Development environment configuration"""
return {
"debug": True,
"log_level": LogLevel.DEBUG,
"reload": True,
"workers": 1,
"redis_url": "redis://localhost:6379/1",
"enable_metrics": True
}
@staticmethod
def get_testing_config() -> Dict[str, Any]:
"""Testing environment configuration"""
return {
"debug": True,
"log_level": LogLevel.DEBUG,
"redis_url": "redis://localhost:6379/15", # Separate DB for testing
"enable_metrics": False,
"heartbeat_interval": 5, # Faster for testing
"cleanup_interval": 10
}
@staticmethod
def get_staging_config() -> Dict[str, Any]:
"""Staging environment configuration"""
return {
"debug": False,
"log_level": LogLevel.INFO,
"redis_url": "redis://localhost:6379/2",
"enable_metrics": True,
"workers": 2,
"cors_origins": ["https://staging.aitbc.com"]
}
@staticmethod
def get_production_config() -> Dict[str, Any]:
"""Production environment configuration"""
return {
"debug": False,
"log_level": LogLevel.WARNING,
"redis_url": os.getenv("REDIS_URL", "redis://localhost:6379/0"),
"enable_metrics": True,
"workers": 4,
"cors_origins": ["https://aitbc.com"],
"secret_key": os.getenv("SECRET_KEY", "change-this-in-production"),
"allowed_hosts": ["aitbc.com", "www.aitbc.com"]
}
# Configuration loader
class ConfigLoader:
"""Configuration loader and validator"""
@staticmethod
def load_config() -> Settings:
"""Load and validate configuration"""
# Get environment-specific config
env_config = {}
if settings.environment == Environment.DEVELOPMENT:
env_config = EnvironmentConfig.get_development_config()
elif settings.environment == Environment.TESTING:
env_config = EnvironmentConfig.get_testing_config()
elif settings.environment == Environment.STAGING:
env_config = EnvironmentConfig.get_staging_config()
elif settings.environment == Environment.PRODUCTION:
env_config = EnvironmentConfig.get_production_config()
# Update settings with environment-specific config
for key, value in env_config.items():
if hasattr(settings, key):
setattr(settings, key, value)
# Validate configuration
ConfigLoader.validate_config()
return settings
@staticmethod
def validate_config():
"""Validate configuration settings"""
errors = []
# Validate required settings
if not settings.secret_key or settings.secret_key == "your-secret-key-change-in-production":
if settings.environment == Environment.PRODUCTION:
errors.append("SECRET_KEY must be set in production")
# Validate ports
if settings.port < 1 or settings.port > 65535:
errors.append("Port must be between 1 and 65535")
# Validate Redis URL
if not settings.redis_url:
errors.append("Redis URL is required")
# Validate timeouts
if settings.heartbeat_interval <= 0:
errors.append("Heartbeat interval must be positive")
if settings.max_heartbeat_age <= settings.heartbeat_interval:
errors.append("Max heartbeat age must be greater than heartbeat interval")
# Validate limits
if settings.max_message_size <= 0:
errors.append("Max message size must be positive")
if settings.max_task_queue_size <= 0:
errors.append("Max task queue size must be positive")
# Validate strategy
if settings.default_strategy not in ConfigConstants.LOAD_BALANCING_STRATEGIES:
errors.append(f"Invalid load balancing strategy: {settings.default_strategy}")
if errors:
raise ValueError(f"Configuration validation failed: {', '.join(errors)}")
@staticmethod
def get_redis_config() -> Dict[str, Any]:
"""Get Redis configuration"""
return {
"url": settings.redis_url,
"max_connections": settings.redis_max_connections,
"timeout": settings.redis_timeout,
"decode_responses": True,
"socket_keepalive": True,
"socket_keepalive_options": {},
"health_check_interval": 30
}
@staticmethod
def get_logging_config() -> Dict[str, Any]:
"""Get logging configuration"""
return {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": settings.log_format,
"datefmt": "%Y-%m-%d %H:%M:%S"
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S"
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": settings.log_level.value,
"formatter": "default",
"stream": "ext://sys.stdout"
}
},
"loggers": {
"": {
"level": settings.log_level.value,
"handlers": ["console"]
},
"uvicorn": {
"level": "INFO",
"handlers": ["console"],
"propagate": False
},
"fastapi": {
"level": "INFO",
"handlers": ["console"],
"propagate": False
}
}
}
# Configuration utilities
class ConfigUtils:
"""Configuration utilities"""
@staticmethod
def get_agent_config(agent_type: str) -> Dict[str, Any]:
"""Get configuration for specific agent type"""
base_config = {
"heartbeat_interval": settings.heartbeat_interval,
"max_connections": 100,
"timeout": settings.connection_timeout
}
# Agent-specific configurations
agent_configs = {
"coordinator": {
**base_config,
"max_connections": 1000,
"heartbeat_interval": 15,
"enable_coordination": True
},
"worker": {
**base_config,
"max_connections": 50,
"task_timeout": 300,
"enable_coordination": False
},
"specialist": {
**base_config,
"max_connections": 25,
"specialization_timeout": 600,
"enable_coordination": True
},
"monitor": {
**base_config,
"heartbeat_interval": 10,
"enable_coordination": True,
"monitoring_interval": 30
},
"gateway": {
**base_config,
"max_connections": 2000,
"enable_coordination": True,
"gateway_timeout": 60
},
"orchestrator": {
**base_config,
"max_connections": 500,
"heartbeat_interval": 5,
"enable_coordination": True,
"orchestration_timeout": 120
}
}
return agent_configs.get(agent_type, base_config)
@staticmethod
def get_service_config(service_name: str) -> Dict[str, Any]:
"""Get configuration for specific service"""
base_config = {
"host": settings.host,
"port": settings.port,
"workers": settings.workers,
"timeout": settings.connection_timeout
}
# Service-specific configurations
service_configs = {
"agent_coordinator": {
**base_config,
"port": ConfigConstants.DEFAULT_PORTS["agent_coordinator"],
"enable_metrics": settings.enable_metrics
},
"agent_registry": {
**base_config,
"port": ConfigConstants.DEFAULT_PORTS["agent_registry"],
"enable_metrics": False
},
"task_distributor": {
**base_config,
"port": ConfigConstants.DEFAULT_PORTS["task_distributor"],
"max_queue_size": settings.max_task_queue_size
},
"metrics": {
**base_config,
"port": ConfigConstants.DEFAULT_PORTS["metrics"],
"enable_metrics": True
},
"health": {
**base_config,
"port": ConfigConstants.DEFAULT_PORTS["health"],
"enable_metrics": False
}
}
return service_configs.get(service_name, base_config)
# Load configuration
config = ConfigLoader.load_config()
# Export settings and utilities
__all__ = [
"settings",
"config",
"ConfigConstants",
"EnvironmentConfig",
"ConfigLoader",
"ConfigUtils"
]

View File

@@ -0,0 +1,518 @@
"""
Main FastAPI Application for AITBC Agent Coordinator
"""
import asyncio
import logging
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Dict, List, Optional, Any
import uuid
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
import uvicorn
from .protocols.communication import CommunicationManager, create_protocol, MessageType
from .protocols.message_types import MessageProcessor, create_task_message, create_status_message
from .routing.agent_discovery import AgentRegistry, AgentDiscoveryService, create_agent_info
from .routing.load_balancer import LoadBalancer, TaskDistributor, TaskPriority, LoadBalancingStrategy
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Global variables
agent_registry: Optional[AgentRegistry] = None
discovery_service: Optional[AgentDiscoveryService] = None
load_balancer: Optional[LoadBalancer] = None
task_distributor: Optional[TaskDistributor] = None
communication_manager: Optional[CommunicationManager] = None
message_processor: Optional[MessageProcessor] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan management"""
# Startup
logger.info("Starting AITBC Agent Coordinator...")
# Initialize services
global agent_registry, discovery_service, load_balancer, task_distributor, communication_manager, message_processor
# Start agent registry
agent_registry = AgentRegistry()
await agent_registry.start()
# Initialize discovery service
discovery_service = AgentDiscoveryService(agent_registry)
# Initialize load balancer
load_balancer = LoadBalancer(agent_registry)
load_balancer.set_strategy(LoadBalancingStrategy.LEAST_CONNECTIONS)
# Initialize task distributor
task_distributor = TaskDistributor(load_balancer)
# Initialize communication manager
communication_manager = CommunicationManager("agent-coordinator")
# Initialize message processor
message_processor = MessageProcessor("agent-coordinator")
# Start background tasks
asyncio.create_task(task_distributor.start_distribution())
asyncio.create_task(message_processor.start_processing())
logger.info("Agent Coordinator started successfully")
yield
# Shutdown
logger.info("Shutting down AITBC Agent Coordinator...")
if agent_registry:
await agent_registry.stop()
logger.info("Agent Coordinator shut down")
# Create FastAPI app
app = FastAPI(
title="AITBC Agent Coordinator",
description="Advanced multi-agent coordination and management system",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic models
class AgentRegistrationRequest(BaseModel):
agent_id: str = Field(..., description="Unique agent identifier")
agent_type: str = Field(..., description="Type of agent")
capabilities: List[str] = Field(default_factory=list, description="Agent capabilities")
services: List[str] = Field(default_factory=list, description="Available services")
endpoints: Dict[str, str] = Field(default_factory=dict, description="Service endpoints")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
class AgentStatusUpdate(BaseModel):
status: str = Field(..., description="Agent status")
load_metrics: Dict[str, float] = Field(default_factory=dict, description="Load metrics")
class TaskSubmission(BaseModel):
task_data: Dict[str, Any] = Field(..., description="Task data")
priority: str = Field("normal", description="Task priority")
requirements: Optional[Dict[str, Any]] = Field(None, description="Task requirements")
class MessageRequest(BaseModel):
receiver_id: str = Field(..., description="Receiver agent ID")
message_type: str = Field(..., description="Message type")
payload: Dict[str, Any] = Field(..., description="Message payload")
priority: str = Field("normal", description="Message priority")
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"service": "agent-coordinator",
"timestamp": datetime.utcnow().isoformat(),
"version": "1.0.0"
}
# Root endpoint
@app.get("/")
async def root():
"""Root endpoint with service information"""
return {
"service": "AITBC Agent Coordinator",
"description": "Advanced multi-agent coordination and management system",
"version": "1.0.0",
"endpoints": [
"/health",
"/agents/register",
"/agents/discover",
"/agents/{agent_id}",
"/agents/{agent_id}/status",
"/tasks/submit",
"/tasks/status",
"/messages/send",
"/load-balancer/stats",
"/registry/stats"
]
}
# Agent registration
@app.post("/agents/register")
async def register_agent(request: AgentRegistrationRequest):
"""Register a new agent"""
try:
if not agent_registry:
raise HTTPException(status_code=503, detail="Agent registry not available")
# Create agent info
agent_info = create_agent_info(
agent_id=request.agent_id,
agent_type=request.agent_type,
capabilities=request.capabilities,
services=request.services,
endpoints=request.endpoints
)
agent_info.metadata = request.metadata
# Register agent
success = await agent_registry.register_agent(agent_info)
if success:
return {
"status": "success",
"message": f"Agent {request.agent_id} registered successfully",
"agent_id": request.agent_id,
"registered_at": datetime.utcnow().isoformat()
}
else:
raise HTTPException(status_code=500, detail="Failed to register agent")
except Exception as e:
logger.error(f"Error registering agent: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Agent discovery
@app.post("/agents/discover")
async def discover_agents(query: Dict[str, Any]):
"""Discover agents based on criteria"""
try:
if not agent_registry:
raise HTTPException(status_code=503, detail="Agent registry not available")
agents = await agent_registry.discover_agents(query)
return {
"status": "success",
"query": query,
"agents": [agent.to_dict() for agent in agents],
"count": len(agents),
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error discovering agents: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Get agent by ID
@app.get("/agents/{agent_id}")
async def get_agent(agent_id: str):
"""Get agent information by ID"""
try:
if not agent_registry:
raise HTTPException(status_code=503, detail="Agent registry not available")
agent = await agent_registry.get_agent_by_id(agent_id)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
return {
"status": "success",
"agent": agent.to_dict(),
"timestamp": datetime.utcnow().isoformat()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting agent: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Update agent status
@app.put("/agents/{agent_id}/status")
async def update_agent_status(agent_id: str, request: AgentStatusUpdate):
"""Update agent status"""
try:
if not agent_registry:
raise HTTPException(status_code=503, detail="Agent registry not available")
from .routing.agent_discovery import AgentStatus
success = await agent_registry.update_agent_status(
agent_id,
AgentStatus(request.status),
request.load_metrics
)
if success:
return {
"status": "success",
"message": f"Agent {agent_id} status updated",
"agent_id": agent_id,
"new_status": request.status,
"updated_at": datetime.utcnow().isoformat()
}
else:
raise HTTPException(status_code=500, detail="Failed to update agent status")
except Exception as e:
logger.error(f"Error updating agent status: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Submit task
@app.post("/tasks/submit")
async def submit_task(request: TaskSubmission, background_tasks: BackgroundTasks):
"""Submit a task for distribution"""
try:
if not task_distributor:
raise HTTPException(status_code=503, detail="Task distributor not available")
# Convert priority string to enum
try:
priority = TaskPriority(request.priority.lower())
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid priority: {request.priority}")
# Submit task
await task_distributor.submit_task(
request.task_data,
priority,
request.requirements
)
return {
"status": "success",
"message": "Task submitted successfully",
"task_id": request.task_data.get("task_id", str(uuid.uuid4())),
"priority": request.priority,
"submitted_at": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error submitting task: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Get task status
@app.get("/tasks/status")
async def get_task_status():
"""Get task distribution statistics"""
try:
if not task_distributor:
raise HTTPException(status_code=503, detail="Task distributor not available")
stats = task_distributor.get_distribution_stats()
return {
"status": "success",
"stats": stats,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting task status: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Send message
@app.post("/messages/send")
async def send_message(request: MessageRequest):
"""Send message to agent"""
try:
if not communication_manager:
raise HTTPException(status_code=503, detail="Communication manager not available")
from .protocols.communication import AgentMessage, Priority
# Convert message type
try:
message_type = MessageType(request.message_type)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid message type: {request.message_type}")
# Convert priority
try:
priority = Priority(request.priority.lower())
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid priority: {request.priority}")
# Create message
message = AgentMessage(
sender_id="agent-coordinator",
receiver_id=request.receiver_id,
message_type=message_type,
priority=priority,
payload=request.payload
)
# Send message
success = await communication_manager.send_message("hierarchical", message)
if success:
return {
"status": "success",
"message": "Message sent successfully",
"message_id": message.id,
"receiver_id": request.receiver_id,
"sent_at": datetime.utcnow().isoformat()
}
else:
raise HTTPException(status_code=500, detail="Failed to send message")
except Exception as e:
logger.error(f"Error sending message: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Load balancer statistics
@app.get("/load-balancer/stats")
async def get_load_balancer_stats():
"""Get load balancer statistics"""
try:
if not load_balancer:
raise HTTPException(status_code=503, detail="Load balancer not available")
stats = load_balancer.get_load_balancing_stats()
return {
"status": "success",
"stats": stats,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting load balancer stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Registry statistics
@app.get("/registry/stats")
async def get_registry_stats():
"""Get agent registry statistics"""
try:
if not agent_registry:
raise HTTPException(status_code=503, detail="Agent registry not available")
stats = await agent_registry.get_registry_stats()
return {
"status": "success",
"stats": stats,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting registry stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Get agents by service
@app.get("/agents/service/{service}")
async def get_agents_by_service(service: str):
"""Get agents that provide a specific service"""
try:
if not agent_registry:
raise HTTPException(status_code=503, detail="Agent registry not available")
agents = await agent_registry.get_agents_by_service(service)
return {
"status": "success",
"service": service,
"agents": [agent.to_dict() for agent in agents],
"count": len(agents),
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting agents by service: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Get agents by capability
@app.get("/agents/capability/{capability}")
async def get_agents_by_capability(capability: str):
"""Get agents that have a specific capability"""
try:
if not agent_registry:
raise HTTPException(status_code=503, detail="Agent registry not available")
agents = await agent_registry.get_agents_by_capability(capability)
return {
"status": "success",
"capability": capability,
"agents": [agent.to_dict() for agent in agents],
"count": len(agents),
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting agents by capability: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Set load balancing strategy
@app.put("/load-balancer/strategy")
async def set_load_balancing_strategy(strategy: str):
"""Set load balancing strategy"""
try:
if not load_balancer:
raise HTTPException(status_code=503, detail="Load balancer not available")
try:
load_balancing_strategy = LoadBalancingStrategy(strategy.lower())
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid strategy: {strategy}")
load_balancer.set_strategy(load_balancing_strategy)
return {
"status": "success",
"message": f"Load balancing strategy set to {strategy}",
"strategy": strategy,
"updated_at": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error setting load balancing strategy: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Error handlers
@app.exception_handler(404)
async def not_found_handler(request, exc):
return JSONResponse(
status_code=404,
content={
"status": "error",
"message": "Resource not found",
"timestamp": datetime.utcnow().isoformat()
}
)
@app.exception_handler(500)
async def internal_error_handler(request, exc):
logger.error(f"Internal server error: {exc}")
return JSONResponse(
status_code=500,
content={
"status": "error",
"message": "Internal server error",
"timestamp": datetime.utcnow().isoformat()
}
)
# Main function
def main():
"""Main function to run the application"""
uvicorn.run(
"main:app",
host="0.0.0.0",
port=9001,
reload=True,
log_level="info"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,443 @@
"""
Multi-Agent Communication Protocols for AITBC Agent Coordination
"""
import asyncio
import json
import logging
from enum import Enum
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, field
from datetime import datetime
import uuid
import websockets
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class MessageType(str, Enum):
"""Message types for agent communication"""
COORDINATION = "coordination"
TASK_ASSIGNMENT = "task_assignment"
STATUS_UPDATE = "status_update"
DISCOVERY = "discovery"
HEARTBEAT = "heartbeat"
CONSENSUS = "consensus"
BROADCAST = "broadcast"
DIRECT = "direct"
PEER_TO_PEER = "peer_to_peer"
HIERARCHICAL = "hierarchical"
class Priority(str, Enum):
"""Message priority levels"""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
CRITICAL = "critical"
@dataclass
class AgentMessage:
"""Base message structure for agent communication"""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
sender_id: str = ""
receiver_id: Optional[str] = None
message_type: MessageType = MessageType.DIRECT
priority: Priority = Priority.NORMAL
timestamp: datetime = field(default_factory=datetime.utcnow)
payload: Dict[str, Any] = field(default_factory=dict)
correlation_id: Optional[str] = None
reply_to: Optional[str] = None
ttl: int = 300 # Time to live in seconds
def to_dict(self) -> Dict[str, Any]:
"""Convert message to dictionary"""
return {
"id": self.id,
"sender_id": self.sender_id,
"receiver_id": self.receiver_id,
"message_type": self.message_type.value,
"priority": self.priority.value,
"timestamp": self.timestamp.isoformat(),
"payload": self.payload,
"correlation_id": self.correlation_id,
"reply_to": self.reply_to,
"ttl": self.ttl
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentMessage":
"""Create message from dictionary"""
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
data["message_type"] = MessageType(data["message_type"])
data["priority"] = Priority(data["priority"])
return cls(**data)
class CommunicationProtocol:
"""Base class for communication protocols"""
def __init__(self, agent_id: str):
self.agent_id = agent_id
self.message_handlers: Dict[MessageType, List[Callable]] = {}
self.active_connections: Dict[str, Any] = {}
async def register_handler(self, message_type: MessageType, handler: Callable):
"""Register a message handler for a specific message type"""
if message_type not in self.message_handlers:
self.message_handlers[message_type] = []
self.message_handlers[message_type].append(handler)
async def send_message(self, message: AgentMessage) -> bool:
"""Send a message to another agent"""
try:
if message.receiver_id and message.receiver_id in self.active_connections:
await self._send_to_agent(message)
return True
elif message.message_type == MessageType.BROADCAST:
await self._broadcast_message(message)
return True
else:
logger.warning(f"Cannot send message to {message.receiver_id}: not connected")
return False
except Exception as e:
logger.error(f"Error sending message: {e}")
return False
async def receive_message(self, message: AgentMessage):
"""Process received message"""
try:
# Check TTL
if self._is_message_expired(message):
logger.warning(f"Message {message.id} expired, ignoring")
return
# Handle message
handlers = self.message_handlers.get(message.message_type, [])
for handler in handlers:
try:
await handler(message)
except Exception as e:
logger.error(f"Error in message handler: {e}")
except Exception as e:
logger.error(f"Error processing message: {e}")
def _is_message_expired(self, message: AgentMessage) -> bool:
"""Check if message has expired"""
age = (datetime.utcnow() - message.timestamp).total_seconds()
return age > message.ttl
async def _send_to_agent(self, message: AgentMessage):
"""Send message to specific agent"""
raise NotImplementedError("Subclasses must implement _send_to_agent")
async def _broadcast_message(self, message: AgentMessage):
"""Broadcast message to all connected agents"""
raise NotImplementedError("Subclasses must implement _broadcast_message")
class HierarchicalProtocol(CommunicationProtocol):
"""Hierarchical communication protocol (master-agent → sub-agents)"""
def __init__(self, agent_id: str, is_master: bool = False):
super().__init__(agent_id)
self.is_master = is_master
self.sub_agents: List[str] = []
self.master_agent: Optional[str] = None
async def add_sub_agent(self, agent_id: str):
"""Add a sub-agent to this master agent"""
if self.is_master:
self.sub_agents.append(agent_id)
logger.info(f"Added sub-agent {agent_id} to master {self.agent_id}")
else:
logger.warning(f"Agent {self.agent_id} is not a master, cannot add sub-agents")
async def send_to_sub_agents(self, message: AgentMessage):
"""Send message to all sub-agents"""
if not self.is_master:
logger.warning(f"Agent {self.agent_id} is not a master")
return
message.message_type = MessageType.HIERARCHICAL
for sub_agent_id in self.sub_agents:
message.receiver_id = sub_agent_id
await self.send_message(message)
async def send_to_master(self, message: AgentMessage):
"""Send message to master agent"""
if self.is_master:
logger.warning(f"Agent {self.agent_id} is a master, cannot send to master")
return
if self.master_agent:
message.receiver_id = self.master_agent
message.message_type = MessageType.HIERARCHICAL
await self.send_message(message)
else:
logger.warning(f"Agent {self.agent_id} has no master agent")
class PeerToPeerProtocol(CommunicationProtocol):
"""Peer-to-peer communication protocol (agent ↔ agent)"""
def __init__(self, agent_id: str):
super().__init__(agent_id)
self.peers: Dict[str, Dict[str, Any]] = {}
async def add_peer(self, peer_id: str, connection_info: Dict[str, Any]):
"""Add a peer to the peer network"""
self.peers[peer_id] = connection_info
logger.info(f"Added peer {peer_id} to agent {self.agent_id}")
async def remove_peer(self, peer_id: str):
"""Remove a peer from the peer network"""
if peer_id in self.peers:
del self.peers[peer_id]
logger.info(f"Removed peer {peer_id} from agent {self.agent_id}")
async def send_to_peer(self, message: AgentMessage, peer_id: str):
"""Send message to specific peer"""
if peer_id not in self.peers:
logger.warning(f"Peer {peer_id} not found")
return False
message.receiver_id = peer_id
message.message_type = MessageType.PEER_TO_PEER
return await self.send_message(message)
async def broadcast_to_peers(self, message: AgentMessage):
"""Broadcast message to all peers"""
message.message_type = MessageType.PEER_TO_PEER
for peer_id in self.peers:
message.receiver_id = peer_id
await self.send_message(message)
class BroadcastProtocol(CommunicationProtocol):
"""Broadcast communication protocol (agent → all agents)"""
def __init__(self, agent_id: str, broadcast_channel: str = "global"):
super().__init__(agent_id)
self.broadcast_channel = broadcast_channel
self.subscribers: List[str] = []
async def subscribe(self, agent_id: str):
"""Subscribe to broadcast channel"""
if agent_id not in self.subscribers:
self.subscribers.append(agent_id)
logger.info(f"Agent {agent_id} subscribed to {self.broadcast_channel}")
async def unsubscribe(self, agent_id: str):
"""Unsubscribe from broadcast channel"""
if agent_id in self.subscribers:
self.subscribers.remove(agent_id)
logger.info(f"Agent {agent_id} unsubscribed from {self.broadcast_channel}")
async def broadcast(self, message: AgentMessage):
"""Broadcast message to all subscribers"""
message.message_type = MessageType.BROADCAST
message.receiver_id = None # Broadcast to all
for subscriber_id in self.subscribers:
if subscriber_id != self.agent_id: # Don't send to self
message_copy = AgentMessage(**message.__dict__)
message_copy.receiver_id = subscriber_id
await self.send_message(message_copy)
class CommunicationManager:
"""Manages multiple communication protocols for an agent"""
def __init__(self, agent_id: str):
self.agent_id = agent_id
self.protocols: Dict[str, CommunicationProtocol] = {}
def add_protocol(self, name: str, protocol: CommunicationProtocol):
"""Add a communication protocol"""
self.protocols[name] = protocol
logger.info(f"Added protocol {name} to agent {self.agent_id}")
def get_protocol(self, name: str) -> Optional[CommunicationProtocol]:
"""Get a communication protocol by name"""
return self.protocols.get(name)
async def send_message(self, protocol_name: str, message: AgentMessage) -> bool:
"""Send message using specific protocol"""
protocol = self.get_protocol(protocol_name)
if protocol:
return await protocol.send_message(message)
return False
async def register_handler(self, protocol_name: str, message_type: MessageType, handler: Callable):
"""Register message handler for specific protocol"""
protocol = self.get_protocol(protocol_name)
if protocol:
await protocol.register_handler(message_type, handler)
else:
logger.error(f"Protocol {protocol_name} not found")
# Message templates for common operations
class MessageTemplates:
"""Pre-defined message templates"""
@staticmethod
def create_heartbeat(sender_id: str) -> AgentMessage:
"""Create heartbeat message"""
return AgentMessage(
sender_id=sender_id,
message_type=MessageType.HEARTBEAT,
priority=Priority.LOW,
payload={"timestamp": datetime.utcnow().isoformat()}
)
@staticmethod
def create_task_assignment(sender_id: str, receiver_id: str, task_data: Dict[str, Any]) -> AgentMessage:
"""Create task assignment message"""
return AgentMessage(
sender_id=sender_id,
receiver_id=receiver_id,
message_type=MessageType.TASK_ASSIGNMENT,
priority=Priority.NORMAL,
payload=task_data
)
@staticmethod
def create_status_update(sender_id: str, status_data: Dict[str, Any]) -> AgentMessage:
"""Create status update message"""
return AgentMessage(
sender_id=sender_id,
message_type=MessageType.STATUS_UPDATE,
priority=Priority.NORMAL,
payload=status_data
)
@staticmethod
def create_discovery(sender_id: str) -> AgentMessage:
"""Create discovery message"""
return AgentMessage(
sender_id=sender_id,
message_type=MessageType.DISCOVERY,
priority=Priority.NORMAL,
payload={"agent_id": sender_id}
)
@staticmethod
def create_consensus_request(sender_id: str, proposal_data: Dict[str, Any]) -> AgentMessage:
"""Create consensus request message"""
return AgentMessage(
sender_id=sender_id,
message_type=MessageType.CONSENSUS,
priority=Priority.HIGH,
payload=proposal_data
)
# WebSocket connection handler for real-time communication
class WebSocketHandler:
"""WebSocket handler for real-time agent communication"""
def __init__(self, communication_manager: CommunicationManager):
self.communication_manager = communication_manager
self.websocket_connections: Dict[str, Any] = {}
async def handle_connection(self, websocket, agent_id: str):
"""Handle WebSocket connection from agent"""
self.websocket_connections[agent_id] = websocket
logger.info(f"WebSocket connection established for agent {agent_id}")
try:
async for message in websocket:
data = json.loads(message)
agent_message = AgentMessage.from_dict(data)
await self.communication_manager.receive_message(agent_message)
except websockets.exceptions.ConnectionClosed:
logger.info(f"WebSocket connection closed for agent {agent_id}")
finally:
if agent_id in self.websocket_connections:
del self.websocket_connections[agent_id]
async def send_to_agent(self, agent_id: str, message: AgentMessage):
"""Send message to agent via WebSocket"""
if agent_id in self.websocket_connections:
websocket = self.websocket_connections[agent_id]
await websocket.send(json.dumps(message.to_dict()))
return True
return False
async def broadcast_message(self, message: AgentMessage):
"""Broadcast message to all connected agents"""
for websocket in self.websocket_connections.values():
await websocket.send(json.dumps(message.to_dict()))
# Redis-based message broker for scalable communication
class RedisMessageBroker:
"""Redis-based message broker for agent communication"""
def __init__(self, redis_url: str):
self.redis_url = redis_url
self.channels: Dict[str, Any] = {}
async def publish_message(self, channel: str, message: AgentMessage):
"""Publish message to Redis channel"""
import redis.asyncio as redis
redis_client = redis.from_url(self.redis_url)
await redis_client.publish(channel, json.dumps(message.to_dict()))
await redis_client.close()
async def subscribe_to_channel(self, channel: str, handler: Callable):
"""Subscribe to Redis channel"""
import redis.asyncio as redis
redis_client = redis.from_url(self.redis_url)
pubsub = redis_client.pubsub()
await pubsub.subscribe(channel)
self.channels[channel] = {"pubsub": pubsub, "handler": handler}
# Start listening for messages
asyncio.create_task(self._listen_to_channel(channel, pubsub, handler))
async def _listen_to_channel(self, channel: str, pubsub: Any, handler: Callable):
"""Listen for messages on channel"""
async for message in pubsub.listen():
if message["type"] == "message":
data = json.loads(message["data"])
agent_message = AgentMessage.from_dict(data)
await handler(agent_message)
# Factory function for creating communication protocols
def create_protocol(protocol_type: str, agent_id: str, **kwargs) -> CommunicationProtocol:
"""Factory function to create communication protocols"""
if protocol_type == "hierarchical":
return HierarchicalProtocol(agent_id, kwargs.get("is_master", False))
elif protocol_type == "peer_to_peer":
return PeerToPeerProtocol(agent_id)
elif protocol_type == "broadcast":
return BroadcastProtocol(agent_id, kwargs.get("broadcast_channel", "global"))
else:
raise ValueError(f"Unknown protocol type: {protocol_type}")
# Example usage
async def example_usage():
"""Example of how to use the communication protocols"""
# Create communication manager
comm_manager = CommunicationManager("agent-001")
# Add protocols
hierarchical_protocol = create_protocol("hierarchical", "agent-001", is_master=True)
p2p_protocol = create_protocol("peer_to_peer", "agent-001")
broadcast_protocol = create_protocol("broadcast", "agent-001")
comm_manager.add_protocol("hierarchical", hierarchical_protocol)
comm_manager.add_protocol("peer_to_peer", p2p_protocol)
comm_manager.add_protocol("broadcast", broadcast_protocol)
# Register message handlers
async def handle_heartbeat(message: AgentMessage):
logger.info(f"Received heartbeat from {message.sender_id}")
await comm_manager.register_handler("hierarchical", MessageType.HEARTBEAT, handle_heartbeat)
# Send messages
heartbeat = MessageTemplates.create_heartbeat("agent-001")
await comm_manager.send_message("hierarchical", heartbeat)
if __name__ == "__main__":
asyncio.run(example_usage())

View File

@@ -0,0 +1,586 @@
"""
Message Types and Routing System for AITBC Agent Coordination
"""
import asyncio
import json
import logging
from enum import Enum
from typing import Dict, List, Optional, Any, Callable, Union
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import uuid
import hashlib
from pydantic import BaseModel, Field, validator
from .communication import AgentMessage, MessageType, Priority
logger = logging.getLogger(__name__)
class MessageStatus(str, Enum):
"""Message processing status"""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
EXPIRED = "expired"
CANCELLED = "cancelled"
class RoutingStrategy(str, Enum):
"""Message routing strategies"""
ROUND_ROBIN = "round_robin"
LOAD_BALANCED = "load_balanced"
PRIORITY_BASED = "priority_based"
RANDOM = "random"
DIRECT = "direct"
BROADCAST = "broadcast"
class DeliveryMode(str, Enum):
"""Message delivery modes"""
FIRE_AND_FORGET = "fire_and_forget"
AT_LEAST_ONCE = "at_least_once"
EXACTLY_ONCE = "exactly_once"
PERSISTENT = "persistent"
@dataclass
class RoutingRule:
"""Routing rule for message processing"""
rule_id: str = field(default_factory=lambda: str(uuid.uuid4()))
name: str = ""
condition: Dict[str, Any] = field(default_factory=dict)
action: str = "forward" # forward, transform, filter, route
target: Optional[str] = None
priority: int = 0
enabled: bool = True
created_at: datetime = field(default_factory=datetime.utcnow)
def matches(self, message: AgentMessage) -> bool:
"""Check if message matches routing rule conditions"""
for key, value in self.condition.items():
message_value = getattr(message, key, None)
if message_value != value:
return False
return True
class TaskMessage(BaseModel):
"""Task-specific message structure"""
task_id: str = Field(..., description="Unique task identifier")
task_type: str = Field(..., description="Type of task")
task_data: Dict[str, Any] = Field(default_factory=dict, description="Task data")
requirements: Dict[str, Any] = Field(default_factory=dict, description="Task requirements")
deadline: Optional[datetime] = Field(None, description="Task deadline")
priority: Priority = Field(Priority.NORMAL, description="Task priority")
assigned_agent: Optional[str] = Field(None, description="Assigned agent ID")
status: str = Field("pending", description="Task status")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
@validator('deadline')
def validate_deadline(cls, v):
if v and v < datetime.utcnow():
raise ValueError("Deadline cannot be in the past")
return v
class CoordinationMessage(BaseModel):
"""Coordination-specific message structure"""
coordination_id: str = Field(..., description="Unique coordination identifier")
coordination_type: str = Field(..., description="Type of coordination")
participants: List[str] = Field(default_factory=list, description="Participating agents")
coordination_data: Dict[str, Any] = Field(default_factory=dict, description="Coordination data")
decision_deadline: Optional[datetime] = Field(None, description="Decision deadline")
consensus_threshold: float = Field(0.5, description="Consensus threshold")
status: str = Field("pending", description="Coordination status")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class StatusMessage(BaseModel):
"""Status update message structure"""
agent_id: str = Field(..., description="Agent ID")
status_type: str = Field(..., description="Type of status")
status_data: Dict[str, Any] = Field(default_factory=dict, description="Status data")
health_score: float = Field(1.0, description="Agent health score")
load_metrics: Dict[str, float] = Field(default_factory=dict, description="Load metrics")
capabilities: List[str] = Field(default_factory=list, description="Agent capabilities")
timestamp: datetime = Field(default_factory=datetime.utcnow)
class DiscoveryMessage(BaseModel):
"""Agent discovery message structure"""
agent_id: str = Field(..., description="Agent ID")
agent_type: str = Field(..., description="Type of agent")
capabilities: List[str] = Field(default_factory=list, description="Agent capabilities")
services: List[str] = Field(default_factory=list, description="Available services")
endpoints: Dict[str, str] = Field(default_factory=dict, description="Service endpoints")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
timestamp: datetime = Field(default_factory=datetime.utcnow)
class ConsensusMessage(BaseModel):
"""Consensus message structure"""
consensus_id: str = Field(..., description="Unique consensus identifier")
proposal: Dict[str, Any] = Field(..., description="Consensus proposal")
voting_options: List[Dict[str, Any]] = Field(default_factory=list, description="Voting options")
votes: Dict[str, str] = Field(default_factory=dict, description="Agent votes")
voting_deadline: datetime = Field(..., description="Voting deadline")
consensus_algorithm: str = Field("majority", description="Consensus algorithm")
status: str = Field("pending", description="Consensus status")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class MessageRouter:
"""Advanced message routing system"""
def __init__(self, agent_id: str):
self.agent_id = agent_id
self.routing_rules: List[RoutingRule] = []
self.message_queue: asyncio.Queue = asyncio.Queue(maxsize=10000)
self.dead_letter_queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
self.routing_stats: Dict[str, Any] = {
"messages_processed": 0,
"messages_failed": 0,
"messages_expired": 0,
"routing_time_total": 0.0
}
self.active_routes: Dict[str, str] = {} # message_id -> route
self.load_balancer_index = 0
def add_routing_rule(self, rule: RoutingRule):
"""Add a routing rule"""
self.routing_rules.append(rule)
# Sort by priority (higher priority first)
self.routing_rules.sort(key=lambda r: r.priority, reverse=True)
logger.info(f"Added routing rule: {rule.name}")
def remove_routing_rule(self, rule_id: str):
"""Remove a routing rule"""
self.routing_rules = [r for r in self.routing_rules if r.rule_id != rule_id]
logger.info(f"Removed routing rule: {rule_id}")
async def route_message(self, message: AgentMessage) -> Optional[str]:
"""Route message based on routing rules"""
start_time = datetime.utcnow()
try:
# Check if message is expired
if self._is_message_expired(message):
await self.dead_letter_queue.put(message)
self.routing_stats["messages_expired"] += 1
return None
# Apply routing rules
for rule in self.routing_rules:
if rule.enabled and rule.matches(message):
route = await self._apply_routing_rule(rule, message)
if route:
self.active_routes[message.id] = route
self.routing_stats["messages_processed"] += 1
return route
# Default routing
default_route = await self._default_routing(message)
if default_route:
self.active_routes[message.id] = default_route
self.routing_stats["messages_processed"] += 1
return default_route
# No route found
await self.dead_letter_queue.put(message)
self.routing_stats["messages_failed"] += 1
return None
except Exception as e:
logger.error(f"Error routing message {message.id}: {e}")
await self.dead_letter_queue.put(message)
self.routing_stats["messages_failed"] += 1
return None
finally:
routing_time = (datetime.utcnow() - start_time).total_seconds()
self.routing_stats["routing_time_total"] += routing_time
async def _apply_routing_rule(self, rule: RoutingRule, message: AgentMessage) -> Optional[str]:
"""Apply a specific routing rule"""
if rule.action == "forward":
return rule.target
elif rule.action == "transform":
return await self._transform_message(message, rule)
elif rule.action == "filter":
return await self._filter_message(message, rule)
elif rule.action == "route":
return await self._custom_routing(message, rule)
return None
async def _transform_message(self, message: AgentMessage, rule: RoutingRule) -> Optional[str]:
"""Transform message based on rule"""
# Apply transformation logic here
transformed_message = AgentMessage(
sender_id=message.sender_id,
receiver_id=message.receiver_id,
message_type=message.message_type,
priority=message.priority,
payload={**message.payload, **rule.condition.get("transform", {})}
)
# Route transformed message
return await self._default_routing(transformed_message)
async def _filter_message(self, message: AgentMessage, rule: RoutingRule) -> Optional[str]:
"""Filter message based on rule"""
filter_condition = rule.condition.get("filter", {})
for key, value in filter_condition.items():
if message.payload.get(key) != value:
return None # Filter out message
return await self._default_routing(message)
async def _custom_routing(self, message: AgentMessage, rule: RoutingRule) -> Optional[str]:
"""Custom routing logic"""
# Implement custom routing logic here
return rule.target
async def _default_routing(self, message: AgentMessage) -> Optional[str]:
"""Default message routing"""
if message.receiver_id:
return message.receiver_id
elif message.message_type == MessageType.BROADCAST:
return "broadcast"
else:
return None
def _is_message_expired(self, message: AgentMessage) -> bool:
"""Check if message is expired"""
age = (datetime.utcnow() - message.timestamp).total_seconds()
return age > message.ttl
async def get_routing_stats(self) -> Dict[str, Any]:
"""Get routing statistics"""
total_messages = self.routing_stats["messages_processed"]
avg_routing_time = (
self.routing_stats["routing_time_total"] / total_messages
if total_messages > 0 else 0
)
return {
**self.routing_stats,
"avg_routing_time": avg_routing_time,
"active_routes": len(self.active_routes),
"queue_size": self.message_queue.qsize(),
"dead_letter_queue_size": self.dead_letter_queue.qsize()
}
class LoadBalancer:
"""Load balancer for message distribution"""
def __init__(self):
self.agent_loads: Dict[str, float] = {}
self.agent_weights: Dict[str, float] = {}
self.last_updated = datetime.utcnow()
def update_agent_load(self, agent_id: str, load: float):
"""Update agent load information"""
self.agent_loads[agent_id] = load
self.last_updated = datetime.utcnow()
def set_agent_weight(self, agent_id: str, weight: float):
"""Set agent weight for load balancing"""
self.agent_weights[agent_id] = weight
def select_agent(self, available_agents: List[str], strategy: RoutingStrategy = RoutingStrategy.LOAD_BALANCED) -> Optional[str]:
"""Select agent based on load balancing strategy"""
if not available_agents:
return None
if strategy == RoutingStrategy.ROUND_ROBIN:
return self._round_robin_selection(available_agents)
elif strategy == RoutingStrategy.LOAD_BALANCED:
return self._load_balanced_selection(available_agents)
elif strategy == RoutingStrategy.PRIORITY_BASED:
return self._priority_based_selection(available_agents)
elif strategy == RoutingStrategy.RANDOM:
return self._random_selection(available_agents)
else:
return available_agents[0]
def _round_robin_selection(self, agents: List[str]) -> str:
"""Round-robin agent selection"""
agent = agents[self.load_balancer_index % len(agents)]
self.load_balancer_index += 1
return agent
def _load_balanced_selection(self, agents: List[str]) -> str:
"""Load-balanced agent selection"""
# Select agent with lowest load
min_load = float('inf')
selected_agent = None
for agent in agents:
load = self.agent_loads.get(agent, 0.0)
weight = self.agent_weights.get(agent, 1.0)
weighted_load = load / weight
if weighted_load < min_load:
min_load = weighted_load
selected_agent = agent
return selected_agent or agents[0]
def _priority_based_selection(self, agents: List[str]) -> str:
"""Priority-based agent selection"""
# Sort by weight (higher weight = higher priority)
weighted_agents = sorted(
agents,
key=lambda a: self.agent_weights.get(a, 1.0),
reverse=True
)
return weighted_agents[0]
def _random_selection(self, agents: List[str]) -> str:
"""Random agent selection"""
import random
return random.choice(agents)
class MessageQueue:
"""Advanced message queue with priority and persistence"""
def __init__(self, max_size: int = 10000):
self.max_size = max_size
self.queues: Dict[Priority, asyncio.Queue] = {
Priority.CRITICAL: asyncio.Queue(maxsize=max_size // 4),
Priority.HIGH: asyncio.Queue(maxsize=max_size // 4),
Priority.NORMAL: asyncio.Queue(maxsize // 2),
Priority.LOW: asyncio.Queue(maxsize // 4)
}
self.message_store: Dict[str, AgentMessage] = {}
self.delivery_confirmations: Dict[str, bool] = {}
async def enqueue(self, message: AgentMessage, delivery_mode: DeliveryMode = DeliveryMode.AT_LEAST_ONCE) -> bool:
"""Enqueue message with priority"""
try:
# Store message for persistence
if delivery_mode in [DeliveryMode.AT_LEAST_ONCE, DeliveryMode.EXACTLY_ONCE, DeliveryMode.PERSISTENT]:
self.message_store[message.id] = message
# Add to appropriate priority queue
queue = self.queues[message.priority]
await queue.put(message)
logger.debug(f"Enqueued message {message.id} with priority {message.priority}")
return True
except asyncio.QueueFull:
logger.error(f"Queue full, cannot enqueue message {message.id}")
return False
async def dequeue(self) -> Optional[AgentMessage]:
"""Dequeue message with priority order"""
# Check queues in priority order
for priority in [Priority.CRITICAL, Priority.HIGH, Priority.NORMAL, Priority.LOW]:
queue = self.queues[priority]
try:
message = queue.get_nowait()
logger.debug(f"Dequeued message {message.id} with priority {priority}")
return message
except asyncio.QueueEmpty:
continue
return None
async def confirm_delivery(self, message_id: str):
"""Confirm message delivery"""
self.delivery_confirmations[message_id] = True
# Clean up if exactly once delivery
if message_id in self.message_store:
del self.message_store[message_id]
def get_queue_stats(self) -> Dict[str, Any]:
"""Get queue statistics"""
return {
"queue_sizes": {
priority.value: queue.qsize()
for priority, queue in self.queues.items()
},
"stored_messages": len(self.message_store),
"delivery_confirmations": len(self.delivery_confirmations),
"max_size": self.max_size
}
class MessageProcessor:
"""Message processor with async handling"""
def __init__(self, agent_id: str):
self.agent_id = agent_id
self.router = MessageRouter(agent_id)
self.load_balancer = LoadBalancer()
self.message_queue = MessageQueue()
self.processors: Dict[str, Callable] = {}
self.processing_stats: Dict[str, Any] = {
"messages_processed": 0,
"processing_time_total": 0.0,
"errors": 0
}
def register_processor(self, message_type: MessageType, processor: Callable):
"""Register message processor"""
self.processors[message_type.value] = processor
logger.info(f"Registered processor for {message_type.value}")
async def process_message(self, message: AgentMessage) -> bool:
"""Process a message"""
start_time = datetime.utcnow()
try:
# Route message
route = await self.router.route_message(message)
if not route:
logger.warning(f"No route found for message {message.id}")
return False
# Process message
processor = self.processors.get(message.message_type.value)
if processor:
await processor(message)
else:
logger.warning(f"No processor found for {message.message_type.value}")
return False
# Update stats
self.processing_stats["messages_processed"] += 1
processing_time = (datetime.utcnow() - start_time).total_seconds()
self.processing_stats["processing_time_total"] += processing_time
return True
except Exception as e:
logger.error(f"Error processing message {message.id}: {e}")
self.processing_stats["errors"] += 1
return False
async def start_processing(self):
"""Start message processing loop"""
while True:
try:
# Dequeue message
message = await self.message_queue.dequeue()
if message:
await self.process_message(message)
else:
await asyncio.sleep(0.01) # Small delay if no messages
except Exception as e:
logger.error(f"Error in processing loop: {e}")
await asyncio.sleep(1)
def get_processing_stats(self) -> Dict[str, Any]:
"""Get processing statistics"""
total_processed = self.processing_stats["messages_processed"]
avg_processing_time = (
self.processing_stats["processing_time_total"] / total_processed
if total_processed > 0 else 0
)
return {
**self.processing_stats,
"avg_processing_time": avg_processing_time,
"queue_stats": self.message_queue.get_queue_stats(),
"routing_stats": self.router.get_routing_stats()
}
# Factory functions for creating message types
def create_task_message(sender_id: str, receiver_id: str, task_type: str, task_data: Dict[str, Any]) -> AgentMessage:
"""Create a task message"""
task_msg = TaskMessage(
task_id=str(uuid.uuid4()),
task_type=task_type,
task_data=task_data
)
return AgentMessage(
sender_id=sender_id,
receiver_id=receiver_id,
message_type=MessageType.TASK_ASSIGNMENT,
payload=task_msg.dict()
)
def create_coordination_message(sender_id: str, coordination_type: str, participants: List[str], data: Dict[str, Any]) -> AgentMessage:
"""Create a coordination message"""
coord_msg = CoordinationMessage(
coordination_id=str(uuid.uuid4()),
coordination_type=coordination_type,
participants=participants,
coordination_data=data
)
return AgentMessage(
sender_id=sender_id,
message_type=MessageType.COORDINATION,
payload=coord_msg.dict()
)
def create_status_message(agent_id: str, status_type: str, status_data: Dict[str, Any]) -> AgentMessage:
"""Create a status message"""
status_msg = StatusMessage(
agent_id=agent_id,
status_type=status_type,
status_data=status_data
)
return AgentMessage(
sender_id=agent_id,
message_type=MessageType.STATUS_UPDATE,
payload=status_msg.dict()
)
def create_discovery_message(agent_id: str, agent_type: str, capabilities: List[str], services: List[str]) -> AgentMessage:
"""Create a discovery message"""
discovery_msg = DiscoveryMessage(
agent_id=agent_id,
agent_type=agent_type,
capabilities=capabilities,
services=services
)
return AgentMessage(
sender_id=agent_id,
message_type=MessageType.DISCOVERY,
payload=discovery_msg.dict()
)
def create_consensus_message(sender_id: str, proposal: Dict[str, Any], voting_options: List[Dict[str, Any]], deadline: datetime) -> AgentMessage:
"""Create a consensus message"""
consensus_msg = ConsensusMessage(
consensus_id=str(uuid.uuid4()),
proposal=proposal,
voting_options=voting_options,
voting_deadline=deadline
)
return AgentMessage(
sender_id=sender_id,
message_type=MessageType.CONSENSUS,
payload=consensus_msg.dict()
)
# Example usage
async def example_usage():
"""Example of how to use the message routing system"""
# Create message processor
processor = MessageProcessor("agent-001")
# Register processors
async def process_task(message: AgentMessage):
task_data = TaskMessage(**message.payload)
logger.info(f"Processing task: {task_data.task_id}")
processor.register_processor(MessageType.TASK_ASSIGNMENT, process_task)
# Create and route message
task_message = create_task_message(
sender_id="agent-001",
receiver_id="agent-002",
task_type="data_processing",
task_data={"input": "test_data"}
)
await processor.message_queue.enqueue(task_message)
# Start processing (in real implementation, this would run in background)
# await processor.start_processing()
if __name__ == "__main__":
asyncio.run(example_usage())

View File

@@ -0,0 +1,641 @@
"""
Agent Discovery and Registration System for AITBC Agent Coordination
"""
import asyncio
import json
import logging
from typing import Dict, List, Optional, Set, Callable, Any
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import uuid
import hashlib
from enum import Enum
import redis.asyncio as redis
from pydantic import BaseModel, Field
from ..protocols.message_types import DiscoveryMessage, create_discovery_message
from ..protocols.communication import AgentMessage, MessageType
logger = logging.getLogger(__name__)
class AgentStatus(str, Enum):
"""Agent status enumeration"""
ACTIVE = "active"
INACTIVE = "inactive"
BUSY = "busy"
MAINTENANCE = "maintenance"
ERROR = "error"
class AgentType(str, Enum):
"""Agent type enumeration"""
COORDINATOR = "coordinator"
WORKER = "worker"
SPECIALIST = "specialist"
MONITOR = "monitor"
GATEWAY = "gateway"
ORCHESTRATOR = "orchestrator"
@dataclass
class AgentInfo:
"""Agent information structure"""
agent_id: str
agent_type: AgentType
status: AgentStatus
capabilities: List[str]
services: List[str]
endpoints: Dict[str, str]
metadata: Dict[str, Any]
last_heartbeat: datetime
registration_time: datetime
load_metrics: Dict[str, float] = field(default_factory=dict)
health_score: float = 1.0
version: str = "1.0.0"
tags: Set[str] = field(default_factory=set)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"agent_id": self.agent_id,
"agent_type": self.agent_type.value,
"status": self.status.value,
"capabilities": self.capabilities,
"services": self.services,
"endpoints": self.endpoints,
"metadata": self.metadata,
"last_heartbeat": self.last_heartbeat.isoformat(),
"registration_time": self.registration_time.isoformat(),
"load_metrics": self.load_metrics,
"health_score": self.health_score,
"version": self.version,
"tags": list(self.tags)
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentInfo":
"""Create from dictionary"""
data["agent_type"] = AgentType(data["agent_type"])
data["status"] = AgentStatus(data["status"])
data["last_heartbeat"] = datetime.fromisoformat(data["last_heartbeat"])
data["registration_time"] = datetime.fromisoformat(data["registration_time"])
data["tags"] = set(data.get("tags", []))
return cls(**data)
class AgentRegistry:
"""Central agent registry for discovery and management"""
def __init__(self, redis_url: str = "redis://localhost:6379/1"):
self.redis_url = redis_url
self.redis_client: Optional[redis.Redis] = None
self.agents: Dict[str, AgentInfo] = {}
self.service_index: Dict[str, Set[str]] = {} # service -> agent_ids
self.capability_index: Dict[str, Set[str]] = {} # capability -> agent_ids
self.type_index: Dict[AgentType, Set[str]] = {} # agent_type -> agent_ids
self.heartbeat_interval = 30 # seconds
self.cleanup_interval = 60 # seconds
self.max_heartbeat_age = 120 # seconds
async def start(self):
"""Start the registry service"""
self.redis_client = redis.from_url(self.redis_url)
# Load existing agents from Redis
await self._load_agents_from_redis()
# Start background tasks
asyncio.create_task(self._heartbeat_monitor())
asyncio.create_task(self._cleanup_inactive_agents())
logger.info("Agent registry started")
async def stop(self):
"""Stop the registry service"""
if self.redis_client:
await self.redis_client.close()
logger.info("Agent registry stopped")
async def register_agent(self, agent_info: AgentInfo) -> bool:
"""Register a new agent"""
try:
# Add to local registry
self.agents[agent_info.agent_id] = agent_info
# Update indexes
self._update_indexes(agent_info)
# Save to Redis
await self._save_agent_to_redis(agent_info)
# Publish registration event
await self._publish_agent_event("agent_registered", agent_info)
logger.info(f"Agent {agent_info.agent_id} registered successfully")
return True
except Exception as e:
logger.error(f"Error registering agent {agent_info.agent_id}: {e}")
return False
async def unregister_agent(self, agent_id: str) -> bool:
"""Unregister an agent"""
try:
if agent_id not in self.agents:
logger.warning(f"Agent {agent_id} not found for unregistration")
return False
agent_info = self.agents[agent_id]
# Remove from local registry
del self.agents[agent_id]
# Update indexes
self._remove_from_indexes(agent_info)
# Remove from Redis
await self._remove_agent_from_redis(agent_id)
# Publish unregistration event
await self._publish_agent_event("agent_unregistered", agent_info)
logger.info(f"Agent {agent_id} unregistered successfully")
return True
except Exception as e:
logger.error(f"Error unregistering agent {agent_id}: {e}")
return False
async def update_agent_status(self, agent_id: str, status: AgentStatus, load_metrics: Optional[Dict[str, float]] = None) -> bool:
"""Update agent status and metrics"""
try:
if agent_id not in self.agents:
logger.warning(f"Agent {agent_id} not found for status update")
return False
agent_info = self.agents[agent_id]
agent_info.status = status
agent_info.last_heartbeat = datetime.utcnow()
if load_metrics:
agent_info.load_metrics.update(load_metrics)
# Update health score
agent_info.health_score = self._calculate_health_score(agent_info)
# Save to Redis
await self._save_agent_to_redis(agent_info)
# Publish status update event
await self._publish_agent_event("agent_status_updated", agent_info)
return True
except Exception as e:
logger.error(f"Error updating agent status {agent_id}: {e}")
return False
async def update_agent_heartbeat(self, agent_id: str) -> bool:
"""Update agent heartbeat"""
try:
if agent_id not in self.agents:
logger.warning(f"Agent {agent_id} not found for heartbeat")
return False
agent_info = self.agents[agent_id]
agent_info.last_heartbeat = datetime.utcnow()
# Update health score
agent_info.health_score = self._calculate_health_score(agent_info)
# Save to Redis
await self._save_agent_to_redis(agent_info)
return True
except Exception as e:
logger.error(f"Error updating heartbeat for {agent_id}: {e}")
return False
async def discover_agents(self, query: Dict[str, Any]) -> List[AgentInfo]:
"""Discover agents based on query criteria"""
results = []
try:
# Start with all agents
candidate_agents = list(self.agents.values())
# Apply filters
if "agent_type" in query:
agent_type = AgentType(query["agent_type"])
candidate_agents = [a for a in candidate_agents if a.agent_type == agent_type]
if "status" in query:
status = AgentStatus(query["status"])
candidate_agents = [a for a in candidate_agents if a.status == status]
if "capabilities" in query:
required_capabilities = set(query["capabilities"])
candidate_agents = [a for a in candidate_agents if required_capabilities.issubset(a.capabilities)]
if "services" in query:
required_services = set(query["services"])
candidate_agents = [a for a in candidate_agents if required_services.issubset(a.services)]
if "tags" in query:
required_tags = set(query["tags"])
candidate_agents = [a for a in candidate_agents if required_tags.issubset(a.tags)]
if "min_health_score" in query:
min_score = query["min_health_score"]
candidate_agents = [a for a in candidate_agents if a.health_score >= min_score]
# Sort by health score (highest first)
results = sorted(candidate_agents, key=lambda a: a.health_score, reverse=True)
# Limit results if specified
if "limit" in query:
results = results[:query["limit"]]
logger.info(f"Discovered {len(results)} agents for query: {query}")
return results
except Exception as e:
logger.error(f"Error discovering agents: {e}")
return []
async def get_agent_by_id(self, agent_id: str) -> Optional[AgentInfo]:
"""Get agent information by ID"""
return self.agents.get(agent_id)
async def get_agents_by_service(self, service: str) -> List[AgentInfo]:
"""Get agents that provide a specific service"""
agent_ids = self.service_index.get(service, set())
return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents]
async def get_agents_by_capability(self, capability: str) -> List[AgentInfo]:
"""Get agents that have a specific capability"""
agent_ids = self.capability_index.get(capability, set())
return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents]
async def get_agents_by_type(self, agent_type: AgentType) -> List[AgentInfo]:
"""Get agents of a specific type"""
agent_ids = self.type_index.get(agent_type, set())
return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents]
async def get_registry_stats(self) -> Dict[str, Any]:
"""Get registry statistics"""
total_agents = len(self.agents)
status_counts = {}
type_counts = {}
for agent_info in self.agents.values():
# Count by status
status = agent_info.status.value
status_counts[status] = status_counts.get(status, 0) + 1
# Count by type
agent_type = agent_info.agent_type.value
type_counts[agent_type] = type_counts.get(agent_type, 0) + 1
return {
"total_agents": total_agents,
"status_counts": status_counts,
"type_counts": type_counts,
"service_count": len(self.service_index),
"capability_count": len(self.capability_index),
"last_cleanup": datetime.utcnow().isoformat()
}
def _update_indexes(self, agent_info: AgentInfo):
"""Update search indexes"""
# Service index
for service in agent_info.services:
if service not in self.service_index:
self.service_index[service] = set()
self.service_index[service].add(agent_info.agent_id)
# Capability index
for capability in agent_info.capabilities:
if capability not in self.capability_index:
self.capability_index[capability] = set()
self.capability_index[capability].add(agent_info.agent_id)
# Type index
if agent_info.agent_type not in self.type_index:
self.type_index[agent_info.agent_type] = set()
self.type_index[agent_info.agent_type].add(agent_info.agent_id)
def _remove_from_indexes(self, agent_info: AgentInfo):
"""Remove agent from search indexes"""
# Service index
for service in agent_info.services:
if service in self.service_index:
self.service_index[service].discard(agent_info.agent_id)
if not self.service_index[service]:
del self.service_index[service]
# Capability index
for capability in agent_info.capabilities:
if capability in self.capability_index:
self.capability_index[capability].discard(agent_info.agent_id)
if not self.capability_index[capability]:
del self.capability_index[capability]
# Type index
if agent_info.agent_type in self.type_index:
self.type_index[agent_info.agent_type].discard(agent_info.agent_id)
if not self.type_index[agent_info.agent_type]:
del self.type_index[agent_info.agent_type]
def _calculate_health_score(self, agent_info: AgentInfo) -> float:
"""Calculate agent health score"""
base_score = 1.0
# Penalty for high load
if agent_info.load_metrics:
avg_load = sum(agent_info.load_metrics.values()) / len(agent_info.load_metrics)
if avg_load > 0.8:
base_score -= 0.3
elif avg_load > 0.6:
base_score -= 0.1
# Penalty for error status
if agent_info.status == AgentStatus.ERROR:
base_score -= 0.5
elif agent_info.status == AgentStatus.MAINTENANCE:
base_score -= 0.2
elif agent_info.status == AgentStatus.BUSY:
base_score -= 0.1
# Penalty for old heartbeat
heartbeat_age = (datetime.utcnow() - agent_info.last_heartbeat).total_seconds()
if heartbeat_age > self.max_heartbeat_age:
base_score -= 0.5
elif heartbeat_age > self.max_heartbeat_age / 2:
base_score -= 0.2
return max(0.0, min(1.0, base_score))
async def _save_agent_to_redis(self, agent_info: AgentInfo):
"""Save agent information to Redis"""
if not self.redis_client:
return
key = f"agent:{agent_info.agent_id}"
await self.redis_client.setex(
key,
timedelta(hours=24), # 24 hour TTL
json.dumps(agent_info.to_dict())
)
async def _remove_agent_from_redis(self, agent_id: str):
"""Remove agent from Redis"""
if not self.redis_client:
return
key = f"agent:{agent_id}"
await self.redis_client.delete(key)
async def _load_agents_from_redis(self):
"""Load agents from Redis"""
if not self.redis_client:
return
try:
# Get all agent keys
keys = await self.redis_client.keys("agent:*")
for key in keys:
data = await self.redis_client.get(key)
if data:
agent_info = AgentInfo.from_dict(json.loads(data))
self.agents[agent_info.agent_id] = agent_info
self._update_indexes(agent_info)
logger.info(f"Loaded {len(self.agents)} agents from Redis")
except Exception as e:
logger.error(f"Error loading agents from Redis: {e}")
async def _publish_agent_event(self, event_type: str, agent_info: AgentInfo):
"""Publish agent event to Redis"""
if not self.redis_client:
return
event = {
"event_type": event_type,
"timestamp": datetime.utcnow().isoformat(),
"agent_info": agent_info.to_dict()
}
await self.redis_client.publish("agent_events", json.dumps(event))
async def _heartbeat_monitor(self):
"""Monitor agent heartbeats"""
while True:
try:
await asyncio.sleep(self.heartbeat_interval)
# Check for agents with old heartbeats
now = datetime.utcnow()
for agent_id, agent_info in list(self.agents.items()):
heartbeat_age = (now - agent_info.last_heartbeat).total_seconds()
if heartbeat_age > self.max_heartbeat_age:
# Mark as inactive
if agent_info.status != AgentStatus.INACTIVE:
await self.update_agent_status(agent_id, AgentStatus.INACTIVE)
logger.warning(f"Agent {agent_id} marked as inactive due to old heartbeat")
except Exception as e:
logger.error(f"Error in heartbeat monitor: {e}")
await asyncio.sleep(5)
async def _cleanup_inactive_agents(self):
"""Clean up inactive agents"""
while True:
try:
await asyncio.sleep(self.cleanup_interval)
# Remove agents that have been inactive too long
now = datetime.utcnow()
max_inactive_age = timedelta(hours=1) # 1 hour
for agent_id, agent_info in list(self.agents.items()):
if agent_info.status == AgentStatus.INACTIVE:
inactive_age = now - agent_info.last_heartbeat
if inactive_age > max_inactive_age:
await self.unregister_agent(agent_id)
logger.info(f"Removed inactive agent {agent_id}")
except Exception as e:
logger.error(f"Error in cleanup task: {e}")
await asyncio.sleep(5)
class AgentDiscoveryService:
"""Service for agent discovery and registration"""
def __init__(self, registry: AgentRegistry):
self.registry = registry
self.discovery_handlers: Dict[str, Callable] = {}
def register_discovery_handler(self, handler_name: str, handler: Callable):
"""Register a discovery handler"""
self.discovery_handlers[handler_name] = handler
logger.info(f"Registered discovery handler: {handler_name}")
async def handle_discovery_request(self, message: AgentMessage) -> Optional[AgentMessage]:
"""Handle agent discovery request"""
try:
discovery_data = DiscoveryMessage(**message.payload)
# Update or register agent
agent_info = AgentInfo(
agent_id=discovery_data.agent_id,
agent_type=AgentType(discovery_data.agent_type),
status=AgentStatus.ACTIVE,
capabilities=discovery_data.capabilities,
services=discovery_data.services,
endpoints=discovery_data.endpoints,
metadata=discovery_data.metadata,
last_heartbeat=datetime.utcnow(),
registration_time=datetime.utcnow()
)
# Register or update agent
if discovery_data.agent_id in self.registry.agents:
await self.registry.update_agent_status(discovery_data.agent_id, AgentStatus.ACTIVE)
else:
await self.registry.register_agent(agent_info)
# Send response with available agents
available_agents = await self.registry.discover_agents({
"status": "active",
"limit": 50
})
response_data = {
"discovery_agents": [agent.to_dict() for agent in available_agents],
"registry_stats": await self.registry.get_registry_stats()
}
response = AgentMessage(
sender_id="discovery_service",
receiver_id=message.sender_id,
message_type=MessageType.DISCOVERY,
payload=response_data,
correlation_id=message.id
)
return response
except Exception as e:
logger.error(f"Error handling discovery request: {e}")
return None
async def find_best_agent(self, requirements: Dict[str, Any]) -> Optional[AgentInfo]:
"""Find the best agent for given requirements"""
try:
# Build discovery query
query = {}
if "agent_type" in requirements:
query["agent_type"] = requirements["agent_type"]
if "capabilities" in requirements:
query["capabilities"] = requirements["capabilities"]
if "services" in requirements:
query["services"] = requirements["services"]
if "min_health_score" in requirements:
query["min_health_score"] = requirements["min_health_score"]
# Discover agents
agents = await self.registry.discover_agents(query)
if not agents:
return None
# Select best agent (highest health score)
return agents[0]
except Exception as e:
logger.error(f"Error finding best agent: {e}")
return None
async def get_service_endpoints(self, service: str) -> Dict[str, List[str]]:
"""Get all endpoints for a specific service"""
try:
agents = await self.registry.get_agents_by_service(service)
endpoints = {}
for agent in agents:
for service_name, endpoint in agent.endpoints.items():
if service_name not in endpoints:
endpoints[service_name] = []
endpoints[service_name].append(endpoint)
return endpoints
except Exception as e:
logger.error(f"Error getting service endpoints: {e}")
return {}
# Factory functions
def create_agent_info(agent_id: str, agent_type: str, capabilities: List[str], services: List[str], endpoints: Dict[str, str]) -> AgentInfo:
"""Create agent information"""
return AgentInfo(
agent_id=agent_id,
agent_type=AgentType(agent_type),
status=AgentStatus.ACTIVE,
capabilities=capabilities,
services=services,
endpoints=endpoints,
metadata={},
last_heartbeat=datetime.utcnow(),
registration_time=datetime.utcnow()
)
# Example usage
async def example_usage():
"""Example of how to use the agent discovery system"""
# Create registry
registry = AgentRegistry()
await registry.start()
# Create discovery service
discovery_service = AgentDiscoveryService(registry)
# Register an agent
agent_info = create_agent_info(
agent_id="agent-001",
agent_type="worker",
capabilities=["data_processing", "analysis"],
services=["process_data", "analyze_results"],
endpoints={"http": "http://localhost:8001", "ws": "ws://localhost:8002"}
)
await registry.register_agent(agent_info)
# Discover agents
agents = await registry.discover_agents({
"capabilities": ["data_processing"],
"status": "active"
})
print(f"Found {len(agents)} agents")
# Find best agent
best_agent = await discovery_service.find_best_agent({
"capabilities": ["data_processing"],
"min_health_score": 0.8
})
if best_agent:
print(f"Best agent: {best_agent.agent_id}")
await registry.stop()
if __name__ == "__main__":
asyncio.run(example_usage())

View File

@@ -0,0 +1,716 @@
"""
Load Balancer for Agent Distribution and Task Assignment
"""
import asyncio
import json
import logging
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
import statistics
import uuid
from collections import defaultdict, deque
from .agent_discovery import AgentRegistry, AgentInfo, AgentStatus, AgentType
from ..protocols.message_types import TaskMessage, create_task_message
from ..protocols.communication import AgentMessage, MessageType, Priority
logger = logging.getLogger(__name__)
class LoadBalancingStrategy(str, Enum):
"""Load balancing strategies"""
ROUND_ROBIN = "round_robin"
LEAST_CONNECTIONS = "least_connections"
LEAST_RESPONSE_TIME = "least_response_time"
WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
RESOURCE_BASED = "resource_based"
CAPABILITY_BASED = "capability_based"
PREDICTIVE = "predictive"
CONSISTENT_HASH = "consistent_hash"
class TaskPriority(str, Enum):
"""Task priority levels"""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
CRITICAL = "critical"
URGENT = "urgent"
@dataclass
class LoadMetrics:
"""Agent load metrics"""
cpu_usage: float = 0.0
memory_usage: float = 0.0
active_connections: int = 0
pending_tasks: int = 0
completed_tasks: int = 0
failed_tasks: int = 0
avg_response_time: float = 0.0
last_updated: datetime = field(default_factory=datetime.utcnow)
def to_dict(self) -> Dict[str, Any]:
return {
"cpu_usage": self.cpu_usage,
"memory_usage": self.memory_usage,
"active_connections": self.active_connections,
"pending_tasks": self.pending_tasks,
"completed_tasks": self.completed_tasks,
"failed_tasks": self.failed_tasks,
"avg_response_time": self.avg_response_time,
"last_updated": self.last_updated.isoformat()
}
@dataclass
class TaskAssignment:
"""Task assignment record"""
task_id: str
agent_id: str
assigned_at: datetime
completed_at: Optional[datetime] = None
status: str = "pending"
response_time: Optional[float] = None
success: bool = False
error_message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"task_id": self.task_id,
"agent_id": self.agent_id,
"assigned_at": self.assigned_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"status": self.status,
"response_time": self.response_time,
"success": self.success,
"error_message": self.error_message
}
@dataclass
class AgentWeight:
"""Agent weight for load balancing"""
agent_id: str
weight: float = 1.0
capacity: int = 100
performance_score: float = 1.0
reliability_score: float = 1.0
last_updated: datetime = field(default_factory=datetime.utcnow)
class LoadBalancer:
"""Advanced load balancer for agent distribution"""
def __init__(self, registry: AgentRegistry):
self.registry = registry
self.strategy = LoadBalancingStrategy.LEAST_CONNECTIONS
self.agent_weights: Dict[str, AgentWeight] = {}
self.agent_metrics: Dict[str, LoadMetrics] = {}
self.task_assignments: Dict[str, TaskAssignment] = {}
self.assignment_history: deque = deque(maxlen=1000)
self.round_robin_index = 0
self.consistent_hash_ring: Dict[int, str] = {}
self.prediction_models: Dict[str, Any] = {}
# Statistics
self.total_assignments = 0
self.successful_assignments = 0
self.failed_assignments = 0
def set_strategy(self, strategy: LoadBalancingStrategy):
"""Set load balancing strategy"""
self.strategy = strategy
logger.info(f"Load balancing strategy changed to: {strategy.value}")
def set_agent_weight(self, agent_id: str, weight: float, capacity: int = 100):
"""Set agent weight and capacity"""
self.agent_weights[agent_id] = AgentWeight(
agent_id=agent_id,
weight=weight,
capacity=capacity
)
logger.info(f"Set weight for agent {agent_id}: {weight}, capacity: {capacity}")
def update_agent_metrics(self, agent_id: str, metrics: LoadMetrics):
"""Update agent load metrics"""
self.agent_metrics[agent_id] = metrics
self.agent_metrics[agent_id].last_updated = datetime.utcnow()
# Update performance score based on metrics
self._update_performance_score(agent_id, metrics)
def _update_performance_score(self, agent_id: str, metrics: LoadMetrics):
"""Update agent performance score based on metrics"""
if agent_id not in self.agent_weights:
self.agent_weights[agent_id] = AgentWeight(agent_id=agent_id)
weight = self.agent_weights[agent_id]
# Calculate performance score (0.0 to 1.0)
performance_factors = []
# CPU usage factor (lower is better)
cpu_factor = max(0.0, 1.0 - metrics.cpu_usage)
performance_factors.append(cpu_factor)
# Memory usage factor (lower is better)
memory_factor = max(0.0, 1.0 - metrics.memory_usage)
performance_factors.append(memory_factor)
# Response time factor (lower is better)
if metrics.avg_response_time > 0:
response_factor = max(0.0, 1.0 - (metrics.avg_response_time / 10.0)) # 10s max
performance_factors.append(response_factor)
# Success rate factor (higher is better)
total_tasks = metrics.completed_tasks + metrics.failed_tasks
if total_tasks > 0:
success_rate = metrics.completed_tasks / total_tasks
performance_factors.append(success_rate)
# Update performance score
if performance_factors:
weight.performance_score = statistics.mean(performance_factors)
# Update reliability score
if total_tasks > 10: # Only update after enough tasks
weight.reliability_score = success_rate
async def assign_task(self, task_data: Dict[str, Any], requirements: Optional[Dict[str, Any]] = None) -> Optional[str]:
"""Assign task to best available agent"""
try:
# Find eligible agents
eligible_agents = await self._find_eligible_agents(task_data, requirements)
if not eligible_agents:
logger.warning("No eligible agents found for task assignment")
return None
# Select best agent based on strategy
selected_agent = await self._select_agent(eligible_agents, task_data)
if not selected_agent:
logger.warning("No agent selected for task assignment")
return None
# Create task assignment
task_id = str(uuid.uuid4())
assignment = TaskAssignment(
task_id=task_id,
agent_id=selected_agent,
assigned_at=datetime.utcnow()
)
# Record assignment
self.task_assignments[task_id] = assignment
self.assignment_history.append(assignment)
self.total_assignments += 1
# Update agent metrics
if selected_agent not in self.agent_metrics:
self.agent_metrics[selected_agent] = LoadMetrics()
self.agent_metrics[selected_agent].pending_tasks += 1
logger.info(f"Task {task_id} assigned to agent {selected_agent}")
return selected_agent
except Exception as e:
logger.error(f"Error assigning task: {e}")
self.failed_assignments += 1
return None
async def complete_task(self, task_id: str, success: bool, response_time: Optional[float] = None, error_message: Optional[str] = None):
"""Mark task as completed"""
try:
if task_id not in self.task_assignments:
logger.warning(f"Task assignment {task_id} not found")
return
assignment = self.task_assignments[task_id]
assignment.completed_at = datetime.utcnow()
assignment.status = "completed"
assignment.success = success
assignment.response_time = response_time
assignment.error_message = error_message
# Update agent metrics
agent_id = assignment.agent_id
if agent_id in self.agent_metrics:
metrics = self.agent_metrics[agent_id]
metrics.pending_tasks = max(0, metrics.pending_tasks - 1)
if success:
metrics.completed_tasks += 1
self.successful_assignments += 1
else:
metrics.failed_tasks += 1
self.failed_assignments += 1
# Update average response time
if response_time:
total_completed = metrics.completed_tasks + metrics.failed_tasks
if total_completed > 0:
metrics.avg_response_time = (
(metrics.avg_response_time * (total_completed - 1) + response_time) / total_completed
)
logger.info(f"Task {task_id} completed by agent {assignment.agent_id}, success: {success}")
except Exception as e:
logger.error(f"Error completing task {task_id}: {e}")
async def _find_eligible_agents(self, task_data: Dict[str, Any], requirements: Optional[Dict[str, Any]] = None) -> List[str]:
"""Find eligible agents for task"""
try:
# Build discovery query
query = {"status": AgentStatus.ACTIVE}
if requirements:
if "agent_type" in requirements:
query["agent_type"] = requirements["agent_type"]
if "capabilities" in requirements:
query["capabilities"] = requirements["capabilities"]
if "services" in requirements:
query["services"] = requirements["services"]
if "min_health_score" in requirements:
query["min_health_score"] = requirements["min_health_score"]
# Discover agents
agents = await self.registry.discover_agents(query)
# Filter by capacity and load
eligible_agents = []
for agent in agents:
agent_id = agent.agent_id
# Check capacity
if agent_id in self.agent_weights:
weight = self.agent_weights[agent_id]
current_load = self._get_agent_load(agent_id)
if current_load < weight.capacity:
eligible_agents.append(agent_id)
else:
# Default capacity check
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
if metrics.pending_tasks < 100: # Default capacity
eligible_agents.append(agent_id)
return eligible_agents
except Exception as e:
logger.error(f"Error finding eligible agents: {e}")
return []
def _get_agent_load(self, agent_id: str) -> int:
"""Get current load for agent"""
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
return metrics.active_connections + metrics.pending_tasks
async def _select_agent(self, eligible_agents: List[str], task_data: Dict[str, Any]) -> Optional[str]:
"""Select best agent based on current strategy"""
if not eligible_agents:
return None
if self.strategy == LoadBalancingStrategy.ROUND_ROBIN:
return self._round_robin_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.LEAST_CONNECTIONS:
return self._least_connections_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.LEAST_RESPONSE_TIME:
return self._least_response_time_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.WEIGHTED_ROUND_ROBIN:
return self._weighted_round_robin_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.RESOURCE_BASED:
return self._resource_based_selection(eligible_agents)
elif self.strategy == LoadBalancingStrategy.CAPABILITY_BASED:
return self._capability_based_selection(eligible_agents, task_data)
elif self.strategy == LoadBalancingStrategy.PREDICTIVE:
return self._predictive_selection(eligible_agents, task_data)
elif self.strategy == LoadBalancingStrategy.CONSISTENT_HASH:
return self._consistent_hash_selection(eligible_agents, task_data)
else:
return eligible_agents[0]
def _round_robin_selection(self, agents: List[str]) -> str:
"""Round-robin agent selection"""
agent = agents[self.round_robin_index % len(agents)]
self.round_robin_index += 1
return agent
def _least_connections_selection(self, agents: List[str]) -> str:
"""Select agent with least connections"""
min_connections = float('inf')
selected_agent = None
for agent_id in agents:
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
connections = metrics.active_connections
if connections < min_connections:
min_connections = connections
selected_agent = agent_id
return selected_agent or agents[0]
def _least_response_time_selection(self, agents: List[str]) -> str:
"""Select agent with least average response time"""
min_response_time = float('inf')
selected_agent = None
for agent_id in agents:
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
response_time = metrics.avg_response_time
if response_time < min_response_time:
min_response_time = response_time
selected_agent = agent_id
return selected_agent or agents[0]
def _weighted_round_robin_selection(self, agents: List[str]) -> str:
"""Weighted round-robin selection"""
# Calculate total weight
total_weight = 0
for agent_id in agents:
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
total_weight += weight.weight
if total_weight == 0:
return agents[0]
# Select agent based on weight
current_weight = self.round_robin_index % total_weight
accumulated_weight = 0
for agent_id in agents:
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
accumulated_weight += weight.weight
if current_weight < accumulated_weight:
self.round_robin_index += 1
return agent_id
return agents[0]
def _resource_based_selection(self, agents: List[str]) -> str:
"""Resource-based selection considering CPU and memory"""
best_score = -1
selected_agent = None
for agent_id in agents:
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
# Calculate resource score (lower usage is better)
cpu_score = max(0, 100 - metrics.cpu_usage)
memory_score = max(0, 100 - metrics.memory_usage)
resource_score = (cpu_score + memory_score) / 2
# Apply performance weight
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
final_score = resource_score * weight.performance_score
if final_score > best_score:
best_score = final_score
selected_agent = agent_id
return selected_agent or agents[0]
def _capability_based_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str:
"""Capability-based selection considering task requirements"""
required_capabilities = task_data.get("required_capabilities", [])
if not required_capabilities:
return agents[0]
best_score = -1
selected_agent = None
for agent_id in agents:
agent_info = self.registry.agents.get(agent_id)
if not agent_info:
continue
# Calculate capability match score
agent_capabilities = set(agent_info.capabilities)
required_set = set(required_capabilities)
if required_set.issubset(agent_capabilities):
# Perfect match
capability_score = 1.0
else:
# Partial match
intersection = required_set.intersection(agent_capabilities)
capability_score = len(intersection) / len(required_set)
# Apply performance weight
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
final_score = capability_score * weight.performance_score
if final_score > best_score:
best_score = final_score
selected_agent = agent_id
return selected_agent or agents[0]
def _predictive_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str:
"""Predictive selection using historical performance"""
task_type = task_data.get("task_type", "unknown")
# Calculate predicted performance for each agent
best_score = -1
selected_agent = None
for agent_id in agents:
# Get historical performance for this task type
score = self._calculate_predicted_score(agent_id, task_type)
if score > best_score:
best_score = score
selected_agent = agent_id
return selected_agent or agents[0]
def _calculate_predicted_score(self, agent_id: str, task_type: str) -> float:
"""Calculate predicted performance score for agent"""
# Simple prediction based on recent performance
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
# Base score from performance and reliability
base_score = (weight.performance_score + weight.reliability_score) / 2
# Adjust based on recent assignments
recent_assignments = [a for a in self.assignment_history if a.agent_id == agent_id][-10:]
if recent_assignments:
success_rate = sum(1 for a in recent_assignments if a.success) / len(recent_assignments)
base_score = base_score * 0.7 + success_rate * 0.3
return base_score
def _consistent_hash_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str:
"""Consistent hash selection for sticky routing"""
# Create hash key from task data
hash_key = json.dumps(task_data, sort_keys=True)
hash_value = int(hashlib.md5(hash_key.encode()).hexdigest(), 16)
# Build hash ring if not exists
if not self.consistent_hash_ring:
self._build_hash_ring(agents)
# Find agent on hash ring
for hash_pos in sorted(self.consistent_hash_ring.keys()):
if hash_value <= hash_pos:
return self.consistent_hash_ring[hash_pos]
# Wrap around
return self.consistent_hash_ring[min(self.consistent_hash_ring.keys())]
def _build_hash_ring(self, agents: List[str]):
"""Build consistent hash ring"""
self.consistent_hash_ring = {}
for agent_id in agents:
# Create multiple virtual nodes for better distribution
for i in range(100):
virtual_key = f"{agent_id}:{i}"
hash_value = int(hashlib.md5(virtual_key.encode()).hexdigest(), 16)
self.consistent_hash_ring[hash_value] = agent_id
def get_load_balancing_stats(self) -> Dict[str, Any]:
"""Get load balancing statistics"""
return {
"strategy": self.strategy.value,
"total_assignments": self.total_assignments,
"successful_assignments": self.successful_assignments,
"failed_assignments": self.failed_assignments,
"success_rate": self.successful_assignments / max(1, self.total_assignments),
"active_agents": len(self.agent_metrics),
"agent_weights": len(self.agent_weights),
"avg_agent_load": statistics.mean([self._get_agent_load(a) for a in self.agent_metrics]) if self.agent_metrics else 0
}
def get_agent_stats(self, agent_id: str) -> Optional[Dict[str, Any]]:
"""Get detailed statistics for a specific agent"""
if agent_id not in self.agent_metrics:
return None
metrics = self.agent_metrics[agent_id]
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
# Get recent assignments
recent_assignments = [a for a in self.assignment_history if a.agent_id == agent_id][-10:]
return {
"agent_id": agent_id,
"metrics": metrics.to_dict(),
"weight": {
"weight": weight.weight,
"capacity": weight.capacity,
"performance_score": weight.performance_score,
"reliability_score": weight.reliability_score
},
"recent_assignments": [a.to_dict() for a in recent_assignments],
"current_load": self._get_agent_load(agent_id)
}
class TaskDistributor:
"""Task distributor with advanced load balancing"""
def __init__(self, load_balancer: LoadBalancer):
self.load_balancer = load_balancer
self.task_queue = asyncio.Queue()
self.priority_queues = {
TaskPriority.URGENT: asyncio.Queue(),
TaskPriority.CRITICAL: asyncio.Queue(),
TaskPriority.HIGH: asyncio.Queue(),
TaskPriority.NORMAL: asyncio.Queue(),
TaskPriority.LOW: asyncio.Queue()
}
self.distribution_stats = {
"tasks_distributed": 0,
"tasks_completed": 0,
"tasks_failed": 0,
"avg_distribution_time": 0.0
}
async def submit_task(self, task_data: Dict[str, Any], priority: TaskPriority = TaskPriority.NORMAL, requirements: Optional[Dict[str, Any]] = None):
"""Submit task for distribution"""
task_info = {
"task_data": task_data,
"priority": priority,
"requirements": requirements,
"submitted_at": datetime.utcnow()
}
await self.priority_queues[priority].put(task_info)
logger.info(f"Task submitted with priority {priority.value}")
async def start_distribution(self):
"""Start task distribution loop"""
while True:
try:
# Check queues in priority order
task_info = None
for priority in [TaskPriority.URGENT, TaskPriority.CRITICAL, TaskPriority.HIGH, TaskPriority.NORMAL, TaskPriority.LOW]:
queue = self.priority_queues[priority]
try:
task_info = queue.get_nowait()
break
except asyncio.QueueEmpty:
continue
if task_info:
await self._distribute_task(task_info)
else:
await asyncio.sleep(0.01) # Small delay if no tasks
except Exception as e:
logger.error(f"Error in distribution loop: {e}")
await asyncio.sleep(1)
async def _distribute_task(self, task_info: Dict[str, Any]):
"""Distribute a single task"""
start_time = datetime.utcnow()
try:
# Assign task
agent_id = await self.load_balancer.assign_task(
task_info["task_data"],
task_info["requirements"]
)
if agent_id:
# Create task message
task_message = create_task_message(
sender_id="task_distributor",
receiver_id=agent_id,
task_type=task_info["task_data"].get("task_type", "unknown"),
task_data=task_info["task_data"]
)
# Send task to agent (implementation depends on communication system)
# await self._send_task_to_agent(agent_id, task_message)
self.distribution_stats["tasks_distributed"] += 1
# Simulate task completion (in real implementation, this would be event-driven)
asyncio.create_task(self._simulate_task_completion(task_info, agent_id))
else:
logger.warning(f"Failed to distribute task: no suitable agent found")
self.distribution_stats["tasks_failed"] += 1
except Exception as e:
logger.error(f"Error distributing task: {e}")
self.distribution_stats["tasks_failed"] += 1
finally:
# Update distribution time
distribution_time = (datetime.utcnow() - start_time).total_seconds()
total_distributed = self.distribution_stats["tasks_distributed"]
self.distribution_stats["avg_distribution_time"] = (
(self.distribution_stats["avg_distribution_time"] * (total_distributed - 1) + distribution_time) / total_distributed
if total_distributed > 0 else distribution_time
)
async def _simulate_task_completion(self, task_info: Dict[str, Any], agent_id: str):
"""Simulate task completion (for testing)"""
# Simulate task processing time
processing_time = 1.0 + (hash(task_info["task_data"].get("task_id", "")) % 5)
await asyncio.sleep(processing_time)
# Mark task as completed
success = hash(agent_id) % 10 > 1 # 90% success rate
await self.load_balancer.complete_task(
task_info["task_data"].get("task_id", str(uuid.uuid4())),
success,
processing_time
)
if success:
self.distribution_stats["tasks_completed"] += 1
else:
self.distribution_stats["tasks_failed"] += 1
def get_distribution_stats(self) -> Dict[str, Any]:
"""Get distribution statistics"""
return {
**self.distribution_stats,
"load_balancer_stats": self.load_balancer.get_load_balancing_stats(),
"queue_sizes": {
priority.value: queue.qsize()
for priority, queue in self.priority_queues.items()
}
}
# Example usage
async def example_usage():
"""Example of how to use the load balancer"""
# Create registry and load balancer
registry = AgentRegistry()
await registry.start()
load_balancer = LoadBalancer(registry)
load_balancer.set_strategy(LoadBalancingStrategy.LEAST_CONNECTIONS)
# Create task distributor
distributor = TaskDistributor(load_balancer)
# Submit some tasks
for i in range(10):
await distributor.submit_task({
"task_id": f"task-{i}",
"task_type": "data_processing",
"data": f"sample_data_{i}"
}, TaskPriority.NORMAL)
# Start distribution (in real implementation, this would run in background)
# await distributor.start_distribution()
await registry.stop()
if __name__ == "__main__":
asyncio.run(example_usage())

View File

@@ -0,0 +1,326 @@
"""
Tests for Agent Communication Protocols
"""
import pytest
import asyncio
from datetime import datetime, timedelta
from unittest.mock import Mock, AsyncMock
from src.app.protocols.communication import (
AgentMessage, MessageType, Priority, CommunicationProtocol,
HierarchicalProtocol, PeerToPeerProtocol, BroadcastProtocol,
CommunicationManager, MessageTemplates
)
class TestAgentMessage:
"""Test AgentMessage class"""
def test_message_creation(self):
"""Test message creation"""
message = AgentMessage(
sender_id="agent-001",
receiver_id="agent-002",
message_type=MessageType.DIRECT,
priority=Priority.NORMAL,
payload={"data": "test"}
)
assert message.sender_id == "agent-001"
assert message.receiver_id == "agent-002"
assert message.message_type == MessageType.DIRECT
assert message.priority == Priority.NORMAL
assert message.payload["data"] == "test"
assert message.ttl == 300
def test_message_serialization(self):
"""Test message serialization"""
message = AgentMessage(
sender_id="agent-001",
receiver_id="agent-002",
message_type=MessageType.DIRECT,
priority=Priority.NORMAL,
payload={"data": "test"}
)
# To dict
message_dict = message.to_dict()
assert message_dict["sender_id"] == "agent-001"
assert message_dict["message_type"] == "direct"
assert message_dict["priority"] == "normal"
# From dict
restored_message = AgentMessage.from_dict(message_dict)
assert restored_message.sender_id == message.sender_id
assert restored_message.receiver_id == message.receiver_id
assert restored_message.message_type == message.message_type
assert restored_message.priority == message.priority
def test_message_expiration(self):
"""Test message expiration"""
old_message = AgentMessage(
sender_id="agent-001",
receiver_id="agent-002",
message_type=MessageType.DIRECT,
timestamp=datetime.utcnow() - timedelta(seconds=400),
ttl=300
)
# Message should be expired
age = (datetime.utcnow() - old_message.timestamp).total_seconds()
assert age > old_message.ttl
class TestHierarchicalProtocol:
"""Test HierarchicalProtocol class"""
@pytest.fixture
def master_protocol(self):
"""Create master protocol"""
return HierarchicalProtocol("master-agent", is_master=True)
@pytest.fixture
def sub_protocol(self):
"""Create sub-agent protocol"""
return HierarchicalProtocol("sub-agent", is_master=False)
def test_add_sub_agent(self, master_protocol):
"""Test adding sub-agent"""
master_protocol.add_sub_agent("sub-agent-001")
assert "sub-agent-001" in master_protocol.sub_agents
def test_send_to_sub_agents(self, master_protocol):
"""Test sending to sub-agents"""
master_protocol.add_sub_agent("sub-agent-001")
master_protocol.add_sub_agent("sub-agent-002")
message = MessageTemplates.create_heartbeat("master-agent")
# Mock the send_message method
master_protocol.send_message = AsyncMock(return_value=True)
# Should send to both sub-agents
asyncio.run(master_protocol.send_to_sub_agents(message))
# Check that send_message was called twice
assert master_protocol.send_message.call_count == 2
def test_send_to_master(self, sub_protocol):
"""Test sending to master"""
sub_protocol.master_agent = "master-agent"
message = MessageTemplates.create_status_update("sub-agent", {"status": "active"})
# Mock the send_message method
sub_protocol.send_message = AsyncMock(return_value=True)
asyncio.run(sub_protocol.send_to_master(message))
# Check that send_message was called once
assert sub_protocol.send_message.call_count == 1
class TestPeerToPeerProtocol:
"""Test PeerToPeerProtocol class"""
@pytest.fixture
def p2p_protocol(self):
"""Create P2P protocol"""
return PeerToPeerProtocol("agent-001")
def test_add_peer(self, p2p_protocol):
"""Test adding peer"""
p2p_protocol.add_peer("agent-002", {"endpoint": "http://localhost:8002"})
assert "agent-002" in p2p_protocol.peers
assert p2p_protocol.peers["agent-002"]["endpoint"] == "http://localhost:8002"
def test_remove_peer(self, p2p_protocol):
"""Test removing peer"""
p2p_protocol.add_peer("agent-002", {"endpoint": "http://localhost:8002"})
p2p_protocol.remove_peer("agent-002")
assert "agent-002" not in p2p_protocol.peers
def test_send_to_peer(self, p2p_protocol):
"""Test sending to peer"""
p2p_protocol.add_peer("agent-002", {"endpoint": "http://localhost:8002"})
message = MessageTemplates.create_task_assignment(
"agent-001", "agent-002", {"task": "test"}
)
# Mock the send_message method
p2p_protocol.send_message = AsyncMock(return_value=True)
result = asyncio.run(p2p_protocol.send_to_peer(message, "agent-002"))
assert result is True
assert p2p_protocol.send_message.call_count == 1
class TestBroadcastProtocol:
"""Test BroadcastProtocol class"""
@pytest.fixture
def broadcast_protocol(self):
"""Create broadcast protocol"""
return BroadcastProtocol("agent-001", "test-channel")
def test_subscribe_unsubscribe(self, broadcast_protocol):
"""Test subscribe and unsubscribe"""
broadcast_protocol.subscribe("agent-002")
assert "agent-002" in broadcast_protocol.subscribers
broadcast_protocol.unsubscribe("agent-002")
assert "agent-002" not in broadcast_protocol.subscribers
def test_broadcast(self, broadcast_protocol):
"""Test broadcasting"""
broadcast_protocol.subscribe("agent-002")
broadcast_protocol.subscribe("agent-003")
message = MessageTemplates.create_discovery("agent-001")
# Mock the send_message method
broadcast_protocol.send_message = AsyncMock(return_value=True)
asyncio.run(broadcast_protocol.broadcast(message))
# Should send to 2 subscribers (not including self)
assert broadcast_protocol.send_message.call_count == 2
class TestCommunicationManager:
"""Test CommunicationManager class"""
@pytest.fixture
def comm_manager(self):
"""Create communication manager"""
return CommunicationManager("agent-001")
def test_add_protocol(self, comm_manager):
"""Test adding protocol"""
protocol = Mock(spec=CommunicationProtocol)
comm_manager.add_protocol("test", protocol)
assert "test" in comm_manager.protocols
assert comm_manager.protocols["test"] == protocol
def test_get_protocol(self, comm_manager):
"""Test getting protocol"""
protocol = Mock(spec=CommunicationProtocol)
comm_manager.add_protocol("test", protocol)
retrieved_protocol = comm_manager.get_protocol("test")
assert retrieved_protocol == protocol
# Test non-existent protocol
assert comm_manager.get_protocol("non-existent") is None
@pytest.mark.asyncio
async def test_send_message(self, comm_manager):
"""Test sending message"""
protocol = Mock(spec=CommunicationProtocol)
protocol.send_message = AsyncMock(return_value=True)
comm_manager.add_protocol("test", protocol)
message = MessageTemplates.create_heartbeat("agent-001")
result = await comm_manager.send_message("test", message)
assert result is True
protocol.send_message.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_register_handler(self, comm_manager):
"""Test registering handler"""
protocol = Mock(spec=CommunicationProtocol)
protocol.register_handler = AsyncMock()
comm_manager.add_protocol("test", protocol)
handler = AsyncMock()
await comm_manager.register_handler("test", MessageType.HEARTBEAT, handler)
protocol.register_handler.assert_called_once_with(MessageType.HEARTBEAT, handler)
class TestMessageTemplates:
"""Test MessageTemplates class"""
def test_create_heartbeat(self):
"""Test creating heartbeat message"""
message = MessageTemplates.create_heartbeat("agent-001")
assert message.sender_id == "agent-001"
assert message.message_type == MessageType.HEARTBEAT
assert message.priority == Priority.LOW
assert "timestamp" in message.payload
def test_create_task_assignment(self):
"""Test creating task assignment message"""
task_data = {"task_id": "task-001", "task_type": "process_data"}
message = MessageTemplates.create_task_assignment("agent-001", "agent-002", task_data)
assert message.sender_id == "agent-001"
assert message.receiver_id == "agent-002"
assert message.message_type == MessageType.TASK_ASSIGNMENT
assert message.payload == task_data
def test_create_status_update(self):
"""Test creating status update message"""
status_data = {"status": "active", "load": 0.5}
message = MessageTemplates.create_status_update("agent-001", status_data)
assert message.sender_id == "agent-001"
assert message.message_type == MessageType.STATUS_UPDATE
assert message.payload == status_data
def test_create_discovery(self):
"""Test creating discovery message"""
message = MessageTemplates.create_discovery("agent-001")
assert message.sender_id == "agent-001"
assert message.message_type == MessageType.DISCOVERY
assert message.payload["agent_id"] == "agent-001"
def test_create_consensus_request(self):
"""Test creating consensus request message"""
proposal_data = {"proposal": "test_proposal"}
message = MessageTemplates.create_consensus_request("agent-001", proposal_data)
assert message.sender_id == "agent-001"
assert message.message_type == MessageType.CONSENSUS
assert message.priority == Priority.HIGH
assert message.payload == proposal_data
# Integration tests
class TestCommunicationIntegration:
"""Integration tests for communication system"""
@pytest.mark.asyncio
async def test_message_flow(self):
"""Test complete message flow"""
# Create communication manager
comm_manager = CommunicationManager("agent-001")
# Create protocols
hierarchical = HierarchicalProtocol("agent-001", is_master=True)
p2p = PeerToPeerProtocol("agent-001")
# Add protocols
comm_manager.add_protocol("hierarchical", hierarchical)
comm_manager.add_protocol("p2p", p2p)
# Mock message sending
hierarchical.send_message = AsyncMock(return_value=True)
p2p.send_message = AsyncMock(return_value=True)
# Register handler
async def handle_heartbeat(message):
assert message.sender_id == "agent-002"
assert message.message_type == MessageType.HEARTBEAT
await comm_manager.register_handler("hierarchical", MessageType.HEARTBEAT, handle_heartbeat)
# Send heartbeat
heartbeat = MessageTemplates.create_heartbeat("agent-001")
result = await comm_manager.send_message("hierarchical", heartbeat)
assert result is True
hierarchical.send_message.assert_called_once()
if __name__ == "__main__":
pytest.main([__file__])