chore(security): enhance environment configuration, CI workflows, and wallet daemon with security improvements
- Restructure .env.example with security-focused documentation, service-specific environment file references, and AWS Secrets Manager integration - Update CLI tests workflow to single Python 3.13 version, add pytest-mock dependency, and consolidate test execution with coverage - Add comprehensive security validation to package publishing workflow with manual approval gates, secret scanning, and release
This commit is contained in:
@@ -0,0 +1,706 @@
|
||||
"""
|
||||
Multi-Modal WebSocket Fusion Service
|
||||
|
||||
Advanced WebSocket stream architecture for multi-modal fusion with
|
||||
per-stream backpressure handling and GPU provider flow control.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Any, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from uuid import uuid4
|
||||
|
||||
from aitbc.logging import get_logger
|
||||
from .websocket_stream_manager import (
|
||||
WebSocketStreamManager, StreamConfig, MessageType,
|
||||
stream_manager, WebSocketStream
|
||||
)
|
||||
from .gpu_multimodal import GPUMultimodalProcessor
|
||||
from .multi_modal_fusion import MultiModalFusionService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FusionStreamType(Enum):
|
||||
"""Types of fusion streams"""
|
||||
VISUAL = "visual"
|
||||
TEXT = "text"
|
||||
AUDIO = "audio"
|
||||
SENSOR = "sensor"
|
||||
CONTROL = "control"
|
||||
METRICS = "metrics"
|
||||
|
||||
|
||||
class GPUProviderStatus(Enum):
|
||||
"""GPU provider status"""
|
||||
AVAILABLE = "available"
|
||||
BUSY = "busy"
|
||||
SLOW = "slow"
|
||||
OVERLOADED = "overloaded"
|
||||
OFFLINE = "offline"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusionStreamConfig:
|
||||
"""Configuration for fusion streams"""
|
||||
stream_type: FusionStreamType
|
||||
max_queue_size: int = 500
|
||||
gpu_timeout: float = 2.0
|
||||
fusion_timeout: float = 5.0
|
||||
batch_size: int = 8
|
||||
enable_gpu_acceleration: bool = True
|
||||
priority: int = 1 # Higher number = higher priority
|
||||
|
||||
def to_stream_config(self) -> StreamConfig:
|
||||
"""Convert to WebSocket stream config"""
|
||||
return StreamConfig(
|
||||
max_queue_size=self.max_queue_size,
|
||||
send_timeout=self.fusion_timeout,
|
||||
heartbeat_interval=30.0,
|
||||
slow_consumer_threshold=0.5,
|
||||
backpressure_threshold=0.7,
|
||||
drop_bulk_threshold=0.85,
|
||||
enable_compression=True,
|
||||
priority_send=True
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusionData:
|
||||
"""Multi-modal fusion data"""
|
||||
stream_id: str
|
||||
stream_type: FusionStreamType
|
||||
data: Any
|
||||
timestamp: float
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
requires_gpu: bool = False
|
||||
processing_priority: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUProviderMetrics:
|
||||
"""GPU provider performance metrics"""
|
||||
provider_id: str
|
||||
status: GPUProviderStatus
|
||||
avg_processing_time: float
|
||||
queue_size: int
|
||||
gpu_utilization: float
|
||||
memory_usage: float
|
||||
error_rate: float
|
||||
last_update: float
|
||||
|
||||
|
||||
class GPUProviderFlowControl:
|
||||
"""Flow control for GPU providers"""
|
||||
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
self.metrics = GPUProviderMetrics(
|
||||
provider_id=provider_id,
|
||||
status=GPUProviderStatus.AVAILABLE,
|
||||
avg_processing_time=0.0,
|
||||
queue_size=0,
|
||||
gpu_utilization=0.0,
|
||||
memory_usage=0.0,
|
||||
error_rate=0.0,
|
||||
last_update=time.time()
|
||||
)
|
||||
|
||||
# Flow control queues
|
||||
self.input_queue = asyncio.Queue(maxsize=100)
|
||||
self.output_queue = asyncio.Queue(maxsize=100)
|
||||
self.control_queue = asyncio.Queue(maxsize=50)
|
||||
|
||||
# Flow control parameters
|
||||
self.max_concurrent_requests = 4
|
||||
self.current_requests = 0
|
||||
self.slow_threshold = 2.0 # seconds
|
||||
self.overload_threshold = 0.8 # queue fill ratio
|
||||
|
||||
# Performance tracking
|
||||
self.request_times = []
|
||||
self.error_count = 0
|
||||
self.total_requests = 0
|
||||
|
||||
# Flow control task
|
||||
self._flow_control_task = None
|
||||
self._running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start flow control"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._flow_control_task = asyncio.create_task(self._flow_control_loop())
|
||||
logger.info(f"GPU provider flow control started: {self.provider_id}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop flow control"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._flow_control_task:
|
||||
self._flow_control_task.cancel()
|
||||
try:
|
||||
await self._flow_control_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info(f"GPU provider flow control stopped: {self.provider_id}")
|
||||
|
||||
async def submit_request(self, data: FusionData) -> Optional[str]:
|
||||
"""Submit request with flow control"""
|
||||
if not self._running:
|
||||
return None
|
||||
|
||||
# Check provider status
|
||||
if self.metrics.status == GPUProviderStatus.OFFLINE:
|
||||
logger.warning(f"GPU provider {self.provider_id} is offline")
|
||||
return None
|
||||
|
||||
# Check backpressure
|
||||
if self.input_queue.qsize() / self.input_queue.maxsize > self.overload_threshold:
|
||||
self.metrics.status = GPUProviderStatus.OVERLOADED
|
||||
logger.warning(f"GPU provider {self.provider_id} is overloaded")
|
||||
return None
|
||||
|
||||
# Submit request
|
||||
request_id = str(uuid4())
|
||||
request_data = {
|
||||
"request_id": request_id,
|
||||
"data": data,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.input_queue.put(request_data),
|
||||
timeout=1.0
|
||||
)
|
||||
return request_id
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Request timeout for GPU provider {self.provider_id}")
|
||||
return None
|
||||
|
||||
async def get_result(self, request_id: str, timeout: float = 5.0) -> Optional[Any]:
|
||||
"""Get processing result"""
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
# Check output queue
|
||||
result = await asyncio.wait_for(
|
||||
self.output_queue.get(),
|
||||
timeout=0.1
|
||||
)
|
||||
|
||||
if result.get("request_id") == request_id:
|
||||
return result.get("data")
|
||||
|
||||
# Put back if not our result
|
||||
await self.output_queue.put(result)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
async def _flow_control_loop(self):
|
||||
"""Main flow control loop"""
|
||||
while self._running:
|
||||
try:
|
||||
# Get next request
|
||||
request_data = await asyncio.wait_for(
|
||||
self.input_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
# Check concurrent request limit
|
||||
if self.current_requests >= self.max_concurrent_requests:
|
||||
# Re-queue request
|
||||
await self.input_queue.put(request_data)
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Process request
|
||||
self.current_requests += 1
|
||||
self.total_requests += 1
|
||||
|
||||
asyncio.create_task(self._process_request(request_data))
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Flow control error for {self.provider_id}: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _process_request(self, request_data: Dict[str, Any]):
|
||||
"""Process individual request"""
|
||||
request_id = request_data["request_id"]
|
||||
data: FusionData = request_data["data"]
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Simulate GPU processing
|
||||
if data.requires_gpu:
|
||||
# Simulate GPU processing time
|
||||
processing_time = np.random.uniform(0.5, 3.0)
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
# Simulate GPU result
|
||||
result = {
|
||||
"processed_data": f"gpu_processed_{data.stream_type}",
|
||||
"processing_time": processing_time,
|
||||
"gpu_utilization": np.random.uniform(0.3, 0.9),
|
||||
"memory_usage": np.random.uniform(0.4, 0.8)
|
||||
}
|
||||
else:
|
||||
# CPU processing
|
||||
processing_time = np.random.uniform(0.1, 0.5)
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
result = {
|
||||
"processed_data": f"cpu_processed_{data.stream_type}",
|
||||
"processing_time": processing_time
|
||||
}
|
||||
|
||||
# Update metrics
|
||||
actual_time = time.time() - start_time
|
||||
self._update_metrics(actual_time, success=True)
|
||||
|
||||
# Send result
|
||||
await self.output_queue.put({
|
||||
"request_id": request_id,
|
||||
"data": result,
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request processing error for {self.provider_id}: {e}")
|
||||
self._update_metrics(time.time() - start_time, success=False)
|
||||
|
||||
# Send error result
|
||||
await self.output_queue.put({
|
||||
"request_id": request_id,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
finally:
|
||||
self.current_requests -= 1
|
||||
|
||||
def _update_metrics(self, processing_time: float, success: bool):
|
||||
"""Update provider metrics"""
|
||||
# Update processing time
|
||||
self.request_times.append(processing_time)
|
||||
if len(self.request_times) > 100:
|
||||
self.request_times.pop(0)
|
||||
|
||||
self.metrics.avg_processing_time = np.mean(self.request_times)
|
||||
|
||||
# Update error rate
|
||||
if not success:
|
||||
self.error_count += 1
|
||||
|
||||
self.metrics.error_rate = self.error_count / max(self.total_requests, 1)
|
||||
|
||||
# Update queue sizes
|
||||
self.metrics.queue_size = self.input_queue.qsize()
|
||||
|
||||
# Update status
|
||||
if self.metrics.error_rate > 0.1:
|
||||
self.metrics.status = GPUProviderStatus.OFFLINE
|
||||
elif self.metrics.avg_processing_time > self.slow_threshold:
|
||||
self.metrics.status = GPUProviderStatus.SLOW
|
||||
elif self.metrics.queue_size > self.input_queue.maxsize * 0.8:
|
||||
self.metrics.status = GPUProviderStatus.OVERLOADED
|
||||
elif self.current_requests >= self.max_concurrent_requests:
|
||||
self.metrics.status = GPUProviderStatus.BUSY
|
||||
else:
|
||||
self.metrics.status = GPUProviderStatus.AVAILABLE
|
||||
|
||||
self.metrics.last_update = time.time()
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get provider metrics"""
|
||||
return {
|
||||
"provider_id": self.provider_id,
|
||||
"status": self.metrics.status.value,
|
||||
"avg_processing_time": self.metrics.avg_processing_time,
|
||||
"queue_size": self.metrics.queue_size,
|
||||
"current_requests": self.current_requests,
|
||||
"max_concurrent_requests": self.max_concurrent_requests,
|
||||
"error_rate": self.metrics.error_rate,
|
||||
"total_requests": self.total_requests,
|
||||
"last_update": self.metrics.last_update
|
||||
}
|
||||
|
||||
|
||||
class MultiModalWebSocketFusion:
|
||||
"""Multi-modal fusion service with WebSocket streaming and backpressure control"""
|
||||
|
||||
def __init__(self):
|
||||
self.stream_manager = stream_manager
|
||||
self.fusion_service = None # Will be injected
|
||||
self.gpu_providers: Dict[str, GPUProviderFlowControl] = {}
|
||||
|
||||
# Fusion streams
|
||||
self.fusion_streams: Dict[str, FusionStreamConfig] = {}
|
||||
self.active_fusions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Performance metrics
|
||||
self.fusion_metrics = {
|
||||
"total_fusions": 0,
|
||||
"successful_fusions": 0,
|
||||
"failed_fusions": 0,
|
||||
"avg_fusion_time": 0.0,
|
||||
"gpu_utilization": 0.0,
|
||||
"memory_usage": 0.0
|
||||
}
|
||||
|
||||
# Backpressure control
|
||||
self.backpressure_enabled = True
|
||||
self.global_queue_size = 0
|
||||
self.max_global_queue_size = 10000
|
||||
|
||||
# Running state
|
||||
self._running = False
|
||||
self._monitor_task = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the fusion service"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Start stream manager
|
||||
await self.stream_manager.start()
|
||||
|
||||
# Initialize GPU providers
|
||||
await self._initialize_gpu_providers()
|
||||
|
||||
# Start monitoring
|
||||
self._monitor_task = asyncio.create_task(self._monitor_loop())
|
||||
|
||||
logger.info("Multi-Modal WebSocket Fusion started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the fusion service"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Stop GPU providers
|
||||
for provider in self.gpu_providers.values():
|
||||
await provider.stop()
|
||||
|
||||
# Stop stream manager
|
||||
await self.stream_manager.stop()
|
||||
|
||||
# Stop monitoring
|
||||
if self._monitor_task:
|
||||
self._monitor_task.cancel()
|
||||
try:
|
||||
await self._monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Multi-Modal WebSocket Fusion stopped")
|
||||
|
||||
async def register_fusion_stream(self, stream_id: str, config: FusionStreamConfig):
|
||||
"""Register a fusion stream"""
|
||||
self.fusion_streams[stream_id] = config
|
||||
logger.info(f"Registered fusion stream: {stream_id} ({config.stream_type.value})")
|
||||
|
||||
async def handle_websocket_connection(self, websocket, stream_id: str,
|
||||
stream_type: FusionStreamType):
|
||||
"""Handle WebSocket connection for fusion stream"""
|
||||
config = FusionStreamConfig(
|
||||
stream_type=stream_type,
|
||||
max_queue_size=500,
|
||||
gpu_timeout=2.0,
|
||||
fusion_timeout=5.0
|
||||
)
|
||||
|
||||
async with self.stream_manager.manage_stream(websocket, config.to_stream_config()) as stream:
|
||||
logger.info(f"Fusion stream connected: {stream_id} ({stream_type.value})")
|
||||
|
||||
try:
|
||||
# Handle incoming messages
|
||||
async for message in websocket:
|
||||
await self._handle_stream_message(stream_id, stream_type, message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fusion stream {stream_id}: {e}")
|
||||
|
||||
async def _handle_stream_message(self, stream_id: str, stream_type: FusionStreamType,
|
||||
message: str):
|
||||
"""Handle incoming stream message"""
|
||||
try:
|
||||
data = json.loads(message)
|
||||
|
||||
# Create fusion data
|
||||
fusion_data = FusionData(
|
||||
stream_id=stream_id,
|
||||
stream_type=stream_type,
|
||||
data=data.get("data"),
|
||||
timestamp=time.time(),
|
||||
metadata=data.get("metadata", {}),
|
||||
requires_gpu=data.get("requires_gpu", False),
|
||||
processing_priority=data.get("priority", 1)
|
||||
)
|
||||
|
||||
# Submit to GPU provider if needed
|
||||
if fusion_data.requires_gpu:
|
||||
await self._submit_to_gpu_provider(fusion_data)
|
||||
else:
|
||||
await self._process_cpu_fusion(fusion_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling stream message: {e}")
|
||||
|
||||
async def _submit_to_gpu_provider(self, fusion_data: FusionData):
|
||||
"""Submit fusion data to GPU provider"""
|
||||
# Select best GPU provider
|
||||
provider_id = await self._select_gpu_provider(fusion_data)
|
||||
|
||||
if not provider_id:
|
||||
logger.warning("No available GPU providers")
|
||||
await self._handle_fusion_error(fusion_data, "No GPU providers available")
|
||||
return
|
||||
|
||||
provider = self.gpu_providers[provider_id]
|
||||
|
||||
# Submit request
|
||||
request_id = await provider.submit_request(fusion_data)
|
||||
|
||||
if not request_id:
|
||||
await self._handle_fusion_error(fusion_data, "GPU provider overloaded")
|
||||
return
|
||||
|
||||
# Wait for result
|
||||
result = await provider.get_result(request_id, timeout=5.0)
|
||||
|
||||
if result and "error" not in result:
|
||||
await self._handle_fusion_result(fusion_data, result)
|
||||
else:
|
||||
error = result.get("error", "Unknown error") if result else "Timeout"
|
||||
await self._handle_fusion_error(fusion_data, error)
|
||||
|
||||
async def _process_cpu_fusion(self, fusion_data: FusionData):
|
||||
"""Process fusion data on CPU"""
|
||||
try:
|
||||
# Simulate CPU fusion processing
|
||||
processing_time = np.random.uniform(0.1, 0.5)
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
result = {
|
||||
"processed_data": f"cpu_fused_{fusion_data.stream_type}",
|
||||
"processing_time": processing_time,
|
||||
"fusion_type": "cpu"
|
||||
}
|
||||
|
||||
await self._handle_fusion_result(fusion_data, result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CPU fusion error: {e}")
|
||||
await self._handle_fusion_error(fusion_data, str(e))
|
||||
|
||||
async def _handle_fusion_result(self, fusion_data: FusionData, result: Dict[str, Any]):
|
||||
"""Handle successful fusion result"""
|
||||
# Update metrics
|
||||
self.fusion_metrics["total_fusions"] += 1
|
||||
self.fusion_metrics["successful_fusions"] += 1
|
||||
|
||||
# Broadcast result
|
||||
broadcast_data = {
|
||||
"type": "fusion_result",
|
||||
"stream_id": fusion_data.stream_id,
|
||||
"stream_type": fusion_data.stream_type.value,
|
||||
"result": result,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
await self.stream_manager.broadcast_to_all(broadcast_data, MessageType.IMPORTANT)
|
||||
|
||||
logger.info(f"Fusion completed for {fusion_data.stream_id}")
|
||||
|
||||
async def _handle_fusion_error(self, fusion_data: FusionData, error: str):
|
||||
"""Handle fusion error"""
|
||||
# Update metrics
|
||||
self.fusion_metrics["total_fusions"] += 1
|
||||
self.fusion_metrics["failed_fusions"] += 1
|
||||
|
||||
# Broadcast error
|
||||
error_data = {
|
||||
"type": "fusion_error",
|
||||
"stream_id": fusion_data.stream_id,
|
||||
"stream_type": fusion_data.stream_type.value,
|
||||
"error": error,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
await self.stream_manager.broadcast_to_all(error_data, MessageType.CRITICAL)
|
||||
|
||||
logger.error(f"Fusion error for {fusion_data.stream_id}: {error}")
|
||||
|
||||
async def _select_gpu_provider(self, fusion_data: FusionData) -> Optional[str]:
|
||||
"""Select best GPU provider based on load and performance"""
|
||||
available_providers = []
|
||||
|
||||
for provider_id, provider in self.gpu_providers.items():
|
||||
metrics = provider.get_metrics()
|
||||
|
||||
# Check if provider is available
|
||||
if metrics["status"] == GPUProviderStatus.AVAILABLE.value:
|
||||
available_providers.append((provider_id, metrics))
|
||||
|
||||
if not available_providers:
|
||||
return None
|
||||
|
||||
# Select provider with lowest queue size and processing time
|
||||
best_provider = min(
|
||||
available_providers,
|
||||
key=lambda x: (x[1]["queue_size"], x[1]["avg_processing_time"])
|
||||
)
|
||||
|
||||
return best_provider[0]
|
||||
|
||||
async def _initialize_gpu_providers(self):
|
||||
"""Initialize GPU providers"""
|
||||
# Create mock GPU providers
|
||||
provider_configs = [
|
||||
{"provider_id": "gpu_1", "max_concurrent": 4},
|
||||
{"provider_id": "gpu_2", "max_concurrent": 2},
|
||||
{"provider_id": "gpu_3", "max_concurrent": 6}
|
||||
]
|
||||
|
||||
for config in provider_configs:
|
||||
provider = GPUProviderFlowControl(config["provider_id"])
|
||||
provider.max_concurrent_requests = config["max_concurrent"]
|
||||
await provider.start()
|
||||
self.gpu_providers[config["provider_id"]] = provider
|
||||
|
||||
logger.info(f"Initialized {len(self.gpu_providers)} GPU providers")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""Monitor system performance and backpressure"""
|
||||
while self._running:
|
||||
try:
|
||||
# Update global metrics
|
||||
await self._update_global_metrics()
|
||||
|
||||
# Check backpressure
|
||||
if self.backpressure_enabled:
|
||||
await self._check_backpressure()
|
||||
|
||||
# Monitor GPU providers
|
||||
await self._monitor_gpu_providers()
|
||||
|
||||
# Sleep
|
||||
await asyncio.sleep(10) # Monitor every 10 seconds
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Monitor loop error: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _update_global_metrics(self):
|
||||
"""Update global performance metrics"""
|
||||
# Get stream manager metrics
|
||||
manager_metrics = self.stream_manager.get_manager_metrics()
|
||||
|
||||
# Update global queue size
|
||||
self.global_queue_size = manager_metrics["total_queue_size"]
|
||||
|
||||
# Calculate GPU utilization
|
||||
total_gpu_util = 0
|
||||
total_memory = 0
|
||||
active_providers = 0
|
||||
|
||||
for provider in self.gpu_providers.values():
|
||||
metrics = provider.get_metrics()
|
||||
if metrics["status"] != GPUProviderStatus.OFFLINE.value:
|
||||
total_gpu_util += metrics.get("gpu_utilization", 0)
|
||||
total_memory += metrics.get("memory_usage", 0)
|
||||
active_providers += 1
|
||||
|
||||
if active_providers > 0:
|
||||
self.fusion_metrics["gpu_utilization"] = total_gpu_util / active_providers
|
||||
self.fusion_metrics["memory_usage"] = total_memory / active_providers
|
||||
|
||||
async def _check_backpressure(self):
|
||||
"""Check and handle backpressure"""
|
||||
if self.global_queue_size > self.max_global_queue_size * 0.8:
|
||||
logger.warning("High backpressure detected, applying flow control")
|
||||
|
||||
# Get slow streams
|
||||
slow_streams = self.stream_manager.get_slow_streams(threshold=0.8)
|
||||
|
||||
# Handle slow streams
|
||||
for stream_id in slow_streams:
|
||||
await self.stream_manager.handle_slow_consumer(stream_id, "throttle")
|
||||
|
||||
async def _monitor_gpu_providers(self):
|
||||
"""Monitor GPU provider health"""
|
||||
for provider_id, provider in self.gpu_providers.items():
|
||||
metrics = provider.get_metrics()
|
||||
|
||||
# Check for unhealthy providers
|
||||
if metrics["status"] == GPUProviderStatus.OFFLINE.value:
|
||||
logger.warning(f"GPU provider {provider_id} is offline")
|
||||
|
||||
elif metrics["error_rate"] > 0.1:
|
||||
logger.warning(f"GPU provider {provider_id} has high error rate: {metrics['error_rate']}")
|
||||
|
||||
elif metrics["avg_processing_time"] > 5.0:
|
||||
logger.warning(f"GPU provider {provider_id} is slow: {metrics['avg_processing_time']}s")
|
||||
|
||||
def get_comprehensive_metrics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive system metrics"""
|
||||
# Get stream manager metrics
|
||||
stream_metrics = self.stream_manager.get_manager_metrics()
|
||||
|
||||
# Get GPU provider metrics
|
||||
gpu_metrics = {}
|
||||
for provider_id, provider in self.gpu_providers.items():
|
||||
gpu_metrics[provider_id] = provider.get_metrics()
|
||||
|
||||
# Get fusion metrics
|
||||
fusion_metrics = self.fusion_metrics.copy()
|
||||
|
||||
# Calculate success rate
|
||||
if fusion_metrics["total_fusions"] > 0:
|
||||
fusion_metrics["success_rate"] = (
|
||||
fusion_metrics["successful_fusions"] / fusion_metrics["total_fusions"]
|
||||
)
|
||||
else:
|
||||
fusion_metrics["success_rate"] = 0.0
|
||||
|
||||
return {
|
||||
"timestamp": time.time(),
|
||||
"system_status": "running" if self._running else "stopped",
|
||||
"backpressure_enabled": self.backpressure_enabled,
|
||||
"global_queue_size": self.global_queue_size,
|
||||
"max_global_queue_size": self.max_global_queue_size,
|
||||
"stream_metrics": stream_metrics,
|
||||
"gpu_metrics": gpu_metrics,
|
||||
"fusion_metrics": fusion_metrics,
|
||||
"active_fusion_streams": len(self.fusion_streams),
|
||||
"registered_gpu_providers": len(self.gpu_providers)
|
||||
}
|
||||
|
||||
|
||||
# Global fusion service instance
|
||||
multimodal_fusion_service = MultiModalWebSocketFusion()
|
||||
420
apps/coordinator-api/src/app/services/secure_wallet_service.py
Normal file
420
apps/coordinator-api/src/app/services/secure_wallet_service.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Secure Wallet Service - Fixed Version
|
||||
Implements proper Ethereum cryptography and secure key storage
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Dict
|
||||
from sqlalchemy import select
|
||||
from sqlmodel import Session
|
||||
from datetime import datetime
|
||||
import secrets
|
||||
|
||||
from ..domain.wallet import (
|
||||
AgentWallet, NetworkConfig, TokenBalance, WalletTransaction,
|
||||
WalletType, TransactionStatus
|
||||
)
|
||||
from ..schemas.wallet import WalletCreate, TransactionRequest
|
||||
from ..blockchain.contract_interactions import ContractInteractionService
|
||||
|
||||
# Import our fixed crypto utilities
|
||||
from .wallet_crypto import (
|
||||
generate_ethereum_keypair,
|
||||
verify_keypair_consistency,
|
||||
encrypt_private_key,
|
||||
decrypt_private_key,
|
||||
validate_private_key_format,
|
||||
create_secure_wallet,
|
||||
recover_wallet
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecureWalletService:
|
||||
"""Secure wallet service with proper cryptography and key management"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
contract_service: ContractInteractionService
|
||||
):
|
||||
self.session = session
|
||||
self.contract_service = contract_service
|
||||
|
||||
async def create_wallet(self, request: WalletCreate, encryption_password: str) -> AgentWallet:
|
||||
"""
|
||||
Create a new wallet with proper security
|
||||
|
||||
Args:
|
||||
request: Wallet creation request
|
||||
encryption_password: Strong password for private key encryption
|
||||
|
||||
Returns:
|
||||
Created wallet record
|
||||
|
||||
Raises:
|
||||
ValueError: If password is weak or wallet already exists
|
||||
"""
|
||||
# Validate password strength
|
||||
from ..utils.security import validate_password_strength
|
||||
password_validation = validate_password_strength(encryption_password)
|
||||
|
||||
if not password_validation["is_acceptable"]:
|
||||
raise ValueError(
|
||||
f"Password too weak: {', '.join(password_validation['issues'])}"
|
||||
)
|
||||
|
||||
# Check if agent already has an active wallet of this type
|
||||
existing = self.session.exec(
|
||||
select(AgentWallet).where(
|
||||
AgentWallet.agent_id == request.agent_id,
|
||||
AgentWallet.wallet_type == request.wallet_type,
|
||||
AgentWallet.is_active == True
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise ValueError(f"Agent {request.agent_id} already has an active {request.wallet_type} wallet")
|
||||
|
||||
try:
|
||||
# Generate proper Ethereum keypair
|
||||
private_key, public_key, address = generate_ethereum_keypair()
|
||||
|
||||
# Verify keypair consistency
|
||||
if not verify_keypair_consistency(private_key, address):
|
||||
raise RuntimeError("Keypair generation failed consistency check")
|
||||
|
||||
# Encrypt private key securely
|
||||
encrypted_data = encrypt_private_key(private_key, encryption_password)
|
||||
|
||||
# Create wallet record
|
||||
wallet = AgentWallet(
|
||||
agent_id=request.agent_id,
|
||||
address=address,
|
||||
public_key=public_key,
|
||||
wallet_type=request.wallet_type,
|
||||
metadata=request.metadata,
|
||||
encrypted_private_key=encrypted_data,
|
||||
encryption_version="1.0",
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(wallet)
|
||||
self.session.commit()
|
||||
self.session.refresh(wallet)
|
||||
|
||||
logger.info(f"Created secure wallet {wallet.address} for agent {request.agent_id}")
|
||||
return wallet
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create secure wallet: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
async def get_wallet_by_agent(self, agent_id: str) -> List[AgentWallet]:
|
||||
"""Retrieve all active wallets for an agent"""
|
||||
return self.session.exec(
|
||||
select(AgentWallet).where(
|
||||
AgentWallet.agent_id == agent_id,
|
||||
AgentWallet.is_active == True
|
||||
)
|
||||
).all()
|
||||
|
||||
async def get_wallet_with_private_key(
|
||||
self,
|
||||
wallet_id: int,
|
||||
encryption_password: str
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Get wallet with decrypted private key (for signing operations)
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
encryption_password: Password for decryption
|
||||
|
||||
Returns:
|
||||
Wallet keys including private key
|
||||
|
||||
Raises:
|
||||
ValueError: If decryption fails or wallet not found
|
||||
"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
raise ValueError("Wallet not found")
|
||||
|
||||
if not wallet.is_active:
|
||||
raise ValueError("Wallet is not active")
|
||||
|
||||
try:
|
||||
# Decrypt private key
|
||||
if isinstance(wallet.encrypted_private_key, dict):
|
||||
# New format
|
||||
keys = recover_wallet(wallet.encrypted_private_key, encryption_password)
|
||||
else:
|
||||
# Legacy format - cannot decrypt securely
|
||||
raise ValueError(
|
||||
"Wallet uses legacy encryption format. "
|
||||
"Please migrate to secure encryption."
|
||||
)
|
||||
|
||||
return {
|
||||
"wallet_id": wallet_id,
|
||||
"address": wallet.address,
|
||||
"private_key": keys["private_key"],
|
||||
"public_key": keys["public_key"],
|
||||
"agent_id": wallet.agent_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt wallet {wallet_id}: {e}")
|
||||
raise ValueError(f"Failed to access wallet: {str(e)}")
|
||||
|
||||
async def verify_wallet_integrity(self, wallet_id: int) -> Dict[str, bool]:
|
||||
"""
|
||||
Verify wallet cryptographic integrity
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
|
||||
Returns:
|
||||
Integrity check results
|
||||
"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
return {"exists": False}
|
||||
|
||||
results = {
|
||||
"exists": True,
|
||||
"active": wallet.is_active,
|
||||
"has_encrypted_key": bool(wallet.encrypted_private_key),
|
||||
"address_format_valid": False,
|
||||
"public_key_present": bool(wallet.public_key)
|
||||
}
|
||||
|
||||
# Validate address format
|
||||
try:
|
||||
from eth_utils import to_checksum_address
|
||||
to_checksum_address(wallet.address)
|
||||
results["address_format_valid"] = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if we can verify the keypair consistency
|
||||
# (We can't do this without the password, but we can check the format)
|
||||
if wallet.public_key and wallet.encrypted_private_key:
|
||||
results["has_keypair_data"] = True
|
||||
|
||||
return results
|
||||
|
||||
async def migrate_wallet_encryption(
|
||||
self,
|
||||
wallet_id: int,
|
||||
old_password: str,
|
||||
new_password: str
|
||||
) -> AgentWallet:
|
||||
"""
|
||||
Migrate wallet from old encryption to new secure encryption
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
old_password: Current password
|
||||
new_password: New strong password
|
||||
|
||||
Returns:
|
||||
Updated wallet
|
||||
"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
raise ValueError("Wallet not found")
|
||||
|
||||
try:
|
||||
# Get current private key
|
||||
current_keys = await self.get_wallet_with_private_key(wallet_id, old_password)
|
||||
|
||||
# Validate new password
|
||||
from ..utils.security import validate_password_strength
|
||||
password_validation = validate_password_strength(new_password)
|
||||
|
||||
if not password_validation["is_acceptable"]:
|
||||
raise ValueError(
|
||||
f"New password too weak: {', '.join(password_validation['issues'])}"
|
||||
)
|
||||
|
||||
# Re-encrypt with new password
|
||||
new_encrypted_data = encrypt_private_key(current_keys["private_key"], new_password)
|
||||
|
||||
# Update wallet
|
||||
wallet.encrypted_private_key = new_encrypted_data
|
||||
wallet.encryption_version = "1.0"
|
||||
wallet.updated_at = datetime.utcnow()
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(wallet)
|
||||
|
||||
logger.info(f"Migrated wallet {wallet_id} to secure encryption")
|
||||
return wallet
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate wallet {wallet_id}: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
async def get_balances(self, wallet_id: int) -> List[TokenBalance]:
|
||||
"""Get all tracked balances for a wallet"""
|
||||
return self.session.exec(
|
||||
select(TokenBalance).where(TokenBalance.wallet_id == wallet_id)
|
||||
).all()
|
||||
|
||||
async def update_balance(self, wallet_id: int, chain_id: int, token_address: str, balance: float) -> TokenBalance:
|
||||
"""Update a specific token balance for a wallet"""
|
||||
record = self.session.exec(
|
||||
select(TokenBalance).where(
|
||||
TokenBalance.wallet_id == wallet_id,
|
||||
TokenBalance.chain_id == chain_id,
|
||||
TokenBalance.token_address == token_address
|
||||
)
|
||||
).first()
|
||||
|
||||
if record:
|
||||
record.balance = balance
|
||||
record.updated_at = datetime.utcnow()
|
||||
else:
|
||||
record = TokenBalance(
|
||||
wallet_id=wallet_id,
|
||||
chain_id=chain_id,
|
||||
token_address=token_address,
|
||||
balance=balance,
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
self.session.add(record)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(record)
|
||||
return record
|
||||
|
||||
async def create_transaction(
|
||||
self,
|
||||
wallet_id: int,
|
||||
request: TransactionRequest,
|
||||
encryption_password: str
|
||||
) -> WalletTransaction:
|
||||
"""
|
||||
Create a transaction with proper signing
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
request: Transaction request
|
||||
encryption_password: Password for private key access
|
||||
|
||||
Returns:
|
||||
Created transaction record
|
||||
"""
|
||||
# Get wallet keys
|
||||
wallet_keys = await self.get_wallet_with_private_key(wallet_id, encryption_password)
|
||||
|
||||
# Create transaction record
|
||||
transaction = WalletTransaction(
|
||||
wallet_id=wallet_id,
|
||||
to_address=request.to_address,
|
||||
amount=request.amount,
|
||||
token_address=request.token_address,
|
||||
chain_id=request.chain_id,
|
||||
data=request.data or "",
|
||||
status=TransactionStatus.PENDING,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(transaction)
|
||||
self.session.commit()
|
||||
self.session.refresh(transaction)
|
||||
|
||||
# TODO: Implement actual blockchain transaction signing and submission
|
||||
# This would use the private_key to sign the transaction
|
||||
|
||||
logger.info(f"Created transaction {transaction.id} for wallet {wallet_id}")
|
||||
return transaction
|
||||
|
||||
async def deactivate_wallet(self, wallet_id: int, reason: str = "User request") -> bool:
|
||||
"""Deactivate a wallet"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
return False
|
||||
|
||||
wallet.is_active = False
|
||||
wallet.updated_at = datetime.utcnow()
|
||||
wallet.deactivation_reason = reason
|
||||
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"Deactivated wallet {wallet_id}: {reason}")
|
||||
return True
|
||||
|
||||
async def get_wallet_security_audit(self, wallet_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive security audit for a wallet
|
||||
|
||||
Args:
|
||||
wallet_id: Wallet ID
|
||||
|
||||
Returns:
|
||||
Security audit results
|
||||
"""
|
||||
wallet = self.session.get(AgentWallet, wallet_id)
|
||||
if not wallet:
|
||||
return {"error": "Wallet not found"}
|
||||
|
||||
audit = {
|
||||
"wallet_id": wallet_id,
|
||||
"agent_id": wallet.agent_id,
|
||||
"address": wallet.address,
|
||||
"is_active": wallet.is_active,
|
||||
"encryption_version": getattr(wallet, 'encryption_version', 'unknown'),
|
||||
"created_at": wallet.created_at.isoformat() if wallet.created_at else None,
|
||||
"updated_at": wallet.updated_at.isoformat() if wallet.updated_at else None
|
||||
}
|
||||
|
||||
# Check encryption security
|
||||
if isinstance(wallet.encrypted_private_key, dict):
|
||||
audit["encryption_secure"] = True
|
||||
audit["encryption_algorithm"] = wallet.encrypted_private_key.get("algorithm")
|
||||
audit["encryption_iterations"] = wallet.encrypted_private_key.get("iterations")
|
||||
else:
|
||||
audit["encryption_secure"] = False
|
||||
audit["encryption_issues"] = ["Uses legacy or broken encryption"]
|
||||
|
||||
# Check address format
|
||||
try:
|
||||
from eth_utils import to_checksum_address
|
||||
to_checksum_address(wallet.address)
|
||||
audit["address_valid"] = True
|
||||
except:
|
||||
audit["address_valid"] = False
|
||||
audit["address_issues"] = ["Invalid Ethereum address format"]
|
||||
|
||||
# Check keypair data
|
||||
audit["has_public_key"] = bool(wallet.public_key)
|
||||
audit["has_encrypted_private_key"] = bool(wallet.encrypted_private_key)
|
||||
|
||||
# Overall security score
|
||||
security_score = 0
|
||||
if audit["encryption_secure"]:
|
||||
security_score += 40
|
||||
if audit["address_valid"]:
|
||||
security_score += 30
|
||||
if audit["has_public_key"]:
|
||||
security_score += 15
|
||||
if audit["has_encrypted_private_key"]:
|
||||
security_score += 15
|
||||
|
||||
audit["security_score"] = security_score
|
||||
audit["security_level"] = (
|
||||
"Excellent" if security_score >= 90 else
|
||||
"Good" if security_score >= 70 else
|
||||
"Fair" if security_score >= 50 else
|
||||
"Poor"
|
||||
)
|
||||
|
||||
return audit
|
||||
238
apps/coordinator-api/src/app/services/wallet_crypto.py
Normal file
238
apps/coordinator-api/src/app/services/wallet_crypto.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Secure Cryptographic Operations for Agent Wallets
|
||||
Fixed implementation using proper Ethereum cryptography
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from typing import Tuple, Dict, Any
|
||||
from eth_account import Account
|
||||
from eth_utils import to_checksum_address
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
|
||||
def generate_ethereum_keypair() -> Tuple[str, str, str]:
|
||||
"""
|
||||
Generate proper Ethereum keypair using secp256k1
|
||||
|
||||
Returns:
|
||||
Tuple of (private_key, public_key, address)
|
||||
"""
|
||||
# Use eth_account which properly implements secp256k1
|
||||
account = Account.create()
|
||||
|
||||
private_key = account.key.hex()
|
||||
public_key = account._private_key.public_key.to_hex()
|
||||
address = account.address
|
||||
|
||||
return private_key, public_key, address
|
||||
|
||||
|
||||
def verify_keypair_consistency(private_key: str, expected_address: str) -> bool:
|
||||
"""
|
||||
Verify that a private key generates the expected address
|
||||
|
||||
Args:
|
||||
private_key: 32-byte private key hex
|
||||
expected_address: Expected Ethereum address
|
||||
|
||||
Returns:
|
||||
True if keypair is consistent
|
||||
"""
|
||||
try:
|
||||
account = Account.from_key(private_key)
|
||||
return to_checksum_address(account.address) == to_checksum_address(expected_address)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def derive_secure_key(password: str, salt: bytes = None) -> bytes:
|
||||
"""
|
||||
Derive secure encryption key using PBKDF2
|
||||
|
||||
Args:
|
||||
password: User password
|
||||
salt: Optional salt (generated if not provided)
|
||||
|
||||
Returns:
|
||||
Tuple of (key, salt) for storage
|
||||
"""
|
||||
if salt is None:
|
||||
salt = secrets.token_bytes(32)
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=600_000, # OWASP recommended minimum
|
||||
)
|
||||
|
||||
key = kdf.derive(password.encode())
|
||||
return base64.urlsafe_b64encode(key), salt
|
||||
|
||||
|
||||
def encrypt_private_key(private_key: str, password: str) -> Dict[str, str]:
|
||||
"""
|
||||
Encrypt private key with proper KDF and Fernet
|
||||
|
||||
Args:
|
||||
private_key: 32-byte private key hex
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
Dict with encrypted data and salt
|
||||
"""
|
||||
# Derive encryption key
|
||||
fernet_key, salt = derive_secure_key(password)
|
||||
|
||||
# Encrypt
|
||||
f = Fernet(fernet_key)
|
||||
encrypted = f.encrypt(private_key.encode())
|
||||
|
||||
return {
|
||||
"encrypted_key": encrypted.decode(),
|
||||
"salt": base64.b64encode(salt).decode(),
|
||||
"algorithm": "PBKDF2-SHA256-Fernet",
|
||||
"iterations": 600_000
|
||||
}
|
||||
|
||||
|
||||
def decrypt_private_key(encrypted_data: Dict[str, str], password: str) -> str:
|
||||
"""
|
||||
Decrypt private key with proper verification
|
||||
|
||||
Args:
|
||||
encrypted_data: Dict with encrypted key and salt
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
Decrypted private key
|
||||
|
||||
Raises:
|
||||
ValueError: If decryption fails
|
||||
"""
|
||||
try:
|
||||
# Extract salt and encrypted key
|
||||
salt = base64.b64decode(encrypted_data["salt"])
|
||||
encrypted_key = encrypted_data["encrypted_key"].encode()
|
||||
|
||||
# Derive same key
|
||||
fernet_key, _ = derive_secure_key(password, salt)
|
||||
|
||||
# Decrypt
|
||||
f = Fernet(fernet_key)
|
||||
decrypted = f.decrypt(encrypted_key)
|
||||
|
||||
return decrypted.decode()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decrypt private key: {str(e)}")
|
||||
|
||||
|
||||
def validate_private_key_format(private_key: str) -> bool:
|
||||
"""
|
||||
Validate private key format
|
||||
|
||||
Args:
|
||||
private_key: Private key to validate
|
||||
|
||||
Returns:
|
||||
True if format is valid
|
||||
"""
|
||||
try:
|
||||
# Remove 0x prefix if present
|
||||
if private_key.startswith("0x"):
|
||||
private_key = private_key[2:]
|
||||
|
||||
# Check length (32 bytes = 64 hex chars)
|
||||
if len(private_key) != 64:
|
||||
return False
|
||||
|
||||
# Check if valid hex
|
||||
int(private_key, 16)
|
||||
|
||||
# Try to create account to verify it's a valid secp256k1 key
|
||||
Account.from_key("0x" + private_key)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Security configuration constants
|
||||
class SecurityConfig:
|
||||
"""Security configuration constants"""
|
||||
|
||||
# PBKDF2 settings
|
||||
PBKDF2_ITERATIONS = 600_000
|
||||
PBKDF2_ALGORITHM = hashes.SHA256
|
||||
SALT_LENGTH = 32
|
||||
|
||||
# Fernet settings
|
||||
FERNET_KEY_LENGTH = 32
|
||||
|
||||
# Validation
|
||||
PRIVATE_KEY_LENGTH = 64 # 32 bytes in hex
|
||||
ADDRESS_LENGTH = 40 # 20 bytes in hex (without 0x)
|
||||
|
||||
|
||||
# Backward compatibility wrapper for existing code
|
||||
def create_secure_wallet(agent_id: str, password: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a wallet with proper security
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
password: Strong password for encryption
|
||||
|
||||
Returns:
|
||||
Wallet data with encrypted private key
|
||||
"""
|
||||
# Generate proper keypair
|
||||
private_key, public_key, address = generate_ethereum_keypair()
|
||||
|
||||
# Validate consistency
|
||||
if not verify_keypair_consistency(private_key, address):
|
||||
raise RuntimeError("Keypair generation failed consistency check")
|
||||
|
||||
# Encrypt private key
|
||||
encrypted_data = encrypt_private_key(private_key, password)
|
||||
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"address": address,
|
||||
"public_key": public_key,
|
||||
"encrypted_private_key": encrypted_data,
|
||||
"created_at": secrets.token_hex(16), # For tracking
|
||||
"version": "1.0"
|
||||
}
|
||||
|
||||
|
||||
def recover_wallet(encrypted_data: Dict[str, str], password: str) -> Dict[str, str]:
|
||||
"""
|
||||
Recover wallet from encrypted data
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted wallet data
|
||||
password: Password for decryption
|
||||
|
||||
Returns:
|
||||
Wallet keys
|
||||
"""
|
||||
# Decrypt private key
|
||||
private_key = decrypt_private_key(encrypted_data, password)
|
||||
|
||||
# Validate format
|
||||
if not validate_private_key_format(private_key):
|
||||
raise ValueError("Decrypted private key has invalid format")
|
||||
|
||||
# Derive address and public key to verify
|
||||
account = Account.from_key("0x" + private_key)
|
||||
|
||||
return {
|
||||
"private_key": private_key,
|
||||
"public_key": account._private_key.public_key.to_hex(),
|
||||
"address": account.address
|
||||
}
|
||||
@@ -0,0 +1,641 @@
|
||||
"""
|
||||
WebSocket Stream Manager with Backpressure Control
|
||||
|
||||
Advanced WebSocket stream architecture with per-stream flow control,
|
||||
bounded queues, and event loop protection for multi-modal fusion.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import weakref
|
||||
from typing import Dict, List, Optional, Any, Callable, Set, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from collections import deque
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import websockets
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from aitbc.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StreamStatus(Enum):
|
||||
"""Stream connection status"""
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
SLOW_CONSUMER = "slow_consumer"
|
||||
BACKPRESSURE = "backpressure"
|
||||
DISCONNECTED = "disconnected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Message types for stream classification"""
|
||||
CRITICAL = "critical" # High priority, must deliver
|
||||
IMPORTANT = "important" # Normal priority
|
||||
BULK = "bulk" # Low priority, can be dropped
|
||||
CONTROL = "control" # Stream control messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamMessage:
|
||||
"""Message with priority and metadata"""
|
||||
data: Any
|
||||
message_type: MessageType
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.message_id,
|
||||
"type": self.message_type.value,
|
||||
"timestamp": self.timestamp,
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamMetrics:
|
||||
"""Metrics for stream performance monitoring"""
|
||||
messages_sent: int = 0
|
||||
messages_dropped: int = 0
|
||||
bytes_sent: int = 0
|
||||
last_send_time: float = 0
|
||||
avg_send_time: float = 0
|
||||
queue_size: int = 0
|
||||
backpressure_events: int = 0
|
||||
slow_consumer_events: int = 0
|
||||
|
||||
def update_send_metrics(self, send_time: float, message_size: int):
|
||||
"""Update send performance metrics"""
|
||||
self.messages_sent += 1
|
||||
self.bytes_sent += message_size
|
||||
self.last_send_time = time.time()
|
||||
|
||||
# Update average send time
|
||||
if self.messages_sent == 1:
|
||||
self.avg_send_time = send_time
|
||||
else:
|
||||
self.avg_send_time = (self.avg_send_time * (self.messages_sent - 1) + send_time) / self.messages_sent
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamConfig:
|
||||
"""Configuration for individual streams"""
|
||||
max_queue_size: int = 1000
|
||||
send_timeout: float = 5.0
|
||||
heartbeat_interval: float = 30.0
|
||||
slow_consumer_threshold: float = 0.5 # seconds
|
||||
backpressure_threshold: float = 0.8 # queue fill ratio
|
||||
drop_bulk_threshold: float = 0.9 # queue fill ratio for bulk messages
|
||||
enable_compression: bool = True
|
||||
priority_send: bool = True
|
||||
|
||||
|
||||
class BoundedMessageQueue:
|
||||
"""Bounded queue with priority and backpressure handling"""
|
||||
|
||||
def __init__(self, max_size: int = 1000):
|
||||
self.max_size = max_size
|
||||
self.queues = {
|
||||
MessageType.CRITICAL: deque(maxlen=max_size // 4),
|
||||
MessageType.IMPORTANT: deque(maxlen=max_size // 2),
|
||||
MessageType.BULK: deque(maxlen=max_size // 4),
|
||||
MessageType.CONTROL: deque(maxlen=100) # Small control queue
|
||||
}
|
||||
self.total_size = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def put(self, message: StreamMessage) -> bool:
|
||||
"""Add message to queue with backpressure handling"""
|
||||
async with self._lock:
|
||||
# Check if we're at capacity
|
||||
if self.total_size >= self.max_size:
|
||||
# Drop bulk messages first
|
||||
if message.message_type == MessageType.BULK:
|
||||
return False
|
||||
|
||||
# Drop oldest important messages if critical
|
||||
if message.message_type == MessageType.IMPORTANT:
|
||||
if self.queues[MessageType.IMPORTANT]:
|
||||
self.queues[MessageType.IMPORTANT].popleft()
|
||||
self.total_size -= 1
|
||||
else:
|
||||
return False
|
||||
|
||||
# Always allow critical messages (drop oldest if needed)
|
||||
if message.message_type == MessageType.CRITICAL:
|
||||
if self.queues[MessageType.CRITICAL]:
|
||||
self.queues[MessageType.CRITICAL].popleft()
|
||||
self.total_size -= 1
|
||||
|
||||
self.queues[message.message_type].append(message)
|
||||
self.total_size += 1
|
||||
return True
|
||||
|
||||
async def get(self) -> Optional[StreamMessage]:
|
||||
"""Get next message by priority"""
|
||||
async with self._lock:
|
||||
# Priority order: CONTROL > CRITICAL > IMPORTANT > BULK
|
||||
for message_type in [MessageType.CONTROL, MessageType.CRITICAL,
|
||||
MessageType.IMPORTANT, MessageType.BULK]:
|
||||
if self.queues[message_type]:
|
||||
message = self.queues[message_type].popleft()
|
||||
self.total_size -= 1
|
||||
return message
|
||||
return None
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get total queue size"""
|
||||
return self.total_size
|
||||
|
||||
def fill_ratio(self) -> float:
|
||||
"""Get queue fill ratio"""
|
||||
return self.total_size / self.max_size
|
||||
|
||||
|
||||
class WebSocketStream:
|
||||
"""Individual WebSocket stream with backpressure control"""
|
||||
|
||||
def __init__(self, websocket: WebSocketServerProtocol,
|
||||
stream_id: str, config: StreamConfig):
|
||||
self.websocket = websocket
|
||||
self.stream_id = stream_id
|
||||
self.config = config
|
||||
self.status = StreamStatus.CONNECTING
|
||||
self.queue = BoundedMessageQueue(config.max_queue_size)
|
||||
self.metrics = StreamMetrics()
|
||||
self.last_heartbeat = time.time()
|
||||
self.slow_consumer_count = 0
|
||||
|
||||
# Event loop protection
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._sender_task = None
|
||||
self._heartbeat_task = None
|
||||
self._running = False
|
||||
|
||||
# Weak reference for cleanup
|
||||
self._finalizer = weakref.finalize(self, self._cleanup)
|
||||
|
||||
async def start(self):
|
||||
"""Start stream processing"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self.status = StreamStatus.CONNECTED
|
||||
|
||||
# Start sender task
|
||||
self._sender_task = asyncio.create_task(self._sender_loop())
|
||||
|
||||
# Start heartbeat task
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
logger.info(f"Stream {self.stream_id} started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop stream processing"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self.status = StreamStatus.DISCONNECTED
|
||||
|
||||
# Cancel tasks
|
||||
if self._sender_task:
|
||||
self._sender_task.cancel()
|
||||
try:
|
||||
await self._sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
await self._heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info(f"Stream {self.stream_id} stopped")
|
||||
|
||||
async def send_message(self, data: Any, message_type: MessageType = MessageType.IMPORTANT) -> bool:
|
||||
"""Send message with backpressure handling"""
|
||||
if not self._running:
|
||||
return False
|
||||
|
||||
message = StreamMessage(data=data, message_type=message_type)
|
||||
|
||||
# Check backpressure
|
||||
queue_ratio = self.queue.fill_ratio()
|
||||
if queue_ratio > self.config.backpressure_threshold:
|
||||
self.status = StreamStatus.BACKPRESSURE
|
||||
self.metrics.backpressure_events += 1
|
||||
|
||||
# Drop bulk messages under backpressure
|
||||
if message_type == MessageType.BULK and queue_ratio > self.config.drop_bulk_threshold:
|
||||
self.metrics.messages_dropped += 1
|
||||
return False
|
||||
|
||||
# Add to queue
|
||||
success = await self.queue.put(message)
|
||||
if not success:
|
||||
self.metrics.messages_dropped += 1
|
||||
|
||||
return success
|
||||
|
||||
async def _sender_loop(self):
|
||||
"""Main sender loop with backpressure control"""
|
||||
while self._running:
|
||||
try:
|
||||
# Get next message
|
||||
message = await self.queue.get()
|
||||
if message is None:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Send with timeout and backpressure protection
|
||||
start_time = time.time()
|
||||
success = await self._send_with_backpressure(message)
|
||||
send_time = time.time() - start_time
|
||||
|
||||
if success:
|
||||
message_size = len(json.dumps(message.to_dict()).encode())
|
||||
self.metrics.update_send_metrics(send_time, message_size)
|
||||
else:
|
||||
# Retry logic
|
||||
message.retry_count += 1
|
||||
if message.retry_count < message.max_retries:
|
||||
await self.queue.put(message)
|
||||
else:
|
||||
self.metrics.messages_dropped += 1
|
||||
logger.warning(f"Message {message.message_id} dropped after max retries")
|
||||
|
||||
# Check for slow consumer
|
||||
if send_time > self.config.slow_consumer_threshold:
|
||||
self.slow_consumer_count += 1
|
||||
self.metrics.slow_consumer_events += 1
|
||||
|
||||
if self.slow_consumer_count > 5: # Threshold for slow consumer detection
|
||||
self.status = StreamStatus.SLOW_CONSUMER
|
||||
logger.warning(f"Stream {self.stream_id} detected as slow consumer")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sender loop for stream {self.stream_id}: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _send_with_backpressure(self, message: StreamMessage) -> bool:
|
||||
"""Send message with backpressure and timeout protection"""
|
||||
try:
|
||||
async with self._send_lock:
|
||||
# Use asyncio.wait_for for timeout protection
|
||||
message_data = message.to_dict()
|
||||
|
||||
if self.config.enable_compression:
|
||||
# Compress large messages
|
||||
message_str = json.dumps(message_data, separators=(',', ':'))
|
||||
if len(message_str) > 1024: # Compress messages > 1KB
|
||||
message_data['_compressed'] = True
|
||||
message_str = json.dumps(message_data, separators=(',', ':'))
|
||||
else:
|
||||
message_str = json.dumps(message_data)
|
||||
|
||||
# Send with timeout
|
||||
await asyncio.wait_for(
|
||||
self.websocket.send(message_str),
|
||||
timeout=self.config.send_timeout
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Send timeout for stream {self.stream_id}")
|
||||
return False
|
||||
except ConnectionClosed:
|
||||
logger.info(f"Connection closed for stream {self.stream_id}")
|
||||
await self.stop()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Send error for stream {self.stream_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""Heartbeat loop for connection health monitoring"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.config.heartbeat_interval)
|
||||
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
# Send heartbeat
|
||||
heartbeat_msg = {
|
||||
"type": "heartbeat",
|
||||
"timestamp": time.time(),
|
||||
"stream_id": self.stream_id,
|
||||
"queue_size": self.queue.size(),
|
||||
"status": self.status.value
|
||||
}
|
||||
|
||||
await self.send_message(heartbeat_msg, MessageType.CONTROL)
|
||||
self.last_heartbeat = time.time()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Heartbeat error for stream {self.stream_id}: {e}")
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get stream metrics"""
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"status": self.status.value,
|
||||
"queue_size": self.queue.size(),
|
||||
"queue_fill_ratio": self.queue.fill_ratio(),
|
||||
"messages_sent": self.metrics.messages_sent,
|
||||
"messages_dropped": self.metrics.messages_dropped,
|
||||
"bytes_sent": self.metrics.bytes_sent,
|
||||
"avg_send_time": self.metrics.avg_send_time,
|
||||
"backpressure_events": self.metrics.backpressure_events,
|
||||
"slow_consumer_events": self.metrics.slow_consumer_events,
|
||||
"last_heartbeat": self.last_heartbeat
|
||||
}
|
||||
|
||||
def _cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self._running:
|
||||
# This should be called by garbage collector
|
||||
logger.warning(f"Stream {self.stream_id} cleanup called while running")
|
||||
|
||||
|
||||
class WebSocketStreamManager:
|
||||
"""Manages multiple WebSocket streams with backpressure control"""
|
||||
|
||||
def __init__(self, default_config: Optional[StreamConfig] = None):
|
||||
self.default_config = default_config or StreamConfig()
|
||||
self.streams: Dict[str, WebSocketStream] = {}
|
||||
self.stream_configs: Dict[str, StreamConfig] = {}
|
||||
|
||||
# Global metrics
|
||||
self.total_connections = 0
|
||||
self.total_messages_sent = 0
|
||||
self.total_messages_dropped = 0
|
||||
|
||||
# Event loop protection
|
||||
self._manager_lock = asyncio.Lock()
|
||||
self._cleanup_task = None
|
||||
self._running = False
|
||||
|
||||
# Message broadcasting
|
||||
self._broadcast_queue = asyncio.Queue(maxsize=10000)
|
||||
self._broadcast_task = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the stream manager"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Start cleanup task
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
# Start broadcast task
|
||||
self._broadcast_task = asyncio.create_task(self._broadcast_loop())
|
||||
|
||||
logger.info("WebSocket Stream Manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the stream manager"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Stop all streams
|
||||
streams_to_stop = list(self.streams.values())
|
||||
for stream in streams_to_stop:
|
||||
await stream.stop()
|
||||
|
||||
# Cancel tasks
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._broadcast_task:
|
||||
self._broadcast_task.cancel()
|
||||
try:
|
||||
await self._broadcast_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("WebSocket Stream Manager stopped")
|
||||
|
||||
async def manage_stream(self, websocket: WebSocketServerProtocol,
|
||||
config: Optional[StreamConfig] = None):
|
||||
"""Context manager for stream lifecycle"""
|
||||
stream_id = str(uuid.uuid4())
|
||||
stream_config = config or self.default_config
|
||||
|
||||
stream = None
|
||||
try:
|
||||
# Create and start stream
|
||||
stream = WebSocketStream(websocket, stream_id, stream_config)
|
||||
await stream.start()
|
||||
|
||||
async with self._manager_lock:
|
||||
self.streams[stream_id] = stream
|
||||
self.stream_configs[stream_id] = stream_config
|
||||
self.total_connections += 1
|
||||
|
||||
logger.info(f"Stream {stream_id} added to manager")
|
||||
|
||||
yield stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error managing stream {stream_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup stream
|
||||
if stream and stream_id in self.streams:
|
||||
await stream.stop()
|
||||
|
||||
async with self._manager_lock:
|
||||
del self.streams[stream_id]
|
||||
if stream_id in self.stream_configs:
|
||||
del self.stream_configs[stream_id]
|
||||
self.total_connections -= 1
|
||||
|
||||
logger.info(f"Stream {stream_id} removed from manager")
|
||||
|
||||
async def broadcast_to_all(self, data: Any, message_type: MessageType = MessageType.IMPORTANT):
|
||||
"""Broadcast message to all streams"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._broadcast_queue.put((data, message_type))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("Broadcast queue full, dropping message")
|
||||
self.total_messages_dropped += 1
|
||||
|
||||
async def broadcast_to_stream(self, stream_id: str, data: Any,
|
||||
message_type: MessageType = MessageType.IMPORTANT):
|
||||
"""Send message to specific stream"""
|
||||
async with self._manager_lock:
|
||||
stream = self.streams.get(stream_id)
|
||||
if stream:
|
||||
await stream.send_message(data, message_type)
|
||||
|
||||
async def _broadcast_loop(self):
|
||||
"""Broadcast messages to all streams"""
|
||||
while self._running:
|
||||
try:
|
||||
# Get broadcast message
|
||||
data, message_type = await self._broadcast_queue.get()
|
||||
|
||||
# Send to all streams concurrently
|
||||
tasks = []
|
||||
async with self._manager_lock:
|
||||
streams = list(self.streams.values())
|
||||
|
||||
for stream in streams:
|
||||
task = asyncio.create_task(
|
||||
stream.send_message(data, message_type)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all sends (with timeout)
|
||||
if tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True),
|
||||
timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Broadcast timeout, some streams may be slow")
|
||||
|
||||
self.total_messages_sent += 1
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in broadcast loop: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Cleanup disconnected streams"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(60) # Cleanup every minute
|
||||
|
||||
disconnected_streams = []
|
||||
async with self._manager_lock:
|
||||
for stream_id, stream in self.streams.items():
|
||||
if stream.status == StreamStatus.DISCONNECTED:
|
||||
disconnected_streams.append(stream_id)
|
||||
|
||||
# Remove disconnected streams
|
||||
for stream_id in disconnected_streams:
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
await stream.stop()
|
||||
del self.streams[stream_id]
|
||||
if stream_id in self.stream_configs:
|
||||
del self.stream_configs[stream_id]
|
||||
self.total_connections -= 1
|
||||
logger.info(f"Cleaned up disconnected stream {stream_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
|
||||
async def get_manager_metrics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive manager metrics"""
|
||||
async with self._manager_lock:
|
||||
stream_metrics = []
|
||||
for stream in self.streams.values():
|
||||
stream_metrics.append(stream.get_metrics())
|
||||
|
||||
# Calculate aggregate metrics
|
||||
total_queue_size = sum(m["queue_size"] for m in stream_metrics)
|
||||
total_messages_sent = sum(m["messages_sent"] for m in stream_metrics)
|
||||
total_messages_dropped = sum(m["messages_dropped"] for m in stream_metrics)
|
||||
total_bytes_sent = sum(m["bytes_sent"] for m in stream_metrics)
|
||||
|
||||
# Status distribution
|
||||
status_counts = {}
|
||||
for stream in self.streams.values():
|
||||
status = stream.status.value
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
|
||||
return {
|
||||
"manager_status": "running" if self._running else "stopped",
|
||||
"total_connections": self.total_connections,
|
||||
"active_streams": len(self.streams),
|
||||
"total_queue_size": total_queue_size,
|
||||
"total_messages_sent": total_messages_sent,
|
||||
"total_messages_dropped": total_messages_dropped,
|
||||
"total_bytes_sent": total_bytes_sent,
|
||||
"broadcast_queue_size": self._broadcast_queue.qsize(),
|
||||
"stream_status_distribution": status_counts,
|
||||
"stream_metrics": stream_metrics
|
||||
}
|
||||
|
||||
async def update_stream_config(self, stream_id: str, config: StreamConfig):
|
||||
"""Update configuration for specific stream"""
|
||||
async with self._manager_lock:
|
||||
if stream_id in self.streams:
|
||||
self.stream_configs[stream_id] = config
|
||||
# Stream will use new config on next send
|
||||
logger.info(f"Updated config for stream {stream_id}")
|
||||
|
||||
def get_slow_streams(self, threshold: float = 0.8) -> List[str]:
|
||||
"""Get streams with high queue fill ratios"""
|
||||
slow_streams = []
|
||||
for stream_id, stream in self.streams.items():
|
||||
if stream.queue.fill_ratio() > threshold:
|
||||
slow_streams.append(stream_id)
|
||||
return slow_streams
|
||||
|
||||
async def handle_slow_consumer(self, stream_id: str, action: str = "warn"):
|
||||
"""Handle slow consumer streams"""
|
||||
async with self._manager_lock:
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
return
|
||||
|
||||
if action == "warn":
|
||||
logger.warning(f"Slow consumer detected: {stream_id}")
|
||||
await stream.send_message(
|
||||
{"warning": "Slow consumer detected", "stream_id": stream_id},
|
||||
MessageType.CONTROL
|
||||
)
|
||||
elif action == "throttle":
|
||||
# Reduce queue size for slow consumer
|
||||
new_config = StreamConfig(
|
||||
max_queue_size=stream.config.max_queue_size // 2,
|
||||
send_timeout=stream.config.send_timeout * 2
|
||||
)
|
||||
await self.update_stream_config(stream_id, new_config)
|
||||
logger.info(f"Throttled slow consumer: {stream_id}")
|
||||
elif action == "disconnect":
|
||||
logger.warning(f"Disconnecting slow consumer: {stream_id}")
|
||||
await stream.stop()
|
||||
|
||||
|
||||
# Global stream manager instance
|
||||
stream_manager = WebSocketStreamManager()
|
||||
Reference in New Issue
Block a user