feat: implement agent coordination foundation (Week 1)
✅ Multi-Agent Communication Framework - Implemented comprehensive communication protocols - Created hierarchical, P2P, and broadcast protocols - Added message types and routing system - Implemented agent discovery and registration - Created load balancer for task distribution - Built FastAPI application with full API ✅ Core Components Implemented - CommunicationManager: Protocol management - MessageRouter: Advanced message routing - AgentRegistry: Agent discovery and management - LoadBalancer: Intelligent task distribution - TaskDistributor: Priority-based task handling - WebSocketHandler: Real-time communication ✅ API Endpoints - /health: Health check endpoint - /agents/register: Agent registration - /agents/discover: Agent discovery - /tasks/submit: Task submission - /messages/send: Message sending - /load-balancer/stats: Load balancing statistics - /registry/stats: Registry statistics ✅ Production Ready - SystemD service configuration - Docker containerization - Comprehensive test suite - Configuration management - Error handling and logging - Performance monitoring 🚀 Week 1 complete: Agent coordination foundation implemented!
This commit is contained in:
39
apps/agent-coordinator/Dockerfile
Normal file
39
apps/agent-coordinator/Dockerfile
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PYTHONPATH=/app/src
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
gcc \
|
||||||
|
g++ \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
COPY pyproject.toml poetry.lock ./
|
||||||
|
RUN pip install poetry && \
|
||||||
|
poetry config virtualenvs.create false && \
|
||||||
|
poetry install --no-dev --no-interaction --no-ansi
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY src/ ./src/
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN useradd --create-home --shell /bin/bash app && \
|
||||||
|
chown -R app:app /app
|
||||||
|
USER app
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:9001/health || exit 1
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 9001
|
||||||
|
|
||||||
|
# Start the application
|
||||||
|
CMD ["poetry", "run", "python", "-m", "uvicorn", "src.app.main:app", "--host", "0.0.0.0", "--port", "9001"]
|
||||||
460
apps/agent-coordinator/src/app/config.py
Normal file
460
apps/agent-coordinator/src/app/config.py
Normal file
@@ -0,0 +1,460 @@
|
|||||||
|
"""
|
||||||
|
Configuration Management for AITBC Agent Coordinator
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from pydantic import BaseSettings, Field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class Environment(str, Enum):
|
||||||
|
"""Environment types"""
|
||||||
|
DEVELOPMENT = "development"
|
||||||
|
TESTING = "testing"
|
||||||
|
STAGING = "staging"
|
||||||
|
PRODUCTION = "production"
|
||||||
|
|
||||||
|
class LogLevel(str, Enum):
|
||||||
|
"""Log levels"""
|
||||||
|
DEBUG = "DEBUG"
|
||||||
|
INFO = "INFO"
|
||||||
|
WARNING = "WARNING"
|
||||||
|
ERROR = "ERROR"
|
||||||
|
CRITICAL = "CRITICAL"
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings"""
|
||||||
|
|
||||||
|
# Application settings
|
||||||
|
app_name: str = "AITBC Agent Coordinator"
|
||||||
|
app_version: str = "1.0.0"
|
||||||
|
environment: Environment = Environment.DEVELOPMENT
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# Server settings
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 9001
|
||||||
|
workers: int = 1
|
||||||
|
|
||||||
|
# Redis settings
|
||||||
|
redis_url: str = "redis://localhost:6379/1"
|
||||||
|
redis_max_connections: int = 10
|
||||||
|
redis_timeout: int = 5
|
||||||
|
|
||||||
|
# Database settings (if needed)
|
||||||
|
database_url: Optional[str] = None
|
||||||
|
|
||||||
|
# Agent registry settings
|
||||||
|
heartbeat_interval: int = 30 # seconds
|
||||||
|
max_heartbeat_age: int = 120 # seconds
|
||||||
|
cleanup_interval: int = 60 # seconds
|
||||||
|
agent_ttl: int = 86400 # 24 hours in seconds
|
||||||
|
|
||||||
|
# Load balancer settings
|
||||||
|
default_strategy: str = "least_connections"
|
||||||
|
max_task_queue_size: int = 10000
|
||||||
|
task_timeout: int = 300 # 5 minutes
|
||||||
|
|
||||||
|
# Communication settings
|
||||||
|
message_ttl: int = 300 # 5 minutes
|
||||||
|
max_message_size: int = 1024 * 1024 # 1MB
|
||||||
|
connection_timeout: int = 30
|
||||||
|
|
||||||
|
# Security settings
|
||||||
|
secret_key: str = "your-secret-key-change-in-production"
|
||||||
|
allowed_hosts: list = ["*"]
|
||||||
|
cors_origins: list = ["*"]
|
||||||
|
|
||||||
|
# Monitoring settings
|
||||||
|
enable_metrics: bool = True
|
||||||
|
metrics_port: int = 9002
|
||||||
|
health_check_interval: int = 30
|
||||||
|
|
||||||
|
# Logging settings
|
||||||
|
log_level: LogLevel = LogLevel.INFO
|
||||||
|
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
log_file: Optional[str] = None
|
||||||
|
|
||||||
|
# Performance settings
|
||||||
|
max_concurrent_tasks: int = 100
|
||||||
|
task_batch_size: int = 10
|
||||||
|
load_balancer_cache_size: int = 1000
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
case_sensitive = False
|
||||||
|
|
||||||
|
# Global settings instance
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
# Configuration constants
|
||||||
|
class ConfigConstants:
|
||||||
|
"""Configuration constants"""
|
||||||
|
|
||||||
|
# Agent types
|
||||||
|
AGENT_TYPES = [
|
||||||
|
"coordinator",
|
||||||
|
"worker",
|
||||||
|
"specialist",
|
||||||
|
"monitor",
|
||||||
|
"gateway",
|
||||||
|
"orchestrator"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Agent statuses
|
||||||
|
AGENT_STATUSES = [
|
||||||
|
"active",
|
||||||
|
"inactive",
|
||||||
|
"busy",
|
||||||
|
"maintenance",
|
||||||
|
"error"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Message types
|
||||||
|
MESSAGE_TYPES = [
|
||||||
|
"coordination",
|
||||||
|
"task_assignment",
|
||||||
|
"status_update",
|
||||||
|
"discovery",
|
||||||
|
"heartbeat",
|
||||||
|
"consensus",
|
||||||
|
"broadcast",
|
||||||
|
"direct",
|
||||||
|
"peer_to_peer",
|
||||||
|
"hierarchical"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Task priorities
|
||||||
|
TASK_PRIORITIES = [
|
||||||
|
"low",
|
||||||
|
"normal",
|
||||||
|
"high",
|
||||||
|
"critical",
|
||||||
|
"urgent"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Load balancing strategies
|
||||||
|
LOAD_BALANCING_STRATEGIES = [
|
||||||
|
"round_robin",
|
||||||
|
"least_connections",
|
||||||
|
"least_response_time",
|
||||||
|
"weighted_round_robin",
|
||||||
|
"resource_based",
|
||||||
|
"capability_based",
|
||||||
|
"predictive",
|
||||||
|
"consistent_hash"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default ports
|
||||||
|
DEFAULT_PORTS = {
|
||||||
|
"agent_coordinator": 9001,
|
||||||
|
"agent_registry": 9002,
|
||||||
|
"task_distributor": 9003,
|
||||||
|
"metrics": 9004,
|
||||||
|
"health": 9005
|
||||||
|
}
|
||||||
|
|
||||||
|
# Timeouts (in seconds)
|
||||||
|
TIMEOUTS = {
|
||||||
|
"connection": 30,
|
||||||
|
"message": 300,
|
||||||
|
"task": 600,
|
||||||
|
"heartbeat": 120,
|
||||||
|
"cleanup": 3600
|
||||||
|
}
|
||||||
|
|
||||||
|
# Limits
|
||||||
|
LIMITS = {
|
||||||
|
"max_message_size": 1024 * 1024, # 1MB
|
||||||
|
"max_task_queue_size": 10000,
|
||||||
|
"max_concurrent_tasks": 100,
|
||||||
|
"max_agent_connections": 1000,
|
||||||
|
"max_redis_connections": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
# Environment-specific configurations
|
||||||
|
class EnvironmentConfig:
|
||||||
|
"""Environment-specific configurations"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_development_config() -> Dict[str, Any]:
|
||||||
|
"""Development environment configuration"""
|
||||||
|
return {
|
||||||
|
"debug": True,
|
||||||
|
"log_level": LogLevel.DEBUG,
|
||||||
|
"reload": True,
|
||||||
|
"workers": 1,
|
||||||
|
"redis_url": "redis://localhost:6379/1",
|
||||||
|
"enable_metrics": True
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_testing_config() -> Dict[str, Any]:
|
||||||
|
"""Testing environment configuration"""
|
||||||
|
return {
|
||||||
|
"debug": True,
|
||||||
|
"log_level": LogLevel.DEBUG,
|
||||||
|
"redis_url": "redis://localhost:6379/15", # Separate DB for testing
|
||||||
|
"enable_metrics": False,
|
||||||
|
"heartbeat_interval": 5, # Faster for testing
|
||||||
|
"cleanup_interval": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_staging_config() -> Dict[str, Any]:
|
||||||
|
"""Staging environment configuration"""
|
||||||
|
return {
|
||||||
|
"debug": False,
|
||||||
|
"log_level": LogLevel.INFO,
|
||||||
|
"redis_url": "redis://localhost:6379/2",
|
||||||
|
"enable_metrics": True,
|
||||||
|
"workers": 2,
|
||||||
|
"cors_origins": ["https://staging.aitbc.com"]
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_production_config() -> Dict[str, Any]:
|
||||||
|
"""Production environment configuration"""
|
||||||
|
return {
|
||||||
|
"debug": False,
|
||||||
|
"log_level": LogLevel.WARNING,
|
||||||
|
"redis_url": os.getenv("REDIS_URL", "redis://localhost:6379/0"),
|
||||||
|
"enable_metrics": True,
|
||||||
|
"workers": 4,
|
||||||
|
"cors_origins": ["https://aitbc.com"],
|
||||||
|
"secret_key": os.getenv("SECRET_KEY", "change-this-in-production"),
|
||||||
|
"allowed_hosts": ["aitbc.com", "www.aitbc.com"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Configuration loader
|
||||||
|
class ConfigLoader:
|
||||||
|
"""Configuration loader and validator"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_config() -> Settings:
|
||||||
|
"""Load and validate configuration"""
|
||||||
|
# Get environment-specific config
|
||||||
|
env_config = {}
|
||||||
|
if settings.environment == Environment.DEVELOPMENT:
|
||||||
|
env_config = EnvironmentConfig.get_development_config()
|
||||||
|
elif settings.environment == Environment.TESTING:
|
||||||
|
env_config = EnvironmentConfig.get_testing_config()
|
||||||
|
elif settings.environment == Environment.STAGING:
|
||||||
|
env_config = EnvironmentConfig.get_staging_config()
|
||||||
|
elif settings.environment == Environment.PRODUCTION:
|
||||||
|
env_config = EnvironmentConfig.get_production_config()
|
||||||
|
|
||||||
|
# Update settings with environment-specific config
|
||||||
|
for key, value in env_config.items():
|
||||||
|
if hasattr(settings, key):
|
||||||
|
setattr(settings, key, value)
|
||||||
|
|
||||||
|
# Validate configuration
|
||||||
|
ConfigLoader.validate_config()
|
||||||
|
|
||||||
|
return settings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_config():
|
||||||
|
"""Validate configuration settings"""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
# Validate required settings
|
||||||
|
if not settings.secret_key or settings.secret_key == "your-secret-key-change-in-production":
|
||||||
|
if settings.environment == Environment.PRODUCTION:
|
||||||
|
errors.append("SECRET_KEY must be set in production")
|
||||||
|
|
||||||
|
# Validate ports
|
||||||
|
if settings.port < 1 or settings.port > 65535:
|
||||||
|
errors.append("Port must be between 1 and 65535")
|
||||||
|
|
||||||
|
# Validate Redis URL
|
||||||
|
if not settings.redis_url:
|
||||||
|
errors.append("Redis URL is required")
|
||||||
|
|
||||||
|
# Validate timeouts
|
||||||
|
if settings.heartbeat_interval <= 0:
|
||||||
|
errors.append("Heartbeat interval must be positive")
|
||||||
|
|
||||||
|
if settings.max_heartbeat_age <= settings.heartbeat_interval:
|
||||||
|
errors.append("Max heartbeat age must be greater than heartbeat interval")
|
||||||
|
|
||||||
|
# Validate limits
|
||||||
|
if settings.max_message_size <= 0:
|
||||||
|
errors.append("Max message size must be positive")
|
||||||
|
|
||||||
|
if settings.max_task_queue_size <= 0:
|
||||||
|
errors.append("Max task queue size must be positive")
|
||||||
|
|
||||||
|
# Validate strategy
|
||||||
|
if settings.default_strategy not in ConfigConstants.LOAD_BALANCING_STRATEGIES:
|
||||||
|
errors.append(f"Invalid load balancing strategy: {settings.default_strategy}")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
raise ValueError(f"Configuration validation failed: {', '.join(errors)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_redis_config() -> Dict[str, Any]:
|
||||||
|
"""Get Redis configuration"""
|
||||||
|
return {
|
||||||
|
"url": settings.redis_url,
|
||||||
|
"max_connections": settings.redis_max_connections,
|
||||||
|
"timeout": settings.redis_timeout,
|
||||||
|
"decode_responses": True,
|
||||||
|
"socket_keepalive": True,
|
||||||
|
"socket_keepalive_options": {},
|
||||||
|
"health_check_interval": 30
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_logging_config() -> Dict[str, Any]:
|
||||||
|
"""Get logging configuration"""
|
||||||
|
return {
|
||||||
|
"version": 1,
|
||||||
|
"disable_existing_loggers": False,
|
||||||
|
"formatters": {
|
||||||
|
"default": {
|
||||||
|
"format": settings.log_format,
|
||||||
|
"datefmt": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
"detailed": {
|
||||||
|
"format": "%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(message)s",
|
||||||
|
"datefmt": "%Y-%m-%d %H:%M:%S"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"handlers": {
|
||||||
|
"console": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"level": settings.log_level.value,
|
||||||
|
"formatter": "default",
|
||||||
|
"stream": "ext://sys.stdout"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"loggers": {
|
||||||
|
"": {
|
||||||
|
"level": settings.log_level.value,
|
||||||
|
"handlers": ["console"]
|
||||||
|
},
|
||||||
|
"uvicorn": {
|
||||||
|
"level": "INFO",
|
||||||
|
"handlers": ["console"],
|
||||||
|
"propagate": False
|
||||||
|
},
|
||||||
|
"fastapi": {
|
||||||
|
"level": "INFO",
|
||||||
|
"handlers": ["console"],
|
||||||
|
"propagate": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Configuration utilities
|
||||||
|
class ConfigUtils:
|
||||||
|
"""Configuration utilities"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_agent_config(agent_type: str) -> Dict[str, Any]:
|
||||||
|
"""Get configuration for specific agent type"""
|
||||||
|
base_config = {
|
||||||
|
"heartbeat_interval": settings.heartbeat_interval,
|
||||||
|
"max_connections": 100,
|
||||||
|
"timeout": settings.connection_timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
# Agent-specific configurations
|
||||||
|
agent_configs = {
|
||||||
|
"coordinator": {
|
||||||
|
**base_config,
|
||||||
|
"max_connections": 1000,
|
||||||
|
"heartbeat_interval": 15,
|
||||||
|
"enable_coordination": True
|
||||||
|
},
|
||||||
|
"worker": {
|
||||||
|
**base_config,
|
||||||
|
"max_connections": 50,
|
||||||
|
"task_timeout": 300,
|
||||||
|
"enable_coordination": False
|
||||||
|
},
|
||||||
|
"specialist": {
|
||||||
|
**base_config,
|
||||||
|
"max_connections": 25,
|
||||||
|
"specialization_timeout": 600,
|
||||||
|
"enable_coordination": True
|
||||||
|
},
|
||||||
|
"monitor": {
|
||||||
|
**base_config,
|
||||||
|
"heartbeat_interval": 10,
|
||||||
|
"enable_coordination": True,
|
||||||
|
"monitoring_interval": 30
|
||||||
|
},
|
||||||
|
"gateway": {
|
||||||
|
**base_config,
|
||||||
|
"max_connections": 2000,
|
||||||
|
"enable_coordination": True,
|
||||||
|
"gateway_timeout": 60
|
||||||
|
},
|
||||||
|
"orchestrator": {
|
||||||
|
**base_config,
|
||||||
|
"max_connections": 500,
|
||||||
|
"heartbeat_interval": 5,
|
||||||
|
"enable_coordination": True,
|
||||||
|
"orchestration_timeout": 120
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return agent_configs.get(agent_type, base_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_service_config(service_name: str) -> Dict[str, Any]:
|
||||||
|
"""Get configuration for specific service"""
|
||||||
|
base_config = {
|
||||||
|
"host": settings.host,
|
||||||
|
"port": settings.port,
|
||||||
|
"workers": settings.workers,
|
||||||
|
"timeout": settings.connection_timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
# Service-specific configurations
|
||||||
|
service_configs = {
|
||||||
|
"agent_coordinator": {
|
||||||
|
**base_config,
|
||||||
|
"port": ConfigConstants.DEFAULT_PORTS["agent_coordinator"],
|
||||||
|
"enable_metrics": settings.enable_metrics
|
||||||
|
},
|
||||||
|
"agent_registry": {
|
||||||
|
**base_config,
|
||||||
|
"port": ConfigConstants.DEFAULT_PORTS["agent_registry"],
|
||||||
|
"enable_metrics": False
|
||||||
|
},
|
||||||
|
"task_distributor": {
|
||||||
|
**base_config,
|
||||||
|
"port": ConfigConstants.DEFAULT_PORTS["task_distributor"],
|
||||||
|
"max_queue_size": settings.max_task_queue_size
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
**base_config,
|
||||||
|
"port": ConfigConstants.DEFAULT_PORTS["metrics"],
|
||||||
|
"enable_metrics": True
|
||||||
|
},
|
||||||
|
"health": {
|
||||||
|
**base_config,
|
||||||
|
"port": ConfigConstants.DEFAULT_PORTS["health"],
|
||||||
|
"enable_metrics": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return service_configs.get(service_name, base_config)
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
config = ConfigLoader.load_config()
|
||||||
|
|
||||||
|
# Export settings and utilities
|
||||||
|
__all__ = [
|
||||||
|
"settings",
|
||||||
|
"config",
|
||||||
|
"ConfigConstants",
|
||||||
|
"EnvironmentConfig",
|
||||||
|
"ConfigLoader",
|
||||||
|
"ConfigUtils"
|
||||||
|
]
|
||||||
518
apps/agent-coordinator/src/app/main.py
Normal file
518
apps/agent-coordinator/src/app/main.py
Normal file
@@ -0,0 +1,518 @@
|
|||||||
|
"""
|
||||||
|
Main FastAPI Application for AITBC Agent Coordinator
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, status
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from .protocols.communication import CommunicationManager, create_protocol, MessageType
|
||||||
|
from .protocols.message_types import MessageProcessor, create_task_message, create_status_message
|
||||||
|
from .routing.agent_discovery import AgentRegistry, AgentDiscoveryService, create_agent_info
|
||||||
|
from .routing.load_balancer import LoadBalancer, TaskDistributor, TaskPriority, LoadBalancingStrategy
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global variables
|
||||||
|
agent_registry: Optional[AgentRegistry] = None
|
||||||
|
discovery_service: Optional[AgentDiscoveryService] = None
|
||||||
|
load_balancer: Optional[LoadBalancer] = None
|
||||||
|
task_distributor: Optional[TaskDistributor] = None
|
||||||
|
communication_manager: Optional[CommunicationManager] = None
|
||||||
|
message_processor: Optional[MessageProcessor] = None
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan management"""
|
||||||
|
# Startup
|
||||||
|
logger.info("Starting AITBC Agent Coordinator...")
|
||||||
|
|
||||||
|
# Initialize services
|
||||||
|
global agent_registry, discovery_service, load_balancer, task_distributor, communication_manager, message_processor
|
||||||
|
|
||||||
|
# Start agent registry
|
||||||
|
agent_registry = AgentRegistry()
|
||||||
|
await agent_registry.start()
|
||||||
|
|
||||||
|
# Initialize discovery service
|
||||||
|
discovery_service = AgentDiscoveryService(agent_registry)
|
||||||
|
|
||||||
|
# Initialize load balancer
|
||||||
|
load_balancer = LoadBalancer(agent_registry)
|
||||||
|
load_balancer.set_strategy(LoadBalancingStrategy.LEAST_CONNECTIONS)
|
||||||
|
|
||||||
|
# Initialize task distributor
|
||||||
|
task_distributor = TaskDistributor(load_balancer)
|
||||||
|
|
||||||
|
# Initialize communication manager
|
||||||
|
communication_manager = CommunicationManager("agent-coordinator")
|
||||||
|
|
||||||
|
# Initialize message processor
|
||||||
|
message_processor = MessageProcessor("agent-coordinator")
|
||||||
|
|
||||||
|
# Start background tasks
|
||||||
|
asyncio.create_task(task_distributor.start_distribution())
|
||||||
|
asyncio.create_task(message_processor.start_processing())
|
||||||
|
|
||||||
|
logger.info("Agent Coordinator started successfully")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
logger.info("Shutting down AITBC Agent Coordinator...")
|
||||||
|
|
||||||
|
if agent_registry:
|
||||||
|
await agent_registry.stop()
|
||||||
|
|
||||||
|
logger.info("Agent Coordinator shut down")
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title="AITBC Agent Coordinator",
|
||||||
|
description="Advanced multi-agent coordination and management system",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pydantic models
|
||||||
|
class AgentRegistrationRequest(BaseModel):
|
||||||
|
agent_id: str = Field(..., description="Unique agent identifier")
|
||||||
|
agent_type: str = Field(..., description="Type of agent")
|
||||||
|
capabilities: List[str] = Field(default_factory=list, description="Agent capabilities")
|
||||||
|
services: List[str] = Field(default_factory=list, description="Available services")
|
||||||
|
endpoints: Dict[str, str] = Field(default_factory=dict, description="Service endpoints")
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||||
|
|
||||||
|
class AgentStatusUpdate(BaseModel):
|
||||||
|
status: str = Field(..., description="Agent status")
|
||||||
|
load_metrics: Dict[str, float] = Field(default_factory=dict, description="Load metrics")
|
||||||
|
|
||||||
|
class TaskSubmission(BaseModel):
|
||||||
|
task_data: Dict[str, Any] = Field(..., description="Task data")
|
||||||
|
priority: str = Field("normal", description="Task priority")
|
||||||
|
requirements: Optional[Dict[str, Any]] = Field(None, description="Task requirements")
|
||||||
|
|
||||||
|
class MessageRequest(BaseModel):
|
||||||
|
receiver_id: str = Field(..., description="Receiver agent ID")
|
||||||
|
message_type: str = Field(..., description="Message type")
|
||||||
|
payload: Dict[str, Any] = Field(..., description="Message payload")
|
||||||
|
priority: str = Field("normal", description="Message priority")
|
||||||
|
|
||||||
|
# Health check endpoint
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint"""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "agent-coordinator",
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"version": "1.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Root endpoint
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint with service information"""
|
||||||
|
return {
|
||||||
|
"service": "AITBC Agent Coordinator",
|
||||||
|
"description": "Advanced multi-agent coordination and management system",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"endpoints": [
|
||||||
|
"/health",
|
||||||
|
"/agents/register",
|
||||||
|
"/agents/discover",
|
||||||
|
"/agents/{agent_id}",
|
||||||
|
"/agents/{agent_id}/status",
|
||||||
|
"/tasks/submit",
|
||||||
|
"/tasks/status",
|
||||||
|
"/messages/send",
|
||||||
|
"/load-balancer/stats",
|
||||||
|
"/registry/stats"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Agent registration
|
||||||
|
@app.post("/agents/register")
|
||||||
|
async def register_agent(request: AgentRegistrationRequest):
|
||||||
|
"""Register a new agent"""
|
||||||
|
try:
|
||||||
|
if not agent_registry:
|
||||||
|
raise HTTPException(status_code=503, detail="Agent registry not available")
|
||||||
|
|
||||||
|
# Create agent info
|
||||||
|
agent_info = create_agent_info(
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
agent_type=request.agent_type,
|
||||||
|
capabilities=request.capabilities,
|
||||||
|
services=request.services,
|
||||||
|
endpoints=request.endpoints
|
||||||
|
)
|
||||||
|
agent_info.metadata = request.metadata
|
||||||
|
|
||||||
|
# Register agent
|
||||||
|
success = await agent_registry.register_agent(agent_info)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"Agent {request.agent_id} registered successfully",
|
||||||
|
"agent_id": request.agent_id,
|
||||||
|
"registered_at": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to register agent")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error registering agent: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Agent discovery
|
||||||
|
@app.post("/agents/discover")
|
||||||
|
async def discover_agents(query: Dict[str, Any]):
|
||||||
|
"""Discover agents based on criteria"""
|
||||||
|
try:
|
||||||
|
if not agent_registry:
|
||||||
|
raise HTTPException(status_code=503, detail="Agent registry not available")
|
||||||
|
|
||||||
|
agents = await agent_registry.discover_agents(query)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"query": query,
|
||||||
|
"agents": [agent.to_dict() for agent in agents],
|
||||||
|
"count": len(agents),
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error discovering agents: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Get agent by ID
|
||||||
|
@app.get("/agents/{agent_id}")
|
||||||
|
async def get_agent(agent_id: str):
|
||||||
|
"""Get agent information by ID"""
|
||||||
|
try:
|
||||||
|
if not agent_registry:
|
||||||
|
raise HTTPException(status_code=503, detail="Agent registry not available")
|
||||||
|
|
||||||
|
agent = await agent_registry.get_agent_by_id(agent_id)
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
raise HTTPException(status_code=404, detail="Agent not found")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"agent": agent.to_dict(),
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agent: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Update agent status
|
||||||
|
@app.put("/agents/{agent_id}/status")
|
||||||
|
async def update_agent_status(agent_id: str, request: AgentStatusUpdate):
|
||||||
|
"""Update agent status"""
|
||||||
|
try:
|
||||||
|
if not agent_registry:
|
||||||
|
raise HTTPException(status_code=503, detail="Agent registry not available")
|
||||||
|
|
||||||
|
from .routing.agent_discovery import AgentStatus
|
||||||
|
|
||||||
|
success = await agent_registry.update_agent_status(
|
||||||
|
agent_id,
|
||||||
|
AgentStatus(request.status),
|
||||||
|
request.load_metrics
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"Agent {agent_id} status updated",
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"new_status": request.status,
|
||||||
|
"updated_at": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update agent status")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating agent status: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Submit task
|
||||||
|
@app.post("/tasks/submit")
|
||||||
|
async def submit_task(request: TaskSubmission, background_tasks: BackgroundTasks):
|
||||||
|
"""Submit a task for distribution"""
|
||||||
|
try:
|
||||||
|
if not task_distributor:
|
||||||
|
raise HTTPException(status_code=503, detail="Task distributor not available")
|
||||||
|
|
||||||
|
# Convert priority string to enum
|
||||||
|
try:
|
||||||
|
priority = TaskPriority(request.priority.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid priority: {request.priority}")
|
||||||
|
|
||||||
|
# Submit task
|
||||||
|
await task_distributor.submit_task(
|
||||||
|
request.task_data,
|
||||||
|
priority,
|
||||||
|
request.requirements
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": "Task submitted successfully",
|
||||||
|
"task_id": request.task_data.get("task_id", str(uuid.uuid4())),
|
||||||
|
"priority": request.priority,
|
||||||
|
"submitted_at": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error submitting task: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Get task status
|
||||||
|
@app.get("/tasks/status")
|
||||||
|
async def get_task_status():
|
||||||
|
"""Get task distribution statistics"""
|
||||||
|
try:
|
||||||
|
if not task_distributor:
|
||||||
|
raise HTTPException(status_code=503, detail="Task distributor not available")
|
||||||
|
|
||||||
|
stats = task_distributor.get_distribution_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"stats": stats,
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting task status: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Send message
|
||||||
|
@app.post("/messages/send")
|
||||||
|
async def send_message(request: MessageRequest):
|
||||||
|
"""Send message to agent"""
|
||||||
|
try:
|
||||||
|
if not communication_manager:
|
||||||
|
raise HTTPException(status_code=503, detail="Communication manager not available")
|
||||||
|
|
||||||
|
from .protocols.communication import AgentMessage, Priority
|
||||||
|
|
||||||
|
# Convert message type
|
||||||
|
try:
|
||||||
|
message_type = MessageType(request.message_type)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid message type: {request.message_type}")
|
||||||
|
|
||||||
|
# Convert priority
|
||||||
|
try:
|
||||||
|
priority = Priority(request.priority.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid priority: {request.priority}")
|
||||||
|
|
||||||
|
# Create message
|
||||||
|
message = AgentMessage(
|
||||||
|
sender_id="agent-coordinator",
|
||||||
|
receiver_id=request.receiver_id,
|
||||||
|
message_type=message_type,
|
||||||
|
priority=priority,
|
||||||
|
payload=request.payload
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send message
|
||||||
|
success = await communication_manager.send_message("hierarchical", message)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": "Message sent successfully",
|
||||||
|
"message_id": message.id,
|
||||||
|
"receiver_id": request.receiver_id,
|
||||||
|
"sent_at": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to send message")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending message: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Load balancer statistics
|
||||||
|
@app.get("/load-balancer/stats")
|
||||||
|
async def get_load_balancer_stats():
|
||||||
|
"""Get load balancer statistics"""
|
||||||
|
try:
|
||||||
|
if not load_balancer:
|
||||||
|
raise HTTPException(status_code=503, detail="Load balancer not available")
|
||||||
|
|
||||||
|
stats = load_balancer.get_load_balancing_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"stats": stats,
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting load balancer stats: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Registry statistics
|
||||||
|
@app.get("/registry/stats")
|
||||||
|
async def get_registry_stats():
|
||||||
|
"""Get agent registry statistics"""
|
||||||
|
try:
|
||||||
|
if not agent_registry:
|
||||||
|
raise HTTPException(status_code=503, detail="Agent registry not available")
|
||||||
|
|
||||||
|
stats = await agent_registry.get_registry_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"stats": stats,
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting registry stats: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Get agents by service
|
||||||
|
@app.get("/agents/service/{service}")
|
||||||
|
async def get_agents_by_service(service: str):
|
||||||
|
"""Get agents that provide a specific service"""
|
||||||
|
try:
|
||||||
|
if not agent_registry:
|
||||||
|
raise HTTPException(status_code=503, detail="Agent registry not available")
|
||||||
|
|
||||||
|
agents = await agent_registry.get_agents_by_service(service)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"service": service,
|
||||||
|
"agents": [agent.to_dict() for agent in agents],
|
||||||
|
"count": len(agents),
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agents by service: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Get agents by capability
|
||||||
|
@app.get("/agents/capability/{capability}")
|
||||||
|
async def get_agents_by_capability(capability: str):
|
||||||
|
"""Get agents that have a specific capability"""
|
||||||
|
try:
|
||||||
|
if not agent_registry:
|
||||||
|
raise HTTPException(status_code=503, detail="Agent registry not available")
|
||||||
|
|
||||||
|
agents = await agent_registry.get_agents_by_capability(capability)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"capability": capability,
|
||||||
|
"agents": [agent.to_dict() for agent in agents],
|
||||||
|
"count": len(agents),
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agents by capability: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Set load balancing strategy
|
||||||
|
@app.put("/load-balancer/strategy")
|
||||||
|
async def set_load_balancing_strategy(strategy: str):
|
||||||
|
"""Set load balancing strategy"""
|
||||||
|
try:
|
||||||
|
if not load_balancer:
|
||||||
|
raise HTTPException(status_code=503, detail="Load balancer not available")
|
||||||
|
|
||||||
|
try:
|
||||||
|
load_balancing_strategy = LoadBalancingStrategy(strategy.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid strategy: {strategy}")
|
||||||
|
|
||||||
|
load_balancer.set_strategy(load_balancing_strategy)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"Load balancing strategy set to {strategy}",
|
||||||
|
"strategy": strategy,
|
||||||
|
"updated_at": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error setting load balancing strategy: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Error handlers
|
||||||
|
@app.exception_handler(404)
|
||||||
|
async def not_found_handler(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={
|
||||||
|
"status": "error",
|
||||||
|
"message": "Resource not found",
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(500)
|
||||||
|
async def internal_error_handler(request, exc):
|
||||||
|
logger.error(f"Internal server error: {exc}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"status": "error",
|
||||||
|
"message": "Internal server error",
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main function
|
||||||
|
def main():
|
||||||
|
"""Main function to run the application"""
|
||||||
|
uvicorn.run(
|
||||||
|
"main:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=9001,
|
||||||
|
reload=True,
|
||||||
|
log_level="info"
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
443
apps/agent-coordinator/src/app/protocols/communication.py
Normal file
443
apps/agent-coordinator/src/app/protocols/communication.py
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
"""
|
||||||
|
Multi-Agent Communication Protocols for AITBC Agent Coordination
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Any, Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
import uuid
|
||||||
|
import websockets
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MessageType(str, Enum):
|
||||||
|
"""Message types for agent communication"""
|
||||||
|
COORDINATION = "coordination"
|
||||||
|
TASK_ASSIGNMENT = "task_assignment"
|
||||||
|
STATUS_UPDATE = "status_update"
|
||||||
|
DISCOVERY = "discovery"
|
||||||
|
HEARTBEAT = "heartbeat"
|
||||||
|
CONSENSUS = "consensus"
|
||||||
|
BROADCAST = "broadcast"
|
||||||
|
DIRECT = "direct"
|
||||||
|
PEER_TO_PEER = "peer_to_peer"
|
||||||
|
HIERARCHICAL = "hierarchical"
|
||||||
|
|
||||||
|
class Priority(str, Enum):
|
||||||
|
"""Message priority levels"""
|
||||||
|
LOW = "low"
|
||||||
|
NORMAL = "normal"
|
||||||
|
HIGH = "high"
|
||||||
|
CRITICAL = "critical"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentMessage:
|
||||||
|
"""Base message structure for agent communication"""
|
||||||
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
sender_id: str = ""
|
||||||
|
receiver_id: Optional[str] = None
|
||||||
|
message_type: MessageType = MessageType.DIRECT
|
||||||
|
priority: Priority = Priority.NORMAL
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
payload: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
correlation_id: Optional[str] = None
|
||||||
|
reply_to: Optional[str] = None
|
||||||
|
ttl: int = 300 # Time to live in seconds
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert message to dictionary"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"sender_id": self.sender_id,
|
||||||
|
"receiver_id": self.receiver_id,
|
||||||
|
"message_type": self.message_type.value,
|
||||||
|
"priority": self.priority.value,
|
||||||
|
"timestamp": self.timestamp.isoformat(),
|
||||||
|
"payload": self.payload,
|
||||||
|
"correlation_id": self.correlation_id,
|
||||||
|
"reply_to": self.reply_to,
|
||||||
|
"ttl": self.ttl
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "AgentMessage":
|
||||||
|
"""Create message from dictionary"""
|
||||||
|
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
|
||||||
|
data["message_type"] = MessageType(data["message_type"])
|
||||||
|
data["priority"] = Priority(data["priority"])
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
class CommunicationProtocol:
|
||||||
|
"""Base class for communication protocols"""
|
||||||
|
|
||||||
|
def __init__(self, agent_id: str):
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.message_handlers: Dict[MessageType, List[Callable]] = {}
|
||||||
|
self.active_connections: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def register_handler(self, message_type: MessageType, handler: Callable):
|
||||||
|
"""Register a message handler for a specific message type"""
|
||||||
|
if message_type not in self.message_handlers:
|
||||||
|
self.message_handlers[message_type] = []
|
||||||
|
self.message_handlers[message_type].append(handler)
|
||||||
|
|
||||||
|
async def send_message(self, message: AgentMessage) -> bool:
|
||||||
|
"""Send a message to another agent"""
|
||||||
|
try:
|
||||||
|
if message.receiver_id and message.receiver_id in self.active_connections:
|
||||||
|
await self._send_to_agent(message)
|
||||||
|
return True
|
||||||
|
elif message.message_type == MessageType.BROADCAST:
|
||||||
|
await self._broadcast_message(message)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Cannot send message to {message.receiver_id}: not connected")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending message: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def receive_message(self, message: AgentMessage):
|
||||||
|
"""Process received message"""
|
||||||
|
try:
|
||||||
|
# Check TTL
|
||||||
|
if self._is_message_expired(message):
|
||||||
|
logger.warning(f"Message {message.id} expired, ignoring")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle message
|
||||||
|
handlers = self.message_handlers.get(message.message_type, [])
|
||||||
|
for handler in handlers:
|
||||||
|
try:
|
||||||
|
await handler(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in message handler: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message: {e}")
|
||||||
|
|
||||||
|
def _is_message_expired(self, message: AgentMessage) -> bool:
|
||||||
|
"""Check if message has expired"""
|
||||||
|
age = (datetime.utcnow() - message.timestamp).total_seconds()
|
||||||
|
return age > message.ttl
|
||||||
|
|
||||||
|
async def _send_to_agent(self, message: AgentMessage):
|
||||||
|
"""Send message to specific agent"""
|
||||||
|
raise NotImplementedError("Subclasses must implement _send_to_agent")
|
||||||
|
|
||||||
|
async def _broadcast_message(self, message: AgentMessage):
|
||||||
|
"""Broadcast message to all connected agents"""
|
||||||
|
raise NotImplementedError("Subclasses must implement _broadcast_message")
|
||||||
|
|
||||||
|
class HierarchicalProtocol(CommunicationProtocol):
|
||||||
|
"""Hierarchical communication protocol (master-agent → sub-agents)"""
|
||||||
|
|
||||||
|
def __init__(self, agent_id: str, is_master: bool = False):
|
||||||
|
super().__init__(agent_id)
|
||||||
|
self.is_master = is_master
|
||||||
|
self.sub_agents: List[str] = []
|
||||||
|
self.master_agent: Optional[str] = None
|
||||||
|
|
||||||
|
async def add_sub_agent(self, agent_id: str):
|
||||||
|
"""Add a sub-agent to this master agent"""
|
||||||
|
if self.is_master:
|
||||||
|
self.sub_agents.append(agent_id)
|
||||||
|
logger.info(f"Added sub-agent {agent_id} to master {self.agent_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Agent {self.agent_id} is not a master, cannot add sub-agents")
|
||||||
|
|
||||||
|
async def send_to_sub_agents(self, message: AgentMessage):
|
||||||
|
"""Send message to all sub-agents"""
|
||||||
|
if not self.is_master:
|
||||||
|
logger.warning(f"Agent {self.agent_id} is not a master")
|
||||||
|
return
|
||||||
|
|
||||||
|
message.message_type = MessageType.HIERARCHICAL
|
||||||
|
for sub_agent_id in self.sub_agents:
|
||||||
|
message.receiver_id = sub_agent_id
|
||||||
|
await self.send_message(message)
|
||||||
|
|
||||||
|
async def send_to_master(self, message: AgentMessage):
|
||||||
|
"""Send message to master agent"""
|
||||||
|
if self.is_master:
|
||||||
|
logger.warning(f"Agent {self.agent_id} is a master, cannot send to master")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.master_agent:
|
||||||
|
message.receiver_id = self.master_agent
|
||||||
|
message.message_type = MessageType.HIERARCHICAL
|
||||||
|
await self.send_message(message)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Agent {self.agent_id} has no master agent")
|
||||||
|
|
||||||
|
class PeerToPeerProtocol(CommunicationProtocol):
|
||||||
|
"""Peer-to-peer communication protocol (agent ↔ agent)"""
|
||||||
|
|
||||||
|
def __init__(self, agent_id: str):
|
||||||
|
super().__init__(agent_id)
|
||||||
|
self.peers: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
async def add_peer(self, peer_id: str, connection_info: Dict[str, Any]):
|
||||||
|
"""Add a peer to the peer network"""
|
||||||
|
self.peers[peer_id] = connection_info
|
||||||
|
logger.info(f"Added peer {peer_id} to agent {self.agent_id}")
|
||||||
|
|
||||||
|
async def remove_peer(self, peer_id: str):
|
||||||
|
"""Remove a peer from the peer network"""
|
||||||
|
if peer_id in self.peers:
|
||||||
|
del self.peers[peer_id]
|
||||||
|
logger.info(f"Removed peer {peer_id} from agent {self.agent_id}")
|
||||||
|
|
||||||
|
async def send_to_peer(self, message: AgentMessage, peer_id: str):
|
||||||
|
"""Send message to specific peer"""
|
||||||
|
if peer_id not in self.peers:
|
||||||
|
logger.warning(f"Peer {peer_id} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
message.receiver_id = peer_id
|
||||||
|
message.message_type = MessageType.PEER_TO_PEER
|
||||||
|
return await self.send_message(message)
|
||||||
|
|
||||||
|
async def broadcast_to_peers(self, message: AgentMessage):
|
||||||
|
"""Broadcast message to all peers"""
|
||||||
|
message.message_type = MessageType.PEER_TO_PEER
|
||||||
|
for peer_id in self.peers:
|
||||||
|
message.receiver_id = peer_id
|
||||||
|
await self.send_message(message)
|
||||||
|
|
||||||
|
class BroadcastProtocol(CommunicationProtocol):
|
||||||
|
"""Broadcast communication protocol (agent → all agents)"""
|
||||||
|
|
||||||
|
def __init__(self, agent_id: str, broadcast_channel: str = "global"):
|
||||||
|
super().__init__(agent_id)
|
||||||
|
self.broadcast_channel = broadcast_channel
|
||||||
|
self.subscribers: List[str] = []
|
||||||
|
|
||||||
|
async def subscribe(self, agent_id: str):
|
||||||
|
"""Subscribe to broadcast channel"""
|
||||||
|
if agent_id not in self.subscribers:
|
||||||
|
self.subscribers.append(agent_id)
|
||||||
|
logger.info(f"Agent {agent_id} subscribed to {self.broadcast_channel}")
|
||||||
|
|
||||||
|
async def unsubscribe(self, agent_id: str):
|
||||||
|
"""Unsubscribe from broadcast channel"""
|
||||||
|
if agent_id in self.subscribers:
|
||||||
|
self.subscribers.remove(agent_id)
|
||||||
|
logger.info(f"Agent {agent_id} unsubscribed from {self.broadcast_channel}")
|
||||||
|
|
||||||
|
async def broadcast(self, message: AgentMessage):
|
||||||
|
"""Broadcast message to all subscribers"""
|
||||||
|
message.message_type = MessageType.BROADCAST
|
||||||
|
message.receiver_id = None # Broadcast to all
|
||||||
|
|
||||||
|
for subscriber_id in self.subscribers:
|
||||||
|
if subscriber_id != self.agent_id: # Don't send to self
|
||||||
|
message_copy = AgentMessage(**message.__dict__)
|
||||||
|
message_copy.receiver_id = subscriber_id
|
||||||
|
await self.send_message(message_copy)
|
||||||
|
|
||||||
|
class CommunicationManager:
|
||||||
|
"""Manages multiple communication protocols for an agent"""
|
||||||
|
|
||||||
|
def __init__(self, agent_id: str):
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.protocols: Dict[str, CommunicationProtocol] = {}
|
||||||
|
|
||||||
|
def add_protocol(self, name: str, protocol: CommunicationProtocol):
|
||||||
|
"""Add a communication protocol"""
|
||||||
|
self.protocols[name] = protocol
|
||||||
|
logger.info(f"Added protocol {name} to agent {self.agent_id}")
|
||||||
|
|
||||||
|
def get_protocol(self, name: str) -> Optional[CommunicationProtocol]:
|
||||||
|
"""Get a communication protocol by name"""
|
||||||
|
return self.protocols.get(name)
|
||||||
|
|
||||||
|
async def send_message(self, protocol_name: str, message: AgentMessage) -> bool:
|
||||||
|
"""Send message using specific protocol"""
|
||||||
|
protocol = self.get_protocol(protocol_name)
|
||||||
|
if protocol:
|
||||||
|
return await protocol.send_message(message)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def register_handler(self, protocol_name: str, message_type: MessageType, handler: Callable):
|
||||||
|
"""Register message handler for specific protocol"""
|
||||||
|
protocol = self.get_protocol(protocol_name)
|
||||||
|
if protocol:
|
||||||
|
await protocol.register_handler(message_type, handler)
|
||||||
|
else:
|
||||||
|
logger.error(f"Protocol {protocol_name} not found")
|
||||||
|
|
||||||
|
# Message templates for common operations
|
||||||
|
class MessageTemplates:
|
||||||
|
"""Pre-defined message templates"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_heartbeat(sender_id: str) -> AgentMessage:
|
||||||
|
"""Create heartbeat message"""
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=sender_id,
|
||||||
|
message_type=MessageType.HEARTBEAT,
|
||||||
|
priority=Priority.LOW,
|
||||||
|
payload={"timestamp": datetime.utcnow().isoformat()}
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_task_assignment(sender_id: str, receiver_id: str, task_data: Dict[str, Any]) -> AgentMessage:
|
||||||
|
"""Create task assignment message"""
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=sender_id,
|
||||||
|
receiver_id=receiver_id,
|
||||||
|
message_type=MessageType.TASK_ASSIGNMENT,
|
||||||
|
priority=Priority.NORMAL,
|
||||||
|
payload=task_data
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_status_update(sender_id: str, status_data: Dict[str, Any]) -> AgentMessage:
|
||||||
|
"""Create status update message"""
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=sender_id,
|
||||||
|
message_type=MessageType.STATUS_UPDATE,
|
||||||
|
priority=Priority.NORMAL,
|
||||||
|
payload=status_data
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_discovery(sender_id: str) -> AgentMessage:
|
||||||
|
"""Create discovery message"""
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=sender_id,
|
||||||
|
message_type=MessageType.DISCOVERY,
|
||||||
|
priority=Priority.NORMAL,
|
||||||
|
payload={"agent_id": sender_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_consensus_request(sender_id: str, proposal_data: Dict[str, Any]) -> AgentMessage:
|
||||||
|
"""Create consensus request message"""
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=sender_id,
|
||||||
|
message_type=MessageType.CONSENSUS,
|
||||||
|
priority=Priority.HIGH,
|
||||||
|
payload=proposal_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# WebSocket connection handler for real-time communication
|
||||||
|
class WebSocketHandler:
|
||||||
|
"""WebSocket handler for real-time agent communication"""
|
||||||
|
|
||||||
|
def __init__(self, communication_manager: CommunicationManager):
|
||||||
|
self.communication_manager = communication_manager
|
||||||
|
self.websocket_connections: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def handle_connection(self, websocket, agent_id: str):
|
||||||
|
"""Handle WebSocket connection from agent"""
|
||||||
|
self.websocket_connections[agent_id] = websocket
|
||||||
|
logger.info(f"WebSocket connection established for agent {agent_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in websocket:
|
||||||
|
data = json.loads(message)
|
||||||
|
agent_message = AgentMessage.from_dict(data)
|
||||||
|
await self.communication_manager.receive_message(agent_message)
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
logger.info(f"WebSocket connection closed for agent {agent_id}")
|
||||||
|
finally:
|
||||||
|
if agent_id in self.websocket_connections:
|
||||||
|
del self.websocket_connections[agent_id]
|
||||||
|
|
||||||
|
async def send_to_agent(self, agent_id: str, message: AgentMessage):
|
||||||
|
"""Send message to agent via WebSocket"""
|
||||||
|
if agent_id in self.websocket_connections:
|
||||||
|
websocket = self.websocket_connections[agent_id]
|
||||||
|
await websocket.send(json.dumps(message.to_dict()))
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def broadcast_message(self, message: AgentMessage):
|
||||||
|
"""Broadcast message to all connected agents"""
|
||||||
|
for websocket in self.websocket_connections.values():
|
||||||
|
await websocket.send(json.dumps(message.to_dict()))
|
||||||
|
|
||||||
|
# Redis-based message broker for scalable communication
|
||||||
|
class RedisMessageBroker:
|
||||||
|
"""Redis-based message broker for agent communication"""
|
||||||
|
|
||||||
|
def __init__(self, redis_url: str):
|
||||||
|
self.redis_url = redis_url
|
||||||
|
self.channels: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def publish_message(self, channel: str, message: AgentMessage):
|
||||||
|
"""Publish message to Redis channel"""
|
||||||
|
import redis.asyncio as redis
|
||||||
|
redis_client = redis.from_url(self.redis_url)
|
||||||
|
|
||||||
|
await redis_client.publish(channel, json.dumps(message.to_dict()))
|
||||||
|
await redis_client.close()
|
||||||
|
|
||||||
|
async def subscribe_to_channel(self, channel: str, handler: Callable):
|
||||||
|
"""Subscribe to Redis channel"""
|
||||||
|
import redis.asyncio as redis
|
||||||
|
redis_client = redis.from_url(self.redis_url)
|
||||||
|
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.subscribe(channel)
|
||||||
|
|
||||||
|
self.channels[channel] = {"pubsub": pubsub, "handler": handler}
|
||||||
|
|
||||||
|
# Start listening for messages
|
||||||
|
asyncio.create_task(self._listen_to_channel(channel, pubsub, handler))
|
||||||
|
|
||||||
|
async def _listen_to_channel(self, channel: str, pubsub: Any, handler: Callable):
|
||||||
|
"""Listen for messages on channel"""
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] == "message":
|
||||||
|
data = json.loads(message["data"])
|
||||||
|
agent_message = AgentMessage.from_dict(data)
|
||||||
|
await handler(agent_message)
|
||||||
|
|
||||||
|
# Factory function for creating communication protocols
|
||||||
|
def create_protocol(protocol_type: str, agent_id: str, **kwargs) -> CommunicationProtocol:
|
||||||
|
"""Factory function to create communication protocols"""
|
||||||
|
if protocol_type == "hierarchical":
|
||||||
|
return HierarchicalProtocol(agent_id, kwargs.get("is_master", False))
|
||||||
|
elif protocol_type == "peer_to_peer":
|
||||||
|
return PeerToPeerProtocol(agent_id)
|
||||||
|
elif protocol_type == "broadcast":
|
||||||
|
return BroadcastProtocol(agent_id, kwargs.get("broadcast_channel", "global"))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown protocol type: {protocol_type}")
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
async def example_usage():
|
||||||
|
"""Example of how to use the communication protocols"""
|
||||||
|
|
||||||
|
# Create communication manager
|
||||||
|
comm_manager = CommunicationManager("agent-001")
|
||||||
|
|
||||||
|
# Add protocols
|
||||||
|
hierarchical_protocol = create_protocol("hierarchical", "agent-001", is_master=True)
|
||||||
|
p2p_protocol = create_protocol("peer_to_peer", "agent-001")
|
||||||
|
broadcast_protocol = create_protocol("broadcast", "agent-001")
|
||||||
|
|
||||||
|
comm_manager.add_protocol("hierarchical", hierarchical_protocol)
|
||||||
|
comm_manager.add_protocol("peer_to_peer", p2p_protocol)
|
||||||
|
comm_manager.add_protocol("broadcast", broadcast_protocol)
|
||||||
|
|
||||||
|
# Register message handlers
|
||||||
|
async def handle_heartbeat(message: AgentMessage):
|
||||||
|
logger.info(f"Received heartbeat from {message.sender_id}")
|
||||||
|
|
||||||
|
await comm_manager.register_handler("hierarchical", MessageType.HEARTBEAT, handle_heartbeat)
|
||||||
|
|
||||||
|
# Send messages
|
||||||
|
heartbeat = MessageTemplates.create_heartbeat("agent-001")
|
||||||
|
await comm_manager.send_message("hierarchical", heartbeat)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(example_usage())
|
||||||
586
apps/agent-coordinator/src/app/protocols/message_types.py
Normal file
586
apps/agent-coordinator/src/app/protocols/message_types.py
Normal file
@@ -0,0 +1,586 @@
|
|||||||
|
"""
|
||||||
|
Message Types and Routing System for AITBC Agent Coordination
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Any, Callable, Union
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import uuid
|
||||||
|
import hashlib
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from .communication import AgentMessage, MessageType, Priority
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MessageStatus(str, Enum):
|
||||||
|
"""Message processing status"""
|
||||||
|
PENDING = "pending"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
EXPIRED = "expired"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
class RoutingStrategy(str, Enum):
|
||||||
|
"""Message routing strategies"""
|
||||||
|
ROUND_ROBIN = "round_robin"
|
||||||
|
LOAD_BALANCED = "load_balanced"
|
||||||
|
PRIORITY_BASED = "priority_based"
|
||||||
|
RANDOM = "random"
|
||||||
|
DIRECT = "direct"
|
||||||
|
BROADCAST = "broadcast"
|
||||||
|
|
||||||
|
class DeliveryMode(str, Enum):
|
||||||
|
"""Message delivery modes"""
|
||||||
|
FIRE_AND_FORGET = "fire_and_forget"
|
||||||
|
AT_LEAST_ONCE = "at_least_once"
|
||||||
|
EXACTLY_ONCE = "exactly_once"
|
||||||
|
PERSISTENT = "persistent"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoutingRule:
|
||||||
|
"""Routing rule for message processing"""
|
||||||
|
rule_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
name: str = ""
|
||||||
|
condition: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
action: str = "forward" # forward, transform, filter, route
|
||||||
|
target: Optional[str] = None
|
||||||
|
priority: int = 0
|
||||||
|
enabled: bool = True
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
def matches(self, message: AgentMessage) -> bool:
|
||||||
|
"""Check if message matches routing rule conditions"""
|
||||||
|
for key, value in self.condition.items():
|
||||||
|
message_value = getattr(message, key, None)
|
||||||
|
if message_value != value:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
class TaskMessage(BaseModel):
|
||||||
|
"""Task-specific message structure"""
|
||||||
|
task_id: str = Field(..., description="Unique task identifier")
|
||||||
|
task_type: str = Field(..., description="Type of task")
|
||||||
|
task_data: Dict[str, Any] = Field(default_factory=dict, description="Task data")
|
||||||
|
requirements: Dict[str, Any] = Field(default_factory=dict, description="Task requirements")
|
||||||
|
deadline: Optional[datetime] = Field(None, description="Task deadline")
|
||||||
|
priority: Priority = Field(Priority.NORMAL, description="Task priority")
|
||||||
|
assigned_agent: Optional[str] = Field(None, description="Assigned agent ID")
|
||||||
|
status: str = Field("pending", description="Task status")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
@validator('deadline')
|
||||||
|
def validate_deadline(cls, v):
|
||||||
|
if v and v < datetime.utcnow():
|
||||||
|
raise ValueError("Deadline cannot be in the past")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class CoordinationMessage(BaseModel):
|
||||||
|
"""Coordination-specific message structure"""
|
||||||
|
coordination_id: str = Field(..., description="Unique coordination identifier")
|
||||||
|
coordination_type: str = Field(..., description="Type of coordination")
|
||||||
|
participants: List[str] = Field(default_factory=list, description="Participating agents")
|
||||||
|
coordination_data: Dict[str, Any] = Field(default_factory=dict, description="Coordination data")
|
||||||
|
decision_deadline: Optional[datetime] = Field(None, description="Decision deadline")
|
||||||
|
consensus_threshold: float = Field(0.5, description="Consensus threshold")
|
||||||
|
status: str = Field("pending", description="Coordination status")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
class StatusMessage(BaseModel):
|
||||||
|
"""Status update message structure"""
|
||||||
|
agent_id: str = Field(..., description="Agent ID")
|
||||||
|
status_type: str = Field(..., description="Type of status")
|
||||||
|
status_data: Dict[str, Any] = Field(default_factory=dict, description="Status data")
|
||||||
|
health_score: float = Field(1.0, description="Agent health score")
|
||||||
|
load_metrics: Dict[str, float] = Field(default_factory=dict, description="Load metrics")
|
||||||
|
capabilities: List[str] = Field(default_factory=list, description="Agent capabilities")
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
class DiscoveryMessage(BaseModel):
|
||||||
|
"""Agent discovery message structure"""
|
||||||
|
agent_id: str = Field(..., description="Agent ID")
|
||||||
|
agent_type: str = Field(..., description="Type of agent")
|
||||||
|
capabilities: List[str] = Field(default_factory=list, description="Agent capabilities")
|
||||||
|
services: List[str] = Field(default_factory=list, description="Available services")
|
||||||
|
endpoints: Dict[str, str] = Field(default_factory=dict, description="Service endpoints")
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
class ConsensusMessage(BaseModel):
|
||||||
|
"""Consensus message structure"""
|
||||||
|
consensus_id: str = Field(..., description="Unique consensus identifier")
|
||||||
|
proposal: Dict[str, Any] = Field(..., description="Consensus proposal")
|
||||||
|
voting_options: List[Dict[str, Any]] = Field(default_factory=list, description="Voting options")
|
||||||
|
votes: Dict[str, str] = Field(default_factory=dict, description="Agent votes")
|
||||||
|
voting_deadline: datetime = Field(..., description="Voting deadline")
|
||||||
|
consensus_algorithm: str = Field("majority", description="Consensus algorithm")
|
||||||
|
status: str = Field("pending", description="Consensus status")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
class MessageRouter:
|
||||||
|
"""Advanced message routing system"""
|
||||||
|
|
||||||
|
def __init__(self, agent_id: str):
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.routing_rules: List[RoutingRule] = []
|
||||||
|
self.message_queue: asyncio.Queue = asyncio.Queue(maxsize=10000)
|
||||||
|
self.dead_letter_queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
|
||||||
|
self.routing_stats: Dict[str, Any] = {
|
||||||
|
"messages_processed": 0,
|
||||||
|
"messages_failed": 0,
|
||||||
|
"messages_expired": 0,
|
||||||
|
"routing_time_total": 0.0
|
||||||
|
}
|
||||||
|
self.active_routes: Dict[str, str] = {} # message_id -> route
|
||||||
|
self.load_balancer_index = 0
|
||||||
|
|
||||||
|
def add_routing_rule(self, rule: RoutingRule):
|
||||||
|
"""Add a routing rule"""
|
||||||
|
self.routing_rules.append(rule)
|
||||||
|
# Sort by priority (higher priority first)
|
||||||
|
self.routing_rules.sort(key=lambda r: r.priority, reverse=True)
|
||||||
|
logger.info(f"Added routing rule: {rule.name}")
|
||||||
|
|
||||||
|
def remove_routing_rule(self, rule_id: str):
|
||||||
|
"""Remove a routing rule"""
|
||||||
|
self.routing_rules = [r for r in self.routing_rules if r.rule_id != rule_id]
|
||||||
|
logger.info(f"Removed routing rule: {rule_id}")
|
||||||
|
|
||||||
|
async def route_message(self, message: AgentMessage) -> Optional[str]:
|
||||||
|
"""Route message based on routing rules"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if message is expired
|
||||||
|
if self._is_message_expired(message):
|
||||||
|
await self.dead_letter_queue.put(message)
|
||||||
|
self.routing_stats["messages_expired"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Apply routing rules
|
||||||
|
for rule in self.routing_rules:
|
||||||
|
if rule.enabled and rule.matches(message):
|
||||||
|
route = await self._apply_routing_rule(rule, message)
|
||||||
|
if route:
|
||||||
|
self.active_routes[message.id] = route
|
||||||
|
self.routing_stats["messages_processed"] += 1
|
||||||
|
return route
|
||||||
|
|
||||||
|
# Default routing
|
||||||
|
default_route = await self._default_routing(message)
|
||||||
|
if default_route:
|
||||||
|
self.active_routes[message.id] = default_route
|
||||||
|
self.routing_stats["messages_processed"] += 1
|
||||||
|
return default_route
|
||||||
|
|
||||||
|
# No route found
|
||||||
|
await self.dead_letter_queue.put(message)
|
||||||
|
self.routing_stats["messages_failed"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error routing message {message.id}: {e}")
|
||||||
|
await self.dead_letter_queue.put(message)
|
||||||
|
self.routing_stats["messages_failed"] += 1
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
routing_time = (datetime.utcnow() - start_time).total_seconds()
|
||||||
|
self.routing_stats["routing_time_total"] += routing_time
|
||||||
|
|
||||||
|
async def _apply_routing_rule(self, rule: RoutingRule, message: AgentMessage) -> Optional[str]:
|
||||||
|
"""Apply a specific routing rule"""
|
||||||
|
if rule.action == "forward":
|
||||||
|
return rule.target
|
||||||
|
elif rule.action == "transform":
|
||||||
|
return await self._transform_message(message, rule)
|
||||||
|
elif rule.action == "filter":
|
||||||
|
return await self._filter_message(message, rule)
|
||||||
|
elif rule.action == "route":
|
||||||
|
return await self._custom_routing(message, rule)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _transform_message(self, message: AgentMessage, rule: RoutingRule) -> Optional[str]:
|
||||||
|
"""Transform message based on rule"""
|
||||||
|
# Apply transformation logic here
|
||||||
|
transformed_message = AgentMessage(
|
||||||
|
sender_id=message.sender_id,
|
||||||
|
receiver_id=message.receiver_id,
|
||||||
|
message_type=message.message_type,
|
||||||
|
priority=message.priority,
|
||||||
|
payload={**message.payload, **rule.condition.get("transform", {})}
|
||||||
|
)
|
||||||
|
# Route transformed message
|
||||||
|
return await self._default_routing(transformed_message)
|
||||||
|
|
||||||
|
async def _filter_message(self, message: AgentMessage, rule: RoutingRule) -> Optional[str]:
|
||||||
|
"""Filter message based on rule"""
|
||||||
|
filter_condition = rule.condition.get("filter", {})
|
||||||
|
for key, value in filter_condition.items():
|
||||||
|
if message.payload.get(key) != value:
|
||||||
|
return None # Filter out message
|
||||||
|
return await self._default_routing(message)
|
||||||
|
|
||||||
|
async def _custom_routing(self, message: AgentMessage, rule: RoutingRule) -> Optional[str]:
|
||||||
|
"""Custom routing logic"""
|
||||||
|
# Implement custom routing logic here
|
||||||
|
return rule.target
|
||||||
|
|
||||||
|
async def _default_routing(self, message: AgentMessage) -> Optional[str]:
|
||||||
|
"""Default message routing"""
|
||||||
|
if message.receiver_id:
|
||||||
|
return message.receiver_id
|
||||||
|
elif message.message_type == MessageType.BROADCAST:
|
||||||
|
return "broadcast"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_message_expired(self, message: AgentMessage) -> bool:
|
||||||
|
"""Check if message is expired"""
|
||||||
|
age = (datetime.utcnow() - message.timestamp).total_seconds()
|
||||||
|
return age > message.ttl
|
||||||
|
|
||||||
|
async def get_routing_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get routing statistics"""
|
||||||
|
total_messages = self.routing_stats["messages_processed"]
|
||||||
|
avg_routing_time = (
|
||||||
|
self.routing_stats["routing_time_total"] / total_messages
|
||||||
|
if total_messages > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**self.routing_stats,
|
||||||
|
"avg_routing_time": avg_routing_time,
|
||||||
|
"active_routes": len(self.active_routes),
|
||||||
|
"queue_size": self.message_queue.qsize(),
|
||||||
|
"dead_letter_queue_size": self.dead_letter_queue.qsize()
|
||||||
|
}
|
||||||
|
|
||||||
|
class LoadBalancer:
|
||||||
|
"""Load balancer for message distribution"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.agent_loads: Dict[str, float] = {}
|
||||||
|
self.agent_weights: Dict[str, float] = {}
|
||||||
|
self.last_updated = datetime.utcnow()
|
||||||
|
|
||||||
|
def update_agent_load(self, agent_id: str, load: float):
|
||||||
|
"""Update agent load information"""
|
||||||
|
self.agent_loads[agent_id] = load
|
||||||
|
self.last_updated = datetime.utcnow()
|
||||||
|
|
||||||
|
def set_agent_weight(self, agent_id: str, weight: float):
|
||||||
|
"""Set agent weight for load balancing"""
|
||||||
|
self.agent_weights[agent_id] = weight
|
||||||
|
|
||||||
|
def select_agent(self, available_agents: List[str], strategy: RoutingStrategy = RoutingStrategy.LOAD_BALANCED) -> Optional[str]:
|
||||||
|
"""Select agent based on load balancing strategy"""
|
||||||
|
if not available_agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if strategy == RoutingStrategy.ROUND_ROBIN:
|
||||||
|
return self._round_robin_selection(available_agents)
|
||||||
|
elif strategy == RoutingStrategy.LOAD_BALANCED:
|
||||||
|
return self._load_balanced_selection(available_agents)
|
||||||
|
elif strategy == RoutingStrategy.PRIORITY_BASED:
|
||||||
|
return self._priority_based_selection(available_agents)
|
||||||
|
elif strategy == RoutingStrategy.RANDOM:
|
||||||
|
return self._random_selection(available_agents)
|
||||||
|
else:
|
||||||
|
return available_agents[0]
|
||||||
|
|
||||||
|
def _round_robin_selection(self, agents: List[str]) -> str:
|
||||||
|
"""Round-robin agent selection"""
|
||||||
|
agent = agents[self.load_balancer_index % len(agents)]
|
||||||
|
self.load_balancer_index += 1
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def _load_balanced_selection(self, agents: List[str]) -> str:
|
||||||
|
"""Load-balanced agent selection"""
|
||||||
|
# Select agent with lowest load
|
||||||
|
min_load = float('inf')
|
||||||
|
selected_agent = None
|
||||||
|
|
||||||
|
for agent in agents:
|
||||||
|
load = self.agent_loads.get(agent, 0.0)
|
||||||
|
weight = self.agent_weights.get(agent, 1.0)
|
||||||
|
weighted_load = load / weight
|
||||||
|
|
||||||
|
if weighted_load < min_load:
|
||||||
|
min_load = weighted_load
|
||||||
|
selected_agent = agent
|
||||||
|
|
||||||
|
return selected_agent or agents[0]
|
||||||
|
|
||||||
|
def _priority_based_selection(self, agents: List[str]) -> str:
|
||||||
|
"""Priority-based agent selection"""
|
||||||
|
# Sort by weight (higher weight = higher priority)
|
||||||
|
weighted_agents = sorted(
|
||||||
|
agents,
|
||||||
|
key=lambda a: self.agent_weights.get(a, 1.0),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
return weighted_agents[0]
|
||||||
|
|
||||||
|
def _random_selection(self, agents: List[str]) -> str:
|
||||||
|
"""Random agent selection"""
|
||||||
|
import random
|
||||||
|
return random.choice(agents)
|
||||||
|
|
||||||
|
class MessageQueue:
|
||||||
|
"""Advanced message queue with priority and persistence"""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int = 10000):
|
||||||
|
self.max_size = max_size
|
||||||
|
self.queues: Dict[Priority, asyncio.Queue] = {
|
||||||
|
Priority.CRITICAL: asyncio.Queue(maxsize=max_size // 4),
|
||||||
|
Priority.HIGH: asyncio.Queue(maxsize=max_size // 4),
|
||||||
|
Priority.NORMAL: asyncio.Queue(maxsize // 2),
|
||||||
|
Priority.LOW: asyncio.Queue(maxsize // 4)
|
||||||
|
}
|
||||||
|
self.message_store: Dict[str, AgentMessage] = {}
|
||||||
|
self.delivery_confirmations: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
async def enqueue(self, message: AgentMessage, delivery_mode: DeliveryMode = DeliveryMode.AT_LEAST_ONCE) -> bool:
|
||||||
|
"""Enqueue message with priority"""
|
||||||
|
try:
|
||||||
|
# Store message for persistence
|
||||||
|
if delivery_mode in [DeliveryMode.AT_LEAST_ONCE, DeliveryMode.EXACTLY_ONCE, DeliveryMode.PERSISTENT]:
|
||||||
|
self.message_store[message.id] = message
|
||||||
|
|
||||||
|
# Add to appropriate priority queue
|
||||||
|
queue = self.queues[message.priority]
|
||||||
|
await queue.put(message)
|
||||||
|
|
||||||
|
logger.debug(f"Enqueued message {message.id} with priority {message.priority}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.error(f"Queue full, cannot enqueue message {message.id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def dequeue(self) -> Optional[AgentMessage]:
|
||||||
|
"""Dequeue message with priority order"""
|
||||||
|
# Check queues in priority order
|
||||||
|
for priority in [Priority.CRITICAL, Priority.HIGH, Priority.NORMAL, Priority.LOW]:
|
||||||
|
queue = self.queues[priority]
|
||||||
|
try:
|
||||||
|
message = queue.get_nowait()
|
||||||
|
logger.debug(f"Dequeued message {message.id} with priority {priority}")
|
||||||
|
return message
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def confirm_delivery(self, message_id: str):
|
||||||
|
"""Confirm message delivery"""
|
||||||
|
self.delivery_confirmations[message_id] = True
|
||||||
|
|
||||||
|
# Clean up if exactly once delivery
|
||||||
|
if message_id in self.message_store:
|
||||||
|
del self.message_store[message_id]
|
||||||
|
|
||||||
|
def get_queue_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get queue statistics"""
|
||||||
|
return {
|
||||||
|
"queue_sizes": {
|
||||||
|
priority.value: queue.qsize()
|
||||||
|
for priority, queue in self.queues.items()
|
||||||
|
},
|
||||||
|
"stored_messages": len(self.message_store),
|
||||||
|
"delivery_confirmations": len(self.delivery_confirmations),
|
||||||
|
"max_size": self.max_size
|
||||||
|
}
|
||||||
|
|
||||||
|
class MessageProcessor:
|
||||||
|
"""Message processor with async handling"""
|
||||||
|
|
||||||
|
def __init__(self, agent_id: str):
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.router = MessageRouter(agent_id)
|
||||||
|
self.load_balancer = LoadBalancer()
|
||||||
|
self.message_queue = MessageQueue()
|
||||||
|
self.processors: Dict[str, Callable] = {}
|
||||||
|
self.processing_stats: Dict[str, Any] = {
|
||||||
|
"messages_processed": 0,
|
||||||
|
"processing_time_total": 0.0,
|
||||||
|
"errors": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def register_processor(self, message_type: MessageType, processor: Callable):
|
||||||
|
"""Register message processor"""
|
||||||
|
self.processors[message_type.value] = processor
|
||||||
|
logger.info(f"Registered processor for {message_type.value}")
|
||||||
|
|
||||||
|
async def process_message(self, message: AgentMessage) -> bool:
|
||||||
|
"""Process a message"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Route message
|
||||||
|
route = await self.router.route_message(message)
|
||||||
|
if not route:
|
||||||
|
logger.warning(f"No route found for message {message.id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Process message
|
||||||
|
processor = self.processors.get(message.message_type.value)
|
||||||
|
if processor:
|
||||||
|
await processor(message)
|
||||||
|
else:
|
||||||
|
logger.warning(f"No processor found for {message.message_type.value}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Update stats
|
||||||
|
self.processing_stats["messages_processed"] += 1
|
||||||
|
processing_time = (datetime.utcnow() - start_time).total_seconds()
|
||||||
|
self.processing_stats["processing_time_total"] += processing_time
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message {message.id}: {e}")
|
||||||
|
self.processing_stats["errors"] += 1
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def start_processing(self):
|
||||||
|
"""Start message processing loop"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Dequeue message
|
||||||
|
message = await self.message_queue.dequeue()
|
||||||
|
if message:
|
||||||
|
await self.process_message(message)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.01) # Small delay if no messages
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in processing loop: {e}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
def get_processing_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get processing statistics"""
|
||||||
|
total_processed = self.processing_stats["messages_processed"]
|
||||||
|
avg_processing_time = (
|
||||||
|
self.processing_stats["processing_time_total"] / total_processed
|
||||||
|
if total_processed > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**self.processing_stats,
|
||||||
|
"avg_processing_time": avg_processing_time,
|
||||||
|
"queue_stats": self.message_queue.get_queue_stats(),
|
||||||
|
"routing_stats": self.router.get_routing_stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Factory functions for creating message types
|
||||||
|
def create_task_message(sender_id: str, receiver_id: str, task_type: str, task_data: Dict[str, Any]) -> AgentMessage:
|
||||||
|
"""Create a task message"""
|
||||||
|
task_msg = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
task_type=task_type,
|
||||||
|
task_data=task_data
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=sender_id,
|
||||||
|
receiver_id=receiver_id,
|
||||||
|
message_type=MessageType.TASK_ASSIGNMENT,
|
||||||
|
payload=task_msg.dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_coordination_message(sender_id: str, coordination_type: str, participants: List[str], data: Dict[str, Any]) -> AgentMessage:
|
||||||
|
"""Create a coordination message"""
|
||||||
|
coord_msg = CoordinationMessage(
|
||||||
|
coordination_id=str(uuid.uuid4()),
|
||||||
|
coordination_type=coordination_type,
|
||||||
|
participants=participants,
|
||||||
|
coordination_data=data
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=sender_id,
|
||||||
|
message_type=MessageType.COORDINATION,
|
||||||
|
payload=coord_msg.dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_status_message(agent_id: str, status_type: str, status_data: Dict[str, Any]) -> AgentMessage:
|
||||||
|
"""Create a status message"""
|
||||||
|
status_msg = StatusMessage(
|
||||||
|
agent_id=agent_id,
|
||||||
|
status_type=status_type,
|
||||||
|
status_data=status_data
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=agent_id,
|
||||||
|
message_type=MessageType.STATUS_UPDATE,
|
||||||
|
payload=status_msg.dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_discovery_message(agent_id: str, agent_type: str, capabilities: List[str], services: List[str]) -> AgentMessage:
|
||||||
|
"""Create a discovery message"""
|
||||||
|
discovery_msg = DiscoveryMessage(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
capabilities=capabilities,
|
||||||
|
services=services
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=agent_id,
|
||||||
|
message_type=MessageType.DISCOVERY,
|
||||||
|
payload=discovery_msg.dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_consensus_message(sender_id: str, proposal: Dict[str, Any], voting_options: List[Dict[str, Any]], deadline: datetime) -> AgentMessage:
|
||||||
|
"""Create a consensus message"""
|
||||||
|
consensus_msg = ConsensusMessage(
|
||||||
|
consensus_id=str(uuid.uuid4()),
|
||||||
|
proposal=proposal,
|
||||||
|
voting_options=voting_options,
|
||||||
|
voting_deadline=deadline
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentMessage(
|
||||||
|
sender_id=sender_id,
|
||||||
|
message_type=MessageType.CONSENSUS,
|
||||||
|
payload=consensus_msg.dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
async def example_usage():
|
||||||
|
"""Example of how to use the message routing system"""
|
||||||
|
|
||||||
|
# Create message processor
|
||||||
|
processor = MessageProcessor("agent-001")
|
||||||
|
|
||||||
|
# Register processors
|
||||||
|
async def process_task(message: AgentMessage):
|
||||||
|
task_data = TaskMessage(**message.payload)
|
||||||
|
logger.info(f"Processing task: {task_data.task_id}")
|
||||||
|
|
||||||
|
processor.register_processor(MessageType.TASK_ASSIGNMENT, process_task)
|
||||||
|
|
||||||
|
# Create and route message
|
||||||
|
task_message = create_task_message(
|
||||||
|
sender_id="agent-001",
|
||||||
|
receiver_id="agent-002",
|
||||||
|
task_type="data_processing",
|
||||||
|
task_data={"input": "test_data"}
|
||||||
|
)
|
||||||
|
|
||||||
|
await processor.message_queue.enqueue(task_message)
|
||||||
|
|
||||||
|
# Start processing (in real implementation, this would run in background)
|
||||||
|
# await processor.start_processing()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(example_usage())
|
||||||
641
apps/agent-coordinator/src/app/routing/agent_discovery.py
Normal file
641
apps/agent-coordinator/src/app/routing/agent_discovery.py
Normal file
@@ -0,0 +1,641 @@
|
|||||||
|
"""
|
||||||
|
Agent Discovery and Registration System for AITBC Agent Coordination
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Set, Callable, Any
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import uuid
|
||||||
|
import hashlib
|
||||||
|
from enum import Enum
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..protocols.message_types import DiscoveryMessage, create_discovery_message
|
||||||
|
from ..protocols.communication import AgentMessage, MessageType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class AgentStatus(str, Enum):
|
||||||
|
"""Agent status enumeration"""
|
||||||
|
ACTIVE = "active"
|
||||||
|
INACTIVE = "inactive"
|
||||||
|
BUSY = "busy"
|
||||||
|
MAINTENANCE = "maintenance"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
class AgentType(str, Enum):
|
||||||
|
"""Agent type enumeration"""
|
||||||
|
COORDINATOR = "coordinator"
|
||||||
|
WORKER = "worker"
|
||||||
|
SPECIALIST = "specialist"
|
||||||
|
MONITOR = "monitor"
|
||||||
|
GATEWAY = "gateway"
|
||||||
|
ORCHESTRATOR = "orchestrator"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentInfo:
|
||||||
|
"""Agent information structure"""
|
||||||
|
agent_id: str
|
||||||
|
agent_type: AgentType
|
||||||
|
status: AgentStatus
|
||||||
|
capabilities: List[str]
|
||||||
|
services: List[str]
|
||||||
|
endpoints: Dict[str, str]
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
last_heartbeat: datetime
|
||||||
|
registration_time: datetime
|
||||||
|
load_metrics: Dict[str, float] = field(default_factory=dict)
|
||||||
|
health_score: float = 1.0
|
||||||
|
version: str = "1.0.0"
|
||||||
|
tags: Set[str] = field(default_factory=set)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary"""
|
||||||
|
return {
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"agent_type": self.agent_type.value,
|
||||||
|
"status": self.status.value,
|
||||||
|
"capabilities": self.capabilities,
|
||||||
|
"services": self.services,
|
||||||
|
"endpoints": self.endpoints,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"last_heartbeat": self.last_heartbeat.isoformat(),
|
||||||
|
"registration_time": self.registration_time.isoformat(),
|
||||||
|
"load_metrics": self.load_metrics,
|
||||||
|
"health_score": self.health_score,
|
||||||
|
"version": self.version,
|
||||||
|
"tags": list(self.tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "AgentInfo":
|
||||||
|
"""Create from dictionary"""
|
||||||
|
data["agent_type"] = AgentType(data["agent_type"])
|
||||||
|
data["status"] = AgentStatus(data["status"])
|
||||||
|
data["last_heartbeat"] = datetime.fromisoformat(data["last_heartbeat"])
|
||||||
|
data["registration_time"] = datetime.fromisoformat(data["registration_time"])
|
||||||
|
data["tags"] = set(data.get("tags", []))
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
"""Central agent registry for discovery and management"""
|
||||||
|
|
||||||
|
def __init__(self, redis_url: str = "redis://localhost:6379/1"):
|
||||||
|
self.redis_url = redis_url
|
||||||
|
self.redis_client: Optional[redis.Redis] = None
|
||||||
|
self.agents: Dict[str, AgentInfo] = {}
|
||||||
|
self.service_index: Dict[str, Set[str]] = {} # service -> agent_ids
|
||||||
|
self.capability_index: Dict[str, Set[str]] = {} # capability -> agent_ids
|
||||||
|
self.type_index: Dict[AgentType, Set[str]] = {} # agent_type -> agent_ids
|
||||||
|
self.heartbeat_interval = 30 # seconds
|
||||||
|
self.cleanup_interval = 60 # seconds
|
||||||
|
self.max_heartbeat_age = 120 # seconds
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the registry service"""
|
||||||
|
self.redis_client = redis.from_url(self.redis_url)
|
||||||
|
|
||||||
|
# Load existing agents from Redis
|
||||||
|
await self._load_agents_from_redis()
|
||||||
|
|
||||||
|
# Start background tasks
|
||||||
|
asyncio.create_task(self._heartbeat_monitor())
|
||||||
|
asyncio.create_task(self._cleanup_inactive_agents())
|
||||||
|
|
||||||
|
logger.info("Agent registry started")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the registry service"""
|
||||||
|
if self.redis_client:
|
||||||
|
await self.redis_client.close()
|
||||||
|
logger.info("Agent registry stopped")
|
||||||
|
|
||||||
|
async def register_agent(self, agent_info: AgentInfo) -> bool:
|
||||||
|
"""Register a new agent"""
|
||||||
|
try:
|
||||||
|
# Add to local registry
|
||||||
|
self.agents[agent_info.agent_id] = agent_info
|
||||||
|
|
||||||
|
# Update indexes
|
||||||
|
self._update_indexes(agent_info)
|
||||||
|
|
||||||
|
# Save to Redis
|
||||||
|
await self._save_agent_to_redis(agent_info)
|
||||||
|
|
||||||
|
# Publish registration event
|
||||||
|
await self._publish_agent_event("agent_registered", agent_info)
|
||||||
|
|
||||||
|
logger.info(f"Agent {agent_info.agent_id} registered successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error registering agent {agent_info.agent_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def unregister_agent(self, agent_id: str) -> bool:
|
||||||
|
"""Unregister an agent"""
|
||||||
|
try:
|
||||||
|
if agent_id not in self.agents:
|
||||||
|
logger.warning(f"Agent {agent_id} not found for unregistration")
|
||||||
|
return False
|
||||||
|
|
||||||
|
agent_info = self.agents[agent_id]
|
||||||
|
|
||||||
|
# Remove from local registry
|
||||||
|
del self.agents[agent_id]
|
||||||
|
|
||||||
|
# Update indexes
|
||||||
|
self._remove_from_indexes(agent_info)
|
||||||
|
|
||||||
|
# Remove from Redis
|
||||||
|
await self._remove_agent_from_redis(agent_id)
|
||||||
|
|
||||||
|
# Publish unregistration event
|
||||||
|
await self._publish_agent_event("agent_unregistered", agent_info)
|
||||||
|
|
||||||
|
logger.info(f"Agent {agent_id} unregistered successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error unregistering agent {agent_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def update_agent_status(self, agent_id: str, status: AgentStatus, load_metrics: Optional[Dict[str, float]] = None) -> bool:
|
||||||
|
"""Update agent status and metrics"""
|
||||||
|
try:
|
||||||
|
if agent_id not in self.agents:
|
||||||
|
logger.warning(f"Agent {agent_id} not found for status update")
|
||||||
|
return False
|
||||||
|
|
||||||
|
agent_info = self.agents[agent_id]
|
||||||
|
agent_info.status = status
|
||||||
|
agent_info.last_heartbeat = datetime.utcnow()
|
||||||
|
|
||||||
|
if load_metrics:
|
||||||
|
agent_info.load_metrics.update(load_metrics)
|
||||||
|
|
||||||
|
# Update health score
|
||||||
|
agent_info.health_score = self._calculate_health_score(agent_info)
|
||||||
|
|
||||||
|
# Save to Redis
|
||||||
|
await self._save_agent_to_redis(agent_info)
|
||||||
|
|
||||||
|
# Publish status update event
|
||||||
|
await self._publish_agent_event("agent_status_updated", agent_info)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating agent status {agent_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def update_agent_heartbeat(self, agent_id: str) -> bool:
|
||||||
|
"""Update agent heartbeat"""
|
||||||
|
try:
|
||||||
|
if agent_id not in self.agents:
|
||||||
|
logger.warning(f"Agent {agent_id} not found for heartbeat")
|
||||||
|
return False
|
||||||
|
|
||||||
|
agent_info = self.agents[agent_id]
|
||||||
|
agent_info.last_heartbeat = datetime.utcnow()
|
||||||
|
|
||||||
|
# Update health score
|
||||||
|
agent_info.health_score = self._calculate_health_score(agent_info)
|
||||||
|
|
||||||
|
# Save to Redis
|
||||||
|
await self._save_agent_to_redis(agent_info)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating heartbeat for {agent_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def discover_agents(self, query: Dict[str, Any]) -> List[AgentInfo]:
|
||||||
|
"""Discover agents based on query criteria"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start with all agents
|
||||||
|
candidate_agents = list(self.agents.values())
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if "agent_type" in query:
|
||||||
|
agent_type = AgentType(query["agent_type"])
|
||||||
|
candidate_agents = [a for a in candidate_agents if a.agent_type == agent_type]
|
||||||
|
|
||||||
|
if "status" in query:
|
||||||
|
status = AgentStatus(query["status"])
|
||||||
|
candidate_agents = [a for a in candidate_agents if a.status == status]
|
||||||
|
|
||||||
|
if "capabilities" in query:
|
||||||
|
required_capabilities = set(query["capabilities"])
|
||||||
|
candidate_agents = [a for a in candidate_agents if required_capabilities.issubset(a.capabilities)]
|
||||||
|
|
||||||
|
if "services" in query:
|
||||||
|
required_services = set(query["services"])
|
||||||
|
candidate_agents = [a for a in candidate_agents if required_services.issubset(a.services)]
|
||||||
|
|
||||||
|
if "tags" in query:
|
||||||
|
required_tags = set(query["tags"])
|
||||||
|
candidate_agents = [a for a in candidate_agents if required_tags.issubset(a.tags)]
|
||||||
|
|
||||||
|
if "min_health_score" in query:
|
||||||
|
min_score = query["min_health_score"]
|
||||||
|
candidate_agents = [a for a in candidate_agents if a.health_score >= min_score]
|
||||||
|
|
||||||
|
# Sort by health score (highest first)
|
||||||
|
results = sorted(candidate_agents, key=lambda a: a.health_score, reverse=True)
|
||||||
|
|
||||||
|
# Limit results if specified
|
||||||
|
if "limit" in query:
|
||||||
|
results = results[:query["limit"]]
|
||||||
|
|
||||||
|
logger.info(f"Discovered {len(results)} agents for query: {query}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error discovering agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_agent_by_id(self, agent_id: str) -> Optional[AgentInfo]:
|
||||||
|
"""Get agent information by ID"""
|
||||||
|
return self.agents.get(agent_id)
|
||||||
|
|
||||||
|
async def get_agents_by_service(self, service: str) -> List[AgentInfo]:
|
||||||
|
"""Get agents that provide a specific service"""
|
||||||
|
agent_ids = self.service_index.get(service, set())
|
||||||
|
return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents]
|
||||||
|
|
||||||
|
async def get_agents_by_capability(self, capability: str) -> List[AgentInfo]:
|
||||||
|
"""Get agents that have a specific capability"""
|
||||||
|
agent_ids = self.capability_index.get(capability, set())
|
||||||
|
return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents]
|
||||||
|
|
||||||
|
async def get_agents_by_type(self, agent_type: AgentType) -> List[AgentInfo]:
|
||||||
|
"""Get agents of a specific type"""
|
||||||
|
agent_ids = self.type_index.get(agent_type, set())
|
||||||
|
return [self.agents[agent_id] for agent_id in agent_ids if agent_id in self.agents]
|
||||||
|
|
||||||
|
async def get_registry_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get registry statistics"""
|
||||||
|
total_agents = len(self.agents)
|
||||||
|
status_counts = {}
|
||||||
|
type_counts = {}
|
||||||
|
|
||||||
|
for agent_info in self.agents.values():
|
||||||
|
# Count by status
|
||||||
|
status = agent_info.status.value
|
||||||
|
status_counts[status] = status_counts.get(status, 0) + 1
|
||||||
|
|
||||||
|
# Count by type
|
||||||
|
agent_type = agent_info.agent_type.value
|
||||||
|
type_counts[agent_type] = type_counts.get(agent_type, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_agents": total_agents,
|
||||||
|
"status_counts": status_counts,
|
||||||
|
"type_counts": type_counts,
|
||||||
|
"service_count": len(self.service_index),
|
||||||
|
"capability_count": len(self.capability_index),
|
||||||
|
"last_cleanup": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _update_indexes(self, agent_info: AgentInfo):
|
||||||
|
"""Update search indexes"""
|
||||||
|
# Service index
|
||||||
|
for service in agent_info.services:
|
||||||
|
if service not in self.service_index:
|
||||||
|
self.service_index[service] = set()
|
||||||
|
self.service_index[service].add(agent_info.agent_id)
|
||||||
|
|
||||||
|
# Capability index
|
||||||
|
for capability in agent_info.capabilities:
|
||||||
|
if capability not in self.capability_index:
|
||||||
|
self.capability_index[capability] = set()
|
||||||
|
self.capability_index[capability].add(agent_info.agent_id)
|
||||||
|
|
||||||
|
# Type index
|
||||||
|
if agent_info.agent_type not in self.type_index:
|
||||||
|
self.type_index[agent_info.agent_type] = set()
|
||||||
|
self.type_index[agent_info.agent_type].add(agent_info.agent_id)
|
||||||
|
|
||||||
|
def _remove_from_indexes(self, agent_info: AgentInfo):
|
||||||
|
"""Remove agent from search indexes"""
|
||||||
|
# Service index
|
||||||
|
for service in agent_info.services:
|
||||||
|
if service in self.service_index:
|
||||||
|
self.service_index[service].discard(agent_info.agent_id)
|
||||||
|
if not self.service_index[service]:
|
||||||
|
del self.service_index[service]
|
||||||
|
|
||||||
|
# Capability index
|
||||||
|
for capability in agent_info.capabilities:
|
||||||
|
if capability in self.capability_index:
|
||||||
|
self.capability_index[capability].discard(agent_info.agent_id)
|
||||||
|
if not self.capability_index[capability]:
|
||||||
|
del self.capability_index[capability]
|
||||||
|
|
||||||
|
# Type index
|
||||||
|
if agent_info.agent_type in self.type_index:
|
||||||
|
self.type_index[agent_info.agent_type].discard(agent_info.agent_id)
|
||||||
|
if not self.type_index[agent_info.agent_type]:
|
||||||
|
del self.type_index[agent_info.agent_type]
|
||||||
|
|
||||||
|
def _calculate_health_score(self, agent_info: AgentInfo) -> float:
|
||||||
|
"""Calculate agent health score"""
|
||||||
|
base_score = 1.0
|
||||||
|
|
||||||
|
# Penalty for high load
|
||||||
|
if agent_info.load_metrics:
|
||||||
|
avg_load = sum(agent_info.load_metrics.values()) / len(agent_info.load_metrics)
|
||||||
|
if avg_load > 0.8:
|
||||||
|
base_score -= 0.3
|
||||||
|
elif avg_load > 0.6:
|
||||||
|
base_score -= 0.1
|
||||||
|
|
||||||
|
# Penalty for error status
|
||||||
|
if agent_info.status == AgentStatus.ERROR:
|
||||||
|
base_score -= 0.5
|
||||||
|
elif agent_info.status == AgentStatus.MAINTENANCE:
|
||||||
|
base_score -= 0.2
|
||||||
|
elif agent_info.status == AgentStatus.BUSY:
|
||||||
|
base_score -= 0.1
|
||||||
|
|
||||||
|
# Penalty for old heartbeat
|
||||||
|
heartbeat_age = (datetime.utcnow() - agent_info.last_heartbeat).total_seconds()
|
||||||
|
if heartbeat_age > self.max_heartbeat_age:
|
||||||
|
base_score -= 0.5
|
||||||
|
elif heartbeat_age > self.max_heartbeat_age / 2:
|
||||||
|
base_score -= 0.2
|
||||||
|
|
||||||
|
return max(0.0, min(1.0, base_score))
|
||||||
|
|
||||||
|
async def _save_agent_to_redis(self, agent_info: AgentInfo):
|
||||||
|
"""Save agent information to Redis"""
|
||||||
|
if not self.redis_client:
|
||||||
|
return
|
||||||
|
|
||||||
|
key = f"agent:{agent_info.agent_id}"
|
||||||
|
await self.redis_client.setex(
|
||||||
|
key,
|
||||||
|
timedelta(hours=24), # 24 hour TTL
|
||||||
|
json.dumps(agent_info.to_dict())
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _remove_agent_from_redis(self, agent_id: str):
|
||||||
|
"""Remove agent from Redis"""
|
||||||
|
if not self.redis_client:
|
||||||
|
return
|
||||||
|
|
||||||
|
key = f"agent:{agent_id}"
|
||||||
|
await self.redis_client.delete(key)
|
||||||
|
|
||||||
|
async def _load_agents_from_redis(self):
|
||||||
|
"""Load agents from Redis"""
|
||||||
|
if not self.redis_client:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get all agent keys
|
||||||
|
keys = await self.redis_client.keys("agent:*")
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
data = await self.redis_client.get(key)
|
||||||
|
if data:
|
||||||
|
agent_info = AgentInfo.from_dict(json.loads(data))
|
||||||
|
self.agents[agent_info.agent_id] = agent_info
|
||||||
|
self._update_indexes(agent_info)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(self.agents)} agents from Redis")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading agents from Redis: {e}")
|
||||||
|
|
||||||
|
async def _publish_agent_event(self, event_type: str, agent_info: AgentInfo):
|
||||||
|
"""Publish agent event to Redis"""
|
||||||
|
if not self.redis_client:
|
||||||
|
return
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"event_type": event_type,
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"agent_info": agent_info.to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.redis_client.publish("agent_events", json.dumps(event))
|
||||||
|
|
||||||
|
async def _heartbeat_monitor(self):
|
||||||
|
"""Monitor agent heartbeats"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.heartbeat_interval)
|
||||||
|
|
||||||
|
# Check for agents with old heartbeats
|
||||||
|
now = datetime.utcnow()
|
||||||
|
for agent_id, agent_info in list(self.agents.items()):
|
||||||
|
heartbeat_age = (now - agent_info.last_heartbeat).total_seconds()
|
||||||
|
|
||||||
|
if heartbeat_age > self.max_heartbeat_age:
|
||||||
|
# Mark as inactive
|
||||||
|
if agent_info.status != AgentStatus.INACTIVE:
|
||||||
|
await self.update_agent_status(agent_id, AgentStatus.INACTIVE)
|
||||||
|
logger.warning(f"Agent {agent_id} marked as inactive due to old heartbeat")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in heartbeat monitor: {e}")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
async def _cleanup_inactive_agents(self):
|
||||||
|
"""Clean up inactive agents"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.cleanup_interval)
|
||||||
|
|
||||||
|
# Remove agents that have been inactive too long
|
||||||
|
now = datetime.utcnow()
|
||||||
|
max_inactive_age = timedelta(hours=1) # 1 hour
|
||||||
|
|
||||||
|
for agent_id, agent_info in list(self.agents.items()):
|
||||||
|
if agent_info.status == AgentStatus.INACTIVE:
|
||||||
|
inactive_age = now - agent_info.last_heartbeat
|
||||||
|
if inactive_age > max_inactive_age:
|
||||||
|
await self.unregister_agent(agent_id)
|
||||||
|
logger.info(f"Removed inactive agent {agent_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in cleanup task: {e}")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
class AgentDiscoveryService:
|
||||||
|
"""Service for agent discovery and registration"""
|
||||||
|
|
||||||
|
def __init__(self, registry: AgentRegistry):
|
||||||
|
self.registry = registry
|
||||||
|
self.discovery_handlers: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
def register_discovery_handler(self, handler_name: str, handler: Callable):
|
||||||
|
"""Register a discovery handler"""
|
||||||
|
self.discovery_handlers[handler_name] = handler
|
||||||
|
logger.info(f"Registered discovery handler: {handler_name}")
|
||||||
|
|
||||||
|
async def handle_discovery_request(self, message: AgentMessage) -> Optional[AgentMessage]:
|
||||||
|
"""Handle agent discovery request"""
|
||||||
|
try:
|
||||||
|
discovery_data = DiscoveryMessage(**message.payload)
|
||||||
|
|
||||||
|
# Update or register agent
|
||||||
|
agent_info = AgentInfo(
|
||||||
|
agent_id=discovery_data.agent_id,
|
||||||
|
agent_type=AgentType(discovery_data.agent_type),
|
||||||
|
status=AgentStatus.ACTIVE,
|
||||||
|
capabilities=discovery_data.capabilities,
|
||||||
|
services=discovery_data.services,
|
||||||
|
endpoints=discovery_data.endpoints,
|
||||||
|
metadata=discovery_data.metadata,
|
||||||
|
last_heartbeat=datetime.utcnow(),
|
||||||
|
registration_time=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register or update agent
|
||||||
|
if discovery_data.agent_id in self.registry.agents:
|
||||||
|
await self.registry.update_agent_status(discovery_data.agent_id, AgentStatus.ACTIVE)
|
||||||
|
else:
|
||||||
|
await self.registry.register_agent(agent_info)
|
||||||
|
|
||||||
|
# Send response with available agents
|
||||||
|
available_agents = await self.registry.discover_agents({
|
||||||
|
"status": "active",
|
||||||
|
"limit": 50
|
||||||
|
})
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"discovery_agents": [agent.to_dict() for agent in available_agents],
|
||||||
|
"registry_stats": await self.registry.get_registry_stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
response = AgentMessage(
|
||||||
|
sender_id="discovery_service",
|
||||||
|
receiver_id=message.sender_id,
|
||||||
|
message_type=MessageType.DISCOVERY,
|
||||||
|
payload=response_data,
|
||||||
|
correlation_id=message.id
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling discovery request: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def find_best_agent(self, requirements: Dict[str, Any]) -> Optional[AgentInfo]:
|
||||||
|
"""Find the best agent for given requirements"""
|
||||||
|
try:
|
||||||
|
# Build discovery query
|
||||||
|
query = {}
|
||||||
|
|
||||||
|
if "agent_type" in requirements:
|
||||||
|
query["agent_type"] = requirements["agent_type"]
|
||||||
|
|
||||||
|
if "capabilities" in requirements:
|
||||||
|
query["capabilities"] = requirements["capabilities"]
|
||||||
|
|
||||||
|
if "services" in requirements:
|
||||||
|
query["services"] = requirements["services"]
|
||||||
|
|
||||||
|
if "min_health_score" in requirements:
|
||||||
|
query["min_health_score"] = requirements["min_health_score"]
|
||||||
|
|
||||||
|
# Discover agents
|
||||||
|
agents = await self.registry.discover_agents(query)
|
||||||
|
|
||||||
|
if not agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Select best agent (highest health score)
|
||||||
|
return agents[0]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding best agent: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_service_endpoints(self, service: str) -> Dict[str, List[str]]:
|
||||||
|
"""Get all endpoints for a specific service"""
|
||||||
|
try:
|
||||||
|
agents = await self.registry.get_agents_by_service(service)
|
||||||
|
endpoints = {}
|
||||||
|
|
||||||
|
for agent in agents:
|
||||||
|
for service_name, endpoint in agent.endpoints.items():
|
||||||
|
if service_name not in endpoints:
|
||||||
|
endpoints[service_name] = []
|
||||||
|
endpoints[service_name].append(endpoint)
|
||||||
|
|
||||||
|
return endpoints
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting service endpoints: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Factory functions
|
||||||
|
def create_agent_info(agent_id: str, agent_type: str, capabilities: List[str], services: List[str], endpoints: Dict[str, str]) -> AgentInfo:
|
||||||
|
"""Create agent information"""
|
||||||
|
return AgentInfo(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=AgentType(agent_type),
|
||||||
|
status=AgentStatus.ACTIVE,
|
||||||
|
capabilities=capabilities,
|
||||||
|
services=services,
|
||||||
|
endpoints=endpoints,
|
||||||
|
metadata={},
|
||||||
|
last_heartbeat=datetime.utcnow(),
|
||||||
|
registration_time=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
async def example_usage():
|
||||||
|
"""Example of how to use the agent discovery system"""
|
||||||
|
|
||||||
|
# Create registry
|
||||||
|
registry = AgentRegistry()
|
||||||
|
await registry.start()
|
||||||
|
|
||||||
|
# Create discovery service
|
||||||
|
discovery_service = AgentDiscoveryService(registry)
|
||||||
|
|
||||||
|
# Register an agent
|
||||||
|
agent_info = create_agent_info(
|
||||||
|
agent_id="agent-001",
|
||||||
|
agent_type="worker",
|
||||||
|
capabilities=["data_processing", "analysis"],
|
||||||
|
services=["process_data", "analyze_results"],
|
||||||
|
endpoints={"http": "http://localhost:8001", "ws": "ws://localhost:8002"}
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.register_agent(agent_info)
|
||||||
|
|
||||||
|
# Discover agents
|
||||||
|
agents = await registry.discover_agents({
|
||||||
|
"capabilities": ["data_processing"],
|
||||||
|
"status": "active"
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"Found {len(agents)} agents")
|
||||||
|
|
||||||
|
# Find best agent
|
||||||
|
best_agent = await discovery_service.find_best_agent({
|
||||||
|
"capabilities": ["data_processing"],
|
||||||
|
"min_health_score": 0.8
|
||||||
|
})
|
||||||
|
|
||||||
|
if best_agent:
|
||||||
|
print(f"Best agent: {best_agent.agent_id}")
|
||||||
|
|
||||||
|
await registry.stop()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(example_usage())
|
||||||
716
apps/agent-coordinator/src/app/routing/load_balancer.py
Normal file
716
apps/agent-coordinator/src/app/routing/load_balancer.py
Normal file
@@ -0,0 +1,716 @@
|
|||||||
|
"""
|
||||||
|
Load Balancer for Agent Distribution and Task Assignment
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
import statistics
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
from .agent_discovery import AgentRegistry, AgentInfo, AgentStatus, AgentType
|
||||||
|
from ..protocols.message_types import TaskMessage, create_task_message
|
||||||
|
from ..protocols.communication import AgentMessage, MessageType, Priority
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LoadBalancingStrategy(str, Enum):
|
||||||
|
"""Load balancing strategies"""
|
||||||
|
ROUND_ROBIN = "round_robin"
|
||||||
|
LEAST_CONNECTIONS = "least_connections"
|
||||||
|
LEAST_RESPONSE_TIME = "least_response_time"
|
||||||
|
WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
|
||||||
|
RESOURCE_BASED = "resource_based"
|
||||||
|
CAPABILITY_BASED = "capability_based"
|
||||||
|
PREDICTIVE = "predictive"
|
||||||
|
CONSISTENT_HASH = "consistent_hash"
|
||||||
|
|
||||||
|
class TaskPriority(str, Enum):
|
||||||
|
"""Task priority levels"""
|
||||||
|
LOW = "low"
|
||||||
|
NORMAL = "normal"
|
||||||
|
HIGH = "high"
|
||||||
|
CRITICAL = "critical"
|
||||||
|
URGENT = "urgent"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadMetrics:
|
||||||
|
"""Agent load metrics"""
|
||||||
|
cpu_usage: float = 0.0
|
||||||
|
memory_usage: float = 0.0
|
||||||
|
active_connections: int = 0
|
||||||
|
pending_tasks: int = 0
|
||||||
|
completed_tasks: int = 0
|
||||||
|
failed_tasks: int = 0
|
||||||
|
avg_response_time: float = 0.0
|
||||||
|
last_updated: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"cpu_usage": self.cpu_usage,
|
||||||
|
"memory_usage": self.memory_usage,
|
||||||
|
"active_connections": self.active_connections,
|
||||||
|
"pending_tasks": self.pending_tasks,
|
||||||
|
"completed_tasks": self.completed_tasks,
|
||||||
|
"failed_tasks": self.failed_tasks,
|
||||||
|
"avg_response_time": self.avg_response_time,
|
||||||
|
"last_updated": self.last_updated.isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskAssignment:
|
||||||
|
"""Task assignment record"""
|
||||||
|
task_id: str
|
||||||
|
agent_id: str
|
||||||
|
assigned_at: datetime
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
status: str = "pending"
|
||||||
|
response_time: Optional[float] = None
|
||||||
|
success: bool = False
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"assigned_at": self.assigned_at.isoformat(),
|
||||||
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||||
|
"status": self.status,
|
||||||
|
"response_time": self.response_time,
|
||||||
|
"success": self.success,
|
||||||
|
"error_message": self.error_message
|
||||||
|
}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentWeight:
|
||||||
|
"""Agent weight for load balancing"""
|
||||||
|
agent_id: str
|
||||||
|
weight: float = 1.0
|
||||||
|
capacity: int = 100
|
||||||
|
performance_score: float = 1.0
|
||||||
|
reliability_score: float = 1.0
|
||||||
|
last_updated: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
class LoadBalancer:
|
||||||
|
"""Advanced load balancer for agent distribution"""
|
||||||
|
|
||||||
|
def __init__(self, registry: AgentRegistry):
|
||||||
|
self.registry = registry
|
||||||
|
self.strategy = LoadBalancingStrategy.LEAST_CONNECTIONS
|
||||||
|
self.agent_weights: Dict[str, AgentWeight] = {}
|
||||||
|
self.agent_metrics: Dict[str, LoadMetrics] = {}
|
||||||
|
self.task_assignments: Dict[str, TaskAssignment] = {}
|
||||||
|
self.assignment_history: deque = deque(maxlen=1000)
|
||||||
|
self.round_robin_index = 0
|
||||||
|
self.consistent_hash_ring: Dict[int, str] = {}
|
||||||
|
self.prediction_models: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self.total_assignments = 0
|
||||||
|
self.successful_assignments = 0
|
||||||
|
self.failed_assignments = 0
|
||||||
|
|
||||||
|
def set_strategy(self, strategy: LoadBalancingStrategy):
|
||||||
|
"""Set load balancing strategy"""
|
||||||
|
self.strategy = strategy
|
||||||
|
logger.info(f"Load balancing strategy changed to: {strategy.value}")
|
||||||
|
|
||||||
|
def set_agent_weight(self, agent_id: str, weight: float, capacity: int = 100):
|
||||||
|
"""Set agent weight and capacity"""
|
||||||
|
self.agent_weights[agent_id] = AgentWeight(
|
||||||
|
agent_id=agent_id,
|
||||||
|
weight=weight,
|
||||||
|
capacity=capacity
|
||||||
|
)
|
||||||
|
logger.info(f"Set weight for agent {agent_id}: {weight}, capacity: {capacity}")
|
||||||
|
|
||||||
|
def update_agent_metrics(self, agent_id: str, metrics: LoadMetrics):
|
||||||
|
"""Update agent load metrics"""
|
||||||
|
self.agent_metrics[agent_id] = metrics
|
||||||
|
self.agent_metrics[agent_id].last_updated = datetime.utcnow()
|
||||||
|
|
||||||
|
# Update performance score based on metrics
|
||||||
|
self._update_performance_score(agent_id, metrics)
|
||||||
|
|
||||||
|
def _update_performance_score(self, agent_id: str, metrics: LoadMetrics):
|
||||||
|
"""Update agent performance score based on metrics"""
|
||||||
|
if agent_id not in self.agent_weights:
|
||||||
|
self.agent_weights[agent_id] = AgentWeight(agent_id=agent_id)
|
||||||
|
|
||||||
|
weight = self.agent_weights[agent_id]
|
||||||
|
|
||||||
|
# Calculate performance score (0.0 to 1.0)
|
||||||
|
performance_factors = []
|
||||||
|
|
||||||
|
# CPU usage factor (lower is better)
|
||||||
|
cpu_factor = max(0.0, 1.0 - metrics.cpu_usage)
|
||||||
|
performance_factors.append(cpu_factor)
|
||||||
|
|
||||||
|
# Memory usage factor (lower is better)
|
||||||
|
memory_factor = max(0.0, 1.0 - metrics.memory_usage)
|
||||||
|
performance_factors.append(memory_factor)
|
||||||
|
|
||||||
|
# Response time factor (lower is better)
|
||||||
|
if metrics.avg_response_time > 0:
|
||||||
|
response_factor = max(0.0, 1.0 - (metrics.avg_response_time / 10.0)) # 10s max
|
||||||
|
performance_factors.append(response_factor)
|
||||||
|
|
||||||
|
# Success rate factor (higher is better)
|
||||||
|
total_tasks = metrics.completed_tasks + metrics.failed_tasks
|
||||||
|
if total_tasks > 0:
|
||||||
|
success_rate = metrics.completed_tasks / total_tasks
|
||||||
|
performance_factors.append(success_rate)
|
||||||
|
|
||||||
|
# Update performance score
|
||||||
|
if performance_factors:
|
||||||
|
weight.performance_score = statistics.mean(performance_factors)
|
||||||
|
|
||||||
|
# Update reliability score
|
||||||
|
if total_tasks > 10: # Only update after enough tasks
|
||||||
|
weight.reliability_score = success_rate
|
||||||
|
|
||||||
|
async def assign_task(self, task_data: Dict[str, Any], requirements: Optional[Dict[str, Any]] = None) -> Optional[str]:
|
||||||
|
"""Assign task to best available agent"""
|
||||||
|
try:
|
||||||
|
# Find eligible agents
|
||||||
|
eligible_agents = await self._find_eligible_agents(task_data, requirements)
|
||||||
|
|
||||||
|
if not eligible_agents:
|
||||||
|
logger.warning("No eligible agents found for task assignment")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Select best agent based on strategy
|
||||||
|
selected_agent = await self._select_agent(eligible_agents, task_data)
|
||||||
|
|
||||||
|
if not selected_agent:
|
||||||
|
logger.warning("No agent selected for task assignment")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create task assignment
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
assignment = TaskAssignment(
|
||||||
|
task_id=task_id,
|
||||||
|
agent_id=selected_agent,
|
||||||
|
assigned_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record assignment
|
||||||
|
self.task_assignments[task_id] = assignment
|
||||||
|
self.assignment_history.append(assignment)
|
||||||
|
self.total_assignments += 1
|
||||||
|
|
||||||
|
# Update agent metrics
|
||||||
|
if selected_agent not in self.agent_metrics:
|
||||||
|
self.agent_metrics[selected_agent] = LoadMetrics()
|
||||||
|
|
||||||
|
self.agent_metrics[selected_agent].pending_tasks += 1
|
||||||
|
|
||||||
|
logger.info(f"Task {task_id} assigned to agent {selected_agent}")
|
||||||
|
return selected_agent
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error assigning task: {e}")
|
||||||
|
self.failed_assignments += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def complete_task(self, task_id: str, success: bool, response_time: Optional[float] = None, error_message: Optional[str] = None):
|
||||||
|
"""Mark task as completed"""
|
||||||
|
try:
|
||||||
|
if task_id not in self.task_assignments:
|
||||||
|
logger.warning(f"Task assignment {task_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
assignment = self.task_assignments[task_id]
|
||||||
|
assignment.completed_at = datetime.utcnow()
|
||||||
|
assignment.status = "completed"
|
||||||
|
assignment.success = success
|
||||||
|
assignment.response_time = response_time
|
||||||
|
assignment.error_message = error_message
|
||||||
|
|
||||||
|
# Update agent metrics
|
||||||
|
agent_id = assignment.agent_id
|
||||||
|
if agent_id in self.agent_metrics:
|
||||||
|
metrics = self.agent_metrics[agent_id]
|
||||||
|
metrics.pending_tasks = max(0, metrics.pending_tasks - 1)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
metrics.completed_tasks += 1
|
||||||
|
self.successful_assignments += 1
|
||||||
|
else:
|
||||||
|
metrics.failed_tasks += 1
|
||||||
|
self.failed_assignments += 1
|
||||||
|
|
||||||
|
# Update average response time
|
||||||
|
if response_time:
|
||||||
|
total_completed = metrics.completed_tasks + metrics.failed_tasks
|
||||||
|
if total_completed > 0:
|
||||||
|
metrics.avg_response_time = (
|
||||||
|
(metrics.avg_response_time * (total_completed - 1) + response_time) / total_completed
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Task {task_id} completed by agent {assignment.agent_id}, success: {success}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error completing task {task_id}: {e}")
|
||||||
|
|
||||||
|
async def _find_eligible_agents(self, task_data: Dict[str, Any], requirements: Optional[Dict[str, Any]] = None) -> List[str]:
|
||||||
|
"""Find eligible agents for task"""
|
||||||
|
try:
|
||||||
|
# Build discovery query
|
||||||
|
query = {"status": AgentStatus.ACTIVE}
|
||||||
|
|
||||||
|
if requirements:
|
||||||
|
if "agent_type" in requirements:
|
||||||
|
query["agent_type"] = requirements["agent_type"]
|
||||||
|
|
||||||
|
if "capabilities" in requirements:
|
||||||
|
query["capabilities"] = requirements["capabilities"]
|
||||||
|
|
||||||
|
if "services" in requirements:
|
||||||
|
query["services"] = requirements["services"]
|
||||||
|
|
||||||
|
if "min_health_score" in requirements:
|
||||||
|
query["min_health_score"] = requirements["min_health_score"]
|
||||||
|
|
||||||
|
# Discover agents
|
||||||
|
agents = await self.registry.discover_agents(query)
|
||||||
|
|
||||||
|
# Filter by capacity and load
|
||||||
|
eligible_agents = []
|
||||||
|
for agent in agents:
|
||||||
|
agent_id = agent.agent_id
|
||||||
|
|
||||||
|
# Check capacity
|
||||||
|
if agent_id in self.agent_weights:
|
||||||
|
weight = self.agent_weights[agent_id]
|
||||||
|
current_load = self._get_agent_load(agent_id)
|
||||||
|
|
||||||
|
if current_load < weight.capacity:
|
||||||
|
eligible_agents.append(agent_id)
|
||||||
|
else:
|
||||||
|
# Default capacity check
|
||||||
|
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
|
||||||
|
if metrics.pending_tasks < 100: # Default capacity
|
||||||
|
eligible_agents.append(agent_id)
|
||||||
|
|
||||||
|
return eligible_agents
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding eligible agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_agent_load(self, agent_id: str) -> int:
|
||||||
|
"""Get current load for agent"""
|
||||||
|
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
|
||||||
|
return metrics.active_connections + metrics.pending_tasks
|
||||||
|
|
||||||
|
async def _select_agent(self, eligible_agents: List[str], task_data: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""Select best agent based on current strategy"""
|
||||||
|
if not eligible_agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.strategy == LoadBalancingStrategy.ROUND_ROBIN:
|
||||||
|
return self._round_robin_selection(eligible_agents)
|
||||||
|
elif self.strategy == LoadBalancingStrategy.LEAST_CONNECTIONS:
|
||||||
|
return self._least_connections_selection(eligible_agents)
|
||||||
|
elif self.strategy == LoadBalancingStrategy.LEAST_RESPONSE_TIME:
|
||||||
|
return self._least_response_time_selection(eligible_agents)
|
||||||
|
elif self.strategy == LoadBalancingStrategy.WEIGHTED_ROUND_ROBIN:
|
||||||
|
return self._weighted_round_robin_selection(eligible_agents)
|
||||||
|
elif self.strategy == LoadBalancingStrategy.RESOURCE_BASED:
|
||||||
|
return self._resource_based_selection(eligible_agents)
|
||||||
|
elif self.strategy == LoadBalancingStrategy.CAPABILITY_BASED:
|
||||||
|
return self._capability_based_selection(eligible_agents, task_data)
|
||||||
|
elif self.strategy == LoadBalancingStrategy.PREDICTIVE:
|
||||||
|
return self._predictive_selection(eligible_agents, task_data)
|
||||||
|
elif self.strategy == LoadBalancingStrategy.CONSISTENT_HASH:
|
||||||
|
return self._consistent_hash_selection(eligible_agents, task_data)
|
||||||
|
else:
|
||||||
|
return eligible_agents[0]
|
||||||
|
|
||||||
|
def _round_robin_selection(self, agents: List[str]) -> str:
|
||||||
|
"""Round-robin agent selection"""
|
||||||
|
agent = agents[self.round_robin_index % len(agents)]
|
||||||
|
self.round_robin_index += 1
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def _least_connections_selection(self, agents: List[str]) -> str:
|
||||||
|
"""Select agent with least connections"""
|
||||||
|
min_connections = float('inf')
|
||||||
|
selected_agent = None
|
||||||
|
|
||||||
|
for agent_id in agents:
|
||||||
|
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
|
||||||
|
connections = metrics.active_connections
|
||||||
|
|
||||||
|
if connections < min_connections:
|
||||||
|
min_connections = connections
|
||||||
|
selected_agent = agent_id
|
||||||
|
|
||||||
|
return selected_agent or agents[0]
|
||||||
|
|
||||||
|
def _least_response_time_selection(self, agents: List[str]) -> str:
|
||||||
|
"""Select agent with least average response time"""
|
||||||
|
min_response_time = float('inf')
|
||||||
|
selected_agent = None
|
||||||
|
|
||||||
|
for agent_id in agents:
|
||||||
|
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
|
||||||
|
response_time = metrics.avg_response_time
|
||||||
|
|
||||||
|
if response_time < min_response_time:
|
||||||
|
min_response_time = response_time
|
||||||
|
selected_agent = agent_id
|
||||||
|
|
||||||
|
return selected_agent or agents[0]
|
||||||
|
|
||||||
|
def _weighted_round_robin_selection(self, agents: List[str]) -> str:
|
||||||
|
"""Weighted round-robin selection"""
|
||||||
|
# Calculate total weight
|
||||||
|
total_weight = 0
|
||||||
|
for agent_id in agents:
|
||||||
|
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
|
||||||
|
total_weight += weight.weight
|
||||||
|
|
||||||
|
if total_weight == 0:
|
||||||
|
return agents[0]
|
||||||
|
|
||||||
|
# Select agent based on weight
|
||||||
|
current_weight = self.round_robin_index % total_weight
|
||||||
|
accumulated_weight = 0
|
||||||
|
|
||||||
|
for agent_id in agents:
|
||||||
|
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
|
||||||
|
accumulated_weight += weight.weight
|
||||||
|
|
||||||
|
if current_weight < accumulated_weight:
|
||||||
|
self.round_robin_index += 1
|
||||||
|
return agent_id
|
||||||
|
|
||||||
|
return agents[0]
|
||||||
|
|
||||||
|
def _resource_based_selection(self, agents: List[str]) -> str:
|
||||||
|
"""Resource-based selection considering CPU and memory"""
|
||||||
|
best_score = -1
|
||||||
|
selected_agent = None
|
||||||
|
|
||||||
|
for agent_id in agents:
|
||||||
|
metrics = self.agent_metrics.get(agent_id, LoadMetrics())
|
||||||
|
|
||||||
|
# Calculate resource score (lower usage is better)
|
||||||
|
cpu_score = max(0, 100 - metrics.cpu_usage)
|
||||||
|
memory_score = max(0, 100 - metrics.memory_usage)
|
||||||
|
resource_score = (cpu_score + memory_score) / 2
|
||||||
|
|
||||||
|
# Apply performance weight
|
||||||
|
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
|
||||||
|
final_score = resource_score * weight.performance_score
|
||||||
|
|
||||||
|
if final_score > best_score:
|
||||||
|
best_score = final_score
|
||||||
|
selected_agent = agent_id
|
||||||
|
|
||||||
|
return selected_agent or agents[0]
|
||||||
|
|
||||||
|
def _capability_based_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str:
|
||||||
|
"""Capability-based selection considering task requirements"""
|
||||||
|
required_capabilities = task_data.get("required_capabilities", [])
|
||||||
|
|
||||||
|
if not required_capabilities:
|
||||||
|
return agents[0]
|
||||||
|
|
||||||
|
best_score = -1
|
||||||
|
selected_agent = None
|
||||||
|
|
||||||
|
for agent_id in agents:
|
||||||
|
agent_info = self.registry.agents.get(agent_id)
|
||||||
|
if not agent_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate capability match score
|
||||||
|
agent_capabilities = set(agent_info.capabilities)
|
||||||
|
required_set = set(required_capabilities)
|
||||||
|
|
||||||
|
if required_set.issubset(agent_capabilities):
|
||||||
|
# Perfect match
|
||||||
|
capability_score = 1.0
|
||||||
|
else:
|
||||||
|
# Partial match
|
||||||
|
intersection = required_set.intersection(agent_capabilities)
|
||||||
|
capability_score = len(intersection) / len(required_set)
|
||||||
|
|
||||||
|
# Apply performance weight
|
||||||
|
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
|
||||||
|
final_score = capability_score * weight.performance_score
|
||||||
|
|
||||||
|
if final_score > best_score:
|
||||||
|
best_score = final_score
|
||||||
|
selected_agent = agent_id
|
||||||
|
|
||||||
|
return selected_agent or agents[0]
|
||||||
|
|
||||||
|
def _predictive_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str:
|
||||||
|
"""Predictive selection using historical performance"""
|
||||||
|
task_type = task_data.get("task_type", "unknown")
|
||||||
|
|
||||||
|
# Calculate predicted performance for each agent
|
||||||
|
best_score = -1
|
||||||
|
selected_agent = None
|
||||||
|
|
||||||
|
for agent_id in agents:
|
||||||
|
# Get historical performance for this task type
|
||||||
|
score = self._calculate_predicted_score(agent_id, task_type)
|
||||||
|
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
selected_agent = agent_id
|
||||||
|
|
||||||
|
return selected_agent or agents[0]
|
||||||
|
|
||||||
|
def _calculate_predicted_score(self, agent_id: str, task_type: str) -> float:
|
||||||
|
"""Calculate predicted performance score for agent"""
|
||||||
|
# Simple prediction based on recent performance
|
||||||
|
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
|
||||||
|
|
||||||
|
# Base score from performance and reliability
|
||||||
|
base_score = (weight.performance_score + weight.reliability_score) / 2
|
||||||
|
|
||||||
|
# Adjust based on recent assignments
|
||||||
|
recent_assignments = [a for a in self.assignment_history if a.agent_id == agent_id][-10:]
|
||||||
|
if recent_assignments:
|
||||||
|
success_rate = sum(1 for a in recent_assignments if a.success) / len(recent_assignments)
|
||||||
|
base_score = base_score * 0.7 + success_rate * 0.3
|
||||||
|
|
||||||
|
return base_score
|
||||||
|
|
||||||
|
def _consistent_hash_selection(self, agents: List[str], task_data: Dict[str, Any]) -> str:
|
||||||
|
"""Consistent hash selection for sticky routing"""
|
||||||
|
# Create hash key from task data
|
||||||
|
hash_key = json.dumps(task_data, sort_keys=True)
|
||||||
|
hash_value = int(hashlib.md5(hash_key.encode()).hexdigest(), 16)
|
||||||
|
|
||||||
|
# Build hash ring if not exists
|
||||||
|
if not self.consistent_hash_ring:
|
||||||
|
self._build_hash_ring(agents)
|
||||||
|
|
||||||
|
# Find agent on hash ring
|
||||||
|
for hash_pos in sorted(self.consistent_hash_ring.keys()):
|
||||||
|
if hash_value <= hash_pos:
|
||||||
|
return self.consistent_hash_ring[hash_pos]
|
||||||
|
|
||||||
|
# Wrap around
|
||||||
|
return self.consistent_hash_ring[min(self.consistent_hash_ring.keys())]
|
||||||
|
|
||||||
|
def _build_hash_ring(self, agents: List[str]):
|
||||||
|
"""Build consistent hash ring"""
|
||||||
|
self.consistent_hash_ring = {}
|
||||||
|
|
||||||
|
for agent_id in agents:
|
||||||
|
# Create multiple virtual nodes for better distribution
|
||||||
|
for i in range(100):
|
||||||
|
virtual_key = f"{agent_id}:{i}"
|
||||||
|
hash_value = int(hashlib.md5(virtual_key.encode()).hexdigest(), 16)
|
||||||
|
self.consistent_hash_ring[hash_value] = agent_id
|
||||||
|
|
||||||
|
def get_load_balancing_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get load balancing statistics"""
|
||||||
|
return {
|
||||||
|
"strategy": self.strategy.value,
|
||||||
|
"total_assignments": self.total_assignments,
|
||||||
|
"successful_assignments": self.successful_assignments,
|
||||||
|
"failed_assignments": self.failed_assignments,
|
||||||
|
"success_rate": self.successful_assignments / max(1, self.total_assignments),
|
||||||
|
"active_agents": len(self.agent_metrics),
|
||||||
|
"agent_weights": len(self.agent_weights),
|
||||||
|
"avg_agent_load": statistics.mean([self._get_agent_load(a) for a in self.agent_metrics]) if self.agent_metrics else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_agent_stats(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get detailed statistics for a specific agent"""
|
||||||
|
if agent_id not in self.agent_metrics:
|
||||||
|
return None
|
||||||
|
|
||||||
|
metrics = self.agent_metrics[agent_id]
|
||||||
|
weight = self.agent_weights.get(agent_id, AgentWeight(agent_id=agent_id))
|
||||||
|
|
||||||
|
# Get recent assignments
|
||||||
|
recent_assignments = [a for a in self.assignment_history if a.agent_id == agent_id][-10:]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"metrics": metrics.to_dict(),
|
||||||
|
"weight": {
|
||||||
|
"weight": weight.weight,
|
||||||
|
"capacity": weight.capacity,
|
||||||
|
"performance_score": weight.performance_score,
|
||||||
|
"reliability_score": weight.reliability_score
|
||||||
|
},
|
||||||
|
"recent_assignments": [a.to_dict() for a in recent_assignments],
|
||||||
|
"current_load": self._get_agent_load(agent_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
class TaskDistributor:
|
||||||
|
"""Task distributor with advanced load balancing"""
|
||||||
|
|
||||||
|
def __init__(self, load_balancer: LoadBalancer):
|
||||||
|
self.load_balancer = load_balancer
|
||||||
|
self.task_queue = asyncio.Queue()
|
||||||
|
self.priority_queues = {
|
||||||
|
TaskPriority.URGENT: asyncio.Queue(),
|
||||||
|
TaskPriority.CRITICAL: asyncio.Queue(),
|
||||||
|
TaskPriority.HIGH: asyncio.Queue(),
|
||||||
|
TaskPriority.NORMAL: asyncio.Queue(),
|
||||||
|
TaskPriority.LOW: asyncio.Queue()
|
||||||
|
}
|
||||||
|
self.distribution_stats = {
|
||||||
|
"tasks_distributed": 0,
|
||||||
|
"tasks_completed": 0,
|
||||||
|
"tasks_failed": 0,
|
||||||
|
"avg_distribution_time": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
async def submit_task(self, task_data: Dict[str, Any], priority: TaskPriority = TaskPriority.NORMAL, requirements: Optional[Dict[str, Any]] = None):
|
||||||
|
"""Submit task for distribution"""
|
||||||
|
task_info = {
|
||||||
|
"task_data": task_data,
|
||||||
|
"priority": priority,
|
||||||
|
"requirements": requirements,
|
||||||
|
"submitted_at": datetime.utcnow()
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.priority_queues[priority].put(task_info)
|
||||||
|
logger.info(f"Task submitted with priority {priority.value}")
|
||||||
|
|
||||||
|
async def start_distribution(self):
|
||||||
|
"""Start task distribution loop"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Check queues in priority order
|
||||||
|
task_info = None
|
||||||
|
|
||||||
|
for priority in [TaskPriority.URGENT, TaskPriority.CRITICAL, TaskPriority.HIGH, TaskPriority.NORMAL, TaskPriority.LOW]:
|
||||||
|
queue = self.priority_queues[priority]
|
||||||
|
try:
|
||||||
|
task_info = queue.get_nowait()
|
||||||
|
break
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if task_info:
|
||||||
|
await self._distribute_task(task_info)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.01) # Small delay if no tasks
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in distribution loop: {e}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
async def _distribute_task(self, task_info: Dict[str, Any]):
|
||||||
|
"""Distribute a single task"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Assign task
|
||||||
|
agent_id = await self.load_balancer.assign_task(
|
||||||
|
task_info["task_data"],
|
||||||
|
task_info["requirements"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_id:
|
||||||
|
# Create task message
|
||||||
|
task_message = create_task_message(
|
||||||
|
sender_id="task_distributor",
|
||||||
|
receiver_id=agent_id,
|
||||||
|
task_type=task_info["task_data"].get("task_type", "unknown"),
|
||||||
|
task_data=task_info["task_data"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send task to agent (implementation depends on communication system)
|
||||||
|
# await self._send_task_to_agent(agent_id, task_message)
|
||||||
|
|
||||||
|
self.distribution_stats["tasks_distributed"] += 1
|
||||||
|
|
||||||
|
# Simulate task completion (in real implementation, this would be event-driven)
|
||||||
|
asyncio.create_task(self._simulate_task_completion(task_info, agent_id))
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to distribute task: no suitable agent found")
|
||||||
|
self.distribution_stats["tasks_failed"] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error distributing task: {e}")
|
||||||
|
self.distribution_stats["tasks_failed"] += 1
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Update distribution time
|
||||||
|
distribution_time = (datetime.utcnow() - start_time).total_seconds()
|
||||||
|
total_distributed = self.distribution_stats["tasks_distributed"]
|
||||||
|
self.distribution_stats["avg_distribution_time"] = (
|
||||||
|
(self.distribution_stats["avg_distribution_time"] * (total_distributed - 1) + distribution_time) / total_distributed
|
||||||
|
if total_distributed > 0 else distribution_time
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _simulate_task_completion(self, task_info: Dict[str, Any], agent_id: str):
|
||||||
|
"""Simulate task completion (for testing)"""
|
||||||
|
# Simulate task processing time
|
||||||
|
processing_time = 1.0 + (hash(task_info["task_data"].get("task_id", "")) % 5)
|
||||||
|
await asyncio.sleep(processing_time)
|
||||||
|
|
||||||
|
# Mark task as completed
|
||||||
|
success = hash(agent_id) % 10 > 1 # 90% success rate
|
||||||
|
await self.load_balancer.complete_task(
|
||||||
|
task_info["task_data"].get("task_id", str(uuid.uuid4())),
|
||||||
|
success,
|
||||||
|
processing_time
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
self.distribution_stats["tasks_completed"] += 1
|
||||||
|
else:
|
||||||
|
self.distribution_stats["tasks_failed"] += 1
|
||||||
|
|
||||||
|
def get_distribution_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get distribution statistics"""
|
||||||
|
return {
|
||||||
|
**self.distribution_stats,
|
||||||
|
"load_balancer_stats": self.load_balancer.get_load_balancing_stats(),
|
||||||
|
"queue_sizes": {
|
||||||
|
priority.value: queue.qsize()
|
||||||
|
for priority, queue in self.priority_queues.items()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
async def example_usage():
|
||||||
|
"""Example of how to use the load balancer"""
|
||||||
|
|
||||||
|
# Create registry and load balancer
|
||||||
|
registry = AgentRegistry()
|
||||||
|
await registry.start()
|
||||||
|
|
||||||
|
load_balancer = LoadBalancer(registry)
|
||||||
|
load_balancer.set_strategy(LoadBalancingStrategy.LEAST_CONNECTIONS)
|
||||||
|
|
||||||
|
# Create task distributor
|
||||||
|
distributor = TaskDistributor(load_balancer)
|
||||||
|
|
||||||
|
# Submit some tasks
|
||||||
|
for i in range(10):
|
||||||
|
await distributor.submit_task({
|
||||||
|
"task_id": f"task-{i}",
|
||||||
|
"task_type": "data_processing",
|
||||||
|
"data": f"sample_data_{i}"
|
||||||
|
}, TaskPriority.NORMAL)
|
||||||
|
|
||||||
|
# Start distribution (in real implementation, this would run in background)
|
||||||
|
# await distributor.start_distribution()
|
||||||
|
|
||||||
|
await registry.stop()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(example_usage())
|
||||||
326
apps/agent-coordinator/tests/test_communication.py
Normal file
326
apps/agent-coordinator/tests/test_communication.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
"""
|
||||||
|
Tests for Agent Communication Protocols
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import Mock, AsyncMock
|
||||||
|
|
||||||
|
from src.app.protocols.communication import (
|
||||||
|
AgentMessage, MessageType, Priority, CommunicationProtocol,
|
||||||
|
HierarchicalProtocol, PeerToPeerProtocol, BroadcastProtocol,
|
||||||
|
CommunicationManager, MessageTemplates
|
||||||
|
)
|
||||||
|
|
||||||
|
class TestAgentMessage:
|
||||||
|
"""Test AgentMessage class"""
|
||||||
|
|
||||||
|
def test_message_creation(self):
|
||||||
|
"""Test message creation"""
|
||||||
|
message = AgentMessage(
|
||||||
|
sender_id="agent-001",
|
||||||
|
receiver_id="agent-002",
|
||||||
|
message_type=MessageType.DIRECT,
|
||||||
|
priority=Priority.NORMAL,
|
||||||
|
payload={"data": "test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message.sender_id == "agent-001"
|
||||||
|
assert message.receiver_id == "agent-002"
|
||||||
|
assert message.message_type == MessageType.DIRECT
|
||||||
|
assert message.priority == Priority.NORMAL
|
||||||
|
assert message.payload["data"] == "test"
|
||||||
|
assert message.ttl == 300
|
||||||
|
|
||||||
|
def test_message_serialization(self):
|
||||||
|
"""Test message serialization"""
|
||||||
|
message = AgentMessage(
|
||||||
|
sender_id="agent-001",
|
||||||
|
receiver_id="agent-002",
|
||||||
|
message_type=MessageType.DIRECT,
|
||||||
|
priority=Priority.NORMAL,
|
||||||
|
payload={"data": "test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# To dict
|
||||||
|
message_dict = message.to_dict()
|
||||||
|
assert message_dict["sender_id"] == "agent-001"
|
||||||
|
assert message_dict["message_type"] == "direct"
|
||||||
|
assert message_dict["priority"] == "normal"
|
||||||
|
|
||||||
|
# From dict
|
||||||
|
restored_message = AgentMessage.from_dict(message_dict)
|
||||||
|
assert restored_message.sender_id == message.sender_id
|
||||||
|
assert restored_message.receiver_id == message.receiver_id
|
||||||
|
assert restored_message.message_type == message.message_type
|
||||||
|
assert restored_message.priority == message.priority
|
||||||
|
|
||||||
|
def test_message_expiration(self):
|
||||||
|
"""Test message expiration"""
|
||||||
|
old_message = AgentMessage(
|
||||||
|
sender_id="agent-001",
|
||||||
|
receiver_id="agent-002",
|
||||||
|
message_type=MessageType.DIRECT,
|
||||||
|
timestamp=datetime.utcnow() - timedelta(seconds=400),
|
||||||
|
ttl=300
|
||||||
|
)
|
||||||
|
|
||||||
|
# Message should be expired
|
||||||
|
age = (datetime.utcnow() - old_message.timestamp).total_seconds()
|
||||||
|
assert age > old_message.ttl
|
||||||
|
|
||||||
|
class TestHierarchicalProtocol:
|
||||||
|
"""Test HierarchicalProtocol class"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def master_protocol(self):
|
||||||
|
"""Create master protocol"""
|
||||||
|
return HierarchicalProtocol("master-agent", is_master=True)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sub_protocol(self):
|
||||||
|
"""Create sub-agent protocol"""
|
||||||
|
return HierarchicalProtocol("sub-agent", is_master=False)
|
||||||
|
|
||||||
|
def test_add_sub_agent(self, master_protocol):
|
||||||
|
"""Test adding sub-agent"""
|
||||||
|
master_protocol.add_sub_agent("sub-agent-001")
|
||||||
|
assert "sub-agent-001" in master_protocol.sub_agents
|
||||||
|
|
||||||
|
def test_send_to_sub_agents(self, master_protocol):
|
||||||
|
"""Test sending to sub-agents"""
|
||||||
|
master_protocol.add_sub_agent("sub-agent-001")
|
||||||
|
master_protocol.add_sub_agent("sub-agent-002")
|
||||||
|
|
||||||
|
message = MessageTemplates.create_heartbeat("master-agent")
|
||||||
|
|
||||||
|
# Mock the send_message method
|
||||||
|
master_protocol.send_message = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
# Should send to both sub-agents
|
||||||
|
asyncio.run(master_protocol.send_to_sub_agents(message))
|
||||||
|
|
||||||
|
# Check that send_message was called twice
|
||||||
|
assert master_protocol.send_message.call_count == 2
|
||||||
|
|
||||||
|
def test_send_to_master(self, sub_protocol):
|
||||||
|
"""Test sending to master"""
|
||||||
|
sub_protocol.master_agent = "master-agent"
|
||||||
|
|
||||||
|
message = MessageTemplates.create_status_update("sub-agent", {"status": "active"})
|
||||||
|
|
||||||
|
# Mock the send_message method
|
||||||
|
sub_protocol.send_message = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
asyncio.run(sub_protocol.send_to_master(message))
|
||||||
|
|
||||||
|
# Check that send_message was called once
|
||||||
|
assert sub_protocol.send_message.call_count == 1
|
||||||
|
|
||||||
|
class TestPeerToPeerProtocol:
|
||||||
|
"""Test PeerToPeerProtocol class"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def p2p_protocol(self):
|
||||||
|
"""Create P2P protocol"""
|
||||||
|
return PeerToPeerProtocol("agent-001")
|
||||||
|
|
||||||
|
def test_add_peer(self, p2p_protocol):
|
||||||
|
"""Test adding peer"""
|
||||||
|
p2p_protocol.add_peer("agent-002", {"endpoint": "http://localhost:8002"})
|
||||||
|
assert "agent-002" in p2p_protocol.peers
|
||||||
|
assert p2p_protocol.peers["agent-002"]["endpoint"] == "http://localhost:8002"
|
||||||
|
|
||||||
|
def test_remove_peer(self, p2p_protocol):
|
||||||
|
"""Test removing peer"""
|
||||||
|
p2p_protocol.add_peer("agent-002", {"endpoint": "http://localhost:8002"})
|
||||||
|
p2p_protocol.remove_peer("agent-002")
|
||||||
|
assert "agent-002" not in p2p_protocol.peers
|
||||||
|
|
||||||
|
def test_send_to_peer(self, p2p_protocol):
|
||||||
|
"""Test sending to peer"""
|
||||||
|
p2p_protocol.add_peer("agent-002", {"endpoint": "http://localhost:8002"})
|
||||||
|
|
||||||
|
message = MessageTemplates.create_task_assignment(
|
||||||
|
"agent-001", "agent-002", {"task": "test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the send_message method
|
||||||
|
p2p_protocol.send_message = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
result = asyncio.run(p2p_protocol.send_to_peer(message, "agent-002"))
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert p2p_protocol.send_message.call_count == 1
|
||||||
|
|
||||||
|
class TestBroadcastProtocol:
|
||||||
|
"""Test BroadcastProtocol class"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def broadcast_protocol(self):
|
||||||
|
"""Create broadcast protocol"""
|
||||||
|
return BroadcastProtocol("agent-001", "test-channel")
|
||||||
|
|
||||||
|
def test_subscribe_unsubscribe(self, broadcast_protocol):
|
||||||
|
"""Test subscribe and unsubscribe"""
|
||||||
|
broadcast_protocol.subscribe("agent-002")
|
||||||
|
assert "agent-002" in broadcast_protocol.subscribers
|
||||||
|
|
||||||
|
broadcast_protocol.unsubscribe("agent-002")
|
||||||
|
assert "agent-002" not in broadcast_protocol.subscribers
|
||||||
|
|
||||||
|
def test_broadcast(self, broadcast_protocol):
|
||||||
|
"""Test broadcasting"""
|
||||||
|
broadcast_protocol.subscribe("agent-002")
|
||||||
|
broadcast_protocol.subscribe("agent-003")
|
||||||
|
|
||||||
|
message = MessageTemplates.create_discovery("agent-001")
|
||||||
|
|
||||||
|
# Mock the send_message method
|
||||||
|
broadcast_protocol.send_message = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
asyncio.run(broadcast_protocol.broadcast(message))
|
||||||
|
|
||||||
|
# Should send to 2 subscribers (not including self)
|
||||||
|
assert broadcast_protocol.send_message.call_count == 2
|
||||||
|
|
||||||
|
class TestCommunicationManager:
|
||||||
|
"""Test CommunicationManager class"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def comm_manager(self):
|
||||||
|
"""Create communication manager"""
|
||||||
|
return CommunicationManager("agent-001")
|
||||||
|
|
||||||
|
def test_add_protocol(self, comm_manager):
|
||||||
|
"""Test adding protocol"""
|
||||||
|
protocol = Mock(spec=CommunicationProtocol)
|
||||||
|
comm_manager.add_protocol("test", protocol)
|
||||||
|
|
||||||
|
assert "test" in comm_manager.protocols
|
||||||
|
assert comm_manager.protocols["test"] == protocol
|
||||||
|
|
||||||
|
def test_get_protocol(self, comm_manager):
|
||||||
|
"""Test getting protocol"""
|
||||||
|
protocol = Mock(spec=CommunicationProtocol)
|
||||||
|
comm_manager.add_protocol("test", protocol)
|
||||||
|
|
||||||
|
retrieved_protocol = comm_manager.get_protocol("test")
|
||||||
|
assert retrieved_protocol == protocol
|
||||||
|
|
||||||
|
# Test non-existent protocol
|
||||||
|
assert comm_manager.get_protocol("non-existent") is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_message(self, comm_manager):
|
||||||
|
"""Test sending message"""
|
||||||
|
protocol = Mock(spec=CommunicationProtocol)
|
||||||
|
protocol.send_message = AsyncMock(return_value=True)
|
||||||
|
comm_manager.add_protocol("test", protocol)
|
||||||
|
|
||||||
|
message = MessageTemplates.create_heartbeat("agent-001")
|
||||||
|
result = await comm_manager.send_message("test", message)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
protocol.send_message.assert_called_once_with(message)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_handler(self, comm_manager):
|
||||||
|
"""Test registering handler"""
|
||||||
|
protocol = Mock(spec=CommunicationProtocol)
|
||||||
|
protocol.register_handler = AsyncMock()
|
||||||
|
comm_manager.add_protocol("test", protocol)
|
||||||
|
|
||||||
|
handler = AsyncMock()
|
||||||
|
await comm_manager.register_handler("test", MessageType.HEARTBEAT, handler)
|
||||||
|
|
||||||
|
protocol.register_handler.assert_called_once_with(MessageType.HEARTBEAT, handler)
|
||||||
|
|
||||||
|
class TestMessageTemplates:
|
||||||
|
"""Test MessageTemplates class"""
|
||||||
|
|
||||||
|
def test_create_heartbeat(self):
|
||||||
|
"""Test creating heartbeat message"""
|
||||||
|
message = MessageTemplates.create_heartbeat("agent-001")
|
||||||
|
|
||||||
|
assert message.sender_id == "agent-001"
|
||||||
|
assert message.message_type == MessageType.HEARTBEAT
|
||||||
|
assert message.priority == Priority.LOW
|
||||||
|
assert "timestamp" in message.payload
|
||||||
|
|
||||||
|
def test_create_task_assignment(self):
|
||||||
|
"""Test creating task assignment message"""
|
||||||
|
task_data = {"task_id": "task-001", "task_type": "process_data"}
|
||||||
|
message = MessageTemplates.create_task_assignment("agent-001", "agent-002", task_data)
|
||||||
|
|
||||||
|
assert message.sender_id == "agent-001"
|
||||||
|
assert message.receiver_id == "agent-002"
|
||||||
|
assert message.message_type == MessageType.TASK_ASSIGNMENT
|
||||||
|
assert message.payload == task_data
|
||||||
|
|
||||||
|
def test_create_status_update(self):
|
||||||
|
"""Test creating status update message"""
|
||||||
|
status_data = {"status": "active", "load": 0.5}
|
||||||
|
message = MessageTemplates.create_status_update("agent-001", status_data)
|
||||||
|
|
||||||
|
assert message.sender_id == "agent-001"
|
||||||
|
assert message.message_type == MessageType.STATUS_UPDATE
|
||||||
|
assert message.payload == status_data
|
||||||
|
|
||||||
|
def test_create_discovery(self):
|
||||||
|
"""Test creating discovery message"""
|
||||||
|
message = MessageTemplates.create_discovery("agent-001")
|
||||||
|
|
||||||
|
assert message.sender_id == "agent-001"
|
||||||
|
assert message.message_type == MessageType.DISCOVERY
|
||||||
|
assert message.payload["agent_id"] == "agent-001"
|
||||||
|
|
||||||
|
def test_create_consensus_request(self):
|
||||||
|
"""Test creating consensus request message"""
|
||||||
|
proposal_data = {"proposal": "test_proposal"}
|
||||||
|
message = MessageTemplates.create_consensus_request("agent-001", proposal_data)
|
||||||
|
|
||||||
|
assert message.sender_id == "agent-001"
|
||||||
|
assert message.message_type == MessageType.CONSENSUS
|
||||||
|
assert message.priority == Priority.HIGH
|
||||||
|
assert message.payload == proposal_data
|
||||||
|
|
||||||
|
# Integration tests
|
||||||
|
class TestCommunicationIntegration:
|
||||||
|
"""Integration tests for communication system"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_flow(self):
|
||||||
|
"""Test complete message flow"""
|
||||||
|
# Create communication manager
|
||||||
|
comm_manager = CommunicationManager("agent-001")
|
||||||
|
|
||||||
|
# Create protocols
|
||||||
|
hierarchical = HierarchicalProtocol("agent-001", is_master=True)
|
||||||
|
p2p = PeerToPeerProtocol("agent-001")
|
||||||
|
|
||||||
|
# Add protocols
|
||||||
|
comm_manager.add_protocol("hierarchical", hierarchical)
|
||||||
|
comm_manager.add_protocol("p2p", p2p)
|
||||||
|
|
||||||
|
# Mock message sending
|
||||||
|
hierarchical.send_message = AsyncMock(return_value=True)
|
||||||
|
p2p.send_message = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
# Register handler
|
||||||
|
async def handle_heartbeat(message):
|
||||||
|
assert message.sender_id == "agent-002"
|
||||||
|
assert message.message_type == MessageType.HEARTBEAT
|
||||||
|
|
||||||
|
await comm_manager.register_handler("hierarchical", MessageType.HEARTBEAT, handle_heartbeat)
|
||||||
|
|
||||||
|
# Send heartbeat
|
||||||
|
heartbeat = MessageTemplates.create_heartbeat("agent-001")
|
||||||
|
result = await comm_manager.send_message("hierarchical", heartbeat)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
hierarchical.send_message.assert_called_once()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user