chore(security): enhance environment configuration, CI workflows, and wallet daemon with security improvements

- Restructure .env.example with security-focused documentation, service-specific environment file references, and AWS Secrets Manager integration
- Update CLI tests workflow to single Python 3.13 version, add pytest-mock dependency, and consolidate test execution with coverage
- Add comprehensive security validation to package publishing workflow with manual approval gates, secret scanning, and release
This commit is contained in:
oib
2026-03-03 10:33:46 +01:00
parent 00d00cb964
commit f353e00172
220 changed files with 42506 additions and 921 deletions

View File

@@ -0,0 +1,706 @@
"""
Multi-Modal WebSocket Fusion Service
Advanced WebSocket stream architecture for multi-modal fusion with
per-stream backpressure handling and GPU provider flow control.
"""
import asyncio
import json
import time
import numpy as np
import torch
from typing import Dict, List, Optional, Any, Tuple, Union
from dataclasses import dataclass, field
from enum import Enum
from uuid import uuid4
from aitbc.logging import get_logger
from .websocket_stream_manager import (
WebSocketStreamManager, StreamConfig, MessageType,
stream_manager, WebSocketStream
)
from .gpu_multimodal import GPUMultimodalProcessor
from .multi_modal_fusion import MultiModalFusionService
logger = get_logger(__name__)
class FusionStreamType(Enum):
"""Types of fusion streams"""
VISUAL = "visual"
TEXT = "text"
AUDIO = "audio"
SENSOR = "sensor"
CONTROL = "control"
METRICS = "metrics"
class GPUProviderStatus(Enum):
"""GPU provider status"""
AVAILABLE = "available"
BUSY = "busy"
SLOW = "slow"
OVERLOADED = "overloaded"
OFFLINE = "offline"
@dataclass
class FusionStreamConfig:
"""Configuration for fusion streams"""
stream_type: FusionStreamType
max_queue_size: int = 500
gpu_timeout: float = 2.0
fusion_timeout: float = 5.0
batch_size: int = 8
enable_gpu_acceleration: bool = True
priority: int = 1 # Higher number = higher priority
def to_stream_config(self) -> StreamConfig:
"""Convert to WebSocket stream config"""
return StreamConfig(
max_queue_size=self.max_queue_size,
send_timeout=self.fusion_timeout,
heartbeat_interval=30.0,
slow_consumer_threshold=0.5,
backpressure_threshold=0.7,
drop_bulk_threshold=0.85,
enable_compression=True,
priority_send=True
)
@dataclass
class FusionData:
"""Multi-modal fusion data"""
stream_id: str
stream_type: FusionStreamType
data: Any
timestamp: float
metadata: Dict[str, Any] = field(default_factory=dict)
requires_gpu: bool = False
processing_priority: int = 1
@dataclass
class GPUProviderMetrics:
"""GPU provider performance metrics"""
provider_id: str
status: GPUProviderStatus
avg_processing_time: float
queue_size: int
gpu_utilization: float
memory_usage: float
error_rate: float
last_update: float
class GPUProviderFlowControl:
"""Flow control for GPU providers"""
def __init__(self, provider_id: str):
self.provider_id = provider_id
self.metrics = GPUProviderMetrics(
provider_id=provider_id,
status=GPUProviderStatus.AVAILABLE,
avg_processing_time=0.0,
queue_size=0,
gpu_utilization=0.0,
memory_usage=0.0,
error_rate=0.0,
last_update=time.time()
)
# Flow control queues
self.input_queue = asyncio.Queue(maxsize=100)
self.output_queue = asyncio.Queue(maxsize=100)
self.control_queue = asyncio.Queue(maxsize=50)
# Flow control parameters
self.max_concurrent_requests = 4
self.current_requests = 0
self.slow_threshold = 2.0 # seconds
self.overload_threshold = 0.8 # queue fill ratio
# Performance tracking
self.request_times = []
self.error_count = 0
self.total_requests = 0
# Flow control task
self._flow_control_task = None
self._running = False
async def start(self):
"""Start flow control"""
if self._running:
return
self._running = True
self._flow_control_task = asyncio.create_task(self._flow_control_loop())
logger.info(f"GPU provider flow control started: {self.provider_id}")
async def stop(self):
"""Stop flow control"""
if not self._running:
return
self._running = False
if self._flow_control_task:
self._flow_control_task.cancel()
try:
await self._flow_control_task
except asyncio.CancelledError:
pass
logger.info(f"GPU provider flow control stopped: {self.provider_id}")
async def submit_request(self, data: FusionData) -> Optional[str]:
"""Submit request with flow control"""
if not self._running:
return None
# Check provider status
if self.metrics.status == GPUProviderStatus.OFFLINE:
logger.warning(f"GPU provider {self.provider_id} is offline")
return None
# Check backpressure
if self.input_queue.qsize() / self.input_queue.maxsize > self.overload_threshold:
self.metrics.status = GPUProviderStatus.OVERLOADED
logger.warning(f"GPU provider {self.provider_id} is overloaded")
return None
# Submit request
request_id = str(uuid4())
request_data = {
"request_id": request_id,
"data": data,
"timestamp": time.time()
}
try:
await asyncio.wait_for(
self.input_queue.put(request_data),
timeout=1.0
)
return request_id
except asyncio.TimeoutError:
logger.warning(f"Request timeout for GPU provider {self.provider_id}")
return None
async def get_result(self, request_id: str, timeout: float = 5.0) -> Optional[Any]:
"""Get processing result"""
start_time = time.time()
while time.time() - start_time < timeout:
try:
# Check output queue
result = await asyncio.wait_for(
self.output_queue.get(),
timeout=0.1
)
if result.get("request_id") == request_id:
return result.get("data")
# Put back if not our result
await self.output_queue.put(result)
except asyncio.TimeoutError:
continue
return None
async def _flow_control_loop(self):
"""Main flow control loop"""
while self._running:
try:
# Get next request
request_data = await asyncio.wait_for(
self.input_queue.get(),
timeout=1.0
)
# Check concurrent request limit
if self.current_requests >= self.max_concurrent_requests:
# Re-queue request
await self.input_queue.put(request_data)
await asyncio.sleep(0.1)
continue
# Process request
self.current_requests += 1
self.total_requests += 1
asyncio.create_task(self._process_request(request_data))
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"Flow control error for {self.provider_id}: {e}")
await asyncio.sleep(0.1)
async def _process_request(self, request_data: Dict[str, Any]):
"""Process individual request"""
request_id = request_data["request_id"]
data: FusionData = request_data["data"]
start_time = time.time()
try:
# Simulate GPU processing
if data.requires_gpu:
# Simulate GPU processing time
processing_time = np.random.uniform(0.5, 3.0)
await asyncio.sleep(processing_time)
# Simulate GPU result
result = {
"processed_data": f"gpu_processed_{data.stream_type}",
"processing_time": processing_time,
"gpu_utilization": np.random.uniform(0.3, 0.9),
"memory_usage": np.random.uniform(0.4, 0.8)
}
else:
# CPU processing
processing_time = np.random.uniform(0.1, 0.5)
await asyncio.sleep(processing_time)
result = {
"processed_data": f"cpu_processed_{data.stream_type}",
"processing_time": processing_time
}
# Update metrics
actual_time = time.time() - start_time
self._update_metrics(actual_time, success=True)
# Send result
await self.output_queue.put({
"request_id": request_id,
"data": result,
"timestamp": time.time()
})
except Exception as e:
logger.error(f"Request processing error for {self.provider_id}: {e}")
self._update_metrics(time.time() - start_time, success=False)
# Send error result
await self.output_queue.put({
"request_id": request_id,
"error": str(e),
"timestamp": time.time()
})
finally:
self.current_requests -= 1
def _update_metrics(self, processing_time: float, success: bool):
"""Update provider metrics"""
# Update processing time
self.request_times.append(processing_time)
if len(self.request_times) > 100:
self.request_times.pop(0)
self.metrics.avg_processing_time = np.mean(self.request_times)
# Update error rate
if not success:
self.error_count += 1
self.metrics.error_rate = self.error_count / max(self.total_requests, 1)
# Update queue sizes
self.metrics.queue_size = self.input_queue.qsize()
# Update status
if self.metrics.error_rate > 0.1:
self.metrics.status = GPUProviderStatus.OFFLINE
elif self.metrics.avg_processing_time > self.slow_threshold:
self.metrics.status = GPUProviderStatus.SLOW
elif self.metrics.queue_size > self.input_queue.maxsize * 0.8:
self.metrics.status = GPUProviderStatus.OVERLOADED
elif self.current_requests >= self.max_concurrent_requests:
self.metrics.status = GPUProviderStatus.BUSY
else:
self.metrics.status = GPUProviderStatus.AVAILABLE
self.metrics.last_update = time.time()
def get_metrics(self) -> Dict[str, Any]:
"""Get provider metrics"""
return {
"provider_id": self.provider_id,
"status": self.metrics.status.value,
"avg_processing_time": self.metrics.avg_processing_time,
"queue_size": self.metrics.queue_size,
"current_requests": self.current_requests,
"max_concurrent_requests": self.max_concurrent_requests,
"error_rate": self.metrics.error_rate,
"total_requests": self.total_requests,
"last_update": self.metrics.last_update
}
class MultiModalWebSocketFusion:
"""Multi-modal fusion service with WebSocket streaming and backpressure control"""
def __init__(self):
self.stream_manager = stream_manager
self.fusion_service = None # Will be injected
self.gpu_providers: Dict[str, GPUProviderFlowControl] = {}
# Fusion streams
self.fusion_streams: Dict[str, FusionStreamConfig] = {}
self.active_fusions: Dict[str, Dict[str, Any]] = {}
# Performance metrics
self.fusion_metrics = {
"total_fusions": 0,
"successful_fusions": 0,
"failed_fusions": 0,
"avg_fusion_time": 0.0,
"gpu_utilization": 0.0,
"memory_usage": 0.0
}
# Backpressure control
self.backpressure_enabled = True
self.global_queue_size = 0
self.max_global_queue_size = 10000
# Running state
self._running = False
self._monitor_task = None
async def start(self):
"""Start the fusion service"""
if self._running:
return
self._running = True
# Start stream manager
await self.stream_manager.start()
# Initialize GPU providers
await self._initialize_gpu_providers()
# Start monitoring
self._monitor_task = asyncio.create_task(self._monitor_loop())
logger.info("Multi-Modal WebSocket Fusion started")
async def stop(self):
"""Stop the fusion service"""
if not self._running:
return
self._running = False
# Stop GPU providers
for provider in self.gpu_providers.values():
await provider.stop()
# Stop stream manager
await self.stream_manager.stop()
# Stop monitoring
if self._monitor_task:
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
logger.info("Multi-Modal WebSocket Fusion stopped")
async def register_fusion_stream(self, stream_id: str, config: FusionStreamConfig):
"""Register a fusion stream"""
self.fusion_streams[stream_id] = config
logger.info(f"Registered fusion stream: {stream_id} ({config.stream_type.value})")
async def handle_websocket_connection(self, websocket, stream_id: str,
stream_type: FusionStreamType):
"""Handle WebSocket connection for fusion stream"""
config = FusionStreamConfig(
stream_type=stream_type,
max_queue_size=500,
gpu_timeout=2.0,
fusion_timeout=5.0
)
async with self.stream_manager.manage_stream(websocket, config.to_stream_config()) as stream:
logger.info(f"Fusion stream connected: {stream_id} ({stream_type.value})")
try:
# Handle incoming messages
async for message in websocket:
await self._handle_stream_message(stream_id, stream_type, message)
except Exception as e:
logger.error(f"Error in fusion stream {stream_id}: {e}")
async def _handle_stream_message(self, stream_id: str, stream_type: FusionStreamType,
message: str):
"""Handle incoming stream message"""
try:
data = json.loads(message)
# Create fusion data
fusion_data = FusionData(
stream_id=stream_id,
stream_type=stream_type,
data=data.get("data"),
timestamp=time.time(),
metadata=data.get("metadata", {}),
requires_gpu=data.get("requires_gpu", False),
processing_priority=data.get("priority", 1)
)
# Submit to GPU provider if needed
if fusion_data.requires_gpu:
await self._submit_to_gpu_provider(fusion_data)
else:
await self._process_cpu_fusion(fusion_data)
except Exception as e:
logger.error(f"Error handling stream message: {e}")
async def _submit_to_gpu_provider(self, fusion_data: FusionData):
"""Submit fusion data to GPU provider"""
# Select best GPU provider
provider_id = await self._select_gpu_provider(fusion_data)
if not provider_id:
logger.warning("No available GPU providers")
await self._handle_fusion_error(fusion_data, "No GPU providers available")
return
provider = self.gpu_providers[provider_id]
# Submit request
request_id = await provider.submit_request(fusion_data)
if not request_id:
await self._handle_fusion_error(fusion_data, "GPU provider overloaded")
return
# Wait for result
result = await provider.get_result(request_id, timeout=5.0)
if result and "error" not in result:
await self._handle_fusion_result(fusion_data, result)
else:
error = result.get("error", "Unknown error") if result else "Timeout"
await self._handle_fusion_error(fusion_data, error)
async def _process_cpu_fusion(self, fusion_data: FusionData):
"""Process fusion data on CPU"""
try:
# Simulate CPU fusion processing
processing_time = np.random.uniform(0.1, 0.5)
await asyncio.sleep(processing_time)
result = {
"processed_data": f"cpu_fused_{fusion_data.stream_type}",
"processing_time": processing_time,
"fusion_type": "cpu"
}
await self._handle_fusion_result(fusion_data, result)
except Exception as e:
logger.error(f"CPU fusion error: {e}")
await self._handle_fusion_error(fusion_data, str(e))
async def _handle_fusion_result(self, fusion_data: FusionData, result: Dict[str, Any]):
"""Handle successful fusion result"""
# Update metrics
self.fusion_metrics["total_fusions"] += 1
self.fusion_metrics["successful_fusions"] += 1
# Broadcast result
broadcast_data = {
"type": "fusion_result",
"stream_id": fusion_data.stream_id,
"stream_type": fusion_data.stream_type.value,
"result": result,
"timestamp": time.time()
}
await self.stream_manager.broadcast_to_all(broadcast_data, MessageType.IMPORTANT)
logger.info(f"Fusion completed for {fusion_data.stream_id}")
async def _handle_fusion_error(self, fusion_data: FusionData, error: str):
"""Handle fusion error"""
# Update metrics
self.fusion_metrics["total_fusions"] += 1
self.fusion_metrics["failed_fusions"] += 1
# Broadcast error
error_data = {
"type": "fusion_error",
"stream_id": fusion_data.stream_id,
"stream_type": fusion_data.stream_type.value,
"error": error,
"timestamp": time.time()
}
await self.stream_manager.broadcast_to_all(error_data, MessageType.CRITICAL)
logger.error(f"Fusion error for {fusion_data.stream_id}: {error}")
async def _select_gpu_provider(self, fusion_data: FusionData) -> Optional[str]:
"""Select best GPU provider based on load and performance"""
available_providers = []
for provider_id, provider in self.gpu_providers.items():
metrics = provider.get_metrics()
# Check if provider is available
if metrics["status"] == GPUProviderStatus.AVAILABLE.value:
available_providers.append((provider_id, metrics))
if not available_providers:
return None
# Select provider with lowest queue size and processing time
best_provider = min(
available_providers,
key=lambda x: (x[1]["queue_size"], x[1]["avg_processing_time"])
)
return best_provider[0]
async def _initialize_gpu_providers(self):
"""Initialize GPU providers"""
# Create mock GPU providers
provider_configs = [
{"provider_id": "gpu_1", "max_concurrent": 4},
{"provider_id": "gpu_2", "max_concurrent": 2},
{"provider_id": "gpu_3", "max_concurrent": 6}
]
for config in provider_configs:
provider = GPUProviderFlowControl(config["provider_id"])
provider.max_concurrent_requests = config["max_concurrent"]
await provider.start()
self.gpu_providers[config["provider_id"]] = provider
logger.info(f"Initialized {len(self.gpu_providers)} GPU providers")
async def _monitor_loop(self):
"""Monitor system performance and backpressure"""
while self._running:
try:
# Update global metrics
await self._update_global_metrics()
# Check backpressure
if self.backpressure_enabled:
await self._check_backpressure()
# Monitor GPU providers
await self._monitor_gpu_providers()
# Sleep
await asyncio.sleep(10) # Monitor every 10 seconds
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Monitor loop error: {e}")
await asyncio.sleep(1)
async def _update_global_metrics(self):
"""Update global performance metrics"""
# Get stream manager metrics
manager_metrics = self.stream_manager.get_manager_metrics()
# Update global queue size
self.global_queue_size = manager_metrics["total_queue_size"]
# Calculate GPU utilization
total_gpu_util = 0
total_memory = 0
active_providers = 0
for provider in self.gpu_providers.values():
metrics = provider.get_metrics()
if metrics["status"] != GPUProviderStatus.OFFLINE.value:
total_gpu_util += metrics.get("gpu_utilization", 0)
total_memory += metrics.get("memory_usage", 0)
active_providers += 1
if active_providers > 0:
self.fusion_metrics["gpu_utilization"] = total_gpu_util / active_providers
self.fusion_metrics["memory_usage"] = total_memory / active_providers
async def _check_backpressure(self):
"""Check and handle backpressure"""
if self.global_queue_size > self.max_global_queue_size * 0.8:
logger.warning("High backpressure detected, applying flow control")
# Get slow streams
slow_streams = self.stream_manager.get_slow_streams(threshold=0.8)
# Handle slow streams
for stream_id in slow_streams:
await self.stream_manager.handle_slow_consumer(stream_id, "throttle")
async def _monitor_gpu_providers(self):
"""Monitor GPU provider health"""
for provider_id, provider in self.gpu_providers.items():
metrics = provider.get_metrics()
# Check for unhealthy providers
if metrics["status"] == GPUProviderStatus.OFFLINE.value:
logger.warning(f"GPU provider {provider_id} is offline")
elif metrics["error_rate"] > 0.1:
logger.warning(f"GPU provider {provider_id} has high error rate: {metrics['error_rate']}")
elif metrics["avg_processing_time"] > 5.0:
logger.warning(f"GPU provider {provider_id} is slow: {metrics['avg_processing_time']}s")
def get_comprehensive_metrics(self) -> Dict[str, Any]:
"""Get comprehensive system metrics"""
# Get stream manager metrics
stream_metrics = self.stream_manager.get_manager_metrics()
# Get GPU provider metrics
gpu_metrics = {}
for provider_id, provider in self.gpu_providers.items():
gpu_metrics[provider_id] = provider.get_metrics()
# Get fusion metrics
fusion_metrics = self.fusion_metrics.copy()
# Calculate success rate
if fusion_metrics["total_fusions"] > 0:
fusion_metrics["success_rate"] = (
fusion_metrics["successful_fusions"] / fusion_metrics["total_fusions"]
)
else:
fusion_metrics["success_rate"] = 0.0
return {
"timestamp": time.time(),
"system_status": "running" if self._running else "stopped",
"backpressure_enabled": self.backpressure_enabled,
"global_queue_size": self.global_queue_size,
"max_global_queue_size": self.max_global_queue_size,
"stream_metrics": stream_metrics,
"gpu_metrics": gpu_metrics,
"fusion_metrics": fusion_metrics,
"active_fusion_streams": len(self.fusion_streams),
"registered_gpu_providers": len(self.gpu_providers)
}
# Global fusion service instance
multimodal_fusion_service = MultiModalWebSocketFusion()

View File

@@ -0,0 +1,420 @@
"""
Secure Wallet Service - Fixed Version
Implements proper Ethereum cryptography and secure key storage
"""
from __future__ import annotations
import logging
from typing import List, Optional, Dict
from sqlalchemy import select
from sqlmodel import Session
from datetime import datetime
import secrets
from ..domain.wallet import (
AgentWallet, NetworkConfig, TokenBalance, WalletTransaction,
WalletType, TransactionStatus
)
from ..schemas.wallet import WalletCreate, TransactionRequest
from ..blockchain.contract_interactions import ContractInteractionService
# Import our fixed crypto utilities
from .wallet_crypto import (
generate_ethereum_keypair,
verify_keypair_consistency,
encrypt_private_key,
decrypt_private_key,
validate_private_key_format,
create_secure_wallet,
recover_wallet
)
logger = logging.getLogger(__name__)
class SecureWalletService:
"""Secure wallet service with proper cryptography and key management"""
def __init__(
self,
session: Session,
contract_service: ContractInteractionService
):
self.session = session
self.contract_service = contract_service
async def create_wallet(self, request: WalletCreate, encryption_password: str) -> AgentWallet:
"""
Create a new wallet with proper security
Args:
request: Wallet creation request
encryption_password: Strong password for private key encryption
Returns:
Created wallet record
Raises:
ValueError: If password is weak or wallet already exists
"""
# Validate password strength
from ..utils.security import validate_password_strength
password_validation = validate_password_strength(encryption_password)
if not password_validation["is_acceptable"]:
raise ValueError(
f"Password too weak: {', '.join(password_validation['issues'])}"
)
# Check if agent already has an active wallet of this type
existing = self.session.exec(
select(AgentWallet).where(
AgentWallet.agent_id == request.agent_id,
AgentWallet.wallet_type == request.wallet_type,
AgentWallet.is_active == True
)
).first()
if existing:
raise ValueError(f"Agent {request.agent_id} already has an active {request.wallet_type} wallet")
try:
# Generate proper Ethereum keypair
private_key, public_key, address = generate_ethereum_keypair()
# Verify keypair consistency
if not verify_keypair_consistency(private_key, address):
raise RuntimeError("Keypair generation failed consistency check")
# Encrypt private key securely
encrypted_data = encrypt_private_key(private_key, encryption_password)
# Create wallet record
wallet = AgentWallet(
agent_id=request.agent_id,
address=address,
public_key=public_key,
wallet_type=request.wallet_type,
metadata=request.metadata,
encrypted_private_key=encrypted_data,
encryption_version="1.0",
created_at=datetime.utcnow()
)
self.session.add(wallet)
self.session.commit()
self.session.refresh(wallet)
logger.info(f"Created secure wallet {wallet.address} for agent {request.agent_id}")
return wallet
except Exception as e:
logger.error(f"Failed to create secure wallet: {e}")
self.session.rollback()
raise
async def get_wallet_by_agent(self, agent_id: str) -> List[AgentWallet]:
"""Retrieve all active wallets for an agent"""
return self.session.exec(
select(AgentWallet).where(
AgentWallet.agent_id == agent_id,
AgentWallet.is_active == True
)
).all()
async def get_wallet_with_private_key(
self,
wallet_id: int,
encryption_password: str
) -> Dict[str, str]:
"""
Get wallet with decrypted private key (for signing operations)
Args:
wallet_id: Wallet ID
encryption_password: Password for decryption
Returns:
Wallet keys including private key
Raises:
ValueError: If decryption fails or wallet not found
"""
wallet = self.session.get(AgentWallet, wallet_id)
if not wallet:
raise ValueError("Wallet not found")
if not wallet.is_active:
raise ValueError("Wallet is not active")
try:
# Decrypt private key
if isinstance(wallet.encrypted_private_key, dict):
# New format
keys = recover_wallet(wallet.encrypted_private_key, encryption_password)
else:
# Legacy format - cannot decrypt securely
raise ValueError(
"Wallet uses legacy encryption format. "
"Please migrate to secure encryption."
)
return {
"wallet_id": wallet_id,
"address": wallet.address,
"private_key": keys["private_key"],
"public_key": keys["public_key"],
"agent_id": wallet.agent_id
}
except Exception as e:
logger.error(f"Failed to decrypt wallet {wallet_id}: {e}")
raise ValueError(f"Failed to access wallet: {str(e)}")
async def verify_wallet_integrity(self, wallet_id: int) -> Dict[str, bool]:
"""
Verify wallet cryptographic integrity
Args:
wallet_id: Wallet ID
Returns:
Integrity check results
"""
wallet = self.session.get(AgentWallet, wallet_id)
if not wallet:
return {"exists": False}
results = {
"exists": True,
"active": wallet.is_active,
"has_encrypted_key": bool(wallet.encrypted_private_key),
"address_format_valid": False,
"public_key_present": bool(wallet.public_key)
}
# Validate address format
try:
from eth_utils import to_checksum_address
to_checksum_address(wallet.address)
results["address_format_valid"] = True
except:
pass
# Check if we can verify the keypair consistency
# (We can't do this without the password, but we can check the format)
if wallet.public_key and wallet.encrypted_private_key:
results["has_keypair_data"] = True
return results
async def migrate_wallet_encryption(
self,
wallet_id: int,
old_password: str,
new_password: str
) -> AgentWallet:
"""
Migrate wallet from old encryption to new secure encryption
Args:
wallet_id: Wallet ID
old_password: Current password
new_password: New strong password
Returns:
Updated wallet
"""
wallet = self.session.get(AgentWallet, wallet_id)
if not wallet:
raise ValueError("Wallet not found")
try:
# Get current private key
current_keys = await self.get_wallet_with_private_key(wallet_id, old_password)
# Validate new password
from ..utils.security import validate_password_strength
password_validation = validate_password_strength(new_password)
if not password_validation["is_acceptable"]:
raise ValueError(
f"New password too weak: {', '.join(password_validation['issues'])}"
)
# Re-encrypt with new password
new_encrypted_data = encrypt_private_key(current_keys["private_key"], new_password)
# Update wallet
wallet.encrypted_private_key = new_encrypted_data
wallet.encryption_version = "1.0"
wallet.updated_at = datetime.utcnow()
self.session.commit()
self.session.refresh(wallet)
logger.info(f"Migrated wallet {wallet_id} to secure encryption")
return wallet
except Exception as e:
logger.error(f"Failed to migrate wallet {wallet_id}: {e}")
self.session.rollback()
raise
async def get_balances(self, wallet_id: int) -> List[TokenBalance]:
"""Get all tracked balances for a wallet"""
return self.session.exec(
select(TokenBalance).where(TokenBalance.wallet_id == wallet_id)
).all()
async def update_balance(self, wallet_id: int, chain_id: int, token_address: str, balance: float) -> TokenBalance:
"""Update a specific token balance for a wallet"""
record = self.session.exec(
select(TokenBalance).where(
TokenBalance.wallet_id == wallet_id,
TokenBalance.chain_id == chain_id,
TokenBalance.token_address == token_address
)
).first()
if record:
record.balance = balance
record.updated_at = datetime.utcnow()
else:
record = TokenBalance(
wallet_id=wallet_id,
chain_id=chain_id,
token_address=token_address,
balance=balance,
updated_at=datetime.utcnow()
)
self.session.add(record)
self.session.commit()
self.session.refresh(record)
return record
async def create_transaction(
self,
wallet_id: int,
request: TransactionRequest,
encryption_password: str
) -> WalletTransaction:
"""
Create a transaction with proper signing
Args:
wallet_id: Wallet ID
request: Transaction request
encryption_password: Password for private key access
Returns:
Created transaction record
"""
# Get wallet keys
wallet_keys = await self.get_wallet_with_private_key(wallet_id, encryption_password)
# Create transaction record
transaction = WalletTransaction(
wallet_id=wallet_id,
to_address=request.to_address,
amount=request.amount,
token_address=request.token_address,
chain_id=request.chain_id,
data=request.data or "",
status=TransactionStatus.PENDING,
created_at=datetime.utcnow()
)
self.session.add(transaction)
self.session.commit()
self.session.refresh(transaction)
# TODO: Implement actual blockchain transaction signing and submission
# This would use the private_key to sign the transaction
logger.info(f"Created transaction {transaction.id} for wallet {wallet_id}")
return transaction
async def deactivate_wallet(self, wallet_id: int, reason: str = "User request") -> bool:
"""Deactivate a wallet"""
wallet = self.session.get(AgentWallet, wallet_id)
if not wallet:
return False
wallet.is_active = False
wallet.updated_at = datetime.utcnow()
wallet.deactivation_reason = reason
self.session.commit()
logger.info(f"Deactivated wallet {wallet_id}: {reason}")
return True
async def get_wallet_security_audit(self, wallet_id: int) -> Dict[str, Any]:
"""
Get comprehensive security audit for a wallet
Args:
wallet_id: Wallet ID
Returns:
Security audit results
"""
wallet = self.session.get(AgentWallet, wallet_id)
if not wallet:
return {"error": "Wallet not found"}
audit = {
"wallet_id": wallet_id,
"agent_id": wallet.agent_id,
"address": wallet.address,
"is_active": wallet.is_active,
"encryption_version": getattr(wallet, 'encryption_version', 'unknown'),
"created_at": wallet.created_at.isoformat() if wallet.created_at else None,
"updated_at": wallet.updated_at.isoformat() if wallet.updated_at else None
}
# Check encryption security
if isinstance(wallet.encrypted_private_key, dict):
audit["encryption_secure"] = True
audit["encryption_algorithm"] = wallet.encrypted_private_key.get("algorithm")
audit["encryption_iterations"] = wallet.encrypted_private_key.get("iterations")
else:
audit["encryption_secure"] = False
audit["encryption_issues"] = ["Uses legacy or broken encryption"]
# Check address format
try:
from eth_utils import to_checksum_address
to_checksum_address(wallet.address)
audit["address_valid"] = True
except:
audit["address_valid"] = False
audit["address_issues"] = ["Invalid Ethereum address format"]
# Check keypair data
audit["has_public_key"] = bool(wallet.public_key)
audit["has_encrypted_private_key"] = bool(wallet.encrypted_private_key)
# Overall security score
security_score = 0
if audit["encryption_secure"]:
security_score += 40
if audit["address_valid"]:
security_score += 30
if audit["has_public_key"]:
security_score += 15
if audit["has_encrypted_private_key"]:
security_score += 15
audit["security_score"] = security_score
audit["security_level"] = (
"Excellent" if security_score >= 90 else
"Good" if security_score >= 70 else
"Fair" if security_score >= 50 else
"Poor"
)
return audit

View File

@@ -0,0 +1,238 @@
"""
Secure Cryptographic Operations for Agent Wallets
Fixed implementation using proper Ethereum cryptography
"""
import secrets
from typing import Tuple, Dict, Any
from eth_account import Account
from eth_utils import to_checksum_address
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
import base64
import hashlib
def generate_ethereum_keypair() -> Tuple[str, str, str]:
"""
Generate proper Ethereum keypair using secp256k1
Returns:
Tuple of (private_key, public_key, address)
"""
# Use eth_account which properly implements secp256k1
account = Account.create()
private_key = account.key.hex()
public_key = account._private_key.public_key.to_hex()
address = account.address
return private_key, public_key, address
def verify_keypair_consistency(private_key: str, expected_address: str) -> bool:
"""
Verify that a private key generates the expected address
Args:
private_key: 32-byte private key hex
expected_address: Expected Ethereum address
Returns:
True if keypair is consistent
"""
try:
account = Account.from_key(private_key)
return to_checksum_address(account.address) == to_checksum_address(expected_address)
except Exception:
return False
def derive_secure_key(password: str, salt: bytes = None) -> bytes:
"""
Derive secure encryption key using PBKDF2
Args:
password: User password
salt: Optional salt (generated if not provided)
Returns:
Tuple of (key, salt) for storage
"""
if salt is None:
salt = secrets.token_bytes(32)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=600_000, # OWASP recommended minimum
)
key = kdf.derive(password.encode())
return base64.urlsafe_b64encode(key), salt
def encrypt_private_key(private_key: str, password: str) -> Dict[str, str]:
"""
Encrypt private key with proper KDF and Fernet
Args:
private_key: 32-byte private key hex
password: User password
Returns:
Dict with encrypted data and salt
"""
# Derive encryption key
fernet_key, salt = derive_secure_key(password)
# Encrypt
f = Fernet(fernet_key)
encrypted = f.encrypt(private_key.encode())
return {
"encrypted_key": encrypted.decode(),
"salt": base64.b64encode(salt).decode(),
"algorithm": "PBKDF2-SHA256-Fernet",
"iterations": 600_000
}
def decrypt_private_key(encrypted_data: Dict[str, str], password: str) -> str:
"""
Decrypt private key with proper verification
Args:
encrypted_data: Dict with encrypted key and salt
password: User password
Returns:
Decrypted private key
Raises:
ValueError: If decryption fails
"""
try:
# Extract salt and encrypted key
salt = base64.b64decode(encrypted_data["salt"])
encrypted_key = encrypted_data["encrypted_key"].encode()
# Derive same key
fernet_key, _ = derive_secure_key(password, salt)
# Decrypt
f = Fernet(fernet_key)
decrypted = f.decrypt(encrypted_key)
return decrypted.decode()
except Exception as e:
raise ValueError(f"Failed to decrypt private key: {str(e)}")
def validate_private_key_format(private_key: str) -> bool:
"""
Validate private key format
Args:
private_key: Private key to validate
Returns:
True if format is valid
"""
try:
# Remove 0x prefix if present
if private_key.startswith("0x"):
private_key = private_key[2:]
# Check length (32 bytes = 64 hex chars)
if len(private_key) != 64:
return False
# Check if valid hex
int(private_key, 16)
# Try to create account to verify it's a valid secp256k1 key
Account.from_key("0x" + private_key)
return True
except Exception:
return False
# Security configuration constants
class SecurityConfig:
"""Security configuration constants"""
# PBKDF2 settings
PBKDF2_ITERATIONS = 600_000
PBKDF2_ALGORITHM = hashes.SHA256
SALT_LENGTH = 32
# Fernet settings
FERNET_KEY_LENGTH = 32
# Validation
PRIVATE_KEY_LENGTH = 64 # 32 bytes in hex
ADDRESS_LENGTH = 40 # 20 bytes in hex (without 0x)
# Backward compatibility wrapper for existing code
def create_secure_wallet(agent_id: str, password: str) -> Dict[str, Any]:
"""
Create a wallet with proper security
Args:
agent_id: Agent identifier
password: Strong password for encryption
Returns:
Wallet data with encrypted private key
"""
# Generate proper keypair
private_key, public_key, address = generate_ethereum_keypair()
# Validate consistency
if not verify_keypair_consistency(private_key, address):
raise RuntimeError("Keypair generation failed consistency check")
# Encrypt private key
encrypted_data = encrypt_private_key(private_key, password)
return {
"agent_id": agent_id,
"address": address,
"public_key": public_key,
"encrypted_private_key": encrypted_data,
"created_at": secrets.token_hex(16), # For tracking
"version": "1.0"
}
def recover_wallet(encrypted_data: Dict[str, str], password: str) -> Dict[str, str]:
"""
Recover wallet from encrypted data
Args:
encrypted_data: Encrypted wallet data
password: Password for decryption
Returns:
Wallet keys
"""
# Decrypt private key
private_key = decrypt_private_key(encrypted_data, password)
# Validate format
if not validate_private_key_format(private_key):
raise ValueError("Decrypted private key has invalid format")
# Derive address and public key to verify
account = Account.from_key("0x" + private_key)
return {
"private_key": private_key,
"public_key": account._private_key.public_key.to_hex(),
"address": account.address
}

View File

@@ -0,0 +1,641 @@
"""
WebSocket Stream Manager with Backpressure Control
Advanced WebSocket stream architecture with per-stream flow control,
bounded queues, and event loop protection for multi-modal fusion.
"""
import asyncio
import json
import time
import weakref
from typing import Dict, List, Optional, Any, Callable, Set, Union
from dataclasses import dataclass, field
from enum import Enum
from collections import deque
import uuid
from contextlib import asynccontextmanager
import websockets
from websockets.server import WebSocketServerProtocol
from websockets.exceptions import ConnectionClosed
from aitbc.logging import get_logger
logger = get_logger(__name__)
class StreamStatus(Enum):
"""Stream connection status"""
CONNECTING = "connecting"
CONNECTED = "connected"
SLOW_CONSUMER = "slow_consumer"
BACKPRESSURE = "backpressure"
DISCONNECTED = "disconnected"
ERROR = "error"
class MessageType(Enum):
"""Message types for stream classification"""
CRITICAL = "critical" # High priority, must deliver
IMPORTANT = "important" # Normal priority
BULK = "bulk" # Low priority, can be dropped
CONTROL = "control" # Stream control messages
@dataclass
class StreamMessage:
"""Message with priority and metadata"""
data: Any
message_type: MessageType
timestamp: float = field(default_factory=time.time)
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
retry_count: int = 0
max_retries: int = 3
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.message_id,
"type": self.message_type.value,
"timestamp": self.timestamp,
"data": self.data
}
@dataclass
class StreamMetrics:
"""Metrics for stream performance monitoring"""
messages_sent: int = 0
messages_dropped: int = 0
bytes_sent: int = 0
last_send_time: float = 0
avg_send_time: float = 0
queue_size: int = 0
backpressure_events: int = 0
slow_consumer_events: int = 0
def update_send_metrics(self, send_time: float, message_size: int):
"""Update send performance metrics"""
self.messages_sent += 1
self.bytes_sent += message_size
self.last_send_time = time.time()
# Update average send time
if self.messages_sent == 1:
self.avg_send_time = send_time
else:
self.avg_send_time = (self.avg_send_time * (self.messages_sent - 1) + send_time) / self.messages_sent
@dataclass
class StreamConfig:
"""Configuration for individual streams"""
max_queue_size: int = 1000
send_timeout: float = 5.0
heartbeat_interval: float = 30.0
slow_consumer_threshold: float = 0.5 # seconds
backpressure_threshold: float = 0.8 # queue fill ratio
drop_bulk_threshold: float = 0.9 # queue fill ratio for bulk messages
enable_compression: bool = True
priority_send: bool = True
class BoundedMessageQueue:
"""Bounded queue with priority and backpressure handling"""
def __init__(self, max_size: int = 1000):
self.max_size = max_size
self.queues = {
MessageType.CRITICAL: deque(maxlen=max_size // 4),
MessageType.IMPORTANT: deque(maxlen=max_size // 2),
MessageType.BULK: deque(maxlen=max_size // 4),
MessageType.CONTROL: deque(maxlen=100) # Small control queue
}
self.total_size = 0
self._lock = asyncio.Lock()
async def put(self, message: StreamMessage) -> bool:
"""Add message to queue with backpressure handling"""
async with self._lock:
# Check if we're at capacity
if self.total_size >= self.max_size:
# Drop bulk messages first
if message.message_type == MessageType.BULK:
return False
# Drop oldest important messages if critical
if message.message_type == MessageType.IMPORTANT:
if self.queues[MessageType.IMPORTANT]:
self.queues[MessageType.IMPORTANT].popleft()
self.total_size -= 1
else:
return False
# Always allow critical messages (drop oldest if needed)
if message.message_type == MessageType.CRITICAL:
if self.queues[MessageType.CRITICAL]:
self.queues[MessageType.CRITICAL].popleft()
self.total_size -= 1
self.queues[message.message_type].append(message)
self.total_size += 1
return True
async def get(self) -> Optional[StreamMessage]:
"""Get next message by priority"""
async with self._lock:
# Priority order: CONTROL > CRITICAL > IMPORTANT > BULK
for message_type in [MessageType.CONTROL, MessageType.CRITICAL,
MessageType.IMPORTANT, MessageType.BULK]:
if self.queues[message_type]:
message = self.queues[message_type].popleft()
self.total_size -= 1
return message
return None
def size(self) -> int:
"""Get total queue size"""
return self.total_size
def fill_ratio(self) -> float:
"""Get queue fill ratio"""
return self.total_size / self.max_size
class WebSocketStream:
"""Individual WebSocket stream with backpressure control"""
def __init__(self, websocket: WebSocketServerProtocol,
stream_id: str, config: StreamConfig):
self.websocket = websocket
self.stream_id = stream_id
self.config = config
self.status = StreamStatus.CONNECTING
self.queue = BoundedMessageQueue(config.max_queue_size)
self.metrics = StreamMetrics()
self.last_heartbeat = time.time()
self.slow_consumer_count = 0
# Event loop protection
self._send_lock = asyncio.Lock()
self._sender_task = None
self._heartbeat_task = None
self._running = False
# Weak reference for cleanup
self._finalizer = weakref.finalize(self, self._cleanup)
async def start(self):
"""Start stream processing"""
if self._running:
return
self._running = True
self.status = StreamStatus.CONNECTED
# Start sender task
self._sender_task = asyncio.create_task(self._sender_loop())
# Start heartbeat task
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
logger.info(f"Stream {self.stream_id} started")
async def stop(self):
"""Stop stream processing"""
if not self._running:
return
self._running = False
self.status = StreamStatus.DISCONNECTED
# Cancel tasks
if self._sender_task:
self._sender_task.cancel()
try:
await self._sender_task
except asyncio.CancelledError:
pass
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
logger.info(f"Stream {self.stream_id} stopped")
async def send_message(self, data: Any, message_type: MessageType = MessageType.IMPORTANT) -> bool:
"""Send message with backpressure handling"""
if not self._running:
return False
message = StreamMessage(data=data, message_type=message_type)
# Check backpressure
queue_ratio = self.queue.fill_ratio()
if queue_ratio > self.config.backpressure_threshold:
self.status = StreamStatus.BACKPRESSURE
self.metrics.backpressure_events += 1
# Drop bulk messages under backpressure
if message_type == MessageType.BULK and queue_ratio > self.config.drop_bulk_threshold:
self.metrics.messages_dropped += 1
return False
# Add to queue
success = await self.queue.put(message)
if not success:
self.metrics.messages_dropped += 1
return success
async def _sender_loop(self):
"""Main sender loop with backpressure control"""
while self._running:
try:
# Get next message
message = await self.queue.get()
if message is None:
await asyncio.sleep(0.01)
continue
# Send with timeout and backpressure protection
start_time = time.time()
success = await self._send_with_backpressure(message)
send_time = time.time() - start_time
if success:
message_size = len(json.dumps(message.to_dict()).encode())
self.metrics.update_send_metrics(send_time, message_size)
else:
# Retry logic
message.retry_count += 1
if message.retry_count < message.max_retries:
await self.queue.put(message)
else:
self.metrics.messages_dropped += 1
logger.warning(f"Message {message.message_id} dropped after max retries")
# Check for slow consumer
if send_time > self.config.slow_consumer_threshold:
self.slow_consumer_count += 1
self.metrics.slow_consumer_events += 1
if self.slow_consumer_count > 5: # Threshold for slow consumer detection
self.status = StreamStatus.SLOW_CONSUMER
logger.warning(f"Stream {self.stream_id} detected as slow consumer")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in sender loop for stream {self.stream_id}: {e}")
await asyncio.sleep(0.1)
async def _send_with_backpressure(self, message: StreamMessage) -> bool:
"""Send message with backpressure and timeout protection"""
try:
async with self._send_lock:
# Use asyncio.wait_for for timeout protection
message_data = message.to_dict()
if self.config.enable_compression:
# Compress large messages
message_str = json.dumps(message_data, separators=(',', ':'))
if len(message_str) > 1024: # Compress messages > 1KB
message_data['_compressed'] = True
message_str = json.dumps(message_data, separators=(',', ':'))
else:
message_str = json.dumps(message_data)
# Send with timeout
await asyncio.wait_for(
self.websocket.send(message_str),
timeout=self.config.send_timeout
)
return True
except asyncio.TimeoutError:
logger.warning(f"Send timeout for stream {self.stream_id}")
return False
except ConnectionClosed:
logger.info(f"Connection closed for stream {self.stream_id}")
await self.stop()
return False
except Exception as e:
logger.error(f"Send error for stream {self.stream_id}: {e}")
return False
async def _heartbeat_loop(self):
"""Heartbeat loop for connection health monitoring"""
while self._running:
try:
await asyncio.sleep(self.config.heartbeat_interval)
if not self._running:
break
# Send heartbeat
heartbeat_msg = {
"type": "heartbeat",
"timestamp": time.time(),
"stream_id": self.stream_id,
"queue_size": self.queue.size(),
"status": self.status.value
}
await self.send_message(heartbeat_msg, MessageType.CONTROL)
self.last_heartbeat = time.time()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Heartbeat error for stream {self.stream_id}: {e}")
def get_metrics(self) -> Dict[str, Any]:
"""Get stream metrics"""
return {
"stream_id": self.stream_id,
"status": self.status.value,
"queue_size": self.queue.size(),
"queue_fill_ratio": self.queue.fill_ratio(),
"messages_sent": self.metrics.messages_sent,
"messages_dropped": self.metrics.messages_dropped,
"bytes_sent": self.metrics.bytes_sent,
"avg_send_time": self.metrics.avg_send_time,
"backpressure_events": self.metrics.backpressure_events,
"slow_consumer_events": self.metrics.slow_consumer_events,
"last_heartbeat": self.last_heartbeat
}
def _cleanup(self):
"""Cleanup resources"""
if self._running:
# This should be called by garbage collector
logger.warning(f"Stream {self.stream_id} cleanup called while running")
class WebSocketStreamManager:
"""Manages multiple WebSocket streams with backpressure control"""
def __init__(self, default_config: Optional[StreamConfig] = None):
self.default_config = default_config or StreamConfig()
self.streams: Dict[str, WebSocketStream] = {}
self.stream_configs: Dict[str, StreamConfig] = {}
# Global metrics
self.total_connections = 0
self.total_messages_sent = 0
self.total_messages_dropped = 0
# Event loop protection
self._manager_lock = asyncio.Lock()
self._cleanup_task = None
self._running = False
# Message broadcasting
self._broadcast_queue = asyncio.Queue(maxsize=10000)
self._broadcast_task = None
async def start(self):
"""Start the stream manager"""
if self._running:
return
self._running = True
# Start cleanup task
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
# Start broadcast task
self._broadcast_task = asyncio.create_task(self._broadcast_loop())
logger.info("WebSocket Stream Manager started")
async def stop(self):
"""Stop the stream manager"""
if not self._running:
return
self._running = False
# Stop all streams
streams_to_stop = list(self.streams.values())
for stream in streams_to_stop:
await stream.stop()
# Cancel tasks
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
if self._broadcast_task:
self._broadcast_task.cancel()
try:
await self._broadcast_task
except asyncio.CancelledError:
pass
logger.info("WebSocket Stream Manager stopped")
async def manage_stream(self, websocket: WebSocketServerProtocol,
config: Optional[StreamConfig] = None):
"""Context manager for stream lifecycle"""
stream_id = str(uuid.uuid4())
stream_config = config or self.default_config
stream = None
try:
# Create and start stream
stream = WebSocketStream(websocket, stream_id, stream_config)
await stream.start()
async with self._manager_lock:
self.streams[stream_id] = stream
self.stream_configs[stream_id] = stream_config
self.total_connections += 1
logger.info(f"Stream {stream_id} added to manager")
yield stream
except Exception as e:
logger.error(f"Error managing stream {stream_id}: {e}")
raise
finally:
# Cleanup stream
if stream and stream_id in self.streams:
await stream.stop()
async with self._manager_lock:
del self.streams[stream_id]
if stream_id in self.stream_configs:
del self.stream_configs[stream_id]
self.total_connections -= 1
logger.info(f"Stream {stream_id} removed from manager")
async def broadcast_to_all(self, data: Any, message_type: MessageType = MessageType.IMPORTANT):
"""Broadcast message to all streams"""
if not self._running:
return
try:
await self._broadcast_queue.put((data, message_type))
except asyncio.QueueFull:
logger.warning("Broadcast queue full, dropping message")
self.total_messages_dropped += 1
async def broadcast_to_stream(self, stream_id: str, data: Any,
message_type: MessageType = MessageType.IMPORTANT):
"""Send message to specific stream"""
async with self._manager_lock:
stream = self.streams.get(stream_id)
if stream:
await stream.send_message(data, message_type)
async def _broadcast_loop(self):
"""Broadcast messages to all streams"""
while self._running:
try:
# Get broadcast message
data, message_type = await self._broadcast_queue.get()
# Send to all streams concurrently
tasks = []
async with self._manager_lock:
streams = list(self.streams.values())
for stream in streams:
task = asyncio.create_task(
stream.send_message(data, message_type)
)
tasks.append(task)
# Wait for all sends (with timeout)
if tasks:
try:
await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True),
timeout=1.0
)
except asyncio.TimeoutError:
logger.warning("Broadcast timeout, some streams may be slow")
self.total_messages_sent += 1
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in broadcast loop: {e}")
await asyncio.sleep(0.1)
async def _cleanup_loop(self):
"""Cleanup disconnected streams"""
while self._running:
try:
await asyncio.sleep(60) # Cleanup every minute
disconnected_streams = []
async with self._manager_lock:
for stream_id, stream in self.streams.items():
if stream.status == StreamStatus.DISCONNECTED:
disconnected_streams.append(stream_id)
# Remove disconnected streams
for stream_id in disconnected_streams:
if stream_id in self.streams:
stream = self.streams[stream_id]
await stream.stop()
del self.streams[stream_id]
if stream_id in self.stream_configs:
del self.stream_configs[stream_id]
self.total_connections -= 1
logger.info(f"Cleaned up disconnected stream {stream_id}")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in cleanup loop: {e}")
async def get_manager_metrics(self) -> Dict[str, Any]:
"""Get comprehensive manager metrics"""
async with self._manager_lock:
stream_metrics = []
for stream in self.streams.values():
stream_metrics.append(stream.get_metrics())
# Calculate aggregate metrics
total_queue_size = sum(m["queue_size"] for m in stream_metrics)
total_messages_sent = sum(m["messages_sent"] for m in stream_metrics)
total_messages_dropped = sum(m["messages_dropped"] for m in stream_metrics)
total_bytes_sent = sum(m["bytes_sent"] for m in stream_metrics)
# Status distribution
status_counts = {}
for stream in self.streams.values():
status = stream.status.value
status_counts[status] = status_counts.get(status, 0) + 1
return {
"manager_status": "running" if self._running else "stopped",
"total_connections": self.total_connections,
"active_streams": len(self.streams),
"total_queue_size": total_queue_size,
"total_messages_sent": total_messages_sent,
"total_messages_dropped": total_messages_dropped,
"total_bytes_sent": total_bytes_sent,
"broadcast_queue_size": self._broadcast_queue.qsize(),
"stream_status_distribution": status_counts,
"stream_metrics": stream_metrics
}
async def update_stream_config(self, stream_id: str, config: StreamConfig):
"""Update configuration for specific stream"""
async with self._manager_lock:
if stream_id in self.streams:
self.stream_configs[stream_id] = config
# Stream will use new config on next send
logger.info(f"Updated config for stream {stream_id}")
def get_slow_streams(self, threshold: float = 0.8) -> List[str]:
"""Get streams with high queue fill ratios"""
slow_streams = []
for stream_id, stream in self.streams.items():
if stream.queue.fill_ratio() > threshold:
slow_streams.append(stream_id)
return slow_streams
async def handle_slow_consumer(self, stream_id: str, action: str = "warn"):
"""Handle slow consumer streams"""
async with self._manager_lock:
stream = self.streams.get(stream_id)
if not stream:
return
if action == "warn":
logger.warning(f"Slow consumer detected: {stream_id}")
await stream.send_message(
{"warning": "Slow consumer detected", "stream_id": stream_id},
MessageType.CONTROL
)
elif action == "throttle":
# Reduce queue size for slow consumer
new_config = StreamConfig(
max_queue_size=stream.config.max_queue_size // 2,
send_timeout=stream.config.send_timeout * 2
)
await self.update_stream_config(stream_id, new_config)
logger.info(f"Throttled slow consumer: {stream_id}")
elif action == "disconnect":
logger.warning(f"Disconnecting slow consumer: {stream_id}")
await stream.stop()
# Global stream manager instance
stream_manager = WebSocketStreamManager()