Some checks failed
Cross-Node Transaction Testing / transaction-test (push) Has been cancelled
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Multi-Node Stress Testing / stress-test (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
- Added logger initialization to EventRouter in events.py - Fixed datetime.timedelta references to use timedelta directly in security_hardening.py - Fixed StateTransition timestamp default_factory to use lambda in state.py - Fixed StateValidator.validate_transitions to only check source states exist - Moved cross_chain_bridge_enhanced.py to cross_chain/bridge_enhanced.py - Updated import paths in global_marketplace
618 lines
20 KiB
Python
618 lines
20 KiB
Python
"""
|
|
Tests for state management utilities
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
import json
|
|
import tempfile
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
|
|
from aitbc.state import (
|
|
StateTransitionError,
|
|
StatePersistenceError,
|
|
StateTransition,
|
|
StateMachine,
|
|
ConfigurableStateMachine,
|
|
StatePersistence,
|
|
AsyncStateMachine,
|
|
StateMonitor,
|
|
StateValidator,
|
|
StateSnapshot,
|
|
)
|
|
|
|
|
|
class TestExceptions:
|
|
"""Tests for state exceptions"""
|
|
|
|
def test_state_transition_error(self):
|
|
"""Test StateTransitionError"""
|
|
with pytest.raises(StateTransitionError):
|
|
raise StateTransitionError("Invalid transition")
|
|
|
|
def test_state_persistence_error(self):
|
|
"""Test StatePersistenceError"""
|
|
with pytest.raises(StatePersistenceError):
|
|
raise StatePersistenceError("Persistence failed")
|
|
|
|
|
|
class TestStateTransition:
|
|
"""Tests for StateTransition dataclass"""
|
|
|
|
def test_state_transition_creation(self):
|
|
"""Test StateTransition creation"""
|
|
transition = StateTransition(
|
|
from_state="state1",
|
|
to_state="state2",
|
|
data={"key": "value"}
|
|
)
|
|
assert transition.from_state == "state1"
|
|
assert transition.to_state == "state2"
|
|
assert transition.data == {"key": "value"}
|
|
assert transition.timestamp is not None
|
|
|
|
def test_state_transition_defaults(self):
|
|
"""Test StateTransition with defaults"""
|
|
transition = StateTransition(
|
|
from_state="state1",
|
|
to_state="state2"
|
|
)
|
|
assert transition.data == {}
|
|
assert transition.timestamp is not None
|
|
|
|
|
|
class TestStateMachine:
|
|
"""Tests for StateMachine"""
|
|
|
|
def test_initialization(self):
|
|
"""Test StateMachine initialization"""
|
|
machine = TestableStateMachine("initial")
|
|
assert machine.current_state == "initial"
|
|
assert machine.transitions == []
|
|
assert machine.state_data == {"initial": {}}
|
|
|
|
def test_can_transition_valid(self):
|
|
"""Test can_transition with valid transition"""
|
|
machine = TestableStateMachine("state1")
|
|
assert machine.can_transition("state2") is True
|
|
|
|
def test_can_transition_invalid(self):
|
|
"""Test can_transition with invalid transition"""
|
|
machine = TestableStateMachine("state1")
|
|
assert machine.can_transition("invalid") is False
|
|
|
|
def test_transition_success(self):
|
|
"""Test successful transition"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.transition("state2")
|
|
|
|
assert machine.current_state == "state2"
|
|
assert len(machine.transitions) == 1
|
|
assert machine.transitions[0].from_state == "state1"
|
|
assert machine.transitions[0].to_state == "state2"
|
|
|
|
def test_transition_with_data(self):
|
|
"""Test transition with data"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.transition("state2", data={"key": "value"})
|
|
|
|
assert machine.transitions[0].data == {"key": "value"}
|
|
|
|
def test_transition_invalid(self):
|
|
"""Test invalid transition raises error"""
|
|
machine = TestableStateMachine("state1")
|
|
|
|
with pytest.raises(StateTransitionError):
|
|
machine.transition("invalid")
|
|
|
|
def test_get_state_data_current(self):
|
|
"""Test get_state_data for current state"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.set_state_data({"key": "value"})
|
|
|
|
data = machine.get_state_data()
|
|
assert data == {"key": "value"}
|
|
|
|
def test_get_state_data_specific(self):
|
|
"""Test get_state_data for specific state"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.set_state_data({"key": "value1"}, state="state1")
|
|
machine.transition("state2")
|
|
machine.set_state_data({"key": "value2"}, state="state2")
|
|
|
|
data = machine.get_state_data("state1")
|
|
assert data == {"key": "value1"}
|
|
|
|
def test_set_state_data(self):
|
|
"""Test set_state_data"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.set_state_data({"key": "value"})
|
|
|
|
assert machine.state_data["state1"] == {"key": "value"}
|
|
|
|
def test_set_state_data_merge(self):
|
|
"""Test set_state_data merges existing data"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.set_state_data({"key1": "value1"})
|
|
machine.set_state_data({"key2": "value2"})
|
|
|
|
assert machine.state_data["state1"] == {"key1": "value1", "key2": "value2"}
|
|
|
|
def test_get_transition_history(self):
|
|
"""Test get_transition_history"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.transition("state2")
|
|
machine.transition("state3")
|
|
|
|
history = machine.get_transition_history()
|
|
assert len(history) == 2
|
|
|
|
def test_get_transition_history_with_limit(self):
|
|
"""Test get_transition_history with limit"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.transition("state2")
|
|
machine.transition("state3")
|
|
machine.transition("state4")
|
|
|
|
history = machine.get_transition_history(limit=2)
|
|
assert len(history) == 2
|
|
assert history[0].from_state == "state2"
|
|
assert history[1].from_state == "state3"
|
|
|
|
def test_reset(self):
|
|
"""Test reset state machine"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.transition("state2")
|
|
machine.set_state_data({"key": "value"})
|
|
|
|
machine.reset("initial")
|
|
|
|
assert machine.current_state == "initial"
|
|
assert machine.transitions == []
|
|
assert machine.state_data == {"initial": {}}
|
|
|
|
|
|
class TestConfigurableStateMachine:
|
|
"""Tests for ConfigurableStateMachine"""
|
|
|
|
def test_initialization(self):
|
|
"""Test ConfigurableStateMachine initialization"""
|
|
transitions = {
|
|
"state1": ["state2", "state3"],
|
|
"state2": ["state3"]
|
|
}
|
|
machine = ConfigurableStateMachine("state1", transitions)
|
|
|
|
assert machine.current_state == "state1"
|
|
assert machine.transitions_config == transitions
|
|
|
|
def test_get_valid_transitions(self):
|
|
"""Test get_valid_transitions from config"""
|
|
transitions = {"state1": ["state2", "state3"]}
|
|
machine = ConfigurableStateMachine("state1", transitions)
|
|
|
|
valid = machine.get_valid_transitions("state1")
|
|
assert valid == ["state2", "state3"]
|
|
|
|
def test_get_valid_transitions_empty(self):
|
|
"""Test get_valid_transitions for state with no transitions"""
|
|
transitions = {"state1": []}
|
|
machine = ConfigurableStateMachine("state1", transitions)
|
|
|
|
valid = machine.get_valid_transitions("state1")
|
|
assert valid == []
|
|
|
|
def test_add_transition(self):
|
|
"""Test add_transition"""
|
|
transitions = {"state1": ["state2"]}
|
|
machine = ConfigurableStateMachine("state1", transitions)
|
|
|
|
machine.add_transition("state1", "state3")
|
|
|
|
assert "state3" in machine.transitions_config["state1"]
|
|
|
|
def test_add_transition_new_from_state(self):
|
|
"""Test add_transition creates new from_state"""
|
|
transitions = {}
|
|
machine = ConfigurableStateMachine("state1", transitions)
|
|
|
|
machine.add_transition("state1", "state2")
|
|
|
|
assert "state1" in machine.transitions_config
|
|
assert "state2" in machine.transitions_config["state1"]
|
|
|
|
|
|
class TestStatePersistence:
|
|
"""Tests for StatePersistence"""
|
|
|
|
def test_initialization(self):
|
|
"""Test StatePersistence initialization"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
storage_path = os.path.join(tmpdir, "state.json")
|
|
persistence = StatePersistence(storage_path)
|
|
|
|
assert persistence.storage_path == storage_path
|
|
|
|
def test_save_state(self):
|
|
"""Test save_state"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
storage_path = os.path.join(tmpdir, "state.json")
|
|
persistence = StatePersistence(storage_path)
|
|
|
|
machine = TestableStateMachine("state1")
|
|
machine.transition("state2")
|
|
|
|
persistence.save_state(machine)
|
|
|
|
assert os.path.exists(storage_path)
|
|
|
|
def test_load_state(self):
|
|
"""Test load_state"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
storage_path = os.path.join(tmpdir, "state.json")
|
|
persistence = StatePersistence(storage_path)
|
|
|
|
machine = TestableStateMachine("state1")
|
|
machine.transition("state2")
|
|
persistence.save_state(machine)
|
|
|
|
loaded = persistence.load_state()
|
|
|
|
assert loaded is not None
|
|
assert loaded["current_state"] == "state2"
|
|
|
|
def test_load_state_not_exists(self):
|
|
"""Test load_state when file doesn't exist"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
storage_path = os.path.join(tmpdir, "nonexistent.json")
|
|
persistence = StatePersistence(storage_path)
|
|
|
|
loaded = persistence.load_state()
|
|
|
|
assert loaded is None
|
|
|
|
def test_delete_state(self):
|
|
"""Test delete_state"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
storage_path = os.path.join(tmpdir, "state.json")
|
|
persistence = StatePersistence(storage_path)
|
|
|
|
machine = TestableStateMachine("state1")
|
|
persistence.save_state(machine)
|
|
|
|
persistence.delete_state()
|
|
|
|
assert not os.path.exists(storage_path)
|
|
|
|
def test_delete_state_not_exists(self):
|
|
"""Test delete_state when file doesn't exist"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
storage_path = os.path.join(tmpdir, "nonexistent.json")
|
|
persistence = StatePersistence(storage_path)
|
|
|
|
# Should not raise
|
|
persistence.delete_state()
|
|
|
|
def test_save_state_error(self):
|
|
"""Test save_state raises error on failure"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Create a path that will fail (e.g., invalid directory)
|
|
storage_path = os.path.join(tmpdir, "subdir", "state.json")
|
|
persistence = StatePersistence(storage_path)
|
|
|
|
machine = TestableStateMachine("state1")
|
|
# Don't create the parent directory - this will cause an error
|
|
# Manually clear the directory that was auto-created
|
|
import shutil
|
|
if os.path.exists(os.path.dirname(storage_path)):
|
|
shutil.rmtree(os.path.dirname(storage_path))
|
|
|
|
with pytest.raises(StatePersistenceError):
|
|
persistence.save_state(machine)
|
|
|
|
|
|
class TestAsyncStateMachine:
|
|
"""Tests for AsyncStateMachine"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization(self):
|
|
"""Test AsyncStateMachine initialization"""
|
|
machine = AsyncTestableStateMachine("initial")
|
|
assert machine.current_state == "initial"
|
|
assert machine.transition_handlers == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_transition(self):
|
|
"""Test on_transition handler registration"""
|
|
machine = AsyncTestableStateMachine("state1")
|
|
handler = Mock()
|
|
|
|
machine.on_transition("state2", handler)
|
|
|
|
assert "state2" in machine.transition_handlers
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transition_async(self):
|
|
"""Test async transition"""
|
|
machine = AsyncTestableStateMachine("state1")
|
|
await machine.transition_async("state2")
|
|
|
|
assert machine.current_state == "state2"
|
|
assert len(machine.transitions) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transition_async_invalid(self):
|
|
"""Test async transition with invalid state"""
|
|
machine = AsyncTestableStateMachine("state1")
|
|
|
|
with pytest.raises(StateTransitionError):
|
|
await machine.transition_async("invalid")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transition_async_with_sync_handler(self):
|
|
"""Test async transition calls sync handler"""
|
|
machine = AsyncTestableStateMachine("state1")
|
|
handler = Mock()
|
|
machine.on_transition("state2", handler)
|
|
|
|
await machine.transition_async("state2")
|
|
|
|
handler.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transition_async_with_async_handler(self):
|
|
"""Test async transition calls async handler"""
|
|
machine = AsyncTestableStateMachine("state1")
|
|
|
|
async_handler_called = [False]
|
|
|
|
async def async_handler(transition):
|
|
async_handler_called[0] = True
|
|
|
|
machine.on_transition("state2", async_handler)
|
|
|
|
await machine.transition_async("state2")
|
|
|
|
assert async_handler_called[0] is True
|
|
|
|
|
|
class TestStateMonitor:
|
|
"""Tests for StateMonitor"""
|
|
|
|
def test_initialization(self):
|
|
"""Test StateMonitor initialization"""
|
|
machine = TestableStateMachine("state1")
|
|
monitor = StateMonitor(machine)
|
|
|
|
assert monitor.state_machine == machine
|
|
assert monitor.observers == []
|
|
|
|
def test_add_observer(self):
|
|
"""Test add_observer"""
|
|
machine = TestableStateMachine("state1")
|
|
monitor = StateMonitor(machine)
|
|
observer = Mock()
|
|
|
|
monitor.add_observer(observer)
|
|
|
|
assert observer in monitor.observers
|
|
|
|
def test_remove_observer(self):
|
|
"""Test remove_observer"""
|
|
machine = TestableStateMachine("state1")
|
|
monitor = StateMonitor(machine)
|
|
observer = Mock()
|
|
monitor.add_observer(observer)
|
|
|
|
result = monitor.remove_observer(observer)
|
|
|
|
assert result is True
|
|
assert observer not in monitor.observers
|
|
|
|
def test_remove_observer_not_found(self):
|
|
"""Test remove_observer when observer not found"""
|
|
machine = TestableStateMachine("state1")
|
|
monitor = StateMonitor(machine)
|
|
observer = Mock()
|
|
|
|
result = monitor.remove_observer(observer)
|
|
|
|
assert result is False
|
|
|
|
def test_notify_observers(self):
|
|
"""Test notify_observers"""
|
|
machine = TestableStateMachine("state1")
|
|
monitor = StateMonitor(machine)
|
|
observer1 = Mock()
|
|
observer2 = Mock()
|
|
monitor.add_observer(observer1)
|
|
monitor.add_observer(observer2)
|
|
|
|
transition = StateTransition("state1", "state2")
|
|
monitor.notify_observers(transition)
|
|
|
|
observer1.assert_called_once_with(transition)
|
|
observer2.assert_called_once_with(transition)
|
|
|
|
@patch('aitbc.state.logger')
|
|
def test_notify_observers_error(self, mock_logger):
|
|
"""Test notify_observers handles observer errors"""
|
|
machine = TestableStateMachine("state1")
|
|
monitor = StateMonitor(machine)
|
|
|
|
def failing_observer(transition):
|
|
raise Exception("Observer error")
|
|
|
|
monitor.add_observer(failing_observer)
|
|
|
|
transition = StateTransition("state1", "state2")
|
|
monitor.notify_observers(transition)
|
|
|
|
mock_logger.error.assert_called_once()
|
|
|
|
def test_wrap_transition(self):
|
|
"""Test wrap_transition"""
|
|
machine = TestableStateMachine("state1")
|
|
monitor = StateMonitor(machine)
|
|
observer = Mock()
|
|
monitor.add_observer(observer)
|
|
|
|
wrapped = monitor.wrap_transition(machine.transition)
|
|
wrapped("state2")
|
|
|
|
observer.assert_called_once()
|
|
|
|
|
|
class TestStateValidator:
|
|
"""Tests for StateValidator"""
|
|
|
|
def test_validate_transitions_valid(self):
|
|
"""Test validate_transitions with valid config"""
|
|
transitions = {
|
|
"state1": ["state2", "state3"],
|
|
"state2": ["state3"],
|
|
"state3": []
|
|
}
|
|
|
|
result = StateValidator.validate_transitions(transitions)
|
|
assert result is True
|
|
|
|
def test_validate_transitions_invalid(self):
|
|
"""Test validate_transitions with invalid target state"""
|
|
transitions = {
|
|
"state1": ["state2", "nonexistent"]
|
|
}
|
|
|
|
result = StateValidator.validate_transitions(transitions)
|
|
# "nonexistent" is not a valid state since it's not in transitions.keys()
|
|
assert result is False
|
|
|
|
def test_check_for_deadlocks(self):
|
|
"""Test check_for_deadlocks"""
|
|
transitions = {
|
|
"state1": ["state2"],
|
|
"state2": [] # No outgoing transitions
|
|
}
|
|
|
|
deadlocks = StateValidator.check_for_deadlocks(transitions)
|
|
assert "state2" in deadlocks
|
|
|
|
def test_check_for_deadlocks_none(self):
|
|
"""Test check_for_deadlocks with no deadlocks"""
|
|
transitions = {
|
|
"state1": ["state2"],
|
|
"state2": ["state1"]
|
|
}
|
|
|
|
deadlocks = StateValidator.check_for_deadlocks(transitions)
|
|
assert deadlocks == []
|
|
|
|
def test_check_for_orphans(self):
|
|
"""Test check_for_orphans"""
|
|
transitions = {
|
|
"state1": ["state2"],
|
|
"state2": ["state3"],
|
|
"state3": [] # state3 is an orphan (no incoming transitions from defined states)
|
|
}
|
|
|
|
# Actually state3 has incoming from state2, so let's create a real orphan
|
|
transitions = {
|
|
"state1": ["state2"],
|
|
"state2": [],
|
|
"orphan": [] # No incoming transitions
|
|
}
|
|
|
|
orphans = StateValidator.check_for_orphans(transitions)
|
|
assert "orphan" in orphans
|
|
|
|
def test_check_for_orphans_none(self):
|
|
"""Test check_for_orphans with no orphans"""
|
|
transitions = {
|
|
"state1": ["state2"],
|
|
"state2": ["state1"]
|
|
}
|
|
|
|
orphans = StateValidator.check_for_orphans(transitions)
|
|
assert orphans == []
|
|
|
|
|
|
class TestStateSnapshot:
|
|
"""Tests for StateSnapshot"""
|
|
|
|
def test_initialization(self):
|
|
"""Test StateSnapshot creation"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.transition("state2")
|
|
|
|
snapshot = StateSnapshot(machine)
|
|
|
|
assert snapshot.current_state == "state2"
|
|
assert snapshot.state_data == machine.state_data
|
|
assert snapshot.transitions == machine.transitions
|
|
assert snapshot.timestamp is not None
|
|
|
|
def test_restore(self):
|
|
"""Test restore from snapshot"""
|
|
machine1 = TestableStateMachine("state1")
|
|
machine1.transition("state2")
|
|
machine1.set_state_data({"key": "value"})
|
|
|
|
snapshot = StateSnapshot(machine1)
|
|
|
|
machine2 = TestableStateMachine("initial")
|
|
snapshot.restore(machine2)
|
|
|
|
assert machine2.current_state == "state2"
|
|
assert machine2.state_data == machine1.state_data
|
|
|
|
def test_to_dict(self):
|
|
"""Test to_dict conversion"""
|
|
machine = TestableStateMachine("state1")
|
|
snapshot = StateSnapshot(machine)
|
|
|
|
data = snapshot.to_dict()
|
|
|
|
assert "current_state" in data
|
|
assert "state_data" in data
|
|
assert "transitions" in data
|
|
assert "timestamp" in data
|
|
|
|
def test_from_dict(self):
|
|
"""Test from_dict creation"""
|
|
machine = TestableStateMachine("state1")
|
|
machine.transition("state2")
|
|
snapshot = StateSnapshot(machine)
|
|
|
|
data = snapshot.to_dict()
|
|
restored = StateSnapshot.from_dict(data)
|
|
|
|
assert restored.current_state == snapshot.current_state
|
|
assert restored.state_data == snapshot.state_data
|
|
|
|
|
|
# Helper classes for testing
|
|
class TestableStateMachine(StateMachine):
|
|
"""Concrete implementation for testing"""
|
|
|
|
def get_valid_transitions(self, state: str):
|
|
if state == "state1":
|
|
return ["state2", "state3"]
|
|
elif state == "state2":
|
|
return ["state3"]
|
|
elif state == "state3":
|
|
return ["state4"]
|
|
elif state == "state4":
|
|
return ["state1"]
|
|
return []
|
|
|
|
|
|
class AsyncTestableStateMachine(AsyncStateMachine):
|
|
"""Concrete async implementation for testing"""
|
|
|
|
def get_valid_transitions(self, state: str):
|
|
if state == "state1":
|
|
return ["state2"]
|
|
return []
|