diff --git a/apps/agent-coordinator/Dockerfile b/apps/agent-coordinator/Dockerfile new file mode 100644 index 00000000..522897e1 --- /dev/null +++ b/apps/agent-coordinator/Dockerfile @@ -0,0 +1,39 @@ +FROM python:3.11-slim + +# Set working directory +WORKDIR /app + +# Set environment variables +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV PYTHONPATH=/app/src + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + g++ \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY pyproject.toml poetry.lock ./ +RUN pip install poetry && \ + poetry config virtualenvs.create false && \ + poetry install --no-dev --no-interaction --no-ansi + +# Copy application code +COPY src/ ./src/ + +# Create non-root user +RUN useradd --create-home --shell /bin/bash app && \ + chown -R app:app /app +USER app + +# Health check +HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:9001/health || exit 1 + +# Expose port +EXPOSE 9001 + +# Start the application +CMD ["poetry", "run", "python", "-m", "uvicorn", "src.app.main:app", "--host", "0.0.0.0", "--port", "9001"] diff --git a/apps/agent-coordinator/src/app/config.py b/apps/agent-coordinator/src/app/config.py new file mode 100644 index 00000000..a0f17842 --- /dev/null +++ b/apps/agent-coordinator/src/app/config.py @@ -0,0 +1,460 @@ +""" +Configuration Management for AITBC Agent Coordinator +""" + +import os +from typing import Dict, Any, Optional +from pydantic import BaseSettings, Field +from enum import Enum + +class Environment(str, Enum): + """Environment types""" + DEVELOPMENT = "development" + TESTING = "testing" + STAGING = "staging" + PRODUCTION = "production" + +class LogLevel(str, Enum): + """Log levels""" + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + +class Settings(BaseSettings): + """Application settings""" + + # Application settings + app_name: str = "AITBC Agent Coordinator" + app_version: str = "1.0.0" + environment: Environment = Environment.DEVELOPMENT + debug: bool = False + + # Server settings + host: str = "0.0.0.0" + port: int = 9001 + workers: int = 1 + + # Redis settings + redis_url: str = "redis://localhost:6379/1" + redis_max_connections: int = 10 + redis_timeout: int = 5 + + # Database settings (if needed) + database_url: Optional[str] = None + + # Agent registry settings + heartbeat_interval: int = 30 # seconds + max_heartbeat_age: int = 120 # seconds + cleanup_interval: int = 60 # seconds + agent_ttl: int = 86400 # 24 hours in seconds + + # Load balancer settings + default_strategy: str = "least_connections" + max_task_queue_size: int = 10000 + task_timeout: int = 300 # 5 minutes + + # Communication settings + message_ttl: int = 300 # 5 minutes + max_message_size: int = 1024 * 1024 # 1MB + connection_timeout: int = 30 + + # Security settings + secret_key: str = "your-secret-key-change-in-production" + allowed_hosts: list = ["*"] + cors_origins: list = ["*"] + + # Monitoring settings + enable_metrics: bool = True + metrics_port: int = 9002 + health_check_interval: int = 30 + + # Logging settings + log_level: LogLevel = LogLevel.INFO + log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + log_file: Optional[str] = None + + # Performance settings + max_concurrent_tasks: int = 100 + task_batch_size: int = 10 + load_balancer_cache_size: int = 1000 + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + case_sensitive = False + +# Global settings instance +settings = Settings() + +# Configuration constants +class ConfigConstants: + """Configuration constants""" + + # Agent types + AGENT_TYPES = [ + "coordinator", + "worker", + "specialist", + "monitor", + "gateway", + "orchestrator" + ] + + # Agent statuses + AGENT_STATUSES = [ + "active", + "inactive", + "busy", + "maintenance", + "error" + ] + + # Message types + MESSAGE_TYPES = [ + "coordination", + "task_assignment", + "status_update", + "discovery", + "heartbeat", + "consensus", + "broadcast", + "direct", + "peer_to_peer", + "hierarchical" + ] + + # Task priorities + TASK_PRIORITIES = [ + "low", + "normal", + "high", + "critical", + "urgent" + ] + + # Load balancing strategies + LOAD_BALANCING_STRATEGIES = [ + "round_robin", + "least_connections", + "least_response_time", + "weighted_round_robin", + "resource_based", + "capability_based", + "predictive", + "consistent_hash" + ] + + # Default ports + DEFAULT_PORTS = { + "agent_coordinator": 9001, + "agent_registry": 9002, + "task_distributor": 9003, + "metrics": 9004, + "health": 9005 + } + + # Timeouts (in seconds) + TIMEOUTS = { + "connection": 30, + "message": 300, + "task": 600, + "heartbeat": 120, + "cleanup": 3600 + } + + # Limits + LIMITS = { + "max_message_size": 1024 * 1024, # 1MB + "max_task_queue_size": 10000, + "max_concurrent_tasks": 100, + "max_agent_connections": 1000, + "max_redis_connections": 10 + } + +# Environment-specific configurations +class EnvironmentConfig: + """Environment-specific configurations""" + + @staticmethod + def get_development_config() -> Dict[str, Any]: + """Development environment configuration""" + return { + "debug": True, + "log_level": LogLevel.DEBUG, + "reload": True, + "workers": 1, + "redis_url": "redis://localhost:6379/1", + "enable_metrics": True + } + + @staticmethod + def get_testing_config() -> Dict[str, Any]: + """Testing environment configuration""" + return { + "debug": True, + "log_level": LogLevel.DEBUG, + "redis_url": "redis://localhost:6379/15", # Separate DB for testing + "enable_metrics": False, + "heartbeat_interval": 5, # Faster for testing + "cleanup_interval": 10 + } + + @staticmethod + def get_staging_config() -> Dict[str, Any]: + """Staging environment configuration""" + return { + "debug": False, + "log_level": LogLevel.INFO, + "redis_url": "redis://localhost:6379/2", + "enable_metrics": True, + "workers": 2, + "cors_origins": ["https://staging.aitbc.com"] + } + + @staticmethod + def get_production_config() -> Dict[str, Any]: + """Production environment configuration""" + return { + "debug": False, + "log_level": LogLevel.WARNING, + "redis_url": os.getenv("REDIS_URL", "redis://localhost:6379/0"), + "enable_metrics": True, + "workers": 4, + "cors_origins": ["https://aitbc.com"], + "secret_key": os.getenv("SECRET_KEY", "change-this-in-production"), + "allowed_hosts": ["aitbc.com", "www.aitbc.com"] + } + +# Configuration loader +class ConfigLoader: + """Configuration loader and validator""" + + @staticmethod + def load_config() -> Settings: + """Load and validate configuration""" + # Get environment-specific config + env_config = {} + if settings.environment == Environment.DEVELOPMENT: + env_config = EnvironmentConfig.get_development_config() + elif settings.environment == Environment.TESTING: + env_config = EnvironmentConfig.get_testing_config() + elif settings.environment == Environment.STAGING: + env_config = EnvironmentConfig.get_staging_config() + elif settings.environment == Environment.PRODUCTION: + env_config = EnvironmentConfig.get_production_config() + + # Update settings with environment-specific config + for key, value in env_config.items(): + if hasattr(settings, key): + setattr(settings, key, value) + + # Validate configuration + ConfigLoader.validate_config() + + return settings + + @staticmethod + def validate_config(): + """Validate configuration settings""" + errors = [] + + # Validate required settings + if not settings.secret_key or settings.secret_key == "your-secret-key-change-in-production": + if settings.environment == Environment.PRODUCTION: + errors.append("SECRET_KEY must be set in production") + + # Validate ports + if settings.port < 1 or settings.port > 65535: + errors.append("Port must be between 1 and 65535") + + # Validate Redis URL + if not settings.redis_url: + errors.append("Redis URL is required") + + # Validate timeouts + if settings.heartbeat_interval <= 0: + errors.append("Heartbeat interval must be positive") + + if settings.max_heartbeat_age <= settings.heartbeat_interval: + errors.append("Max heartbeat age must be greater than heartbeat interval") + + # Validate limits + if settings.max_message_size <= 0: + errors.append("Max message size must be positive") + + if settings.max_task_queue_size <= 0: + errors.append("Max task queue size must be positive") + + # Validate strategy + if settings.default_strategy not in ConfigConstants.LOAD_BALANCING_STRATEGIES: + errors.append(f"Invalid load balancing strategy: {settings.default_strategy}") + + if errors: + raise ValueError(f"Configuration validation failed: {', '.join(errors)}") + + @staticmethod + def get_redis_config() -> Dict[str, Any]: + """Get Redis configuration""" + return { + "url": settings.redis_url, + "max_connections": settings.redis_max_connections, + "timeout": settings.redis_timeout, + "decode_responses": True, + "socket_keepalive": True, + "socket_keepalive_options": {}, + "health_check_interval": 30 + } + + @staticmethod + def get_logging_config() -> Dict[str, Any]: + """Get logging configuration""" + return { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": settings.log_format, + "datefmt": "%Y-%m-%d %H:%M:%S" + }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S" + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": settings.log_level.value, + "formatter": "default", + "stream": "ext://sys.stdout" + } + }, + "loggers": { + "": { + "level": settings.log_level.value, + "handlers": ["console"] + }, + "uvicorn": { + "level": "INFO", + "handlers": ["console"], + "propagate": False + }, + "fastapi": { + "level": "INFO", + "handlers": ["console"], + "propagate": False + } + } + } + +# Configuration utilities +class ConfigUtils: + """Configuration utilities""" + + @staticmethod + def get_agent_config(agent_type: str) -> Dict[str, Any]: + """Get configuration for specific agent type""" + base_config = { + "heartbeat_interval": settings.heartbeat_interval, + "max_connections": 100, + "timeout": settings.connection_timeout + } + + # Agent-specific configurations + agent_configs = { + "coordinator": { + **base_config, + "max_connections": 1000, + "heartbeat_interval": 15, + "enable_coordination": True + }, + "worker": { + **base_config, + "max_connections": 50, + "task_timeout": 300, + "enable_coordination": False + }, + "specialist": { + **base_config, + "max_connections": 25, + "specialization_timeout": 600, + "enable_coordination": True + }, + "monitor": { + **base_config, + "heartbeat_interval": 10, + "enable_coordination": True, + "monitoring_interval": 30 + }, + "gateway": { + **base_config, + "max_connections": 2000, + "enable_coordination": True, + "gateway_timeout": 60 + }, + "orchestrator": { + **base_config, + "max_connections": 500, + "heartbeat_interval": 5, + "enable_coordination": True, + "orchestration_timeout": 120 + } + } + + return agent_configs.get(agent_type, base_config) + + @staticmethod + def get_service_config(service_name: str) -> Dict[str, Any]: + """Get configuration for specific service""" + base_config = { + "host": settings.host, + "port": settings.port, + "workers": settings.workers, + "timeout": settings.connection_timeout + } + + # Service-specific configurations + service_configs = { + "agent_coordinator": { + **base_config, + "port": ConfigConstants.DEFAULT_PORTS["agent_coordinator"], + "enable_metrics": settings.enable_metrics + }, + "agent_registry": { + **base_config, + "port": ConfigConstants.DEFAULT_PORTS["agent_registry"], + "enable_metrics": False + }, + "task_distributor": { + **base_config, + "port": ConfigConstants.DEFAULT_PORTS["task_distributor"], + "max_queue_size": settings.max_task_queue_size + }, + "metrics": { + **base_config, + "port": ConfigConstants.DEFAULT_PORTS["metrics"], + "enable_metrics": True + }, + "health": { + **base_config, + "port": ConfigConstants.DEFAULT_PORTS["health"], + "enable_metrics": False + } + } + + return service_configs.get(service_name, base_config) + +# Load configuration +config = ConfigLoader.load_config() + +# Export settings and utilities +__all__ = [ + "settings", + "config", + "ConfigConstants", + "EnvironmentConfig", + "ConfigLoader", + "ConfigUtils" +] diff --git a/apps/agent-coordinator/src/app/main.py b/apps/agent-coordinator/src/app/main.py new file mode 100644 index 00000000..b6eb6ca8 --- /dev/null +++ b/apps/agent-coordinator/src/app/main.py @@ -0,0 +1,518 @@ +""" +Main FastAPI Application for AITBC Agent Coordinator +""" + +import asyncio +import logging +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Dict, List, Optional, Any +import uuid + +from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field +import uvicorn + +from .protocols.communication import CommunicationManager, create_protocol, MessageType +from .protocols.message_types import MessageProcessor, create_task_message, create_status_message +from .routing.agent_discovery import AgentRegistry, AgentDiscoveryService, create_agent_info +from .routing.load_balancer import LoadBalancer, TaskDistributor, TaskPriority, LoadBalancingStrategy + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# Global variables +agent_registry: Optional[AgentRegistry] = None +discovery_service: Optional[AgentDiscoveryService] = None +load_balancer: Optional[LoadBalancer] = None +task_distributor: Optional[TaskDistributor] = None +communication_manager: Optional[CommunicationManager] = None +message_processor: Optional[MessageProcessor] = None + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan management""" + # Startup + logger.info("Starting AITBC Agent Coordinator...") + + # Initialize services + global agent_registry, discovery_service, load_balancer, task_distributor, communication_manager, message_processor + + # Start agent registry + agent_registry = AgentRegistry() + await agent_registry.start() + + # Initialize discovery service + discovery_service = AgentDiscoveryService(agent_registry) + + # Initialize load balancer + load_balancer = LoadBalancer(agent_registry) + load_balancer.set_strategy(LoadBalancingStrategy.LEAST_CONNECTIONS) + + # Initialize task distributor + task_distributor = TaskDistributor(load_balancer) + + # Initialize communication manager + communication_manager = CommunicationManager("agent-coordinator") + + # Initialize message processor + message_processor = MessageProcessor("agent-coordinator") + + # Start background tasks + asyncio.create_task(task_distributor.start_distribution()) + asyncio.create_task(message_processor.start_processing()) + + logger.info("Agent Coordinator started successfully") + + yield + + # Shutdown + logger.info("Shutting down AITBC Agent Coordinator...") + + if agent_registry: + await agent_registry.stop() + + logger.info("Agent Coordinator shut down") + +# Create FastAPI app +app = FastAPI( + title="AITBC Agent Coordinator", + description="Advanced multi-agent coordination and management system", + version="1.0.0", + lifespan=lifespan +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Pydantic models +class AgentRegistrationRequest(BaseModel): + agent_id: str = Field(..., description="Unique agent identifier") + agent_type: str = Field(..., description="Type of agent") + capabilities: List[str] = Field(default_factory=list, description="Agent capabilities") + services: List[str] = Field(default_factory=list, description="Available services") + endpoints: Dict[str, str] = Field(default_factory=dict, description="Service endpoints") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") + +class AgentStatusUpdate(BaseModel): + status: str = Field(..., description="Agent status") + load_metrics: Dict[str, float] = Field(default_factory=dict, description="Load metrics") + +class TaskSubmission(BaseModel): + task_data: Dict[str, Any] = Field(..., description="Task data") + priority: str = Field("normal", description="Task priority") + requirements: Optional[Dict[str, Any]] = Field(None, description="Task requirements") + +class MessageRequest(BaseModel): + receiver_id: str = Field(..., description="Receiver agent ID") + message_type: str = Field(..., description="Message type") + payload: Dict[str, Any] = Field(..., description="Message payload") + priority: str = Field("normal", description="Message priority") + +# Health check endpoint +@app.get("/health") +async def health_check(): + """Health check endpoint""" + return { + "status": "healthy", + "service": "agent-coordinator", + "timestamp": datetime.utcnow().isoformat(), + "version": "1.0.0" + } + +# Root endpoint +@app.get("/") +async def root(): + """Root endpoint with service information""" + return { + "service": "AITBC Agent Coordinator", + "description": "Advanced multi-agent coordination and management system", + "version": "1.0.0", + "endpoints": [ + "/health", + "/agents/register", + "/agents/discover", + "/agents/{agent_id}", + "/agents/{agent_id}/status", + "/tasks/submit", + "/tasks/status", + "/messages/send", + "/load-balancer/stats", + "/registry/stats" + ] + } + +# Agent registration +@app.post("/agents/register") +async def register_agent(request: AgentRegistrationRequest): + """Register a new agent""" + try: + if not agent_registry: + raise HTTPException(status_code=503, detail="Agent registry not available") + + # Create agent info + agent_info = create_agent_info( + agent_id=request.agent_id, + agent_type=request.agent_type, + capabilities=request.capabilities, + services=request.services, + endpoints=request.endpoints + ) + agent_info.metadata = request.metadata + + # Register agent + success = await agent_registry.register_agent(agent_info) + + if success: + return { + "status": "success", + "message": f"Agent {request.agent_id} registered successfully", + "agent_id": request.agent_id, + "registered_at": datetime.utcnow().isoformat() + } + else: + raise HTTPException(status_code=500, detail="Failed to register agent") + + except Exception as e: + logger.error(f"Error registering agent: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Agent discovery +@app.post("/agents/discover") +async def discover_agents(query: Dict[str, Any]): + """Discover agents based on criteria""" + try: + if not agent_registry: + raise HTTPException(status_code=503, detail="Agent registry not available") + + agents = await agent_registry.discover_agents(query) + + return { + "status": "success", + "query": query, + "agents": [agent.to_dict() for agent in agents], + "count": len(agents), + "timestamp": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Error discovering agents: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Get agent by ID +@app.get("/agents/{agent_id}") +async def get_agent(agent_id: str): + """Get agent information by ID""" + try: + if not agent_registry: + raise HTTPException(status_code=503, detail="Agent registry not available") + + agent = await agent_registry.get_agent_by_id(agent_id) + + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + + return { + "status": "success", + "agent": agent.to_dict(), + "timestamp": datetime.utcnow().isoformat() + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting agent: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Update agent status +@app.put("/agents/{agent_id}/status") +async def update_agent_status(agent_id: str, request: AgentStatusUpdate): + """Update agent status""" + try: + if not agent_registry: + raise HTTPException(status_code=503, detail="Agent registry not available") + + from .routing.agent_discovery import AgentStatus + + success = await agent_registry.update_agent_status( + agent_id, + AgentStatus(request.status), + request.load_metrics + ) + + if success: + return { + "status": "success", + "message": f"Agent {agent_id} status updated", + "agent_id": agent_id, + "new_status": request.status, + "updated_at": datetime.utcnow().isoformat() + } + else: + raise HTTPException(status_code=500, detail="Failed to update agent status") + + except Exception as e: + logger.error(f"Error updating agent status: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Submit task +@app.post("/tasks/submit") +async def submit_task(request: TaskSubmission, background_tasks: BackgroundTasks): + """Submit a task for distribution""" + try: + if not task_distributor: + raise HTTPException(status_code=503, detail="Task distributor not available") + + # Convert priority string to enum + try: + priority = TaskPriority(request.priority.lower()) + except ValueError: + raise HTTPException(status_code=400, detail=f"Invalid priority: {request.priority}") + + # Submit task + await task_distributor.submit_task( + request.task_data, + priority, + request.requirements + ) + + return { + "status": "success", + "message": "Task submitted successfully", + "task_id": request.task_data.get("task_id", str(uuid.uuid4())), + "priority": request.priority, + "submitted_at": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Error submitting task: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Get task status +@app.get("/tasks/status") +async def get_task_status(): + """Get task distribution statistics""" + try: + if not task_distributor: + raise HTTPException(status_code=503, detail="Task distributor not available") + + stats = task_distributor.get_distribution_stats() + + return { + "status": "success", + "stats": stats, + "timestamp": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Error getting task status: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Send message +@app.post("/messages/send") +async def send_message(request: MessageRequest): + """Send message to agent""" + try: + if not communication_manager: + raise HTTPException(status_code=503, detail="Communication manager not available") + + from .protocols.communication import AgentMessage, Priority + + # Convert message type + try: + message_type = MessageType(request.message_type) + except ValueError: + raise HTTPException(status_code=400, detail=f"Invalid message type: {request.message_type}") + + # Convert priority + try: + priority = Priority(request.priority.lower()) + except ValueError: + raise HTTPException(status_code=400, detail=f"Invalid priority: {request.priority}") + + # Create message + message = AgentMessage( + sender_id="agent-coordinator", + receiver_id=request.receiver_id, + message_type=message_type, + priority=priority, + payload=request.payload + ) + + # Send message + success = await communication_manager.send_message("hierarchical", message) + + if success: + return { + "status": "success", + "message": "Message sent successfully", + "message_id": message.id, + "receiver_id": request.receiver_id, + "sent_at": datetime.utcnow().isoformat() + } + else: + raise HTTPException(status_code=500, detail="Failed to send message") + + except Exception as e: + logger.error(f"Error sending message: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Load balancer statistics +@app.get("/load-balancer/stats") +async def get_load_balancer_stats(): + """Get load balancer statistics""" + try: + if not load_balancer: + raise HTTPException(status_code=503, detail="Load balancer not available") + + stats = load_balancer.get_load_balancing_stats() + + return { + "status": "success", + "stats": stats, + "timestamp": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Error getting load balancer stats: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Registry statistics +@app.get("/registry/stats") +async def get_registry_stats(): + """Get agent registry statistics""" + try: + if not agent_registry: + raise HTTPException(status_code=503, detail="Agent registry not available") + + stats = await agent_registry.get_registry_stats() + + return { + "status": "success", + "stats": stats, + "timestamp": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Error getting registry stats: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Get agents by service +@app.get("/agents/service/{service}") +async def get_agents_by_service(service: str): + """Get agents that provide a specific service""" + try: + if not agent_registry: + raise HTTPException(status_code=503, detail="Agent registry not available") + + agents = await agent_registry.get_agents_by_service(service) + + return { + "status": "success", + "service": service, + "agents": [agent.to_dict() for agent in agents], + "count": len(agents), + "timestamp": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Error getting agents by service: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Get agents by capability +@app.get("/agents/capability/{capability}") +async def get_agents_by_capability(capability: str): + """Get agents that have a specific capability""" + try: + if not agent_registry: + raise HTTPException(status_code=503, detail="Agent registry not available") + + agents = await agent_registry.get_agents_by_capability(capability) + + return { + "status": "success", + "capability": capability, + "agents": [agent.to_dict() for agent in agents], + "count": len(agents), + "timestamp": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Error getting agents by capability: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Set load balancing strategy +@app.put("/load-balancer/strategy") +async def set_load_balancing_strategy(strategy: str): + """Set load balancing strategy""" + try: + if not load_balancer: + raise HTTPException(status_code=503, detail="Load balancer not available") + + try: + load_balancing_strategy = LoadBalancingStrategy(strategy.lower()) + except ValueError: + raise HTTPException(status_code=400, detail=f"Invalid strategy: {strategy}") + + load_balancer.set_strategy(load_balancing_strategy) + + return { + "status": "success", + "message": f"Load balancing strategy set to {strategy}", + "strategy": strategy, + "updated_at": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Error setting load balancing strategy: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# Error handlers +@app.exception_handler(404) +async def not_found_handler(request, exc): + return JSONResponse( + status_code=404, + content={ + "status": "error", + "message": "Resource not found", + "timestamp": datetime.utcnow().isoformat() + } + ) + +@app.exception_handler(500) +async def internal_error_handler(request, exc): + logger.error(f"Internal server error: {exc}") + return JSONResponse( + status_code=500, + content={ + "status": "error", + "message": "Internal server error", + "timestamp": datetime.utcnow().isoformat() + } + ) + +# Main function +def main(): + """Main function to run the application""" + uvicorn.run( + "main:app", + host="0.0.0.0", + port=9001, + reload=True, + log_level="info" + ) + +if __name__ == "__main__": + main() diff --git a/apps/agent-coordinator/src/app/protocols/communication.py b/apps/agent-coordinator/src/app/protocols/communication.py new file mode 100644 index 00000000..9cbde448 --- /dev/null +++ b/apps/agent-coordinator/src/app/protocols/communication.py @@ -0,0 +1,443 @@ +""" +Multi-Agent Communication Protocols for AITBC Agent Coordination +""" + +import asyncio +import json +import logging +from enum import Enum +from typing import Dict, List, Optional, Any, Callable +from dataclasses import dataclass, field +from datetime import datetime +import uuid +import websockets +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +class MessageType(str, Enum): + """Message types for agent communication""" + COORDINATION = "coordination" + TASK_ASSIGNMENT = "task_assignment" + STATUS_UPDATE = "status_update" + DISCOVERY = "discovery" + HEARTBEAT = "heartbeat" + CONSENSUS = "consensus" + BROADCAST = "broadcast" + DIRECT = "direct" + PEER_TO_PEER = "peer_to_peer" + HIERARCHICAL = "hierarchical" + +class Priority(str, Enum): + """Message priority levels""" + LOW = "low" + NORMAL = "normal" + HIGH = "high" + CRITICAL = "critical" + +@dataclass +class AgentMessage: + """Base message structure for agent communication""" + id: str = field(default_factory=lambda: str(uuid.uuid4())) + sender_id: str = "" + receiver_id: Optional[str] = None + message_type: MessageType = MessageType.DIRECT + priority: Priority = Priority.NORMAL + timestamp: datetime = field(default_factory=datetime.utcnow) + payload: Dict[str, Any] = field(default_factory=dict) + correlation_id: Optional[str] = None + reply_to: Optional[str] = None + ttl: int = 300 # Time to live in seconds + + def to_dict(self) -> Dict[str, Any]: + """Convert message to dictionary""" + return { + "id": self.id, + "sender_id": self.sender_id, + "receiver_id": self.receiver_id, + "message_type": self.message_type.value, + "priority": self.priority.value, + "timestamp": self.timestamp.isoformat(), + "payload": self.payload, + "correlation_id": self.correlation_id, + "reply_to": self.reply_to, + "ttl": self.ttl + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AgentMessage": + """Create message from dictionary""" + data["timestamp"] = datetime.fromisoformat(data["timestamp"]) + data["message_type"] = MessageType(data["message_type"]) + data["priority"] = Priority(data["priority"]) + return cls(**data) + +class CommunicationProtocol: + """Base class for communication protocols""" + + def __init__(self, agent_id: str): + self.agent_id = agent_id + self.message_handlers: Dict[MessageType, List[Callable]] = {} + self.active_connections: Dict[str, Any] = {} + + async def register_handler(self, message_type: MessageType, handler: Callable): + """Register a message handler for a specific message type""" + if message_type not in self.message_handlers: + self.message_handlers[message_type] = [] + self.message_handlers[message_type].append(handler) + + async def send_message(self, message: AgentMessage) -> bool: + """Send a message to another agent""" + try: + if message.receiver_id and message.receiver_id in self.active_connections: + await self._send_to_agent(message) + return True + elif message.message_type == MessageType.BROADCAST: + await self._broadcast_message(message) + return True + else: + logger.warning(f"Cannot send message to {message.receiver_id}: not connected") + return False + except Exception as e: + logger.error(f"Error sending message: {e}") + return False + + async def receive_message(self, message: AgentMessage): + """Process received message""" + try: + # Check TTL + if self._is_message_expired(message): + logger.warning(f"Message {message.id} expired, ignoring") + return + + # Handle message + handlers = self.message_handlers.get(message.message_type, []) + for handler in handlers: + try: + await handler(message) + except Exception as e: + logger.error(f"Error in message handler: {e}") + + except Exception as e: + logger.error(f"Error processing message: {e}") + + def _is_message_expired(self, message: AgentMessage) -> bool: + """Check if message has expired""" + age = (datetime.utcnow() - message.timestamp).total_seconds() + return age > message.ttl + + async def _send_to_agent(self, message: AgentMessage): + """Send message to specific agent""" + raise NotImplementedError("Subclasses must implement _send_to_agent") + + async def _broadcast_message(self, message: AgentMessage): + """Broadcast message to all connected agents""" + raise NotImplementedError("Subclasses must implement _broadcast_message") + +class HierarchicalProtocol(CommunicationProtocol): + """Hierarchical communication protocol (master-agent → sub-agents)""" + + def __init__(self, agent_id: str, is_master: bool = False): + super().__init__(agent_id) + self.is_master = is_master + self.sub_agents: List[str] = [] + self.master_agent: Optional[str] = None + + async def add_sub_agent(self, agent_id: str): + """Add a sub-agent to this master agent""" + if self.is_master: + self.sub_agents.append(agent_id) + logger.info(f"Added sub-agent {agent_id} to master {self.agent_id}") + else: + logger.warning(f"Agent {self.agent_id} is not a master, cannot add sub-agents") + + async def send_to_sub_agents(self, message: AgentMessage): + """Send message to all sub-agents""" + if not self.is_master: + logger.warning(f"Agent {self.agent_id} is not a master") + return + + message.message_type = MessageType.HIERARCHICAL + for sub_agent_id in self.sub_agents: + message.receiver_id = sub_agent_id + await self.send_message(message) + + async def send_to_master(self, message: AgentMessage): + """Send message to master agent""" + if self.is_master: + logger.warning(f"Agent {self.agent_id} is a master, cannot send to master") + return + + if self.master_agent: + message.receiver_id = self.master_agent + message.message_type = MessageType.HIERARCHICAL + await self.send_message(message) + else: + logger.warning(f"Agent {self.agent_id} has no master agent") + +class PeerToPeerProtocol(CommunicationProtocol): + """Peer-to-peer communication protocol (agent ↔ agent)""" + + def __init__(self, agent_id: str): + super().__init__(agent_id) + self.peers: Dict[str, Dict[str, Any]] = {} + + async def add_peer(self, peer_id: str, connection_info: Dict[str, Any]): + """Add a peer to the peer network""" + self.peers[peer_id] = connection_info + logger.info(f"Added peer {peer_id} to agent {self.agent_id}") + + async def remove_peer(self, peer_id: str): + """Remove a peer from the peer network""" + if peer_id in self.peers: + del self.peers[peer_id] + logger.info(f"Removed peer {peer_id} from agent {self.agent_id}") + + async def send_to_peer(self, message: AgentMessage, peer_id: str): + """Send message to specific peer""" + if peer_id not in self.peers: + logger.warning(f"Peer {peer_id} not found") + return False + + message.receiver_id = peer_id + message.message_type = MessageType.PEER_TO_PEER + return await self.send_message(message) + + async def broadcast_to_peers(self, message: AgentMessage): + """Broadcast message to all peers""" + message.message_type = MessageType.PEER_TO_PEER + for peer_id in self.peers: + message.receiver_id = peer_id + await self.send_message(message) + +class BroadcastProtocol(CommunicationProtocol): + """Broadcast communication protocol (agent → all agents)""" + + def __init__(self, agent_id: str, broadcast_channel: str = "global"): + super().__init__(agent_id) + self.broadcast_channel = broadcast_channel + self.subscribers: List[str] = [] + + async def subscribe(self, agent_id: str): + """Subscribe to broadcast channel""" + if agent_id not in self.subscribers: + self.subscribers.append(agent_id) + logger.info(f"Agent {agent_id} subscribed to {self.broadcast_channel}") + + async def unsubscribe(self, agent_id: str): + """Unsubscribe from broadcast channel""" + if agent_id in self.subscribers: + self.subscribers.remove(agent_id) + logger.info(f"Agent {agent_id} unsubscribed from {self.broadcast_channel}") + + async def broadcast(self, message: AgentMessage): + """Broadcast message to all subscribers""" + message.message_type = MessageType.BROADCAST + message.receiver_id = None # Broadcast to all + + for subscriber_id in self.subscribers: + if subscriber_id != self.agent_id: # Don't send to self + message_copy = AgentMessage(**message.__dict__) + message_copy.receiver_id = subscriber_id + await self.send_message(message_copy) + +class CommunicationManager: + """Manages multiple communication protocols for an agent""" + + def __init__(self, agent_id: str): + self.agent_id = agent_id + self.protocols: Dict[str, CommunicationProtocol] = {} + + def add_protocol(self, name: str, protocol: CommunicationProtocol): + """Add a communication protocol""" + self.protocols[name] = protocol + logger.info(f"Added protocol {name} to agent {self.agent_id}") + + def get_protocol(self, name: str) -> Optional[CommunicationProtocol]: + """Get a communication protocol by name""" + return self.protocols.get(name) + + async def send_message(self, protocol_name: str, message: AgentMessage) -> bool: + """Send message using specific protocol""" + protocol = self.get_protocol(protocol_name) + if protocol: + return await protocol.send_message(message) + return False + + async def register_handler(self, protocol_name: str, message_type: MessageType, handler: Callable): + """Register message handler for specific protocol""" + protocol = self.get_protocol(protocol_name) + if protocol: + await protocol.register_handler(message_type, handler) + else: + logger.error(f"Protocol {protocol_name} not found") + +# Message templates for common operations +class MessageTemplates: + """Pre-defined message templates""" + + @staticmethod + def create_heartbeat(sender_id: str) -> AgentMessage: + """Create heartbeat message""" + return AgentMessage( + sender_id=sender_id, + message_type=MessageType.HEARTBEAT, + priority=Priority.LOW, + payload={"timestamp": datetime.utcnow().isoformat()} + ) + + @staticmethod + def create_task_assignment(sender_id: str, receiver_id: str, task_data: Dict[str, Any]) -> AgentMessage: + """Create task assignment message""" + return AgentMessage( + sender_id=sender_id, + receiver_id=receiver_id, + message_type=MessageType.TASK_ASSIGNMENT, + priority=Priority.NORMAL, + payload=task_data + ) + + @staticmethod + def create_status_update(sender_id: str, status_data: Dict[str, Any]) -> AgentMessage: + """Create status update message""" + return AgentMessage( + sender_id=sender_id, + message_type=MessageType.STATUS_UPDATE, + priority=Priority.NORMAL, + payload=status_data + ) + + @staticmethod + def create_discovery(sender_id: str) -> AgentMessage: + """Create discovery message""" + return AgentMessage( + sender_id=sender_id, + message_type=MessageType.DISCOVERY, + priority=Priority.NORMAL, + payload={"agent_id": sender_id} + ) + + @staticmethod + def create_consensus_request(sender_id: str, proposal_data: Dict[str, Any]) -> AgentMessage: + """Create consensus request message""" + return AgentMessage( + sender_id=sender_id, + message_type=MessageType.CONSENSUS, + priority=Priority.HIGH, + payload=proposal_data + ) + +# WebSocket connection handler for real-time communication +class WebSocketHandler: + """WebSocket handler for real-time agent communication""" + + def __init__(self, communication_manager: CommunicationManager): + self.communication_manager = communication_manager + self.websocket_connections: Dict[str, Any] = {} + + async def handle_connection(self, websocket, agent_id: str): + """Handle WebSocket connection from agent""" + self.websocket_connections[agent_id] = websocket + logger.info(f"WebSocket connection established for agent {agent_id}") + + try: + async for message in websocket: + data = json.loads(message) + agent_message = AgentMessage.from_dict(data) + await self.communication_manager.receive_message(agent_message) + except websockets.exceptions.ConnectionClosed: + logger.info(f"WebSocket connection closed for agent {agent_id}") + finally: + if agent_id in self.websocket_connections: + del self.websocket_connections[agent_id] + + async def send_to_agent(self, agent_id: str, message: AgentMessage): + """Send message to agent via WebSocket""" + if agent_id in self.websocket_connections: + websocket = self.websocket_connections[agent_id] + await websocket.send(json.dumps(message.to_dict())) + return True + return False + + async def broadcast_message(self, message: AgentMessage): + """Broadcast message to all connected agents""" + for websocket in self.websocket_connections.values(): + await websocket.send(json.dumps(message.to_dict())) + +# Redis-based message broker for scalable communication +class RedisMessageBroker: + """Redis-based message broker for agent communication""" + + def __init__(self, redis_url: str): + self.redis_url = redis_url + self.channels: Dict[str, Any] = {} + + async def publish_message(self, channel: str, message: AgentMessage): + """Publish message to Redis channel""" + import redis.asyncio as redis + redis_client = redis.from_url(self.redis_url) + + await redis_client.publish(channel, json.dumps(message.to_dict())) + await redis_client.close() + + async def subscribe_to_channel(self, channel: str, handler: Callable): + """Subscribe to Redis channel""" + import redis.asyncio as redis + redis_client = redis.from_url(self.redis_url) + + pubsub = redis_client.pubsub() + await pubsub.subscribe(channel) + + self.channels[channel] = {"pubsub": pubsub, "handler": handler} + + # Start listening for messages + asyncio.create_task(self._listen_to_channel(channel, pubsub, handler)) + + async def _listen_to_channel(self, channel: str, pubsub: Any, handler: Callable): + """Listen for messages on channel""" + async for message in pubsub.listen(): + if message["type"] == "message": + data = json.loads(message["data"]) + agent_message = AgentMessage.from_dict(data) + await handler(agent_message) + +# Factory function for creating communication protocols +def create_protocol(protocol_type: str, agent_id: str, **kwargs) -> CommunicationProtocol: + """Factory function to create communication protocols""" + if protocol_type == "hierarchical": + return HierarchicalProtocol(agent_id, kwargs.get("is_master", False)) + elif protocol_type == "peer_to_peer": + return PeerToPeerProtocol(agent_id) + elif protocol_type == "broadcast": + return BroadcastProtocol(agent_id, kwargs.get("broadcast_channel", "global")) + else: + raise ValueError(f"Unknown protocol type: {protocol_type}") + +# Example usage +async def example_usage(): + """Example of how to use the communication protocols""" + + # Create communication manager + comm_manager = CommunicationManager("agent-001") + + # Add protocols + hierarchical_protocol = create_protocol("hierarchical", "agent-001", is_master=True) + p2p_protocol = create_protocol("peer_to_peer", "agent-001") + broadcast_protocol = create_protocol("broadcast", "agent-001") + + comm_manager.add_protocol("hierarchical", hierarchical_protocol) + comm_manager.add_protocol("peer_to_peer", p2p_protocol) + comm_manager.add_protocol("broadcast", broadcast_protocol) + + # Register message handlers + async def handle_heartbeat(message: AgentMessage): + logger.info(f"Received heartbeat from {message.sender_id}") + + await comm_manager.register_handler("hierarchical", MessageType.HEARTBEAT, handle_heartbeat) + + # Send messages + heartbeat = MessageTemplates.create_heartbeat("agent-001") + await comm_manager.send_message("hierarchical", heartbeat) + +if __name__ == "__main__": + asyncio.run(example_usage()) diff --git a/apps/agent-coordinator/src/app/protocols/message_types.py b/apps/agent-coordinator/src/app/protocols/message_types.py new file mode 100644 index 00000000..b6610071 --- /dev/null +++ b/apps/agent-coordinator/src/app/protocols/message_types.py @@ -0,0 +1,586 @@ +""" +Message Types and Routing System for AITBC Agent Coordination +""" + +import asyncio +import json +import logging +from enum import Enum +from typing import Dict, List, Optional, Any, Callable, Union +from dataclasses import dataclass, field +from datetime import datetime, timedelta +import uuid +import hashlib +from pydantic import BaseModel, Field, validator +from .communication import AgentMessage, MessageType, Priority + +logger = logging.getLogger(__name__) + +class MessageStatus(str, Enum): + """Message processing status""" + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + EXPIRED = "expired" + CANCELLED = "cancelled" + +class RoutingStrategy(str, Enum): + """Message routing strategies""" + ROUND_ROBIN = "round_robin" + LOAD_BALANCED = "load_balanced" + PRIORITY_BASED = "priority_based" + RANDOM = "random" + DIRECT = "direct" + BROADCAST = "broadcast" + +class DeliveryMode(str, Enum): + """Message delivery modes""" + FIRE_AND_FORGET = "fire_and_forget" + AT_LEAST_ONCE = "at_least_once" + EXACTLY_ONCE = "exactly_once" + PERSISTENT = "persistent" + +@dataclass +class RoutingRule: + """Routing rule for message processing""" + rule_id: str = field(default_factory=lambda: str(uuid.uuid4())) + name: str = "" + condition: Dict[str, Any] = field(default_factory=dict) + action: str = "forward" # forward, transform, filter, route + target: Optional[str] = None + priority: int = 0 + enabled: bool = True + created_at: datetime = field(default_factory=datetime.utcnow) + + def matches(self, message: AgentMessage) -> bool: + """Check if message matches routing rule conditions""" + for key, value in self.condition.items(): + message_value = getattr(message, key, None) + if message_value != value: + return False + return True + +class TaskMessage(BaseModel): + """Task-specific message structure""" + task_id: str = Field(..., description="Unique task identifier") + task_type: str = Field(..., description="Type of task") + task_data: Dict[str, Any] = Field(default_factory=dict, description="Task data") + requirements: Dict[str, Any] = Field(default_factory=dict, description="Task requirements") + deadline: Optional[datetime] = Field(None, description="Task deadline") + priority: Priority = Field(Priority.NORMAL, description="Task priority") + assigned_agent: Optional[str] = Field(None, description="Assigned agent ID") + status: str = Field("pending", description="Task status") + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + @validator('deadline') + def validate_deadline(cls, v): + if v and v < datetime.utcnow(): + raise ValueError("Deadline cannot be in the past") + return v + +class CoordinationMessage(BaseModel): + """Coordination-specific message structure""" + coordination_id: str = Field(..., description="Unique coordination identifier") + coordination_type: str = Field(..., description="Type of coordination") + participants: List[str] = Field(default_factory=list, description="Participating agents") + coordination_data: Dict[str, Any] = Field(default_factory=dict, description="Coordination data") + decision_deadline: Optional[datetime] = Field(None, description="Decision deadline") + consensus_threshold: float = Field(0.5, description="Consensus threshold") + status: str = Field("pending", description="Coordination status") + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + +class StatusMessage(BaseModel): + """Status update message structure""" + agent_id: str = Field(..., description="Agent ID") + status_type: str = Field(..., description="Type of status") + status_data: Dict[str, Any] = Field(default_factory=dict, description="Status data") + health_score: float = Field(1.0, description="Agent health score") + load_metrics: Dict[str, float] = Field(default_factory=dict, description="Load metrics") + capabilities: List[str] = Field(default_factory=list, description="Agent capabilities") + timestamp: datetime = Field(default_factory=datetime.utcnow) + +class DiscoveryMessage(BaseModel): + """Agent discovery message structure""" + agent_id: str = Field(..., description="Agent ID") + agent_type: str = Field(..., description="Type of agent") + capabilities: List[str] = Field(default_factory=list, description="Agent capabilities") + services: List[str] = Field(default_factory=list, description="Available services") + endpoints: Dict[str, str] = Field(default_factory=dict, description="Service endpoints") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") + timestamp: datetime = Field(default_factory=datetime.utcnow) + +class ConsensusMessage(BaseModel): + """Consensus message structure""" + consensus_id: str = Field(..., description="Unique consensus identifier") + proposal: Dict[str, Any] = Field(..., description="Consensus proposal") + voting_options: List[Dict[str, Any]] = Field(default_factory=list, description="Voting options") + votes: Dict[str, str] = Field(default_factory=dict, description="Agent votes") + voting_deadline: datetime = Field(..., description="Voting deadline") + consensus_algorithm: str = Field("majority", description="Consensus algorithm") + status: str = Field("pending", description="Consensus status") + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + +class MessageRouter: + """Advanced message routing system""" + + def __init__(self, agent_id: str): + self.agent_id = agent_id + self.routing_rules: List[RoutingRule] = [] + self.message_queue: asyncio.Queue = asyncio.Queue(maxsize=10000) + self.dead_letter_queue: asyncio.Queue = asyncio.Queue(maxsize=1000) + self.routing_stats: Dict[str, Any] = { + "messages_processed": 0, + "messages_failed": 0, + "messages_expired": 0, + "routing_time_total": 0.0 + } + self.active_routes: Dict[str, str] = {} # message_id -> route + self.load_balancer_index = 0 + + def add_routing_rule(self, rule: RoutingRule): + """Add a routing rule""" + self.routing_rules.append(rule) + # Sort by priority (higher priority first) + self.routing_rules.sort(key=lambda r: r.priority, reverse=True) + logger.info(f"Added routing rule: {rule.name}") + + def remove_routing_rule(self, rule_id: str): + """Remove a routing rule""" + self.routing_rules = [r for r in self.routing_rules if r.rule_id != rule_id] + logger.info(f"Removed routing rule: {rule_id}") + + async def route_message(self, message: AgentMessage) -> Optional[str]: + """Route message based on routing rules""" + start_time = datetime.utcnow() + + try: + # Check if message is expired + if self._is_message_expired(message): + await self.dead_letter_queue.put(message) + self.routing_stats["messages_expired"] += 1 + return None + + # Apply routing rules + for rule in self.routing_rules: + if rule.enabled and rule.matches(message): + route = await self._apply_routing_rule(rule, message) + if route: + self.active_routes[message.id] = route + self.routing_stats["messages_processed"] += 1 + return route + + # Default routing + default_route = await self._default_routing(message) + if default_route: + self.active_routes[message.id] = default_route + self.routing_stats["messages_processed"] += 1 + return default_route + + # No route found + await self.dead_letter_queue.put(message) + self.routing_stats["messages_failed"] += 1 + return None + + except Exception as e: + logger.error(f"Error routing message {message.id}: {e}") + await self.dead_letter_queue.put(message) + self.routing_stats["messages_failed"] += 1 + return None + finally: + routing_time = (datetime.utcnow() - start_time).total_seconds() + self.routing_stats["routing_time_total"] += routing_time + + async def _apply_routing_rule(self, rule: RoutingRule, message: AgentMessage) -> Optional[str]: + """Apply a specific routing rule""" + if rule.action == "forward": + return rule.target + elif rule.action == "transform": + return await self._transform_message(message, rule) + elif rule.action == "filter": + return await self._filter_message(message, rule) + elif rule.action == "route": + return await self._custom_routing(message, rule) + return None + + async def _transform_message(self, message: AgentMessage, rule: RoutingRule) -> Optional[str]: + """Transform message based on rule""" + # Apply transformation logic here + transformed_message = AgentMessage( + sender_id=message.sender_id, + receiver_id=message.receiver_id, + message_type=message.message_type, + priority=message.priority, + payload={**message.payload, **rule.condition.get("transform", {})} + ) + # Route transformed message + return await self._default_routing(transformed_message) + + async def _filter_message(self, message: AgentMessage, rule: RoutingRule) -> Optional[str]: + """Filter message based on rule""" + filter_condition = rule.condition.get("filter", {}) + for key, value in filter_condition.items(): + if message.payload.get(key) != value: + return None # Filter out message + return await self._default_routing(message) + + async def _custom_routing(self, message: AgentMessage, rule: RoutingRule) -> Optional[str]: + """Custom routing logic""" + # Implement custom routing logic here + return rule.target + + async def _default_routing(self, message: AgentMessage) -> Optional[str]: + """Default message routing""" + if message.receiver_id: + return message.receiver_id + elif message.message_type == MessageType.BROADCAST: + return "broadcast" + else: + return None + + def _is_message_expired(self, message: AgentMessage) -> bool: + """Check if message is expired""" + age = (datetime.utcnow() - message.timestamp).total_seconds() + return age > message.ttl + + async def get_routing_stats(self) -> Dict[str, Any]: + """Get routing statistics""" + total_messages = self.routing_stats["messages_processed"] + avg_routing_time = ( + self.routing_stats["routing_time_total"] / total_messages + if total_messages > 0 else 0 + ) + + return { + **self.routing_stats, + "avg_routing_time": avg_routing_time, + "active_routes": len(self.active_routes), + "queue_size": self.message_queue.qsize(), + "dead_letter_queue_size": self.dead_letter_queue.qsize() + } + +class LoadBalancer: + """Load balancer for message distribution""" + + def __init__(self): + self.agent_loads: Dict[str, float] = {} + self.agent_weights: Dict[str, float] = {} + self.last_updated = datetime.utcnow() + + def update_agent_load(self, agent_id: str, load: float): + """Update agent load information""" + self.agent_loads[agent_id] = load + self.last_updated = datetime.utcnow() + + def set_agent_weight(self, agent_id: str, weight: float): + """Set agent weight for load balancing""" + self.agent_weights[agent_id] = weight + + def select_agent(self, available_agents: List[str], strategy: RoutingStrategy = RoutingStrategy.LOAD_BALANCED) -> Optional[str]: + """Select agent based on load balancing strategy""" + if not available_agents: + return None + + if strategy == RoutingStrategy.ROUND_ROBIN: + return self._round_robin_selection(available_agents) + elif strategy == RoutingStrategy.LOAD_BALANCED: + return self._load_balanced_selection(available_agents) + elif strategy == RoutingStrategy.PRIORITY_BASED: + return self._priority_based_selection(available_agents) + elif strategy == RoutingStrategy.RANDOM: + return self._random_selection(available_agents) + else: + return available_agents[0] + + def _round_robin_selection(self, agents: List[str]) -> str: + """Round-robin agent selection""" + agent = agents[self.load_balancer_index % len(agents)] + self.load_balancer_index += 1 + return agent + + def _load_balanced_selection(self, agents: List[str]) -> str: + """Load-balanced agent selection""" + # Select agent with lowest load + min_load = float('inf') + selected_agent = None + + for agent in agents: + load = self.agent_loads.get(agent, 0.0) + weight = self.agent_weights.get(agent, 1.0) + weighted_load = load / weight + + if weighted_load < min_load: + min_load = weighted_load + selected_agent = agent + + return selected_agent or agents[0] + + def _priority_based_selection(self, agents: List[str]) -> str: + """Priority-based agent selection""" + # Sort by weight (higher weight = higher priority) + weighted_agents = sorted( + agents, + key=lambda a: self.agent_weights.get(a, 1.0), + reverse=True + ) + return weighted_agents[0] + + def _random_selection(self, agents: List[str]) -> str: + """Random agent selection""" + import random + return random.choice(agents) + +class MessageQueue: + """Advanced message queue with priority and persistence""" + + def __init__(self, max_size: int = 10000): + self.max_size = max_size + self.queues: Dict[Priority, asyncio.Queue] = { + Priority.CRITICAL: asyncio.Queue(maxsize=max_size // 4), + Priority.HIGH: asyncio.Queue(maxsize=max_size // 4), + Priority.NORMAL: asyncio.Queue(maxsize // 2), + Priority.LOW: asyncio.Queue(maxsize // 4) + } + self.message_store: Dict[str, AgentMessage] = {} + self.delivery_confirmations: Dict[str, bool] = {} + + async def enqueue(self, message: AgentMessage, delivery_mode: DeliveryMode = DeliveryMode.AT_LEAST_ONCE) -> bool: + """Enqueue message with priority""" + try: + # Store message for persistence + if delivery_mode in [DeliveryMode.AT_LEAST_ONCE, DeliveryMode.EXACTLY_ONCE, DeliveryMode.PERSISTENT]: + self.message_store[message.id] = message + + # Add to appropriate priority queue + queue = self.queues[message.priority] + await queue.put(message) + + logger.debug(f"Enqueued message {message.id} with priority {message.priority}") + return True + + except asyncio.QueueFull: + logger.error(f"Queue full, cannot enqueue message {message.id}") + return False + + async def dequeue(self) -> Optional[AgentMessage]: + """Dequeue message with priority order""" + # Check queues in priority order + for priority in [Priority.CRITICAL, Priority.HIGH, Priority.NORMAL, Priority.LOW]: + queue = self.queues[priority] + try: + message = queue.get_nowait() + logger.debug(f"Dequeued message {message.id} with priority {priority}") + return message + except asyncio.QueueEmpty: + continue + + return None + + async def confirm_delivery(self, message_id: str): + """Confirm message delivery""" + self.delivery_confirmations[message_id] = True + + # Clean up if exactly once delivery + if message_id in self.message_store: + del self.message_store[message_id] + + def get_queue_stats(self) -> Dict[str, Any]: + """Get queue statistics""" + return { + "queue_sizes": { + priority.value: queue.qsize() + for priority, queue in self.queues.items() + }, + "stored_messages": len(self.message_store), + "delivery_confirmations": len(self.delivery_confirmations), + "max_size": self.max_size + } + +class MessageProcessor: + """Message processor with async handling""" + + def __init__(self, agent_id: str): + self.agent_id = agent_id + self.router = MessageRouter(agent_id) + self.load_balancer = LoadBalancer() + self.message_queue = MessageQueue() + self.processors: Dict[str, Callable] = {} + self.processing_stats: Dict[str, Any] = { + "messages_processed": 0, + "processing_time_total": 0.0, + "errors": 0 + } + + def register_processor(self, message_type: MessageType, processor: Callable): + """Register message processor""" + self.processors[message_type.value] = processor + logger.info(f"Registered processor for {message_type.value}") + + async def process_message(self, message: AgentMessage) -> bool: + """Process a message""" + start_time = datetime.utcnow() + + try: + # Route message + route = await self.router.route_message(message) + if not route: + logger.warning(f"No route found for message {message.id}") + return False + + # Process message + processor = self.processors.get(message.message_type.value) + if processor: + await processor(message) + else: + logger.warning(f"No processor found for {message.message_type.value}") + return False + + # Update stats + self.processing_stats["messages_processed"] += 1 + processing_time = (datetime.utcnow() - start_time).total_seconds() + self.processing_stats["processing_time_total"] += processing_time + + return True + + except Exception as e: + logger.error(f"Error processing message {message.id}: {e}") + self.processing_stats["errors"] += 1 + return False + + async def start_processing(self): + """Start message processing loop""" + while True: + try: + # Dequeue message + message = await self.message_queue.dequeue() + if message: + await self.process_message(message) + else: + await asyncio.sleep(0.01) # Small delay if no messages + + except Exception as e: + logger.error(f"Error in processing loop: {e}") + await asyncio.sleep(1) + + def get_processing_stats(self) -> Dict[str, Any]: + """Get processing statistics""" + total_processed = self.processing_stats["messages_processed"] + avg_processing_time = ( + self.processing_stats["processing_time_total"] / total_processed + if total_processed > 0 else 0 + ) + + return { + **self.processing_stats, + "avg_processing_time": avg_processing_time, + "queue_stats": self.message_queue.get_queue_stats(), + "routing_stats": self.router.get_routing_stats() + } + +# Factory functions for creating message types +def create_task_message(sender_id: str, receiver_id: str, task_type: str, task_data: Dict[str, Any]) -> AgentMessage: + """Create a task message""" + task_msg = TaskMessage( + task_id=str(uuid.uuid4()), + task_type=task_type, + task_data=task_data + ) + + return AgentMessage( + sender_id=sender_id, + receiver_id=receiver_id, + message_type=MessageType.TASK_ASSIGNMENT, + payload=task_msg.dict() + ) + +def create_coordination_message(sender_id: str, coordination_type: str, participants: List[str], data: Dict[str, Any]) -> AgentMessage: + """Create a coordination message""" + coord_msg = CoordinationMessage( + coordination_id=str(uuid.uuid4()), + coordination_type=coordination_type, + participants=participants, + coordination_data=data + ) + + return AgentMessage( + sender_id=sender_id, + message_type=MessageType.COORDINATION, + payload=coord_msg.dict() + ) + +def create_status_message(agent_id: str, status_type: str, status_data: Dict[str, Any]) -> AgentMessage: + """Create a status message""" + status_msg = StatusMessage( + agent_id=agent_id, + status_type=status_type, + status_data=status_data + ) + + return AgentMessage( + sender_id=agent_id, + message_type=MessageType.STATUS_UPDATE, + payload=status_msg.dict() + ) + +def create_discovery_message(agent_id: str, agent_type: str, capabilities: List[str], services: List[str]) -> AgentMessage: + """Create a discovery message""" + discovery_msg = DiscoveryMessage( + agent_id=agent_id, + agent_type=agent_type, + capabilities=capabilities, + services=services + ) + + return AgentMessage( + sender_id=agent_id, + message_type=MessageType.DISCOVERY, + payload=discovery_msg.dict() + ) + +def create_consensus_message(sender_id: str, proposal: Dict[str, Any], voting_options: List[Dict[str, Any]], deadline: datetime) -> AgentMessage: + """Create a consensus message""" + consensus_msg = ConsensusMessage( + consensus_id=str(uuid.uuid4()), + proposal=proposal, + voting_options=voting_options, + voting_deadline=deadline + ) + + return AgentMessage( + sender_id=sender_id, + message_type=MessageType.CONSENSUS, + payload=consensus_msg.dict() + ) + +# Example usage +async def example_usage(): + """Example of how to use the message routing system""" + + # Create message processor + processor = MessageProcessor("agent-001") + + # Register processors + async def process_task(message: AgentMessage): + task_data = TaskMessage(**message.payload) + logger.info(f"Processing task: {task_data.task_id}") + + processor.register_processor(MessageType.TASK_ASSIGNMENT, process_task) + + # Create and route message + task_message = create_task_message( + sender_id="agent-001", + receiver_id="agent-002", + task_type="data_processing", + task_data={"input": "test_data"} + ) + + await processor.message_queue.enqueue(task_message) + + # Start processing (in real implementation, this would run in background) + # await processor.start_processing() + +if __name__ == "__main__": + asyncio.run(example_usage()) diff --git a/apps/agent-coordinator/src/app/routing/agent_discovery.py b/apps/agent-coordinator/src/app/routing/agent_discovery.py new file mode 100644 index 00000000..bb6031f2 --- /dev/null +++ b/apps/agent-coordinator/src/app/routing/agent_discovery.py @@ -0,0 +1,641 @@ +""" +Agent Discovery and Registration System for AITBC Agent Coordination +""" + +import asyncio +import json +import logging +from typing import Dict, List, Optional, Set, Callable, Any +from dataclasses import dataclass, field +from datetime import datetime, timedelta +import uuid +import hashlib +from enum import Enum +import redis.asyncio as redis +from pydantic import BaseModel, Field + +from ..protocols.message_types import DiscoveryMessage, create_discovery_message +from ..protocols.communication import AgentMessage, MessageType + +logger = logging.getLogger(__name__) + +class AgentStatus(str, Enum): + """Agent status enumeration""" + ACTIVE = "active" + INACTIVE = "inactive" + BUSY = "busy" + MAINTENANCE = "maintenance" + ERROR = "error" + +class AgentType(str, Enum): + """Agent type enumeration""" + COORDINATOR = "coordinator" + WORKER = "worker" + SPECIALIST = "specialist" + MONITOR = "monitor" + GATEWAY = "gateway" + ORCHESTRATOR = "orchestrator" + +@dataclass +class AgentInfo: + """Agent information structure""" + agent_id: str + agent_type: AgentType + status: AgentStatus + capabilities: List[str] + services: List[str] + endpoints: Dict[str, str] + metadata: Dict[str, Any] + last_heartbeat: datetime + registration_time: datetime + load_metrics: Dict[str, float] = field(default_factory=dict) + health_score: float = 1.0 + version: str = "1.0.0" + tags: Set[str] = field(default_factory=set) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "agent_id": self.agent_id, + "agent_type": self.agent_type.value, + "status": self.status.value, + "capabilities": self.capabilities, + "services": self.services, + "endpoints": self.endpoints, + "metadata": self.metadata, + "last_heartbeat": self.last_heartbeat.isoformat(), + "registration_time": self.registration_time.isoformat(), + "load_metrics": self.load_metrics, + "health_score": self.health_score, + "version": self.version, + "tags": list(self.tags) + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AgentInfo": + """Create from dictionary""" + data["agent_type"] = AgentType(data["agent_type"]) + data["status"] = AgentStatus(data["status"]) + data["last_heartbeat"] = datetime.fromisoformat(data["last_heartbeat"]) + data["registration_time"] = datetime.fromisoformat(data["registration_time"]) + data["tags"] = set(data.get("tags", [])) + return cls(**data) + +class AgentRegistry: + """Central agent registry for discovery and management""" + + def __init__(self, redis_url: str = "redis://localhost:6379/1"): + self.redis_url = redis_url + self.redis_client: Optional[redis.Redis] = None + self.agents: Dict[str, AgentInfo] = {} + self.service_index: Dict[str, Set[str]] = {} # service -> agent_ids + self.capability_index: Dict[str, Set[str]] = {} # capability -> agent_ids + self.type_index: Dict[AgentType, Set[str]] = {} # agent_type -> agent_ids + self.heartbeat_interval = 30 # seconds + self.cleanup_interval = 60 # seconds + self.max_heartbeat_age = 120 # seconds + + async def start(self): + """Start the registry service""" + self.redis_client = redis.from_url(self.redis_url) + + # Load existing agents from Redis + await self._load_agents_from_redis() + + # Start background tasks + asyncio.create_task(self._heartbeat_monitor()) + asyncio.create_task(self._cleanup_inactive_agents()) + + logger.info("Agent registry started") + + async def stop(self): + """Stop the registry service""" + if self.redis_client: + await self.redis_client.close() + logger.info("Agent registry stopped") + + async def register_agent(self, agent_info: AgentInfo) -> bool: + """Register a new agent""" + try: + # Add to local registry + self.agents[agent_info.agent_id] = agent_info + + # Update indexes + self._update_indexes(agent_info) + + # Save to Redis + await self._save_agent_to_redis(agent_info) + + # Publish registration event + await self._publish_agent_event("agent_registered", agent_info) + + logger.info(f"Agent {agent_info.agent_id} registered successfully") + return True + + except Exception as e: + logger.error(f"Error registering agent {agent_info.agent_id}: {e}") + return False + + async def unregister_agent(self, agent_id: str) -> bool: + """Unregister an agent""" + try: + if agent_id not in self.agents: + logger.warning(f"Agent {agent_id} not found for unregistration") + return False + + agent_info = self.agents[agent_id] + + # Remove from local registry + del self.agents[agent_id] + + # Update indexes + self._remove_from_indexes(agent_info) + + # Remove from Redis + await self._remove_agent_from_redis(agent_id) + + # Publish unregistration event + await self._publish_agent_event("agent_unregistered", agent_info) + + logger.info(f"Agent {agent_id} unregistered successfully") + return True + + except Exception as e: + logger.error(f"Error unregistering agent {agent_id}: {e}") + return False + + async def update_agent_status(self, agent_id: str, status: AgentStatus, load_metrics: Optional[Dict[str, float]] = None) -> bool: + """Update agent status and metrics""" + try: + if agent_id not in self.agents: + logger.warning(f"Agent {agent_id} not found for status update") + return False + + agent_info = self.agents[agent_id] + agent_info.status = status + agent_info.last_heartbeat = datetime.utcnow() + + if load_metrics: + agent_info.load_metrics.update(load_metrics) + + # Update health score + agent_info.health_score = self._calculate_health_score(agent_info) + + # Save to Redis + await self._save_agent_to_redis(agent_info) + + # Publish status update event + await self._publish_agent_event("agent_status_updated", agent_info) + + return True + + except Exception as e: + logger.error(f"Error updating agent status {agent_id}: {e}") + return False + + async def update_agent_heartbeat(self, agent_id: str) -> bool: + """Update agent heartbeat""" + try: + if agent_id not in self.agents: + logger.warning(f"Agent {agent_id} not found for heartbeat") + return False + + agent_info = self.agents[agent_id] + agent_info.last_heartbeat = datetime.utcnow() + + # Update health score + agent_info.health_score = self._calculate_health_score(agent_info) + + # Save to Redis + await self._save_agent_to_redis(agent_info) + + return True + + except Exception as e: + logger.error(f"Error updating heartbeat for {agent_id}: {e}") + return False + + async def discover_agents(self, query: Dict[str, Any]) -> List[AgentInfo]: + """Discover agents based on query criteria""" + results = [] + + try: + # Start with all agents + candidate_agents = list(self.agents.values()) + + # Apply filters + if "agent_type" in query: + agent_type = AgentType(query["agent_type"]) + candidate_agents = [a for a in candidate_agents if a.agent_type == agent_type] + + if "status" in query: + status = AgentStatus(query["status"]) + candidate_agents = [a for a in candidate_agents if a.status == status] + + if "capabilities" in query: + required_capabilities = set(query["capabilities"]) + candidate_agents = [a for a in candidate_agents if required_capabilities.issubset(a.capabilities)] + + if "services" in query: + required_services = set(query["services"]) + candidate_agents = [a for a in candidate_agents if required_services.issubset(a.services)] + + if "tags" in query: + required_tags = set(query["tags"]) + candidate_agents = [a for a in candidate_agents if required_tags.issubset(a.tags)] + + if "min_health_score" in query: + min_score = query["min_health_score"] + candidate_agents = [a for a in candidate_agents if a.health_score >= min_score] + + # Sort by health score (highest first) + results = sorted(candidate_agents, key=lambda a: a.health_score, reverse=True) + + # Limit results if specified + if "limit" in query: + results = results[:query["limit"]] + + logger.info(f"Discovered {len(results)} agents for query: {query}") + return results + + except Exception as e: + logger.error(f"Error discovering agents: {e}") + return [] + + async def get_agent_by_id(self, agent_id: str) -> Optional[AgentInfo]: + """Get agent information by ID""" + return self.agents.get(agent_id) + + async def get_agents_by_service(self, service: str) -> List[AgentInfo]: + """Get agents that provide a specific service""" + agent_ids = self.service_index.get(service, set()) + return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents] + + async def get_agents_by_capability(self, capability: str) -> List[AgentInfo]: + """Get agents that have a specific capability""" + agent_ids = self.capability_index.get(capability, set()) + return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents] + + async def get_agents_by_type(self, agent_type: AgentType) -> List[AgentInfo]: + """Get agents of a specific type""" + agent_ids = self.type_index.get(agent_type, set()) + return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents] + + async def get_registry_stats(self) -> Dict[str, Any]: + """Get registry statistics""" + total_agents = len(self.agents) + status_counts = {} + type_counts = {} + + for agent_info in self.agents.values(): + # Count by status + status = agent_info.status.value + status_counts[status] = status_counts.get(status, 0) + 1 + + # Count by type + agent_type = agent_info.agent_type.value + type_counts[agent_type] = type_counts.get(agent_type, 0) + 1 + + return { + "total_agents": total_agents, + "status_counts": status_counts, + "type_counts": type_counts, + "service_count": len(self.service_index), + "capability_count": len(self.capability_index), + "last_cleanup": datetime.utcnow().isoformat() + } + + def _update_indexes(self, agent_info: AgentInfo): + """Update search indexes""" + # Service index + for service in agent_info.services: + if service not in self.service_index: + self.service_index[service] = set() + self.service_index[service].add(agent_info.agent_id) + + # Capability index + for capability in agent_info.capabilities: + if capability not in self.capability_index: + self.capability_index[capability] = set() + self.capability_index[capability].add(agent_info.agent_id) + + # Type index + if agent_info.agent_type not in self.type_index: + self.type_index[agent_info.agent_type] = set() + self.type_index[agent_info.agent_type].add(agent_info.agent_id) + + def _remove_from_indexes(self, agent_info: AgentInfo): + """Remove agent from search indexes""" + # Service index + for service in agent_info.services: + if service in self.service_index: + self.service_index[service].discard(agent_info.agent_id) + if not self.service_index[service]: + del self.service_index[service] + + # Capability index + for capability in agent_info.capabilities: + if capability in self.capability_index: + self.capability_index[capability].discard(agent_info.agent_id) + if not self.capability_index[capability]: + del self.capability_index[capability] + + # Type index + if agent_info.agent_type in self.type_index: + self.type_index[agent_info.agent_type].discard(agent_info.agent_id) + if not self.type_index[agent_info.agent_type]: + del self.type_index[agent_info.agent_type] + + def _calculate_health_score(self, agent_info: AgentInfo) -> float: + """Calculate agent health score""" + base_score = 1.0 + + # Penalty for high load + if agent_info.load_metrics: + avg_load = sum(agent_info.load_metrics.values()) / len(agent_info.load_metrics) + if avg_load > 0.8: + base_score -= 0.3 + elif avg_load > 0.6: + base_score -= 0.1 + + # Penalty for error status + if agent_info.status == AgentStatus.ERROR: + base_score -= 0.5 + elif agent_info.status == AgentStatus.MAINTENANCE: + base_score -= 0.2 + elif agent_info.status == AgentStatus.BUSY: + base_score -= 0.1 + + # Penalty for old heartbeat + heartbeat_age = (datetime.utcnow() - agent_info.last_heartbeat).total_seconds() + if heartbeat_age > self.max_heartbeat_age: + base_score -= 0.5 + elif heartbeat_age > self.max_heartbeat_age / 2: + base_score -= 0.2 + + return max(0.0, min(1.0, base_score)) + + async def _save_agent_to_redis(self, agent_info: AgentInfo): + """Save agent information to Redis""" + if not self.redis_client: + return + + key = f"agent:{agent_info.agent_id}" + await self.redis_client.setex( + key, + timedelta(hours=24), # 24 hour TTL + json.dumps(agent_info.to_dict()) + ) + + async def _remove_agent_from_redis(self, agent_id: str): + """Remove agent from Redis""" + if not self.redis_client: + return + + key = f"agent:{agent_id}" + await self.redis_client.delete(key) + + async def _load_agents_from_redis(self): + """Load agents from Redis""" + if not self.redis_client: + return + + try: + # Get all agent keys + keys = await self.redis_client.keys("agent:*") + + for key in keys: + data = await self.redis_client.get(key) + if data: + agent_info = AgentInfo.from_dict(json.loads(data)) + self.agents[agent_info.agent_id] = agent_info + self._update_indexes(agent_info) + + logger.info(f"Loaded {len(self.agents)} agents from Redis") + + except Exception as e: + logger.error(f"Error loading agents from Redis: {e}") + + async def _publish_agent_event(self, event_type: str, agent_info: AgentInfo): + """Publish agent event to Redis""" + if not self.redis_client: + return + + event = { + "event_type": event_type, + "timestamp": datetime.utcnow().isoformat(), + "agent_info": agent_info.to_dict() + } + + await self.redis_client.publish("agent_events", json.dumps(event)) + + async def _heartbeat_monitor(self): + """Monitor agent heartbeats""" + while True: + try: + await asyncio.sleep(self.heartbeat_interval) + + # Check for agents with old heartbeats + now = datetime.utcnow() + for agent_id, agent_info in list(self.agents.items()): + heartbeat_age = (now - agent_info.last_heartbeat).total_seconds() + + if heartbeat_age > self.max_heartbeat_age: + # Mark as inactive + if agent_info.status != AgentStatus.INACTIVE: + await self.update_agent_status(agent_id, AgentStatus.INACTIVE) + logger.warning(f"Agent {agent_id} marked as inactive due to old heartbeat") + + except Exception as e: + logger.error(f"Error in heartbeat monitor: {e}") + await asyncio.sleep(5) + + async def _cleanup_inactive_agents(self): + """Clean up inactive agents""" + while True: + try: + await asyncio.sleep(self.cleanup_interval) + + # Remove agents that have been inactive too long + now = datetime.utcnow() + max_inactive_age = timedelta(hours=1) # 1 hour + + for agent_id, agent_info in list(self.agents.items()): + if agent_info.status == AgentStatus.INACTIVE: + inactive_age = now - agent_info.last_heartbeat + if inactive_age > max_inactive_age: + await self.unregister_agent(agent_id) + logger.info(f"Removed inactive agent {agent_id}") + + except Exception as e: + logger.error(f"Error in cleanup task: {e}") + await asyncio.sleep(5) + +class AgentDiscoveryService: + """Service for agent discovery and registration""" + + def __init__(self, registry: AgentRegistry): + self.registry = registry + self.discovery_handlers: Dict[str, Callable] = {} + + def register_discovery_handler(self, handler_name: str, handler: Callable): + """Register a discovery handler""" + self.discovery_handlers[handler_name] = handler + logger.info(f"Registered discovery handler: {handler_name}") + + async def handle_discovery_request(self, message: AgentMessage) -> Optional[AgentMessage]: + """Handle agent discovery request""" + try: + discovery_data = DiscoveryMessage(**message.payload) + + # Update or register agent + agent_info = AgentInfo( + agent_id=discovery_data.agent_id, + agent_type=AgentType(discovery_data.agent_type), + status=AgentStatus.ACTIVE, + capabilities=discovery_data.capabilities, + services=discovery_data.services, + endpoints=discovery_data.endpoints, + metadata=discovery_data.metadata, + last_heartbeat=datetime.utcnow(), + registration_time=datetime.utcnow() + ) + + # Register or update agent + if discovery_data.agent_id in self.registry.agents: + await self.registry.update_agent_status(discovery_data.agent_id, AgentStatus.ACTIVE) + else: + await self.registry.register_agent(agent_info) + + # Send response with available agents + available_agents = await self.registry.discover_agents({ + "status": "active", + "limit": 50 + }) + + response_data = { + "discovery_agents": [agent.to_dict() for agent in available_agents], + "registry_stats": await self.registry.get_registry_stats() + } + + response = AgentMessage( + sender_id="discovery_service", + receiver_id=message.sender_id, + message_type=MessageType.DISCOVERY, + payload=response_data, + correlation_id=message.id + ) + + return response + + except Exception as e: + logger.error(f"Error handling discovery request: {e}") + return None + + async def find_best_agent(self, requirements: Dict[str, Any]) -> Optional[AgentInfo]: + """Find the best agent for given requirements""" + try: + # Build discovery query + query = {} + + if "agent_type" in requirements: + query["agent_type"] = requirements["agent_type"] + + if "capabilities" in requirements: + query["capabilities"] = requirements["capabilities"] + + if "services" in requirements: + query["services"] = requirements["services"] + + if "min_health_score" in requirements: + query["min_health_score"] = requirements["min_health_score"] + + # Discover agents + agents = await self.registry.discover_agents(query) + + if not agents: + return None + + # Select best agent (highest health score) + return agents[0] + + except Exception as e: + logger.error(f"Error finding best agent: {e}") + return None + + async def get_service_endpoints(self, service: str) -> Dict[str, List[str]]: + """Get all endpoints for a specific service""" + try: + agents = await self.registry.get_agents_by_service(service) + endpoints = {} + + for agent in agents: + for service_name, endpoint in agent.endpoints.items(): + if service_name not in endpoints: + endpoints[service_name] = [] + endpoints[service_name].append(endpoint) + + return endpoints + + except Exception as e: + logger.error(f"Error getting service endpoints: {e}") + return {} + +# Factory functions +def create_agent_info(agent_id: str, agent_type: str, capabilities: List[str], services: List[str], endpoints: Dict[str, str]) -> AgentInfo: + """Create agent information""" + return AgentInfo( + agent_id=agent_id, + agent_type=AgentType(agent_type), + status=AgentStatus.ACTIVE, + capabilities=capabilities, + services=services, + endpoints=endpoints, + metadata={}, + last_heartbeat=datetime.utcnow(), + registration_time=datetime.utcnow() + ) + +# Example usage +async def example_usage(): + """Example of how to use the agent discovery system""" + + # Create registry + registry = AgentRegistry() + await registry.start() + + # Create discovery service + discovery_service = AgentDiscoveryService(registry) + + # Register an agent + agent_info = create_agent_info( + agent_id="agent-001", + agent_type="worker", + capabilities=["data_processing", "analysis"], + services=["process_data", "analyze_results"], + endpoints={"http": "http://localhost:8001", "ws": "ws://localhost:8002"} + ) + + await registry.register_agent(agent_info) + + # Discover agents + agents = await registry.discover_agents({ + "capabilities": ["data_processing"], + "status": "active" + }) + + print(f"Found {len(agents)} agents") + + # Find best agent + best_agent = await discovery_service.find_best_agent({ + "capabilities": ["data_processing"], + "min_health_score": 0.8 + }) + + if best_agent: + print(f"Best agent: {best_agent.agent_id}") + + await registry.stop() + +if __name__ == "__main__": + asyncio.run(example_usage()) diff --git a/apps/agent-coordinator/src/app/routing/load_balancer.py b/apps/agent-coordinator/src/app/routing/load_balancer.py new file mode 100644 index 00000000..5308da66 --- /dev/null +++ b/apps/agent-coordinator/src/app/routing/load_balancer.py @@ -0,0 +1,716 @@ +""" +Load Balancer for Agent Distribution and Task Assignment +""" + +import asyncio +import json +import logging +from typing import Dict, List, Optional, Tuple, Any, Callable +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from enum import Enum +import statistics +import uuid +from collections import defaultdict, deque + +from .agent_discovery import AgentRegistry, AgentInfo, AgentStatus, AgentType +from ..protocols.message_types import TaskMessage, create_task_message +from ..protocols.communication import AgentMessage, MessageType, Priority + +logger = logging.getLogger(__name__) + +class LoadBalancingStrategy(str, Enum): + """Load balancing strategies""" + ROUND_ROBIN = "round_robin" + LEAST_CONNECTIONS = "least_connections" + LEAST_RESPONSE_TIME = "least_response_time" + WEIGHTED_ROUND_ROBIN = "weighted_round_robin" + RESOURCE_BASED = "resource_based" + CAPABILITY_BASED = "capability_based" + PREDICTIVE = "predictive" + CONSISTENT_HASH = "consistent_hash" + +class TaskPriority(str, Enum): + """Task priority levels""" + LOW = "low" + NORMAL = "normal" + HIGH = "high" + CRITICAL = "critical" + URGENT = "urgent" + +@dataclass +class LoadMetrics: + """Agent load metrics""" + cpu_usage: float = 0.0 + memory_usage: float = 0.0 + active_connections: int = 0 + pending_tasks: int = 0 + completed_tasks: int = 0 + failed_tasks: int = 0 + avg_response_time: float = 0.0 + last_updated: datetime = field(default_factory=datetime.utcnow) + + def to_dict(self) -> Dict[str, Any]: + return { + "cpu_usage": self.cpu_usage, + "memory_usage": self.memory_usage, + "active_connections": self.active_connections, + "pending_tasks": self.pending_tasks, + "completed_tasks": self.completed_tasks, + "failed_tasks": self.failed_tasks, + "avg_response_time": self.avg_response_time, + "last_updated": self.last_updated.isoformat() + } + +@dataclass +class TaskAssignment: + """Task assignment record""" + task_id: str + agent_id: str + assigned_at: datetime + completed_at: Optional[datetime] = None + status: str = "pending" + response_time: Optional[float] = None + success: bool = False + error_message: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "task_id": self.task_id, + "agent_id": self.agent_id, + "assigned_at": self.assigned_at.isoformat(), + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "status": self.status, + "response_time": self.response_time, + "success": self.success, + "error_message": self.error_message + } + +@dataclass +class AgentWeight: + """Agent weight for load balancing""" + agent_id: str + weight: float = 1.0 + capacity: int = 100 + performance_score: float = 1.0 + reliability_score: float = 1.0 + last_updated: datetime = field(default_factory=datetime.utcnow) + +class LoadBalancer: + """Advanced load balancer for agent distribution""" + + def __init__(self, registry: AgentRegistry): + self.registry = registry + self.strategy = LoadBalancingStrategy.LEAST_CONNECTIONS + self.agent_weights: Dict[str, AgentWeight] = {} + self.agent_metrics: Dict[str, LoadMetrics] = {} + self.task_assignments: Dict[str, TaskAssignment] = {} + self.assignment_history: deque = deque(maxlen=1000) + self.round_robin_index = 0 + self.consistent_hash_ring: Dict[int, str] = {} + self.prediction_models: Dict[str, Any] = {} + + # Statistics + self.total_assignments = 0 + self.successful_assignments = 0 + self.failed_assignments = 0 + + def set_strategy(self, strategy: LoadBalancingStrategy): + """Set load balancing strategy""" + self.strategy = strategy + logger.info(f"Load balancing strategy changed to: {strategy.value}") + + def set_agent_weight(self, agent_id: str, weight: float, capacity: int = 100): + """Set agent weight and capacity""" + self.agent_weights[agent_id] = AgentWeight( + agent_id=agent_id, + weight=weight, + capacity=capacity + ) + logger.info(f"Set weight for agent {agent_id}: {weight}, capacity: {capacity}") + + def update_agent_metrics(self, agent_id: str, metrics: LoadMetrics): + """Update agent load metrics""" + self.agent_metrics[agent_id] = metrics + self.agent_metrics[agent_id].last_updated = datetime.utcnow() + + # Update performance score based on metrics + self._update_performance_score(agent_id, metrics) + + def _update_performance_score(self, agent_id: str, metrics: LoadMetrics): + """Update agent performance score based on metrics""" + if agent_id not in self.agent_weights: + self.agent_weights[agent_id] = AgentWeight(agent_id=agent_id) + + weight = self.agent_weights[agent_id] + + # Calculate performance score (0.0 to 1.0) + performance_factors = [] + + # CPU usage factor (lower is better) + cpu_factor = max(0.0, 1.0 - metrics.cpu_usage) + performance_factors.append(cpu_factor) + + # Memory usage factor (lower is better) + memory_factor = max(0.0, 1.0 - metrics.memory_usage) + performance_factors.append(memory_factor) + + # Response time factor (lower is better) + if metrics.avg_response_time > 0: + response_factor = max(0.0, 1.0 - (metrics.avg_response_time / 10.0)) # 10s max + performance_factors.append(response_factor) + + # Success rate factor (higher is better) + total_tasks = metrics.completed_tasks + metrics.failed_tasks + if total_tasks > 0: + success_rate = metrics.completed_tasks / total_tasks + performance_factors.append(success_rate) + + # Update performance score + if performance_factors: + weight.performance_score = statistics.mean(performance_factors) + + # Update reliability score + if total_tasks > 10: # Only update after enough tasks + weight.reliability_score = success_rate + + async def assign_task(self, task_data: Dict[str, Any], requirements: Optional[Dict[str, Any]] = None) -> Optional[str]: + """Assign task to best available agent""" + try: + # Find eligible agents + eligible_agents = await self._find_eligible_agents(task_data, requirements) + + if not eligible_agents: + logger.warning("No eligible agents found for task assignment") + return None + + # Select best agent based on strategy + selected_agent = await self._select_agent(eligible_agents, task_data) + + if not selected_agent: + logger.warning("No agent selected for task assignment") + return None + + # Create task assignment + task_id = str(uuid.uuid4()) + assignment = TaskAssignment( + task_id=task_id, + agent_id=selected_agent, + assigned_at=datetime.utcnow() + ) + + # Record assignment + self.task_assignments[task_id] = assignment + self.assignment_history.append(assignment) + self.total_assignments += 1 + + # Update agent metrics + if selected_agent not in self.agent_metrics: + self.agent_metrics[selected_agent] = LoadMetrics() + + self.agent_metrics[selected_agent].pending_tasks += 1 + + logger.info(f"Task {task_id} assigned to agent {selected_agent}") + return selected_agent + + except Exception as e: + logger.error(f"Error assigning task: {e}") + self.failed_assignments += 1 + return None + + async def complete_task(self, task_id: str, success: bool, response_time: Optional[float] = None, error_message: Optional[str] = None): + """Mark task as completed""" + try: + if task_id not in self.task_assignments: + logger.warning(f"Task assignment {task_id} not found") + return + + assignment = self.task_assignments[task_id] + assignment.completed_at = datetime.utcnow() + assignment.status = "completed" + assignment.success = success + assignment.response_time = response_time + assignment.error_message = error_message + + # Update agent metrics + agent_id = assignment.agent_id + if agent_id in self.agent_metrics: + metrics = self.agent_metrics[agent_id] + metrics.pending_tasks = max(0, metrics.pending_tasks - 1) + + if success: + metrics.completed_tasks += 1 + self.successful_assignments += 1 + else: + metrics.failed_tasks += 1 + self.failed_assignments += 1 + + # Update average response time + if response_time: + total_completed = metrics.completed_tasks + metrics.failed_tasks + if total_completed > 0: + metrics.avg_response_time = ( + (metrics.avg_response_time * (total_completed - 1) + response_time) / total_completed + ) + + logger.info(f"Task {task_id} completed by agent {assignment.agent_id}, success: {success}") + + except Exception as e: + logger.error(f"Error completing task {task_id}: {e}") + + async def _find_eligible_agents(self, task_data: Dict[str, Any], requirements: Optional[Dict[str, Any]] = None) -> List[str]: + """Find eligible agents for task""" + try: + # Build discovery query + query = {"status": AgentStatus.ACTIVE} + + if requirements: + if "agent_type" in requirements: + query["agent_type"] = requirements["agent_type"] + + if "capabilities" in requirements: + query["capabilities"] = requirements["capabilities"] + + if "services" in requirements: + query["services"] = requirements["services"] + + if "min_health_score" in requirements: + query["min_health_score"] = requirements["min_health_score"] + + # Discover agents + agents = await self.registry.discover_agents(query) + + # Filter by capacity and load + eligible_agents = [] + for agent in agents: + agent_id = agent.agent_id + + # Check capacity + if agent_id in self.agent_weights: + weight = self.agent_weights[agent_id] + current_load = self._get_agent_load(agent_id) + + if current_load < weight.capacity: + eligible_agents.append(agent_id) + else: + # Default capacity check + metrics = self.agent_metrics.get(agent_id, LoadMetrics()) + if metrics.pending_tasks < 100: # Default capacity + eligible_agents.append(agent_id) + + return eligible_agents + + except Exception as e: + logger.error(f"Error finding eligible agents: {e}") + return [] + + def _get_agent_load(self, agent_id: str) -> int: + """Get current load for agent""" + metrics = self.agent_metrics.get(agent_id, LoadMetrics()) + return metrics.active_connections + metrics.pending_tasks + + async def _select_agent(self, eligible_agents: List[str], task_data: Dict[str, Any]) -> Optional[str]: + """Select best agent based on current strategy""" + if not eligible_agents: + return None + + if self.strategy == LoadBalancingStrategy.ROUND_ROBIN: + return self._round_robin_selection(eligible_agents) + elif self.strategy == LoadBalancingStrategy.LEAST_CONNECTIONS: + return self._least_connections_selection(eligible_agents) + elif self.strategy == LoadBalancingStrategy.LEAST_RESPONSE_TIME: + return self._least_response_time_selection(eligible_agents) + elif self.strategy == LoadBalancingStrategy.WEIGHTED_ROUND_ROBIN: + return self._weighted_round_robin_selection(eligible_agents) + elif self.strategy == LoadBalancingStrategy.RESOURCE_BASED: + return self._resource_based_selection(eligible_agents) + elif self.strategy == LoadBalancingStrategy.CAPABILITY_BASED: + return self._capability_based_selection(eligible_agents, task_data) + elif self.strategy == LoadBalancingStrategy.PREDICTIVE: + return self._predictive_selection(eligible_agents, task_data) + elif self.strategy == LoadBalancingStrategy.CONSISTENT_HASH: + return self._consistent_hash_selection(eligible_agents, task_data) + else: + return eligible_agents[0] + + def _round_robin_selection(self, agents: List[str]) -> str: + """Round-robin agent selection""" + agent = agents[self.round_robin_index % len(agents)] + self.round_robin_index += 1 + return agent + + def _least_connections_selection(self, agents: List[str]) -> str: + """Select agent with least connections""" + min_connections = float('inf') + selected_agent = None + + for agent_id in agents: + metrics = self.agent_metrics.get(agent_id, LoadMetrics()) + connections = metrics.active_connections + + if connections < min_connections: + min_connections = connections + selected_agent = agent_id + + return selected_agent or agents[0] + + def _least_response_time_selection(self, agents: List[str]) -> str: + """Select agent with least average response time""" + min_response_time = float('inf') + selected_agent = None + + for agent_id in agents: + metrics = self.agent_metrics.get(agent_id, LoadMetrics()) + response_time = metrics.avg_response_time + + if response_time < min_response_time: + min_response_time = response_time + selected_agent = agent_id + + return selected_agent or agents[0] + + def _weighted_round_robin_selection(self, agents: List[str]) -> str: + """Weighted round-robin selection""" + # Calculate total weight + total_weight = 0 + for agent_id in agents: + weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id)) + total_weight += weight.weight + + if total_weight == 0: + return agents[0] + + # Select agent based on weight + current_weight = self.round_robin_index % total_weight + accumulated_weight = 0 + + for agent_id in agents: + weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id)) + accumulated_weight += weight.weight + + if current_weight < accumulated_weight: + self.round_robin_index += 1 + return agent_id + + return agents[0] + + def _resource_based_selection(self, agents: List[str]) -> str: + """Resource-based selection considering CPU and memory""" + best_score = -1 + selected_agent = None + + for agent_id in agents: + metrics = self.agent_metrics.get(agent_id, LoadMetrics()) + + # Calculate resource score (lower usage is better) + cpu_score = max(0, 100 - metrics.cpu_usage) + memory_score = max(0, 100 - metrics.memory_usage) + resource_score = (cpu_score + memory_score) / 2 + + # Apply performance weight + weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id)) + final_score = resource_score * weight.performance_score + + if final_score > best_score: + best_score = final_score + selected_agent = agent_id + + return selected_agent or agents[0] + + def _capability_based_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str: + """Capability-based selection considering task requirements""" + required_capabilities = task_data.get("required_capabilities", []) + + if not required_capabilities: + return agents[0] + + best_score = -1 + selected_agent = None + + for agent_id in agents: + agent_info = self.registry.agents.get(agent_id) + if not agent_info: + continue + + # Calculate capability match score + agent_capabilities = set(agent_info.capabilities) + required_set = set(required_capabilities) + + if required_set.issubset(agent_capabilities): + # Perfect match + capability_score = 1.0 + else: + # Partial match + intersection = required_set.intersection(agent_capabilities) + capability_score = len(intersection) / len(required_set) + + # Apply performance weight + weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id)) + final_score = capability_score * weight.performance_score + + if final_score > best_score: + best_score = final_score + selected_agent = agent_id + + return selected_agent or agents[0] + + def _predictive_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str: + """Predictive selection using historical performance""" + task_type = task_data.get("task_type", "unknown") + + # Calculate predicted performance for each agent + best_score = -1 + selected_agent = None + + for agent_id in agents: + # Get historical performance for this task type + score = self._calculate_predicted_score(agent_id, task_type) + + if score > best_score: + best_score = score + selected_agent = agent_id + + return selected_agent or agents[0] + + def _calculate_predicted_score(self, agent_id: str, task_type: str) -> float: + """Calculate predicted performance score for agent""" + # Simple prediction based on recent performance + weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id)) + + # Base score from performance and reliability + base_score = (weight.performance_score + weight.reliability_score) / 2 + + # Adjust based on recent assignments + recent_assignments = [a for a in self.assignment_history if a.agent_id == agent_id][-10:] + if recent_assignments: + success_rate = sum(1 for a in recent_assignments if a.success) / len(recent_assignments) + base_score = base_score * 0.7 + success_rate * 0.3 + + return base_score + + def _consistent_hash_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str: + """Consistent hash selection for sticky routing""" + # Create hash key from task data + hash_key = json.dumps(task_data, sort_keys=True) + hash_value = int(hashlib.md5(hash_key.encode()).hexdigest(), 16) + + # Build hash ring if not exists + if not self.consistent_hash_ring: + self._build_hash_ring(agents) + + # Find agent on hash ring + for hash_pos in sorted(self.consistent_hash_ring.keys()): + if hash_value <= hash_pos: + return self.consistent_hash_ring[hash_pos] + + # Wrap around + return self.consistent_hash_ring[min(self.consistent_hash_ring.keys())] + + def _build_hash_ring(self, agents: List[str]): + """Build consistent hash ring""" + self.consistent_hash_ring = {} + + for agent_id in agents: + # Create multiple virtual nodes for better distribution + for i in range(100): + virtual_key = f"{agent_id}:{i}" + hash_value = int(hashlib.md5(virtual_key.encode()).hexdigest(), 16) + self.consistent_hash_ring[hash_value] = agent_id + + def get_load_balancing_stats(self) -> Dict[str, Any]: + """Get load balancing statistics""" + return { + "strategy": self.strategy.value, + "total_assignments": self.total_assignments, + "successful_assignments": self.successful_assignments, + "failed_assignments": self.failed_assignments, + "success_rate": self.successful_assignments / max(1, self.total_assignments), + "active_agents": len(self.agent_metrics), + "agent_weights": len(self.agent_weights), + "avg_agent_load": statistics.mean([self._get_agent_load(a) for a in self.agent_metrics]) if self.agent_metrics else 0 + } + + def get_agent_stats(self, agent_id: str) -> Optional[Dict[str, Any]]: + """Get detailed statistics for a specific agent""" + if agent_id not in self.agent_metrics: + return None + + metrics = self.agent_metrics[agent_id] + weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id)) + + # Get recent assignments + recent_assignments = [a for a in self.assignment_history if a.agent_id == agent_id][-10:] + + return { + "agent_id": agent_id, + "metrics": metrics.to_dict(), + "weight": { + "weight": weight.weight, + "capacity": weight.capacity, + "performance_score": weight.performance_score, + "reliability_score": weight.reliability_score + }, + "recent_assignments": [a.to_dict() for a in recent_assignments], + "current_load": self._get_agent_load(agent_id) + } + +class TaskDistributor: + """Task distributor with advanced load balancing""" + + def __init__(self, load_balancer: LoadBalancer): + self.load_balancer = load_balancer + self.task_queue = asyncio.Queue() + self.priority_queues = { + TaskPriority.URGENT: asyncio.Queue(), + TaskPriority.CRITICAL: asyncio.Queue(), + TaskPriority.HIGH: asyncio.Queue(), + TaskPriority.NORMAL: asyncio.Queue(), + TaskPriority.LOW: asyncio.Queue() + } + self.distribution_stats = { + "tasks_distributed": 0, + "tasks_completed": 0, + "tasks_failed": 0, + "avg_distribution_time": 0.0 + } + + async def submit_task(self, task_data: Dict[str, Any], priority: TaskPriority = TaskPriority.NORMAL, requirements: Optional[Dict[str, Any]] = None): + """Submit task for distribution""" + task_info = { + "task_data": task_data, + "priority": priority, + "requirements": requirements, + "submitted_at": datetime.utcnow() + } + + await self.priority_queues[priority].put(task_info) + logger.info(f"Task submitted with priority {priority.value}") + + async def start_distribution(self): + """Start task distribution loop""" + while True: + try: + # Check queues in priority order + task_info = None + + for priority in [TaskPriority.URGENT, TaskPriority.CRITICAL, TaskPriority.HIGH, TaskPriority.NORMAL, TaskPriority.LOW]: + queue = self.priority_queues[priority] + try: + task_info = queue.get_nowait() + break + except asyncio.QueueEmpty: + continue + + if task_info: + await self._distribute_task(task_info) + else: + await asyncio.sleep(0.01) # Small delay if no tasks + + except Exception as e: + logger.error(f"Error in distribution loop: {e}") + await asyncio.sleep(1) + + async def _distribute_task(self, task_info: Dict[str, Any]): + """Distribute a single task""" + start_time = datetime.utcnow() + + try: + # Assign task + agent_id = await self.load_balancer.assign_task( + task_info["task_data"], + task_info["requirements"] + ) + + if agent_id: + # Create task message + task_message = create_task_message( + sender_id="task_distributor", + receiver_id=agent_id, + task_type=task_info["task_data"].get("task_type", "unknown"), + task_data=task_info["task_data"] + ) + + # Send task to agent (implementation depends on communication system) + # await self._send_task_to_agent(agent_id, task_message) + + self.distribution_stats["tasks_distributed"] += 1 + + # Simulate task completion (in real implementation, this would be event-driven) + asyncio.create_task(self._simulate_task_completion(task_info, agent_id)) + + else: + logger.warning(f"Failed to distribute task: no suitable agent found") + self.distribution_stats["tasks_failed"] += 1 + + except Exception as e: + logger.error(f"Error distributing task: {e}") + self.distribution_stats["tasks_failed"] += 1 + + finally: + # Update distribution time + distribution_time = (datetime.utcnow() - start_time).total_seconds() + total_distributed = self.distribution_stats["tasks_distributed"] + self.distribution_stats["avg_distribution_time"] = ( + (self.distribution_stats["avg_distribution_time"] * (total_distributed - 1) + distribution_time) / total_distributed + if total_distributed > 0 else distribution_time + ) + + async def _simulate_task_completion(self, task_info: Dict[str, Any], agent_id: str): + """Simulate task completion (for testing)""" + # Simulate task processing time + processing_time = 1.0 + (hash(task_info["task_data"].get("task_id", "")) % 5) + await asyncio.sleep(processing_time) + + # Mark task as completed + success = hash(agent_id) % 10 > 1 # 90% success rate + await self.load_balancer.complete_task( + task_info["task_data"].get("task_id", str(uuid.uuid4())), + success, + processing_time + ) + + if success: + self.distribution_stats["tasks_completed"] += 1 + else: + self.distribution_stats["tasks_failed"] += 1 + + def get_distribution_stats(self) -> Dict[str, Any]: + """Get distribution statistics""" + return { + **self.distribution_stats, + "load_balancer_stats": self.load_balancer.get_load_balancing_stats(), + "queue_sizes": { + priority.value: queue.qsize() + for priority, queue in self.priority_queues.items() + } + } + +# Example usage +async def example_usage(): + """Example of how to use the load balancer""" + + # Create registry and load balancer + registry = AgentRegistry() + await registry.start() + + load_balancer = LoadBalancer(registry) + load_balancer.set_strategy(LoadBalancingStrategy.LEAST_CONNECTIONS) + + # Create task distributor + distributor = TaskDistributor(load_balancer) + + # Submit some tasks + for i in range(10): + await distributor.submit_task({ + "task_id": f"task-{i}", + "task_type": "data_processing", + "data": f"sample_data_{i}" + }, TaskPriority.NORMAL) + + # Start distribution (in real implementation, this would run in background) + # await distributor.start_distribution() + + await registry.stop() + +if __name__ == "__main__": + asyncio.run(example_usage()) diff --git a/apps/agent-coordinator/tests/test_communication.py b/apps/agent-coordinator/tests/test_communication.py new file mode 100644 index 00000000..71145856 --- /dev/null +++ b/apps/agent-coordinator/tests/test_communication.py @@ -0,0 +1,326 @@ +""" +Tests for Agent Communication Protocols +""" + +import pytest +import asyncio +from datetime import datetime, timedelta +from unittest.mock import Mock, AsyncMock + +from src.app.protocols.communication import ( + AgentMessage, MessageType, Priority, CommunicationProtocol, + HierarchicalProtocol, PeerToPeerProtocol, BroadcastProtocol, + CommunicationManager, MessageTemplates +) + +class TestAgentMessage: + """Test AgentMessage class""" + + def test_message_creation(self): + """Test message creation""" + message = AgentMessage( + sender_id="agent-001", + receiver_id="agent-002", + message_type=MessageType.DIRECT, + priority=Priority.NORMAL, + payload={"data": "test"} + ) + + assert message.sender_id == "agent-001" + assert message.receiver_id == "agent-002" + assert message.message_type == MessageType.DIRECT + assert message.priority == Priority.NORMAL + assert message.payload["data"] == "test" + assert message.ttl == 300 + + def test_message_serialization(self): + """Test message serialization""" + message = AgentMessage( + sender_id="agent-001", + receiver_id="agent-002", + message_type=MessageType.DIRECT, + priority=Priority.NORMAL, + payload={"data": "test"} + ) + + # To dict + message_dict = message.to_dict() + assert message_dict["sender_id"] == "agent-001" + assert message_dict["message_type"] == "direct" + assert message_dict["priority"] == "normal" + + # From dict + restored_message = AgentMessage.from_dict(message_dict) + assert restored_message.sender_id == message.sender_id + assert restored_message.receiver_id == message.receiver_id + assert restored_message.message_type == message.message_type + assert restored_message.priority == message.priority + + def test_message_expiration(self): + """Test message expiration""" + old_message = AgentMessage( + sender_id="agent-001", + receiver_id="agent-002", + message_type=MessageType.DIRECT, + timestamp=datetime.utcnow() - timedelta(seconds=400), + ttl=300 + ) + + # Message should be expired + age = (datetime.utcnow() - old_message.timestamp).total_seconds() + assert age > old_message.ttl + +class TestHierarchicalProtocol: + """Test HierarchicalProtocol class""" + + @pytest.fixture + def master_protocol(self): + """Create master protocol""" + return HierarchicalProtocol("master-agent", is_master=True) + + @pytest.fixture + def sub_protocol(self): + """Create sub-agent protocol""" + return HierarchicalProtocol("sub-agent", is_master=False) + + def test_add_sub_agent(self, master_protocol): + """Test adding sub-agent""" + master_protocol.add_sub_agent("sub-agent-001") + assert "sub-agent-001" in master_protocol.sub_agents + + def test_send_to_sub_agents(self, master_protocol): + """Test sending to sub-agents""" + master_protocol.add_sub_agent("sub-agent-001") + master_protocol.add_sub_agent("sub-agent-002") + + message = MessageTemplates.create_heartbeat("master-agent") + + # Mock the send_message method + master_protocol.send_message = AsyncMock(return_value=True) + + # Should send to both sub-agents + asyncio.run(master_protocol.send_to_sub_agents(message)) + + # Check that send_message was called twice + assert master_protocol.send_message.call_count == 2 + + def test_send_to_master(self, sub_protocol): + """Test sending to master""" + sub_protocol.master_agent = "master-agent" + + message = MessageTemplates.create_status_update("sub-agent", {"status": "active"}) + + # Mock the send_message method + sub_protocol.send_message = AsyncMock(return_value=True) + + asyncio.run(sub_protocol.send_to_master(message)) + + # Check that send_message was called once + assert sub_protocol.send_message.call_count == 1 + +class TestPeerToPeerProtocol: + """Test PeerToPeerProtocol class""" + + @pytest.fixture + def p2p_protocol(self): + """Create P2P protocol""" + return PeerToPeerProtocol("agent-001") + + def test_add_peer(self, p2p_protocol): + """Test adding peer""" + p2p_protocol.add_peer("agent-002", {"endpoint": "http://localhost:8002"}) + assert "agent-002" in p2p_protocol.peers + assert p2p_protocol.peers["agent-002"]["endpoint"] == "http://localhost:8002" + + def test_remove_peer(self, p2p_protocol): + """Test removing peer""" + p2p_protocol.add_peer("agent-002", {"endpoint": "http://localhost:8002"}) + p2p_protocol.remove_peer("agent-002") + assert "agent-002" not in p2p_protocol.peers + + def test_send_to_peer(self, p2p_protocol): + """Test sending to peer""" + p2p_protocol.add_peer("agent-002", {"endpoint": "http://localhost:8002"}) + + message = MessageTemplates.create_task_assignment( + "agent-001", "agent-002", {"task": "test"} + ) + + # Mock the send_message method + p2p_protocol.send_message = AsyncMock(return_value=True) + + result = asyncio.run(p2p_protocol.send_to_peer(message, "agent-002")) + + assert result is True + assert p2p_protocol.send_message.call_count == 1 + +class TestBroadcastProtocol: + """Test BroadcastProtocol class""" + + @pytest.fixture + def broadcast_protocol(self): + """Create broadcast protocol""" + return BroadcastProtocol("agent-001", "test-channel") + + def test_subscribe_unsubscribe(self, broadcast_protocol): + """Test subscribe and unsubscribe""" + broadcast_protocol.subscribe("agent-002") + assert "agent-002" in broadcast_protocol.subscribers + + broadcast_protocol.unsubscribe("agent-002") + assert "agent-002" not in broadcast_protocol.subscribers + + def test_broadcast(self, broadcast_protocol): + """Test broadcasting""" + broadcast_protocol.subscribe("agent-002") + broadcast_protocol.subscribe("agent-003") + + message = MessageTemplates.create_discovery("agent-001") + + # Mock the send_message method + broadcast_protocol.send_message = AsyncMock(return_value=True) + + asyncio.run(broadcast_protocol.broadcast(message)) + + # Should send to 2 subscribers (not including self) + assert broadcast_protocol.send_message.call_count == 2 + +class TestCommunicationManager: + """Test CommunicationManager class""" + + @pytest.fixture + def comm_manager(self): + """Create communication manager""" + return CommunicationManager("agent-001") + + def test_add_protocol(self, comm_manager): + """Test adding protocol""" + protocol = Mock(spec=CommunicationProtocol) + comm_manager.add_protocol("test", protocol) + + assert "test" in comm_manager.protocols + assert comm_manager.protocols["test"] == protocol + + def test_get_protocol(self, comm_manager): + """Test getting protocol""" + protocol = Mock(spec=CommunicationProtocol) + comm_manager.add_protocol("test", protocol) + + retrieved_protocol = comm_manager.get_protocol("test") + assert retrieved_protocol == protocol + + # Test non-existent protocol + assert comm_manager.get_protocol("non-existent") is None + + @pytest.mark.asyncio + async def test_send_message(self, comm_manager): + """Test sending message""" + protocol = Mock(spec=CommunicationProtocol) + protocol.send_message = AsyncMock(return_value=True) + comm_manager.add_protocol("test", protocol) + + message = MessageTemplates.create_heartbeat("agent-001") + result = await comm_manager.send_message("test", message) + + assert result is True + protocol.send_message.assert_called_once_with(message) + + @pytest.mark.asyncio + async def test_register_handler(self, comm_manager): + """Test registering handler""" + protocol = Mock(spec=CommunicationProtocol) + protocol.register_handler = AsyncMock() + comm_manager.add_protocol("test", protocol) + + handler = AsyncMock() + await comm_manager.register_handler("test", MessageType.HEARTBEAT, handler) + + protocol.register_handler.assert_called_once_with(MessageType.HEARTBEAT, handler) + +class TestMessageTemplates: + """Test MessageTemplates class""" + + def test_create_heartbeat(self): + """Test creating heartbeat message""" + message = MessageTemplates.create_heartbeat("agent-001") + + assert message.sender_id == "agent-001" + assert message.message_type == MessageType.HEARTBEAT + assert message.priority == Priority.LOW + assert "timestamp" in message.payload + + def test_create_task_assignment(self): + """Test creating task assignment message""" + task_data = {"task_id": "task-001", "task_type": "process_data"} + message = MessageTemplates.create_task_assignment("agent-001", "agent-002", task_data) + + assert message.sender_id == "agent-001" + assert message.receiver_id == "agent-002" + assert message.message_type == MessageType.TASK_ASSIGNMENT + assert message.payload == task_data + + def test_create_status_update(self): + """Test creating status update message""" + status_data = {"status": "active", "load": 0.5} + message = MessageTemplates.create_status_update("agent-001", status_data) + + assert message.sender_id == "agent-001" + assert message.message_type == MessageType.STATUS_UPDATE + assert message.payload == status_data + + def test_create_discovery(self): + """Test creating discovery message""" + message = MessageTemplates.create_discovery("agent-001") + + assert message.sender_id == "agent-001" + assert message.message_type == MessageType.DISCOVERY + assert message.payload["agent_id"] == "agent-001" + + def test_create_consensus_request(self): + """Test creating consensus request message""" + proposal_data = {"proposal": "test_proposal"} + message = MessageTemplates.create_consensus_request("agent-001", proposal_data) + + assert message.sender_id == "agent-001" + assert message.message_type == MessageType.CONSENSUS + assert message.priority == Priority.HIGH + assert message.payload == proposal_data + +# Integration tests +class TestCommunicationIntegration: + """Integration tests for communication system""" + + @pytest.mark.asyncio + async def test_message_flow(self): + """Test complete message flow""" + # Create communication manager + comm_manager = CommunicationManager("agent-001") + + # Create protocols + hierarchical = HierarchicalProtocol("agent-001", is_master=True) + p2p = PeerToPeerProtocol("agent-001") + + # Add protocols + comm_manager.add_protocol("hierarchical", hierarchical) + comm_manager.add_protocol("p2p", p2p) + + # Mock message sending + hierarchical.send_message = AsyncMock(return_value=True) + p2p.send_message = AsyncMock(return_value=True) + + # Register handler + async def handle_heartbeat(message): + assert message.sender_id == "agent-002" + assert message.message_type == MessageType.HEARTBEAT + + await comm_manager.register_handler("hierarchical", MessageType.HEARTBEAT, handle_heartbeat) + + # Send heartbeat + heartbeat = MessageTemplates.create_heartbeat("agent-001") + result = await comm_manager.send_message("hierarchical", heartbeat) + + assert result is True + hierarchical.send_message.assert_called_once() + +if __name__ == "__main__": + pytest.main([__file__])