Files
aitbc/tests/test_websocket_stream_backpressure.py
oib f353e00172 chore(security): enhance environment configuration, CI workflows, and wallet daemon with security improvements
- Restructure .env.example with security-focused documentation, service-specific environment file references, and AWS Secrets Manager integration
- Update CLI tests workflow to single Python 3.13 version, add pytest-mock dependency, and consolidate test execution with coverage
- Add comprehensive security validation to package publishing workflow with manual approval gates, secret scanning, and release
2026-03-03 10:33:46 +01:00

777 lines
26 KiB
Python

"""
Tests for WebSocket Stream Backpressure Control
Comprehensive test suite for WebSocket stream architecture with
per-stream flow control and backpressure handling.
"""
import pytest
import asyncio
import json
import time
from unittest.mock import Mock, AsyncMock, patch
from typing import Dict, Any
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'apps', 'coordinator-api', 'src'))
from app.services.websocket_stream_manager import (
WebSocketStreamManager, StreamConfig, StreamMessage, MessageType,
BoundedMessageQueue, WebSocketStream, StreamStatus
)
from app.services.multi_modal_websocket_fusion import (
MultiModalWebSocketFusion, FusionStreamType, FusionStreamConfig,
GPUProviderFlowControl, GPUProviderStatus, FusionData
)
class TestBoundedMessageQueue:
"""Test bounded message queue with priority and backpressure"""
@pytest.fixture
def queue(self):
return BoundedMessageQueue(max_size=10)
@pytest.mark.asyncio
async def test_basic_queue_operations(self, queue):
"""Test basic queue put/get operations"""
message = StreamMessage(data="test", message_type=MessageType.IMPORTANT)
# Put message
success = await queue.put(message)
assert success is True
assert queue.size() == 1
# Get message
retrieved = await queue.get()
assert retrieved == message
assert queue.size() == 0
@pytest.mark.asyncio
async def test_priority_ordering(self, queue):
"""Test message priority ordering"""
messages = [
StreamMessage(data="bulk", message_type=MessageType.BULK),
StreamMessage(data="critical", message_type=MessageType.CRITICAL),
StreamMessage(data="important", message_type=MessageType.IMPORTANT),
StreamMessage(data="control", message_type=MessageType.CONTROL)
]
# Add messages in random order
for msg in messages:
await queue.put(msg)
# Should retrieve in priority order: CONTROL > CRITICAL > IMPORTANT > BULK
expected_order = [MessageType.CONTROL, MessageType.CRITICAL,
MessageType.IMPORTANT, MessageType.BULK]
for expected_type in expected_order:
msg = await queue.get()
assert msg.message_type == expected_type
@pytest.mark.asyncio
async def test_backpressure_handling(self, queue):
"""Test backpressure handling when queue is full"""
# Fill queue to capacity
for i in range(queue.max_size):
await queue.put(StreamMessage(data=f"bulk_{i}", message_type=MessageType.BULK))
assert queue.size() == queue.max_size
assert queue.fill_ratio() == 1.0
# Try to add bulk message (should be dropped)
bulk_msg = StreamMessage(data="new_bulk", message_type=MessageType.BULK)
success = await queue.put(bulk_msg)
assert success is False
# Try to add important message (should replace oldest important)
important_msg = StreamMessage(data="new_important", message_type=MessageType.IMPORTANT)
success = await queue.put(important_msg)
assert success is True
# Try to add critical message (should always succeed)
critical_msg = StreamMessage(data="new_critical", message_type=MessageType.CRITICAL)
success = await queue.put(critical_msg)
assert success is True
@pytest.mark.asyncio
async def test_queue_size_limits(self, queue):
"""Test that individual queue size limits are respected"""
# Fill control queue to its limit
for i in range(100): # Control queue limit is 100
await queue.put(StreamMessage(data=f"control_{i}", message_type=MessageType.CONTROL))
# Should still accept other message types
success = await queue.put(StreamMessage(data="important", message_type=MessageType.IMPORTANT))
assert success is True
class TestWebSocketStream:
"""Test individual WebSocket stream with backpressure control"""
@pytest.fixture
def mock_websocket(self):
websocket = Mock()
websocket.send = AsyncMock()
websocket.remote_address = "127.0.0.1:12345"
return websocket
@pytest.fixture
def stream_config(self):
return StreamConfig(
max_queue_size=50,
send_timeout=1.0,
slow_consumer_threshold=0.1,
backpressure_threshold=0.7
)
@pytest.fixture
def stream(self, mock_websocket, stream_config):
return WebSocketStream(mock_websocket, "test_stream", stream_config)
@pytest.mark.asyncio
async def test_stream_start_stop(self, stream):
"""Test stream start and stop"""
assert stream.status == StreamStatus.CONNECTING
await stream.start()
assert stream.status == StreamStatus.CONNECTED
assert stream._running is True
await stream.stop()
assert stream.status == StreamStatus.DISCONNECTED
assert stream._running is False
@pytest.mark.asyncio
async def test_message_sending(self, stream, mock_websocket):
"""Test basic message sending"""
await stream.start()
# Send message
success = await stream.send_message({"test": "data"}, MessageType.IMPORTANT)
assert success is True
# Wait for message to be processed
await asyncio.sleep(0.1)
# Verify message was sent
mock_websocket.send.assert_called()
await stream.stop()
@pytest.mark.asyncio
async def test_slow_consumer_detection(self, stream, mock_websocket):
"""Test slow consumer detection"""
# Make websocket send slow
async def slow_send(message):
await asyncio.sleep(0.2) # Slower than threshold (0.1s)
mock_websocket.send = slow_send
await stream.start()
# Send multiple messages to trigger slow consumer detection
for i in range(10):
await stream.send_message({"test": f"data_{i}"}, MessageType.IMPORTANT)
# Wait for processing
await asyncio.sleep(1.0)
# Check if slow consumer was detected
assert stream.status == StreamStatus.SLOW_CONSUMER
assert stream.metrics.slow_consumer_events > 0
await stream.stop()
@pytest.mark.asyncio
async def test_backpressure_handling(self, stream, mock_websocket):
"""Test backpressure handling"""
await stream.start()
# Fill queue to trigger backpressure
for i in range(40): # 40/50 = 80% > backpressure_threshold (70%)
await stream.send_message({"test": f"data_{i}"}, MessageType.IMPORTANT)
# Wait for processing
await asyncio.sleep(0.1)
# Check backpressure status
assert stream.status == StreamStatus.BACKPRESSURE
assert stream.metrics.backpressure_events > 0
# Try to send bulk message under backpressure
success = await stream.send_message({"bulk": "data"}, MessageType.BULK)
# Should be dropped due to high queue fill ratio
assert stream.queue.fill_ratio() > 0.8
await stream.stop()
@pytest.mark.asyncio
async def test_message_priority_handling(self, stream, mock_websocket):
"""Test that priority messages are handled correctly"""
await stream.start()
# Send messages of different priorities
await stream.send_message({"bulk": "data"}, MessageType.BULK)
await stream.send_message({"critical": "data"}, MessageType.CRITICAL)
await stream.send_message({"important": "data"}, MessageType.IMPORTANT)
await stream.send_message({"control": "data"}, MessageType.CONTROL)
# Wait for processing
await asyncio.sleep(0.2)
# Verify all messages were sent
assert mock_websocket.send.call_count >= 4
await stream.stop()
@pytest.mark.asyncio
async def test_send_timeout_handling(self, stream, mock_websocket):
"""Test send timeout handling"""
# Make websocket send timeout
async def timeout_send(message):
await asyncio.sleep(2.0) # Longer than send_timeout (1.0s)
mock_websocket.send = timeout_send
await stream.start()
# Send message
success = await stream.send_message({"test": "data"}, MessageType.IMPORTANT)
assert success is True
# Wait for processing
await asyncio.sleep(1.5)
# Check that message was dropped due to timeout
assert stream.metrics.messages_dropped > 0
await stream.stop()
def test_stream_metrics(self, stream):
"""Test stream metrics collection"""
metrics = stream.get_metrics()
assert "stream_id" in metrics
assert "status" in metrics
assert "queue_size" in metrics
assert "messages_sent" in metrics
assert "messages_dropped" in metrics
assert "backpressure_events" in metrics
assert "slow_consumer_events" in metrics
class TestWebSocketStreamManager:
"""Test WebSocket stream manager with multiple streams"""
@pytest.fixture
def manager(self):
return WebSocketStreamManager()
@pytest.fixture
def mock_websocket(self):
websocket = Mock()
websocket.send = AsyncMock()
websocket.remote_address = "127.0.0.1:12345"
return websocket
@pytest.mark.asyncio
async def test_manager_start_stop(self, manager):
"""Test manager start and stop"""
await manager.start()
assert manager._running is True
await manager.stop()
assert manager._running is False
@pytest.mark.asyncio
async def test_stream_lifecycle_management(self, manager, mock_websocket):
"""Test stream lifecycle management"""
await manager.start()
# Create stream through manager
stream = None
async with manager.manage_stream(mock_websocket) as s:
stream = s
assert stream is not None
assert stream._running is True
assert len(manager.streams) == 1
assert manager.total_connections == 1
# Stream should be cleaned up
assert len(manager.streams) == 0
assert manager.total_connections == 0
await manager.stop()
@pytest.mark.asyncio
async def test_broadcast_to_all_streams(self, manager):
"""Test broadcasting to all streams"""
await manager.start()
# Create multiple mock websockets
websockets = [Mock() for _ in range(3)]
for ws in websockets:
ws.send = AsyncMock()
ws.remote_address = f"127.0.0.1:{12345 + websockets.index(ws)}"
# Create streams
streams = []
for ws in websockets:
async with manager.manage_stream(ws) as stream:
streams.append(stream)
await asyncio.sleep(0.01) # Small delay
# Broadcast message
await manager.broadcast_to_all({"broadcast": "test"}, MessageType.IMPORTANT)
# Wait for broadcast
await asyncio.sleep(0.2)
# Verify all streams received the message
for ws in websockets:
ws.send.assert_called()
await manager.stop()
@pytest.mark.asyncio
async def test_slow_stream_handling(self, manager):
"""Test handling of slow streams"""
await manager.start()
# Create slow websocket
slow_websocket = Mock()
async def slow_send(message):
await asyncio.sleep(0.5) # Very slow
slow_websocket.send = slow_send
slow_websocket.remote_address = "127.0.0.1:12345"
# Create slow stream
async with manager.manage_stream(slow_websocket) as stream:
# Send messages to fill queue
for i in range(20):
await stream.send_message({"test": f"data_{i}"}, MessageType.IMPORTANT)
await asyncio.sleep(0.5)
# Check if stream is detected as slow
slow_streams = manager.get_slow_streams(threshold=0.5)
assert len(slow_streams) > 0
await manager.stop()
@pytest.mark.asyncio
async def test_manager_metrics(self, manager):
"""Test manager metrics collection"""
await manager.start()
# Create some streams
websockets = [Mock() for _ in range(2)]
for ws in websockets:
ws.send = AsyncMock()
ws.remote_address = f"127.0.0.1:{12345 + websockets.index(ws)}"
streams = []
for ws in websockets:
async with manager.manage_stream(ws) as stream:
streams.append(stream)
await stream.send_message({"test": "data"}, MessageType.IMPORTANT)
await asyncio.sleep(0.01)
# Get metrics
metrics = await manager.get_manager_metrics()
assert "manager_status" in metrics
assert "total_connections" in metrics
assert "active_streams" in metrics
assert "total_queue_size" in metrics
assert "stream_status_distribution" in metrics
assert "stream_metrics" in metrics
await manager.stop()
class TestGPUProviderFlowControl:
"""Test GPU provider flow control"""
@pytest.fixture
def provider(self):
return GPUProviderFlowControl("test_provider")
@pytest.mark.asyncio
async def test_provider_start_stop(self, provider):
"""Test provider start and stop"""
await provider.start()
assert provider._running is True
await provider.stop()
assert provider._running is False
@pytest.mark.asyncio
async def test_request_submission(self, provider):
"""Test request submission and processing"""
await provider.start()
# Create fusion data
fusion_data = FusionData(
stream_id="test_stream",
stream_type=FusionStreamType.VISUAL,
data={"test": "data"},
timestamp=time.time(),
requires_gpu=True
)
# Submit request
request_id = await provider.submit_request(fusion_data)
assert request_id is not None
# Get result
result = await provider.get_result(request_id, timeout=3.0)
assert result is not None
assert "processed_data" in result
await provider.stop()
@pytest.mark.asyncio
async def test_concurrent_request_limiting(self, provider):
"""Test concurrent request limiting"""
provider.max_concurrent_requests = 2
await provider.start()
# Submit multiple requests
fusion_data = FusionData(
stream_id="test_stream",
stream_type=FusionStreamType.VISUAL,
data={"test": "data"},
timestamp=time.time(),
requires_gpu=True
)
request_ids = []
for i in range(5):
request_id = await provider.submit_request(fusion_data)
if request_id:
request_ids.append(request_id)
# Should have processed some requests
assert len(request_ids) > 0
# Get results
results = []
for request_id in request_ids:
result = await provider.get_result(request_id, timeout=5.0)
if result:
results.append(result)
assert len(results) > 0
await provider.stop()
@pytest.mark.asyncio
async def test_overload_handling(self, provider):
"""Test provider overload handling"""
await provider.start()
# Fill input queue to capacity
fusion_data = FusionData(
stream_id="test_stream",
stream_type=FusionStreamType.VISUAL,
data={"test": "data"},
timestamp=time.time(),
requires_gpu=True
)
# Submit many requests to fill queue
request_ids = []
for i in range(150): # More than queue capacity (100)
request_id = await provider.submit_request(fusion_data)
if request_id:
request_ids.append(request_id)
else:
break # Queue is full
# Should have rejected some requests due to overload
assert len(request_ids) < 150
# Check provider status
metrics = provider.get_metrics()
assert metrics["queue_size"] >= provider.input_queue.maxsize * 0.8
await provider.stop()
@pytest.mark.asyncio
async def test_provider_metrics(self, provider):
"""Test provider metrics collection"""
await provider.start()
# Submit some requests
fusion_data = FusionData(
stream_id="test_stream",
stream_type=FusionStreamType.VISUAL,
data={"test": "data"},
timestamp=time.time(),
requires_gpu=True
)
for i in range(3):
request_id = await provider.submit_request(fusion_data)
if request_id:
await provider.get_result(request_id, timeout=3.0)
# Get metrics
metrics = provider.get_metrics()
assert "provider_id" in metrics
assert "status" in metrics
assert "avg_processing_time" in metrics
assert "queue_size" in metrics
assert "total_requests" in metrics
assert "error_rate" in metrics
await provider.stop()
class TestMultiModalWebSocketFusion:
"""Test multi-modal WebSocket fusion service"""
@pytest.fixture
def fusion_service(self):
return MultiModalWebSocketFusion()
@pytest.mark.asyncio
async def test_fusion_service_start_stop(self, fusion_service):
"""Test fusion service start and stop"""
await fusion_service.start()
assert fusion_service._running is True
await fusion_service.stop()
assert fusion_service._running is False
@pytest.mark.asyncio
async def test_fusion_stream_registration(self, fusion_service):
"""Test fusion stream registration"""
await fusion_service.start()
config = FusionStreamConfig(
stream_type=FusionStreamType.VISUAL,
max_queue_size=100,
gpu_timeout=2.0
)
await fusion_service.register_fusion_stream("test_stream", config)
assert "test_stream" in fusion_service.fusion_streams
assert fusion_service.fusion_streams["test_stream"].stream_type == FusionStreamType.VISUAL
await fusion_service.stop()
@pytest.mark.asyncio
async def test_gpu_provider_initialization(self, fusion_service):
"""Test GPU provider initialization"""
await fusion_service.start()
assert len(fusion_service.gpu_providers) > 0
# Check that providers are running
for provider in fusion_service.gpu_providers.values():
assert provider._running is True
await fusion_service.stop()
@pytest.mark.asyncio
async def test_fusion_data_processing(self, fusion_service):
"""Test fusion data processing"""
await fusion_service.start()
# Create fusion data
fusion_data = FusionData(
stream_id="test_stream",
stream_type=FusionStreamType.VISUAL,
data={"test": "data"},
timestamp=time.time(),
requires_gpu=True
)
# Process data
await fusion_service._submit_to_gpu_provider(fusion_data)
# Wait for processing
await asyncio.sleep(1.0)
# Check metrics
assert fusion_service.fusion_metrics["total_fusions"] >= 1
await fusion_service.stop()
@pytest.mark.asyncio
async def test_comprehensive_metrics(self, fusion_service):
"""Test comprehensive metrics collection"""
await fusion_service.start()
# Get metrics
metrics = fusion_service.get_comprehensive_metrics()
assert "timestamp" in metrics
assert "system_status" in metrics
assert "stream_metrics" in metrics
assert "gpu_metrics" in metrics
assert "fusion_metrics" in metrics
assert "active_fusion_streams" in metrics
assert "registered_gpu_providers" in metrics
await fusion_service.stop()
@pytest.mark.asyncio
async def test_backpressure_monitoring(self, fusion_service):
"""Test backpressure monitoring"""
await fusion_service.start()
# Enable backpressure
fusion_service.backpressure_enabled = True
# Simulate high load
fusion_service.global_queue_size = 8000 # High queue size
fusion_service.max_global_queue_size = 10000
# Run monitoring check
await fusion_service._check_backpressure()
# Should have handled backpressure
# (This is a simplified test - in reality would check slow streams)
await fusion_service.stop()
class TestIntegrationScenarios:
"""Integration tests for complete scenarios"""
@pytest.mark.asyncio
async def test_multi_stream_fusion_workflow(self):
"""Test complete multi-stream fusion workflow"""
fusion_service = MultiModalWebSocketFusion()
await fusion_service.start()
try:
# Register multiple streams
stream_configs = [
("visual_stream", FusionStreamType.VISUAL),
("text_stream", FusionStreamType.TEXT),
("audio_stream", FusionStreamType.AUDIO)
]
for stream_id, stream_type in stream_configs:
config = FusionStreamConfig(stream_type=stream_type)
await fusion_service.register_fusion_stream(stream_id, config)
# Process fusion data for each stream
for stream_id, stream_type in stream_configs:
fusion_data = FusionData(
stream_id=stream_id,
stream_type=stream_type,
data={"test": f"data_{stream_type.value}"},
timestamp=time.time(),
requires_gpu=stream_type in [FusionStreamType.VISUAL, FusionStreamType.AUDIO]
)
if fusion_data.requires_gpu:
await fusion_service._submit_to_gpu_provider(fusion_data)
else:
await fusion_service._process_cpu_fusion(fusion_data)
# Wait for processing
await asyncio.sleep(2.0)
# Check results
metrics = fusion_service.get_comprehensive_metrics()
assert metrics["fusion_metrics"]["total_fusions"] >= 3
finally:
await fusion_service.stop()
@pytest.mark.asyncio
async def test_slow_gpu_provider_handling(self):
"""Test handling of slow GPU providers"""
fusion_service = MultiModalWebSocketFusion()
await fusion_service.start()
try:
# Make one GPU provider slow
if "gpu_1" in fusion_service.gpu_providers:
provider = fusion_service.gpu_providers["gpu_1"]
# Simulate slow processing by increasing processing time
original_process = provider._process_request
async def slow_process(request_data):
await asyncio.sleep(1.0) # Add delay
return await original_process(request_data)
provider._process_request = slow_process
# Submit fusion data
fusion_data = FusionData(
stream_id="test_stream",
stream_type=FusionStreamType.VISUAL,
data={"test": "data"},
timestamp=time.time(),
requires_gpu=True
)
# Should select fastest available provider
await fusion_service._submit_to_gpu_provider(fusion_data)
# Wait for processing
await asyncio.sleep(2.0)
# Check that processing completed
assert fusion_service.fusion_metrics["total_fusions"] >= 1
finally:
await fusion_service.stop()
@pytest.mark.asyncio
async def test_system_under_load(self):
"""Test system behavior under high load"""
fusion_service = MultiModalWebSocketFusion()
await fusion_service.start()
try:
# Submit many fusion requests
tasks = []
for i in range(50):
fusion_data = FusionData(
stream_id=f"stream_{i % 5}",
stream_type=FusionStreamType.VISUAL,
data={"test": f"data_{i}"},
timestamp=time.time(),
requires_gpu=True
)
task = asyncio.create_task(
fusion_service._submit_to_gpu_provider(fusion_data)
)
tasks.append(task)
# Wait for all tasks
await asyncio.gather(*tasks, return_exceptions=True)
# Wait for processing
await asyncio.sleep(3.0)
# Check system handled load
metrics = fusion_service.get_comprehensive_metrics()
# Should have processed many requests
assert metrics["fusion_metrics"]["total_fusions"] >= 10
# System should still be responsive
assert metrics["system_status"] == "running"
finally:
await fusion_service.stop()
if __name__ == "__main__":
pytest.main([__file__, "-v"])