chore(security): enhance environment configuration, CI workflows, and wallet daemon with security improvements
- Restructure .env.example with security-focused documentation, service-specific environment file references, and AWS Secrets Manager integration - Update CLI tests workflow to single Python 3.13 version, add pytest-mock dependency, and consolidate test execution with coverage - Add comprehensive security validation to package publishing workflow with manual approval gates, secret scanning, and release
This commit is contained in:
490
tests/TEST_REFACTORING_COMPLETED.md
Normal file
490
tests/TEST_REFACTORING_COMPLETED.md
Normal file
@@ -0,0 +1,490 @@
|
||||
# Test Configuration Refactoring - COMPLETED
|
||||
|
||||
## ✅ REFACTORING COMPLETE
|
||||
|
||||
**Date**: March 3, 2026
|
||||
**Status**: ✅ FULLY COMPLETED
|
||||
**Scope**: Eliminated shell script smell by moving test configuration to pyproject.toml
|
||||
|
||||
## Problem Solved
|
||||
|
||||
### ❌ **Before (Code Smell)**
|
||||
- **Shell Script Dependency**: `run_all_tests.sh` alongside `pytest.ini`
|
||||
- **Configuration Duplication**: Test settings split between files
|
||||
- **CI Integration Issues**: CI workflows calling shell script instead of pytest directly
|
||||
- **Maintenance Overhead**: Two separate files to maintain
|
||||
- **Non-Standard**: Not following Python testing best practices
|
||||
|
||||
### ✅ **After (Clean Integration)**
|
||||
- **Single Source of Truth**: All test configuration in `pyproject.toml`
|
||||
- **Direct pytest Integration**: CI workflows call pytest directly
|
||||
- **Standard Practice**: Follows Python testing best practices
|
||||
- **Better Maintainability**: One file to maintain
|
||||
- **Enhanced CI**: Comprehensive test workflows with proper categorization
|
||||
|
||||
## Changes Made
|
||||
|
||||
### ✅ **1. Consolidated pytest Configuration**
|
||||
|
||||
**Moved from `pytest.ini` to `pyproject.toml`:**
|
||||
```toml
|
||||
[tool.pytest.ini_options]
|
||||
# Test discovery
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
|
||||
# Cache directory - prevent root level cache
|
||||
cache_dir = "dev/cache/.pytest_cache"
|
||||
|
||||
# Test paths to run - include all test directories across the project
|
||||
testpaths = [
|
||||
"tests",
|
||||
"apps/blockchain-node/tests",
|
||||
"apps/coordinator-api/tests",
|
||||
"apps/explorer-web/tests",
|
||||
"apps/pool-hub/tests",
|
||||
"apps/wallet-daemon/tests",
|
||||
"apps/zk-circuits/test",
|
||||
"cli/tests",
|
||||
"contracts/test",
|
||||
"packages/py/aitbc-crypto/tests",
|
||||
"packages/py/aitbc-sdk/tests",
|
||||
"packages/solidity/aitbc-token/test",
|
||||
"scripts/test"
|
||||
]
|
||||
|
||||
# Python path for imports
|
||||
pythonpath = [
|
||||
".",
|
||||
"packages/py/aitbc-crypto/src",
|
||||
"packages/py/aitbc-crypto/tests",
|
||||
"packages/py/aitbc-sdk/src",
|
||||
"packages/py/aitbc-sdk/tests",
|
||||
"apps/coordinator-api/src",
|
||||
"apps/coordinator-api/tests",
|
||||
"apps/wallet-daemon/src",
|
||||
"apps/wallet-daemon/tests",
|
||||
"apps/blockchain-node/src",
|
||||
"apps/blockchain-node/tests",
|
||||
"apps/pool-hub/src",
|
||||
"apps/pool-hub/tests",
|
||||
"apps/explorer-web/src",
|
||||
"apps/explorer-web/tests",
|
||||
"cli",
|
||||
"cli/tests"
|
||||
]
|
||||
|
||||
# Additional options for local testing
|
||||
addopts = [
|
||||
"--verbose",
|
||||
"--tb=short",
|
||||
"--strict-markers",
|
||||
"--disable-warnings",
|
||||
"-ra"
|
||||
]
|
||||
|
||||
# Custom markers
|
||||
markers = [
|
||||
"unit: Unit tests (fast, isolated)",
|
||||
"integration: Integration tests (may require external services)",
|
||||
"slow: Slow running tests",
|
||||
"cli: CLI command tests",
|
||||
"api: API endpoint tests",
|
||||
"blockchain: Blockchain-related tests",
|
||||
"crypto: Cryptography tests",
|
||||
"contracts: Smart contract tests",
|
||||
"e2e: End-to-end tests (full system)",
|
||||
"performance: Performance tests (measure speed/memory)",
|
||||
"security: Security tests (vulnerability scanning)",
|
||||
"gpu: Tests requiring GPU resources",
|
||||
"confidential: Tests for confidential transactions",
|
||||
"multitenant: Multi-tenancy specific tests"
|
||||
]
|
||||
|
||||
# Environment variables for tests
|
||||
env = [
|
||||
"AUDIT_LOG_DIR=/tmp/aitbc-audit",
|
||||
"DATABASE_URL=sqlite:///./test_coordinator.db",
|
||||
"TEST_MODE=true",
|
||||
"SQLITE_DATABASE=sqlite:///./test_coordinator.db"
|
||||
]
|
||||
|
||||
# Warnings
|
||||
filterwarnings = [
|
||||
"ignore::UserWarning",
|
||||
"ignore::DeprecationWarning",
|
||||
"ignore::PendingDeprecationWarning",
|
||||
"ignore::pytest.PytestUnknownMarkWarning",
|
||||
"ignore::pydantic.PydanticDeprecatedSince20",
|
||||
"ignore::sqlalchemy.exc.SADeprecationWarning"
|
||||
]
|
||||
|
||||
# Asyncio configuration
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
# Import mode
|
||||
import_mode = "append"
|
||||
```
|
||||
|
||||
### ✅ **2. Updated CI Workflows**
|
||||
|
||||
**Updated `.github/workflows/ci.yml`:**
|
||||
```yaml
|
||||
- name: Test (pytest)
|
||||
run: poetry run pytest --cov=aitbc_cli --cov-report=term-missing --cov-report=xml
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
name: codecov-umbrella
|
||||
```
|
||||
|
||||
**Updated `.github/workflows/cli-tests.yml`:**
|
||||
```yaml
|
||||
- name: Run CLI tests
|
||||
run: |
|
||||
python -m pytest tests/cli/ -v --tb=short --disable-warnings --cov=aitbc_cli --cov-report=term-missing --cov-report=xml
|
||||
```
|
||||
|
||||
### ✅ **3. Created Comprehensive Test Workflow**
|
||||
|
||||
**New `.github/workflows/comprehensive-tests.yml`:**
|
||||
- **Unit Tests**: Fast, isolated tests across Python versions
|
||||
- **Integration Tests**: Tests requiring external services
|
||||
- **CLI Tests**: CLI-specific testing
|
||||
- **API Tests**: API endpoint testing
|
||||
- **Blockchain Tests**: Blockchain-related tests
|
||||
- **Slow Tests**: Time-intensive tests (not on PRs)
|
||||
- **Performance Tests**: Performance benchmarking
|
||||
- **Security Tests**: Security scanning and testing
|
||||
- **Test Summary**: Comprehensive test reporting
|
||||
|
||||
### ✅ **4. Removed Legacy Files**
|
||||
|
||||
**Backed up and removed:**
|
||||
- `tests/run_all_tests.sh` → `tests/run_all_tests.sh.backup`
|
||||
- `pytest.ini` → `pytest.ini.backup`
|
||||
|
||||
## Benefits Achieved
|
||||
|
||||
### ✅ **Eliminated Code Smell**
|
||||
- **Single Source of Truth**: All test configuration in `pyproject.toml`
|
||||
- **No Shell Script Dependency**: Direct pytest integration
|
||||
- **Standard Practice**: Follows Python testing best practices
|
||||
- **Better Maintainability**: One configuration file
|
||||
|
||||
### ✅ **Enhanced CI Integration**
|
||||
- **Direct pytest Calls**: CI workflows call pytest directly
|
||||
- **Python 3.13 Standardization**: All tests use Python 3.13
|
||||
- **SQLite-Only Database**: All tests use SQLite, no PostgreSQL dependencies
|
||||
- **Better Coverage**: Comprehensive test categorization
|
||||
- **Parallel Execution**: Tests run in parallel by category
|
||||
- **Proper Reporting**: Enhanced test reporting and summaries
|
||||
|
||||
### ✅ **Improved Developer Experience**
|
||||
- **Simplified Usage**: `pytest` command works everywhere
|
||||
- **Better Discovery**: Automatic test discovery across all directories
|
||||
- **Consistent Configuration**: Same configuration locally and in CI
|
||||
- **Enhanced Markers**: Better test categorization
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### **Local Development**
|
||||
|
||||
**Run all tests:**
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
**Run specific test categories:**
|
||||
```bash
|
||||
# Unit tests only
|
||||
pytest -m "unit"
|
||||
|
||||
# CLI tests only
|
||||
pytest -m "cli"
|
||||
|
||||
# Integration tests only
|
||||
pytest -m "integration"
|
||||
|
||||
# Exclude slow tests
|
||||
pytest -m "not slow"
|
||||
```
|
||||
|
||||
**Run with coverage:**
|
||||
```bash
|
||||
pytest --cov=aitbc_cli --cov-report=term-missing
|
||||
```
|
||||
|
||||
**Run specific test files:**
|
||||
```bash
|
||||
pytest tests/cli/test_agent_commands.py
|
||||
pytest apps/coordinator-api/tests/test_api.py
|
||||
```
|
||||
|
||||
### **CI/CD Integration**
|
||||
|
||||
**GitHub Actions automatically:**
|
||||
- Run unit tests across Python 3.11, 3.12, 3.13
|
||||
- Run integration tests with PostgreSQL
|
||||
- Run CLI tests with coverage
|
||||
- Run API tests with database
|
||||
- Run blockchain tests
|
||||
- Run security tests with Bandit
|
||||
- Generate comprehensive test summaries
|
||||
|
||||
### **Test Markers**
|
||||
|
||||
**Available markers:**
|
||||
```bash
|
||||
pytest --markers
|
||||
```
|
||||
|
||||
**Common usage:**
|
||||
```bash
|
||||
# Fast tests for development
|
||||
pytest -m "unit and not slow"
|
||||
|
||||
# Full test suite
|
||||
pytest -m "unit or integration or cli or api"
|
||||
|
||||
# Performance tests only
|
||||
pytest -m "performance"
|
||||
|
||||
# Security tests only
|
||||
pytest -m "security"
|
||||
```
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### **For Developers**
|
||||
|
||||
**Before:**
|
||||
```bash
|
||||
# Run tests via shell script
|
||||
./tests/run_all_tests.sh
|
||||
|
||||
# Or manually with pytest.ini
|
||||
pytest --config=pytest.ini
|
||||
```
|
||||
|
||||
**After:**
|
||||
```bash
|
||||
# Run tests directly
|
||||
pytest
|
||||
|
||||
# Or with specific options
|
||||
pytest -v --tb=short --cov=aitbc_cli
|
||||
```
|
||||
|
||||
### **For CI/CD**
|
||||
|
||||
**Before:**
|
||||
```yaml
|
||||
- name: Run tests
|
||||
run: ./tests/run_all_tests.sh
|
||||
```
|
||||
|
||||
**After:**
|
||||
```yaml
|
||||
- name: Run tests
|
||||
run: pytest --cov=aitbc_cli --cov-report=xml
|
||||
```
|
||||
|
||||
### **For Configuration**
|
||||
|
||||
**Before:**
|
||||
```ini
|
||||
# pytest.ini
|
||||
[tool:pytest]
|
||||
python_files = test_*.py
|
||||
testpaths = tests
|
||||
addopts = --verbose
|
||||
```
|
||||
|
||||
**After:**
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.pytest.ini_options]
|
||||
python_files = ["test_*.py"]
|
||||
testpaths = ["tests"]
|
||||
addopts = ["--verbose"]
|
||||
```
|
||||
|
||||
## Test Organization
|
||||
|
||||
### **Test Categories**
|
||||
|
||||
1. **Unit Tests** (`-m unit`)
|
||||
- Fast, isolated tests
|
||||
- No external dependencies
|
||||
- Mock external services
|
||||
|
||||
2. **Integration Tests** (`-m integration`)
|
||||
- May require external services
|
||||
- Database integration
|
||||
- API integration
|
||||
|
||||
3. **CLI Tests** (`-m cli`)
|
||||
- CLI command testing
|
||||
- Click integration
|
||||
- CLI workflow testing
|
||||
|
||||
4. **API Tests** (`-m api`)
|
||||
- API endpoint testing
|
||||
- HTTP client testing
|
||||
- API integration
|
||||
|
||||
5. **Blockchain Tests** (`-m blockchain`)
|
||||
- Blockchain operations
|
||||
- Cryptographic tests
|
||||
- Smart contract tests
|
||||
|
||||
6. **Slow Tests** (`-m slow`)
|
||||
- Time-intensive tests
|
||||
- Large dataset tests
|
||||
- Performance benchmarks
|
||||
|
||||
7. **Performance Tests** (`-m performance`)
|
||||
- Speed measurements
|
||||
- Memory usage
|
||||
- Benchmarking
|
||||
|
||||
8. **Security Tests** (`-m security`)
|
||||
- Vulnerability scanning
|
||||
- Security validation
|
||||
- Input validation
|
||||
|
||||
### **Test Discovery**
|
||||
|
||||
**Automatic discovery includes:**
|
||||
- `tests/` - Main test directory
|
||||
- `apps/*/tests/` - Application tests
|
||||
- `cli/tests/` - CLI tests
|
||||
- `contracts/test/` - Smart contract tests
|
||||
- `packages/*/tests/` - Package tests
|
||||
- `scripts/test/` - Script tests
|
||||
|
||||
**Python path automatically includes:**
|
||||
- All source directories
|
||||
- All test directories
|
||||
- CLI directory
|
||||
- Package directories
|
||||
|
||||
## Performance Improvements
|
||||
|
||||
### ✅ **Faster Test Execution**
|
||||
- **Parallel Execution**: Tests run in parallel by category
|
||||
- **Smart Caching**: Proper cache directory management
|
||||
- **Selective Testing**: Run only relevant tests
|
||||
- **Optimized Discovery**: Efficient test discovery
|
||||
|
||||
### ✅ **Better Resource Usage**
|
||||
- **Database Services**: Only spin up when needed
|
||||
- **Test Isolation**: Better test isolation
|
||||
- **Memory Management**: Proper memory usage
|
||||
- **Cleanup**: Automatic cleanup after tests
|
||||
|
||||
### ✅ **Enhanced Reporting**
|
||||
- **Coverage Reports**: Comprehensive coverage reporting
|
||||
- **Test Summaries**: Detailed test summaries
|
||||
- **PR Comments**: Automatic PR comments with results
|
||||
- **Artifact Upload**: Proper artifact management
|
||||
|
||||
## Quality Metrics
|
||||
|
||||
### ✅ **Code Quality**
|
||||
- **Configuration**: Single source of truth
|
||||
- **Maintainability**: Easier to maintain
|
||||
- **Consistency**: Consistent across environments
|
||||
- **Best Practices**: Follows Python best practices
|
||||
|
||||
### ✅ **CI/CD Quality**
|
||||
- **Reliability**: More reliable test execution
|
||||
- **Speed**: Faster test execution
|
||||
- **Coverage**: Better test coverage
|
||||
- **Reporting**: Enhanced reporting
|
||||
|
||||
### ✅ **Developer Experience**
|
||||
- **Simplicity**: Easier to run tests
|
||||
- **Flexibility**: More test options
|
||||
- **Discovery**: Better test discovery
|
||||
- **Documentation**: Better documentation
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### **Common Issues**
|
||||
|
||||
**Test discovery not working:**
|
||||
```bash
|
||||
# Check configuration
|
||||
pytest --collect-only
|
||||
|
||||
# Verify testpaths
|
||||
python -c "import pytest; print(pytest.config.getini('testpaths'))"
|
||||
```
|
||||
|
||||
**Import errors:**
|
||||
```bash
|
||||
# Check pythonpath
|
||||
pytest --debug
|
||||
|
||||
# Verify imports
|
||||
python -c "import sys; print(sys.path)"
|
||||
```
|
||||
|
||||
**Coverage issues:**
|
||||
```bash
|
||||
# Check coverage configuration
|
||||
pytest --cov=aitbc_cli --cov-report=term-missing
|
||||
|
||||
# Verify coverage source
|
||||
python -c "import coverage; print(coverage.Coverage().source)"
|
||||
```
|
||||
|
||||
### **Migration Issues**
|
||||
|
||||
**Legacy shell script references:**
|
||||
- Update documentation to use `pytest` directly
|
||||
- Remove shell script references from CI/CD
|
||||
- Update developer guides
|
||||
|
||||
**pytest.ini conflicts:**
|
||||
- Remove `pytest.ini` file
|
||||
- Ensure all configuration is in `pyproject.toml`
|
||||
- Restart IDE to pick up changes
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### ✅ **Planned Improvements**
|
||||
- **Test Parallelization**: Add pytest-xdist for parallel execution
|
||||
- **Test Profiling**: Add test performance profiling
|
||||
- **Test Documentation**: Generate test documentation
|
||||
- **Test Metrics**: Enhanced test metrics collection
|
||||
|
||||
### ✅ **Advanced Features**
|
||||
- **Test Environments**: Multiple test environments
|
||||
- **Test Data Management**: Better test data management
|
||||
- **Test Fixtures**: Enhanced test fixtures
|
||||
- **Test Utilities**: Additional test utilities
|
||||
|
||||
## Conclusion
|
||||
|
||||
The test configuration refactoring successfully eliminates the shell script smell by:
|
||||
|
||||
1. **✅ Consolidated Configuration**: All test configuration in `pyproject.toml`
|
||||
2. **✅ Direct pytest Integration**: CI workflows call pytest directly
|
||||
3. **✅ Enhanced CI/CD**: Comprehensive test workflows
|
||||
4. **✅ Better Developer Experience**: Simplified test execution
|
||||
5. **✅ Standard Practices**: Follows Python testing best practices
|
||||
|
||||
The refactored test system provides a solid foundation for testing the AITBC project while maintaining flexibility, performance, and maintainability.
|
||||
|
||||
---
|
||||
|
||||
**Status**: ✅ COMPLETED
|
||||
**Next Steps**: Monitor test execution and optimize performance
|
||||
**Maintenance**: Regular test configuration updates and review
|
||||
318
tests/USAGE_GUIDE.md
Normal file
318
tests/USAGE_GUIDE.md
Normal file
@@ -0,0 +1,318 @@
|
||||
# Test Configuration Refactoring - Usage Guide
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
The AITBC test suite has been refactored to eliminate the shell script smell and use proper pytest configuration in `pyproject.toml`. We standardize on Python 3.13 for all testing and use SQLite exclusively for database testing.
|
||||
|
||||
### **Basic Usage**
|
||||
|
||||
```bash
|
||||
# Run all fast tests (default)
|
||||
pytest
|
||||
|
||||
# Run with the convenient test runner
|
||||
python tests/test_runner.py
|
||||
|
||||
# Run all tests including slow ones
|
||||
python tests/test_runner.py --all
|
||||
|
||||
# Run with coverage
|
||||
python tests/test_runner.py --coverage
|
||||
```
|
||||
|
||||
### **Test Categories**
|
||||
|
||||
```bash
|
||||
# Unit tests only
|
||||
pytest -m "unit"
|
||||
python tests/test_runner.py --unit
|
||||
|
||||
# Integration tests only
|
||||
pytest -m "integration"
|
||||
python tests/test_runner.py --integration
|
||||
|
||||
# CLI tests only
|
||||
pytest -m "cli"
|
||||
python tests/test_runner.py --cli
|
||||
|
||||
# API tests only
|
||||
pytest -m "api"
|
||||
python tests/test_runner.py --api
|
||||
|
||||
# Blockchain tests only
|
||||
pytest -m "blockchain"
|
||||
python tests/test_runner.py --blockchain
|
||||
|
||||
# Slow tests only
|
||||
pytest -m "slow"
|
||||
python tests/test_runner.py --slow
|
||||
|
||||
# Performance tests only
|
||||
pytest -m "performance"
|
||||
python tests/test_runner.py --performance
|
||||
|
||||
# Security tests only
|
||||
pytest -m "security"
|
||||
python tests/test_runner.py --security
|
||||
```
|
||||
|
||||
### **Advanced Usage**
|
||||
|
||||
```bash
|
||||
# Run specific test files
|
||||
pytest tests/cli/test_agent_commands.py
|
||||
pytest apps/coordinator-api/tests/test_api.py
|
||||
|
||||
# Run with verbose output
|
||||
pytest -v
|
||||
python tests/test_runner.py --verbose
|
||||
|
||||
# Run with coverage
|
||||
pytest --cov=aitbc_cli --cov-report=term-missing
|
||||
python tests/test_runner.py --coverage
|
||||
|
||||
# List available tests
|
||||
pytest --collect-only
|
||||
python tests/test_runner.py --list
|
||||
|
||||
# Show available markers
|
||||
pytest --markers
|
||||
python tests/test_runner.py --markers
|
||||
|
||||
# Run with specific Python path
|
||||
pytest --pythonpath=cli
|
||||
|
||||
# Run with custom options
|
||||
pytest -v --tb=short --disable-warnings
|
||||
```
|
||||
|
||||
## 📋 Test Markers
|
||||
|
||||
The test suite uses the following markers to categorize tests:
|
||||
|
||||
| Marker | Description | Usage |
|
||||
|--------|-------------|-------|
|
||||
| `unit` | Unit tests (fast, isolated) | `pytest -m unit` |
|
||||
| `integration` | Integration tests (may require external services) | `pytest -m integration` |
|
||||
| `cli` | CLI command tests | `pytest -m cli` |
|
||||
| `api` | API endpoint tests | `pytest -m api` |
|
||||
| `blockchain` | Blockchain-related tests | `pytest -m blockchain` |
|
||||
| `crypto` | Cryptography tests | `pytest -m crypto` |
|
||||
| `contracts` | Smart contract tests | `pytest -m contracts` |
|
||||
| `slow` | Slow running tests | `pytest -m slow` |
|
||||
| `performance` | Performance tests | `pytest -m performance` |
|
||||
| `security` | Security tests | `pytest -m security` |
|
||||
| `gpu` | Tests requiring GPU resources | `pytest -m gpu` |
|
||||
| `e2e` | End-to-end tests | `pytest -m e2e` |
|
||||
|
||||
## 🗂️ Test Discovery
|
||||
|
||||
The test suite automatically discovers tests in these directories:
|
||||
|
||||
- `tests/` - Main test directory
|
||||
- `apps/*/tests/` - Application tests
|
||||
- `cli/tests/` - CLI tests
|
||||
- `contracts/test/` - Smart contract tests
|
||||
- `packages/*/tests/` - Package tests
|
||||
- `scripts/test/` - Script tests
|
||||
|
||||
## 🔧 Configuration
|
||||
|
||||
All test configuration is now in `pyproject.toml` with SQLite as the default database:
|
||||
|
||||
```toml
|
||||
[tool.pytest.ini_options]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
testpaths = ["tests", "apps/*/tests", "cli/tests", ...]
|
||||
addopts = ["--verbose", "--tb=short", "--strict-markers", "--disable-warnings", "-ra"]
|
||||
env = [
|
||||
"DATABASE_URL=sqlite:///./test_coordinator.db",
|
||||
"SQLITE_DATABASE=sqlite:///./test_coordinator.db"
|
||||
]
|
||||
markers = [
|
||||
"unit: Unit tests (fast, isolated)",
|
||||
"integration: Integration tests (may require external services)",
|
||||
# ... more markers
|
||||
]
|
||||
```
|
||||
|
||||
## 🚦 CI/CD Integration
|
||||
|
||||
The CI workflows now call pytest directly:
|
||||
|
||||
```yaml
|
||||
- name: Run tests
|
||||
run: pytest --cov=aitbc_cli --cov-report=xml
|
||||
```
|
||||
|
||||
## 📊 Coverage
|
||||
|
||||
```bash
|
||||
# Run with coverage
|
||||
pytest --cov=aitbc_cli --cov-report=term-missing
|
||||
|
||||
# Generate HTML coverage report
|
||||
pytest --cov=aitbc_cli --cov-report=html
|
||||
|
||||
# Coverage for specific module
|
||||
pytest --cov=aitbc_cli.commands.agent --cov-report=term-missing
|
||||
```
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### **Common Issues**
|
||||
|
||||
**Import errors:**
|
||||
```bash
|
||||
# Check python path
|
||||
python -c "import sys; print(sys.path)"
|
||||
|
||||
# Run with explicit python path
|
||||
PYTHONPATH=cli pytest
|
||||
```
|
||||
|
||||
**Test discovery issues:**
|
||||
```bash
|
||||
# Check what tests are discovered
|
||||
pytest --collect-only
|
||||
|
||||
# Check configuration
|
||||
python -c "import pytest; print(pytest.config.getini('testpaths'))"
|
||||
```
|
||||
|
||||
**Coverage issues:**
|
||||
```bash
|
||||
# Check coverage configuration
|
||||
pytest --cov=aitbc_cli --cov-report=term-missing --debug
|
||||
|
||||
# Verify coverage source
|
||||
python -c "import coverage; print(coverage.Coverage().source)"
|
||||
```
|
||||
|
||||
### **Migration from Shell Script**
|
||||
|
||||
**Before:**
|
||||
```bash
|
||||
./tests/run_all_tests.sh
|
||||
```
|
||||
|
||||
**After:**
|
||||
```bash
|
||||
pytest
|
||||
# or
|
||||
python tests/test_runner.py
|
||||
```
|
||||
|
||||
## 🎯 Best Practices
|
||||
|
||||
### **For Developers**
|
||||
|
||||
1. **Use appropriate markers**: Mark your tests with the correct category
|
||||
2. **Keep unit tests fast**: Unit tests should not depend on external services
|
||||
3. **Use fixtures**: Leverage pytest fixtures for setup/teardown
|
||||
4. **Write descriptive tests**: Use clear test names and descriptions
|
||||
|
||||
### **Test Writing Example**
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cli_command_help():
|
||||
"""Test CLI help command."""
|
||||
# Test implementation
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
def test_blockchain_sync():
|
||||
"""Test blockchain synchronization."""
|
||||
# Test implementation
|
||||
|
||||
@pytest.mark.cli
|
||||
def test_agent_create_command():
|
||||
"""Test agent creation CLI command."""
|
||||
# Test implementation
|
||||
```
|
||||
|
||||
### **Running Tests During Development**
|
||||
|
||||
```bash
|
||||
# Quick feedback during development
|
||||
pytest -m "unit" -v
|
||||
|
||||
# Run tests for specific module
|
||||
pytest tests/cli/test_agent_commands.py -v
|
||||
|
||||
# Run tests with coverage for your changes
|
||||
pytest --cov=aitbc_cli --cov-report=term-missing
|
||||
|
||||
# Run tests before committing
|
||||
python tests/test_runner.py --coverage
|
||||
```
|
||||
|
||||
## 📈 Performance Tips
|
||||
|
||||
### **Fast Test Execution**
|
||||
|
||||
```bash
|
||||
# Run only unit tests for quick feedback
|
||||
pytest -m "unit" -v
|
||||
|
||||
# Use parallel execution (if pytest-xdist is installed)
|
||||
pytest -n auto -m "unit"
|
||||
|
||||
# Skip slow tests during development
|
||||
pytest -m "not slow"
|
||||
```
|
||||
|
||||
### **Memory Usage**
|
||||
|
||||
```bash
|
||||
# Run tests with minimal output
|
||||
pytest -q
|
||||
|
||||
# Use specific test paths to reduce discovery overhead
|
||||
pytest tests/cli/
|
||||
```
|
||||
|
||||
## 🔍 Debugging
|
||||
|
||||
### **Debug Mode**
|
||||
|
||||
```bash
|
||||
# Run with debug output
|
||||
pytest --debug
|
||||
|
||||
# Run with pdb on failure
|
||||
pytest --pdb
|
||||
|
||||
# Run with verbose output
|
||||
pytest -v -s
|
||||
```
|
||||
|
||||
### **Test Selection**
|
||||
|
||||
```bash
|
||||
# Run specific test
|
||||
pytest tests/cli/test_agent_commands.py::test_agent_create
|
||||
|
||||
# Run tests matching pattern
|
||||
pytest -k "agent_create"
|
||||
|
||||
# Run failed tests only
|
||||
pytest --lf
|
||||
```
|
||||
|
||||
## 📚 Additional Resources
|
||||
|
||||
- **pytest documentation**: https://docs.pytest.org/
|
||||
- **pytest-cov documentation**: https://pytest-cov.readthedocs.io/
|
||||
- **pytest-mock documentation**: https://pytest-mock.readthedocs.io/
|
||||
- **AITBC Development Guidelines**: See `docs/DEVELOPMENT_GUIDELINES.md`
|
||||
|
||||
---
|
||||
|
||||
**Migration completed**: ✅ All test configuration moved to `pyproject.toml`
|
||||
**Shell script eliminated**: ✅ No more `run_all_tests.sh` dependency
|
||||
**CI/CD updated**: ✅ Direct pytest integration in workflows
|
||||
**Developer experience improved**: ✅ Simplified test execution
|
||||
@@ -1,32 +1,510 @@
|
||||
"""
|
||||
Marketplace Analytics System Integration Tests
|
||||
Marketplace Analytics System Tests
|
||||
Comprehensive testing for analytics, insights, reporting, and dashboards
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import statistics
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import uuid4
|
||||
from typing import Dict, Any
|
||||
from unittest.mock import Mock, patch
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from apps.coordinator_api.src.app.services.analytics_service import (
|
||||
MarketplaceAnalytics, DataCollector, AnalyticsEngine, DashboardManager
|
||||
)
|
||||
from apps.coordinator_api.src.app.domain.analytics import (
|
||||
MarketMetric, MarketInsight, AnalyticsReport, DashboardConfig,
|
||||
AnalyticsPeriod, MetricType, InsightType, ReportType
|
||||
)
|
||||
class TestMarketplaceAnalytics:
|
||||
"""Test marketplace analytics functionality"""
|
||||
|
||||
def test_market_metrics_calculation(self):
|
||||
"""Test market metrics calculation"""
|
||||
# Sample market data
|
||||
market_data = [
|
||||
{'price': 0.10, 'gpu_type': 'RTX 3080', 'timestamp': '2024-01-01T10:00:00Z'},
|
||||
{'price': 0.12, 'gpu_type': 'RTX 3080', 'timestamp': '2024-01-01T11:00:00Z'},
|
||||
{'price': 0.11, 'gpu_type': 'RTX 3080', 'timestamp': '2024-01-01T12:00:00Z'},
|
||||
{'price': 0.15, 'gpu_type': 'RTX 3090', 'timestamp': '2024-01-01T10:00:00Z'},
|
||||
{'price': 0.14, 'gpu_type': 'RTX 3090', 'timestamp': '2024-01-01T11:00:00Z'},
|
||||
]
|
||||
|
||||
# Calculate metrics
|
||||
rtx3080_prices = [d['price'] for d in market_data if d['gpu_type'] == 'RTX 3080']
|
||||
rtx3090_prices = [d['price'] for d in market_data if d['gpu_type'] == 'RTX 3090']
|
||||
|
||||
# Calculate statistics
|
||||
metrics = {
|
||||
'RTX 3080': {
|
||||
'avg_price': statistics.mean(rtx3080_prices),
|
||||
'min_price': min(rtx3080_prices),
|
||||
'max_price': max(rtx3080_prices),
|
||||
'price_volatility': statistics.stdev(rtx3080_prices) if len(rtx3080_prices) > 1 else 0
|
||||
},
|
||||
'RTX 3090': {
|
||||
'avg_price': statistics.mean(rtx3090_prices),
|
||||
'min_price': min(rtx3090_prices),
|
||||
'max_price': max(rtx3090_prices),
|
||||
'price_volatility': statistics.stdev(rtx3090_prices) if len(rtx3090_prices) > 1 else 0
|
||||
}
|
||||
}
|
||||
|
||||
# Validate metrics
|
||||
assert metrics['RTX 3080']['avg_price'] == 0.11
|
||||
assert metrics['RTX 3080']['min_price'] == 0.10
|
||||
assert metrics['RTX 3080']['max_price'] == 0.12
|
||||
assert metrics['RTX 3090']['avg_price'] == 0.145
|
||||
assert metrics['RTX 3090']['min_price'] == 0.14
|
||||
assert metrics['RTX 3090']['max_price'] == 0.15
|
||||
|
||||
def test_demand_analysis(self):
|
||||
"""Test demand analysis functionality"""
|
||||
# Sample demand data
|
||||
demand_data = [
|
||||
{'date': '2024-01-01', 'requests': 120, 'fulfilled': 100},
|
||||
{'date': '2024-01-02', 'requests': 150, 'fulfilled': 130},
|
||||
{'date': '2024-01-03', 'requests': 180, 'fulfilled': 160},
|
||||
{'date': '2024-01-04', 'requests': 140, 'fulfilled': 125},
|
||||
]
|
||||
|
||||
# Calculate demand metrics
|
||||
total_requests = sum(d['requests'] for d in demand_data)
|
||||
total_fulfilled = sum(d['fulfilled'] for d in demand_data)
|
||||
fulfillment_rate = (total_fulfilled / total_requests) * 100
|
||||
|
||||
# Calculate trend
|
||||
daily_rates = [(d['fulfilled'] / d['requests']) * 100 for d in demand_data]
|
||||
trend = 'increasing' if daily_rates[-1] > daily_rates[0] else 'decreasing'
|
||||
|
||||
# Validate analysis
|
||||
assert total_requests == 590
|
||||
assert total_fulfilled == 515
|
||||
assert fulfillment_rate == 87.29 # Approximately
|
||||
assert trend == 'increasing'
|
||||
assert all(0 <= rate <= 100 for rate in daily_rates)
|
||||
|
||||
def test_provider_performance(self):
|
||||
"""Test provider performance analytics"""
|
||||
# Sample provider data
|
||||
provider_data = [
|
||||
{
|
||||
'provider_id': 'provider_1',
|
||||
'total_jobs': 50,
|
||||
'completed_jobs': 45,
|
||||
'avg_completion_time': 25.5, # minutes
|
||||
'avg_rating': 4.8,
|
||||
'gpu_types': ['RTX 3080', 'RTX 3090']
|
||||
},
|
||||
{
|
||||
'provider_id': 'provider_2',
|
||||
'total_jobs': 30,
|
||||
'completed_jobs': 28,
|
||||
'avg_completion_time': 30.2,
|
||||
'avg_rating': 4.6,
|
||||
'gpu_types': ['RTX 3080']
|
||||
},
|
||||
{
|
||||
'provider_id': 'provider_3',
|
||||
'total_jobs': 40,
|
||||
'completed_jobs': 35,
|
||||
'avg_completion_time': 22.1,
|
||||
'avg_rating': 4.9,
|
||||
'gpu_types': ['RTX 3090', 'RTX 4090']
|
||||
}
|
||||
]
|
||||
|
||||
# Calculate performance metrics
|
||||
for provider in provider_data:
|
||||
success_rate = (provider['completed_jobs'] / provider['total_jobs']) * 100
|
||||
provider['success_rate'] = success_rate
|
||||
|
||||
# Sort by performance
|
||||
top_providers = sorted(provider_data, key=lambda x: x['success_rate'], reverse=True)
|
||||
|
||||
# Validate calculations
|
||||
assert top_providers[0]['provider_id'] == 'provider_1'
|
||||
assert top_providers[0]['success_rate'] == 90.0
|
||||
assert top_providers[1]['success_rate'] == 93.33 # provider_2
|
||||
assert top_providers[2]['success_rate'] == 87.5 # provider_3
|
||||
|
||||
# Validate data integrity
|
||||
for provider in provider_data:
|
||||
assert 0 <= provider['success_rate'] <= 100
|
||||
assert provider['avg_rating'] >= 0 and provider['avg_rating'] <= 5
|
||||
assert provider['avg_completion_time'] > 0
|
||||
|
||||
|
||||
class TestAnalyticsEngine:
|
||||
"""Test analytics engine functionality"""
|
||||
|
||||
def test_data_aggregation(self):
|
||||
"""Test data aggregation capabilities"""
|
||||
# Sample time series data
|
||||
time_series_data = [
|
||||
{'timestamp': '2024-01-01T00:00:00Z', 'value': 100},
|
||||
{'timestamp': '2024-01-01T01:00:00Z', 'value': 110},
|
||||
{'timestamp': '2024-01-01T02:00:00Z', 'value': 105},
|
||||
{'timestamp': '2024-01-01T03:00:00Z', 'value': 120},
|
||||
{'timestamp': '2024-01-01T04:00:00Z', 'value': 115},
|
||||
]
|
||||
|
||||
# Aggregate by hour (already hourly data)
|
||||
hourly_avg = statistics.mean([d['value'] for d in time_series_data])
|
||||
hourly_max = max([d['value'] for d in time_series_data])
|
||||
hourly_min = min([d['value'] for d in time_series_data])
|
||||
|
||||
# Create aggregated summary
|
||||
aggregated_data = {
|
||||
'period': 'hourly',
|
||||
'data_points': len(time_series_data),
|
||||
'average': hourly_avg,
|
||||
'maximum': hourly_max,
|
||||
'minimum': hourly_min,
|
||||
'trend': 'up' if time_series_data[-1]['value'] > time_series_data[0]['value'] else 'down'
|
||||
}
|
||||
|
||||
# Validate aggregation
|
||||
assert aggregated_data['period'] == 'hourly'
|
||||
assert aggregated_data['data_points'] == 5
|
||||
assert aggregated_data['average'] == 110.0
|
||||
assert aggregated_data['maximum'] == 120
|
||||
assert aggregated_data['minimum'] == 100
|
||||
assert aggregated_data['trend'] == 'up'
|
||||
|
||||
def test_anomaly_detection(self):
|
||||
"""Test anomaly detection in metrics"""
|
||||
# Sample metrics with anomalies
|
||||
metrics_data = [
|
||||
{'timestamp': '2024-01-01T00:00:00Z', 'response_time': 100},
|
||||
{'timestamp': '2024-01-01T01:00:00Z', 'response_time': 105},
|
||||
{'timestamp': '2024-01-01T02:00:00Z', 'response_time': 98},
|
||||
{'timestamp': '2024-01-01T03:00:00Z', 'response_time': 500}, # Anomaly
|
||||
{'timestamp': '2024-01-01T04:00:00Z', 'response_time': 102},
|
||||
{'timestamp': '2024-01-01T05:00:00Z', 'response_time': 95},
|
||||
]
|
||||
|
||||
# Calculate statistics for anomaly detection
|
||||
response_times = [d['response_time'] for d in metrics_data]
|
||||
mean_time = statistics.mean(response_times)
|
||||
stdev_time = statistics.stdev(response_times) if len(response_times) > 1 else 0
|
||||
|
||||
# Detect anomalies (values > 2 standard deviations from mean)
|
||||
threshold = mean_time + (2 * stdev_time)
|
||||
anomalies = [
|
||||
d for d in metrics_data
|
||||
if d['response_time'] > threshold
|
||||
]
|
||||
|
||||
# Validate anomaly detection
|
||||
assert len(anomalies) == 1
|
||||
assert anomalies[0]['response_time'] == 500
|
||||
assert anomalies[0]['timestamp'] == '2024-01-01T03:00:00Z'
|
||||
|
||||
def test_forecasting_model(self):
|
||||
"""Test simple forecasting model"""
|
||||
# Historical data for forecasting
|
||||
historical_data = [
|
||||
{'period': '2024-01-01', 'demand': 100},
|
||||
{'period': '2024-01-02', 'demand': 110},
|
||||
{'period': '2024-01-03', 'demand': 105},
|
||||
{'period': '2024-01-04', 'demand': 120},
|
||||
{'period': '2024-01-05', 'demand': 115},
|
||||
]
|
||||
|
||||
# Simple moving average forecast
|
||||
demand_values = [d['demand'] for d in historical_data]
|
||||
forecast_period = 3
|
||||
forecast = statistics.mean(demand_values[-forecast_period:])
|
||||
|
||||
# Calculate forecast accuracy (using last known value as "actual")
|
||||
last_actual = demand_values[-1]
|
||||
forecast_error = abs(forecast - last_actual)
|
||||
forecast_accuracy = max(0, 100 - (forecast_error / last_actual * 100))
|
||||
|
||||
# Validate forecast
|
||||
assert forecast > 0
|
||||
assert forecast_accuracy >= 0
|
||||
assert forecast_accuracy <= 100
|
||||
|
||||
|
||||
class TestDashboardManager:
|
||||
"""Test dashboard management functionality"""
|
||||
|
||||
def test_dashboard_configuration(self):
|
||||
"""Test dashboard configuration management"""
|
||||
# Sample dashboard configuration
|
||||
dashboard_config = {
|
||||
'dashboard_id': 'marketplace_overview',
|
||||
'title': 'Marketplace Overview',
|
||||
'layout': 'grid',
|
||||
'widgets': [
|
||||
{
|
||||
'id': 'market_metrics',
|
||||
'type': 'metric_card',
|
||||
'title': 'Market Metrics',
|
||||
'position': {'x': 0, 'y': 0, 'w': 4, 'h': 2},
|
||||
'data_source': 'market_metrics_api'
|
||||
},
|
||||
{
|
||||
'id': 'price_chart',
|
||||
'type': 'line_chart',
|
||||
'title': 'Price Trends',
|
||||
'position': {'x': 4, 'y': 0, 'w': 8, 'h': 4},
|
||||
'data_source': 'price_history_api'
|
||||
},
|
||||
{
|
||||
'id': 'provider_ranking',
|
||||
'type': 'table',
|
||||
'title': 'Top Providers',
|
||||
'position': {'x': 0, 'y': 2, 'w': 6, 'h': 3},
|
||||
'data_source': 'provider_ranking_api'
|
||||
}
|
||||
],
|
||||
'refresh_interval': 300, # 5 minutes
|
||||
'permissions': ['read', 'write']
|
||||
}
|
||||
|
||||
# Validate configuration
|
||||
assert dashboard_config['dashboard_id'] == 'marketplace_overview'
|
||||
assert len(dashboard_config['widgets']) == 3
|
||||
assert dashboard_config['refresh_interval'] == 300
|
||||
assert 'read' in dashboard_config['permissions']
|
||||
|
||||
# Validate widgets
|
||||
for widget in dashboard_config['widgets']:
|
||||
assert 'id' in widget
|
||||
assert 'type' in widget
|
||||
assert 'title' in widget
|
||||
assert 'position' in widget
|
||||
assert 'data_source' in widget
|
||||
|
||||
def test_widget_data_processing(self):
|
||||
"""Test widget data processing"""
|
||||
# Sample data for different widget types
|
||||
widget_data = {
|
||||
'metric_card': {
|
||||
'value': 1250,
|
||||
'change': 5.2,
|
||||
'change_type': 'increase',
|
||||
'unit': 'AITBC',
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
},
|
||||
'line_chart': {
|
||||
'labels': ['Jan', 'Feb', 'Mar', 'Apr', 'May'],
|
||||
'datasets': [
|
||||
{
|
||||
'label': 'RTX 3080',
|
||||
'data': [0.10, 0.11, 0.12, 0.11, 0.13],
|
||||
'borderColor': '#007bff'
|
||||
},
|
||||
{
|
||||
'label': 'RTX 3090',
|
||||
'data': [0.15, 0.14, 0.16, 0.15, 0.17],
|
||||
'borderColor': '#28a745'
|
||||
}
|
||||
]
|
||||
},
|
||||
'table': {
|
||||
'columns': ['provider', 'jobs_completed', 'avg_rating', 'success_rate'],
|
||||
'rows': [
|
||||
['provider_1', 45, 4.8, '90%'],
|
||||
['provider_2', 28, 4.6, '93%'],
|
||||
['provider_3', 35, 4.9, '88%']
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Validate metric card data
|
||||
metric_data = widget_data['metric_card']
|
||||
assert isinstance(metric_data['value'], (int, float))
|
||||
assert isinstance(metric_data['change'], (int, float))
|
||||
assert metric_data['change_type'] in ['increase', 'decrease']
|
||||
assert 'timestamp' in metric_data
|
||||
|
||||
# Validate line chart data
|
||||
chart_data = widget_data['line_chart']
|
||||
assert 'labels' in chart_data
|
||||
assert 'datasets' in chart_data
|
||||
assert len(chart_data['datasets']) == 2
|
||||
assert len(chart_data['labels']) == len(chart_data['datasets'][0]['data'])
|
||||
|
||||
# Validate table data
|
||||
table_data = widget_data['table']
|
||||
assert 'columns' in table_data
|
||||
assert 'rows' in table_data
|
||||
assert len(table_data['columns']) == 4
|
||||
assert len(table_data['rows']) == 3
|
||||
|
||||
def test_dashboard_permissions(self):
|
||||
"""Test dashboard permission management"""
|
||||
# Sample user permissions
|
||||
user_permissions = {
|
||||
'admin': ['read', 'write', 'delete', 'share'],
|
||||
'analyst': ['read', 'write', 'share'],
|
||||
'viewer': ['read'],
|
||||
'guest': []
|
||||
}
|
||||
|
||||
# Sample dashboard access rules
|
||||
dashboard_access = {
|
||||
'marketplace_overview': ['admin', 'analyst', 'viewer'],
|
||||
'system_metrics': ['admin'],
|
||||
'public_stats': ['admin', 'analyst', 'viewer', 'guest']
|
||||
}
|
||||
|
||||
# Test permission checking
|
||||
def check_permission(user_role, dashboard_id, action):
|
||||
if action not in user_permissions[user_role]:
|
||||
return False
|
||||
if user_role not in dashboard_access[dashboard_id]:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Validate permissions
|
||||
assert check_permission('admin', 'marketplace_overview', 'read') is True
|
||||
assert check_permission('admin', 'system_metrics', 'write') is True
|
||||
assert check_permission('viewer', 'system_metrics', 'read') is False
|
||||
assert check_permission('guest', 'public_stats', 'read') is True
|
||||
assert check_permission('analyst', 'marketplace_overview', 'delete') is False
|
||||
|
||||
|
||||
class TestReportingSystem:
|
||||
"""Test reporting system functionality"""
|
||||
|
||||
def test_report_generation(self):
|
||||
"""Test report generation capabilities"""
|
||||
# Sample report data
|
||||
report_data = {
|
||||
'report_id': 'monthly_marketplace_report',
|
||||
'title': 'Monthly Marketplace Performance',
|
||||
'period': {
|
||||
'start': '2024-01-01',
|
||||
'end': '2024-01-31'
|
||||
},
|
||||
'sections': [
|
||||
{
|
||||
'title': 'Executive Summary',
|
||||
'content': {
|
||||
'total_transactions': 1250,
|
||||
'total_volume': 156.78,
|
||||
'active_providers': 45,
|
||||
'satisfaction_rate': 4.7
|
||||
}
|
||||
},
|
||||
{
|
||||
'title': 'Price Analysis',
|
||||
'content': {
|
||||
'avg_gpu_price': 0.12,
|
||||
'price_trend': 'stable',
|
||||
'volatility_index': 0.05
|
||||
}
|
||||
}
|
||||
],
|
||||
'generated_at': datetime.utcnow().isoformat(),
|
||||
'format': 'json'
|
||||
}
|
||||
|
||||
# Validate report structure
|
||||
assert 'report_id' in report_data
|
||||
assert 'title' in report_data
|
||||
assert 'period' in report_data
|
||||
assert 'sections' in report_data
|
||||
assert 'generated_at' in report_data
|
||||
|
||||
# Validate sections
|
||||
for section in report_data['sections']:
|
||||
assert 'title' in section
|
||||
assert 'content' in section
|
||||
|
||||
# Validate data integrity
|
||||
summary = report_data['sections'][0]['content']
|
||||
assert summary['total_transactions'] > 0
|
||||
assert summary['total_volume'] > 0
|
||||
assert summary['active_providers'] > 0
|
||||
assert 0 <= summary['satisfaction_rate'] <= 5
|
||||
|
||||
def test_report_export(self):
|
||||
"""Test report export functionality"""
|
||||
# Sample report for export
|
||||
report = {
|
||||
'title': 'Marketplace Analysis',
|
||||
'data': {
|
||||
'metrics': {'transactions': 100, 'volume': 50.5},
|
||||
'trends': {'price': 'up', 'demand': 'stable'}
|
||||
},
|
||||
'metadata': {
|
||||
'generated_by': 'analytics_system',
|
||||
'generated_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Test JSON export
|
||||
json_export = json.dumps(report, indent=2)
|
||||
assert isinstance(json_export, str)
|
||||
assert 'Marketplace Analysis' in json_export
|
||||
|
||||
# Test CSV export (simplified)
|
||||
csv_data = "Metric,Value\n"
|
||||
csv_data += f"Transactions,{report['data']['metrics']['transactions']}\n"
|
||||
csv_data += f"Volume,{report['data']['metrics']['volume']}\n"
|
||||
|
||||
assert 'Transactions,100' in csv_data
|
||||
assert 'Volume,50.5' in csv_data
|
||||
assert csv_data.count('\n') == 3 # Header + 2 data rows
|
||||
|
||||
def test_report_scheduling(self):
|
||||
"""Test report scheduling functionality"""
|
||||
# Sample schedule configuration
|
||||
schedule_config = {
|
||||
'report_id': 'daily_marketplace_summary',
|
||||
'frequency': 'daily',
|
||||
'time': '08:00',
|
||||
'recipients': ['admin@aitbc.com', 'ops@aitbc.com'],
|
||||
'format': 'pdf',
|
||||
'enabled': True,
|
||||
'last_run': '2024-01-01T08:00:00Z',
|
||||
'next_run': '2024-01-02T08:00:00Z'
|
||||
}
|
||||
|
||||
# Validate schedule configuration
|
||||
assert schedule_config['frequency'] in ['daily', 'weekly', 'monthly']
|
||||
assert schedule_config['time'] == '08:00'
|
||||
assert len(schedule_config['recipients']) > 0
|
||||
assert schedule_config['enabled'] is True
|
||||
assert 'next_run' in schedule_config
|
||||
|
||||
# Test next run calculation
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
last_run = datetime.fromisoformat(schedule_config['last_run'].replace('Z', '+00:00'))
|
||||
next_run = datetime.fromisoformat(schedule_config['next_run'].replace('Z', '+00:00'))
|
||||
|
||||
expected_next_run = last_run + timedelta(days=1)
|
||||
assert next_run.date() == expected_next_run.date()
|
||||
assert next_run.hour == 8
|
||||
assert next_run.minute == 0
|
||||
|
||||
|
||||
class TestDataCollector:
|
||||
"""Test data collection functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def data_collector(self):
|
||||
return DataCollector()
|
||||
def test_data_collection_metrics(self):
|
||||
"""Test data collection metrics gathering"""
|
||||
# Sample data collection metrics
|
||||
collection_metrics = {
|
||||
'total_records_collected': 10000,
|
||||
'collection_duration_seconds': 300,
|
||||
'error_rate': 0.02, # 2%
|
||||
'data_sources': ['marketplace_api', 'blockchain_api', 'user_activity'],
|
||||
'last_collection': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Validate metrics
|
||||
assert collection_metrics['total_records_collected'] > 0
|
||||
assert collection_metrics['collection_duration_seconds'] > 0
|
||||
assert 0 <= collection_metrics['error_rate'] <= 1
|
||||
assert len(collection_metrics['data_sources']) > 0
|
||||
assert 'last_collection' in collection_metrics
|
||||
|
||||
# Calculate collection rate
|
||||
collection_rate = collection_metrics['total_records_collected'] / collection_metrics['collection_duration_seconds']
|
||||
assert collection_rate > 10 # Should collect at least 10 records per second
|
||||
|
||||
def test_collect_transaction_volume(self, data_collector):
|
||||
"""Test transaction volume collection"""
|
||||
|
||||
@@ -6,5 +6,26 @@ then patches httpx.Client so every CLI command's HTTP call is routed
|
||||
through the ASGI transport instead of making real network requests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
f
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from click.testing import CliRunner
|
||||
from aitbc_cli.main import cli
|
||||
|
||||
|
||||
class TestCLIIntegration:
|
||||
"""Test CLI integration with coordinator"""
|
||||
|
||||
def test_cli_help(self):
|
||||
"""Test CLI help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'aitbc' in result.output.lower()
|
||||
|
||||
def test_config_show(self):
|
||||
"""Test config show command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['config-show'])
|
||||
assert result.exit_code == 0
|
||||
@@ -15,4 +15,56 @@ def runner():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Mock configu
|
||||
"""Mock configuration for testing"""
|
||||
return {
|
||||
'coordinator_url': 'http://localhost:8000',
|
||||
'api_key': 'test-key',
|
||||
'wallet_name': 'test-wallet'
|
||||
}
|
||||
|
||||
|
||||
class TestMarketplaceCommands:
|
||||
"""Test suite for marketplace commands"""
|
||||
|
||||
def test_marketplace_help(self, runner):
|
||||
"""Test marketplace help command"""
|
||||
result = runner.invoke(cli, ['marketplace', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'marketplace' in result.output.lower()
|
||||
|
||||
def test_marketplace_list(self, runner, mock_config):
|
||||
"""Test marketplace listing command"""
|
||||
with patch('aitbc_cli.config.get_config') as mock_get_config:
|
||||
mock_get_config.return_value = mock_config
|
||||
with patch('httpx.Client.get') as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'offers': [
|
||||
{'id': 1, 'price': 0.1, 'gpu_type': 'RTX 3080'},
|
||||
{'id': 2, 'price': 0.15, 'gpu_type': 'RTX 3090'}
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = runner.invoke(cli, ['marketplace', 'offers', 'list'])
|
||||
assert result.exit_code == 0
|
||||
assert 'offers' in result.output.lower() or 'gpu' in result.output.lower()
|
||||
|
||||
def test_marketplace_gpu_pricing(self, runner, mock_config):
|
||||
"""Test marketplace GPU pricing command"""
|
||||
with patch('aitbc_cli.config.get_config') as mock_get_config:
|
||||
mock_get_config.return_value = mock_config
|
||||
with patch('httpx.Client.get') as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'gpu_model': 'RTX 3080',
|
||||
'avg_price': 0.12,
|
||||
'price_range': {'min': 0.08, 'max': 0.15}
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = runner.invoke(cli, ['marketplace', 'pricing', 'RTX 3080'])
|
||||
assert result.exit_code == 0
|
||||
assert 'price' in result.output.lower() or 'rtx' in result.output.lower()
|
||||
@@ -5,14 +5,25 @@ import json
|
||||
import base64
|
||||
from unittest.mock import Mock, patch
|
||||
from click.testing import CliRunner
|
||||
from aitbc_cli.commands.marketplace_advanced import advanced, models, analytics, trading, dispute
|
||||
from aitbc_cli.main import cli
|
||||
|
||||
|
||||
class TestModelsCommands:
|
||||
"""Test advanced model NFT operations commands"""
|
||||
class TestMarketplaceAdvanced:
|
||||
"""Test advanced marketplace commands"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test environment"""
|
||||
def test_marketplace_help(self):
|
||||
"""Test marketplace help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['marketplace', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'marketplace' in result.output.lower()
|
||||
|
||||
def test_marketplace_agents_help(self):
|
||||
"""Test marketplace agents help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['marketplace', 'agents', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'agents' in result.output.lower()
|
||||
self.runner = CliRunner()
|
||||
self.config = {
|
||||
'coordinator_url': 'http://test:8000',
|
||||
|
||||
@@ -100,7 +100,7 @@ class TestSimulateCommands:
|
||||
with patch('aitbc_cli.commands.simulate.Path') as mock_path_class:
|
||||
# Make Path return our temp directory
|
||||
mock_path_class.return_value = home_dir
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/home" else Path(x)
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/tests/e2e/fixtures/home" else Path(x)
|
||||
|
||||
# Run command
|
||||
result = runner.invoke(simulate, [
|
||||
@@ -129,7 +129,7 @@ class TestSimulateCommands:
|
||||
# Patch the hardcoded path
|
||||
with patch('aitbc_cli.commands.simulate.Path') as mock_path_class:
|
||||
mock_path_class.return_value = home_dir
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/home" else Path(x)
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/tests/e2e/fixtures/home" else Path(x)
|
||||
|
||||
# Run command
|
||||
result = runner.invoke(simulate, [
|
||||
@@ -151,7 +151,7 @@ class TestSimulateCommands:
|
||||
# Patch the hardcoded path
|
||||
with patch('aitbc_cli.commands.simulate.Path') as mock_path_class:
|
||||
mock_path_class.return_value = home_dir
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/home" else Path(x)
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/tests/e2e/fixtures/home" else Path(x)
|
||||
|
||||
# Run command
|
||||
result = runner.invoke(simulate, [
|
||||
@@ -182,7 +182,7 @@ class TestSimulateCommands:
|
||||
# Patch the hardcoded path
|
||||
with patch('aitbc_cli.commands.simulate.Path') as mock_path_class:
|
||||
mock_path_class.return_value = home_dir
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/home" else Path(x)
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/tests/e2e/fixtures/home" else Path(x)
|
||||
|
||||
# Run command
|
||||
result = runner.invoke(simulate, [
|
||||
@@ -210,7 +210,7 @@ class TestSimulateCommands:
|
||||
# Patch the hardcoded path
|
||||
with patch('aitbc_cli.commands.simulate.Path') as mock_path_class:
|
||||
mock_path_class.return_value = home_dir
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/home" else Path(x)
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/tests/e2e/fixtures/home" else Path(x)
|
||||
|
||||
# Run command
|
||||
result = runner.invoke(simulate, [
|
||||
@@ -238,7 +238,7 @@ class TestSimulateCommands:
|
||||
# Patch the hardcoded path
|
||||
with patch('aitbc_cli.commands.simulate.Path') as mock_path_class:
|
||||
mock_path_class.return_value = home_dir
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/home" else Path(x)
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/tests/e2e/fixtures/home" else Path(x)
|
||||
|
||||
# Run command
|
||||
result = runner.invoke(simulate, [
|
||||
@@ -347,7 +347,7 @@ class TestSimulateCommands:
|
||||
# Patch the hardcoded path
|
||||
with patch('aitbc_cli.commands.simulate.Path') as mock_path_class:
|
||||
mock_path_class.return_value = home_dir
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/home" else Path(x)
|
||||
mock_path_class.side_effect = lambda x: home_dir if x == "/home/oib/windsurf/aitbc/tests/e2e/fixtures/home" else Path(x)
|
||||
|
||||
# Run command with reset flag
|
||||
result = runner.invoke(simulate, [
|
||||
|
||||
@@ -12,4 +12,66 @@ from aitbc_cli.main import cli
|
||||
|
||||
|
||||
def extract_json_from_output(output):
|
||||
"""Extract JSON from CLI output"""
|
||||
try:
|
||||
# Look for JSON blocks in output
|
||||
json_match = re.search(r'\{.*\}', output, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
class TestWalletCommands:
|
||||
"""Test suite for wallet commands"""
|
||||
|
||||
def test_wallet_help(self):
|
||||
"""Test wallet help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['wallet', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'wallet' in result.output.lower()
|
||||
|
||||
def test_wallet_create(self):
|
||||
"""Test wallet creation"""
|
||||
runner = CliRunner()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Set wallet directory in environment
|
||||
env = {'WALLET_DIR': temp_dir}
|
||||
# Use unique wallet name with timestamp
|
||||
import time
|
||||
wallet_name = f"test-wallet-{int(time.time())}"
|
||||
result = runner.invoke(cli, ['wallet', 'create', wallet_name], env=env)
|
||||
print(f"Exit code: {result.exit_code}")
|
||||
print(f"Output: {result.output}")
|
||||
print(f"Temp dir contents: {list(Path(temp_dir).iterdir())}")
|
||||
assert result.exit_code == 0
|
||||
# Check if wallet was created successfully
|
||||
assert 'created' in result.output.lower() or 'wallet' in result.output.lower()
|
||||
|
||||
def test_wallet_balance(self):
|
||||
"""Test wallet balance command"""
|
||||
runner = CliRunner()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Set wallet directory in environment
|
||||
env = {'WALLET_DIR': temp_dir}
|
||||
# Use unique wallet name
|
||||
import time
|
||||
wallet_name = f"test-wallet-balance-{int(time.time())}"
|
||||
# Create wallet first
|
||||
create_result = runner.invoke(cli, ['wallet', 'create', wallet_name], env=env)
|
||||
assert create_result.exit_code == 0
|
||||
|
||||
# Switch to the created wallet
|
||||
switch_result = runner.invoke(cli, ['wallet', 'switch', wallet_name], env=env)
|
||||
assert switch_result.exit_code == 0
|
||||
|
||||
# Check balance (uses current active wallet)
|
||||
result = runner.invoke(cli, ['wallet', 'balance'], env=env)
|
||||
print(f"Balance exit code: {result.exit_code}")
|
||||
print(f"Balance output: {result.output}")
|
||||
assert result.exit_code == 0
|
||||
# Should contain balance information
|
||||
assert 'balance' in result.output.lower() or 'aitbc' in result.output.lower()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Enhanced conftest for pytest with AITBC CLI support
|
||||
Enhanced conftest for pytest with AITBC CLI support and comprehensive test coverage
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@@ -16,14 +16,41 @@ sys.path.insert(0, str(project_root))
|
||||
# Add CLI path
|
||||
sys.path.insert(0, str(project_root / "cli"))
|
||||
|
||||
# Add necessary source paths
|
||||
sys.path.insert(0, str(project_root / "packages" / "py" / "aitbc-core" / "src"))
|
||||
sys.path.insert(0, str(project_root / "packages" / "py" / "aitbc-crypto" / "src"))
|
||||
sys.path.insert(0, str(project_root / "packages" / "py" / "aitbc-p2p" / "src"))
|
||||
sys.path.insert(0, str(project_root / "packages" / "py" / "aitbc-sdk" / "src"))
|
||||
sys.path.insert(0, str(project_root / "apps" / "coordinator-api" / "src"))
|
||||
sys.path.insert(0, str(project_root / "apps" / "wallet-daemon" / "src"))
|
||||
sys.path.insert(0, str(project_root / "apps" / "blockchain-node" / "src"))
|
||||
# Add all source paths for comprehensive testing
|
||||
source_paths = [
|
||||
"packages/py/aitbc-core/src",
|
||||
"packages/py/aitbc-crypto/src",
|
||||
"packages/py/aitbc-p2p/src",
|
||||
"packages/py/aitbc-sdk/src",
|
||||
"apps/coordinator-api/src",
|
||||
"apps/wallet-daemon/src",
|
||||
"apps/blockchain-node/src",
|
||||
"apps/pool-hub/src",
|
||||
"apps/explorer-web/src",
|
||||
"apps/zk-circuits/src"
|
||||
]
|
||||
|
||||
for path in source_paths:
|
||||
full_path = project_root / path
|
||||
if full_path.exists():
|
||||
sys.path.insert(0, str(full_path))
|
||||
|
||||
# Add test paths for imports
|
||||
test_paths = [
|
||||
"packages/py/aitbc-crypto/tests",
|
||||
"packages/py/aitbc-sdk/tests",
|
||||
"apps/coordinator-api/tests",
|
||||
"apps/wallet-daemon/tests",
|
||||
"apps/blockchain-node/tests",
|
||||
"apps/pool-hub/tests",
|
||||
"apps/explorer-web/tests",
|
||||
"cli/tests"
|
||||
]
|
||||
|
||||
for path in test_paths:
|
||||
full_path = project_root / path
|
||||
if full_path.exists():
|
||||
sys.path.insert(0, str(full_path))
|
||||
|
||||
# Set up test environment
|
||||
os.environ["TEST_MODE"] = "true"
|
||||
@@ -49,6 +76,75 @@ sys.modules['aitbc_crypto'].encrypt_data = mock_encrypt_data
|
||||
sys.modules['aitbc_crypto'].decrypt_data = mock_decrypt_data
|
||||
sys.modules['aitbc_crypto'].generate_viewing_key = mock_generate_viewing_key
|
||||
|
||||
# Common fixtures for all test types
|
||||
@pytest.fixture
|
||||
def cli_runner():
|
||||
"""Create CLI runner for testing"""
|
||||
return CliRunner()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Mock configuration for testing"""
|
||||
return {
|
||||
'coordinator_url': 'http://localhost:8000',
|
||||
'api_key': 'test-key',
|
||||
'wallet_name': 'test-wallet',
|
||||
'blockchain_url': 'http://localhost:8082'
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create temporary directory for tests"""
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client():
|
||||
"""Mock HTTP client for API testing"""
|
||||
mock_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.put.return_value = mock_response
|
||||
mock_client.delete.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
# Test markers for different test types
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest markers"""
|
||||
config.addinivalue_line("markers", "unit: Unit tests (fast, isolated)")
|
||||
config.addinivalue_line("markers", "integration: Integration tests (may require external services)")
|
||||
config.addinivalue_line("markers", "slow: Slow running tests")
|
||||
config.addinivalue_line("markers", "cli: CLI command tests")
|
||||
config.addinivalue_line("markers", "api: API endpoint tests")
|
||||
config.addinivalue_line("markers", "blockchain: Blockchain-related tests")
|
||||
config.addinivalue_line("markers", "crypto: Cryptography tests")
|
||||
config.addinivalue_line("markers", "contracts: Smart contract tests")
|
||||
|
||||
# Pytest collection hooks
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Modify test collection to add markers based on file location"""
|
||||
for item in items:
|
||||
# Add markers based on file path
|
||||
if "cli/tests" in str(item.fspath):
|
||||
item.add_marker(pytest.mark.cli)
|
||||
elif "apps/coordinator-api/tests" in str(item.fspath):
|
||||
item.add_marker(pytest.mark.api)
|
||||
elif "apps/blockchain-node/tests" in str(item.fspath):
|
||||
item.add_marker(pytest.mark.blockchain)
|
||||
elif "packages/py/aitbc-crypto/tests" in str(item.fspath):
|
||||
item.add_marker(pytest.mark.crypto)
|
||||
elif "contracts/test" in str(item.fspath):
|
||||
item.add_marker(pytest.mark.contracts)
|
||||
|
||||
# Add slow marker for integration tests
|
||||
if "integration" in str(item.fspath).lower():
|
||||
item.add_marker(pytest.mark.integration)
|
||||
item.add_marker(pytest.mark.slow)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aitbc_cli_runner():
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
|
||||
This directory contains comprehensive end-to-end tests for the AITBC enhanced services, validating complete workflows, performance benchmarks, and system integration.
|
||||
|
||||
## 📁 Directory Structure
|
||||
|
||||
```
|
||||
tests/e2e/
|
||||
├── fixtures/ # Test fixtures and mock data
|
||||
│ ├── home/ # Mock agent home directories
|
||||
│ │ ├── client1/ # Client agent home
|
||||
│ │ └── miner1/ # Miner agent home
|
||||
│ └── __init__.py # Fixture utilities and classes
|
||||
├── conftest.py # Pytest configuration
|
||||
├── conftest_fixtures.py # Extended fixture configuration
|
||||
├── test_*.py # Individual test files
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## 🎯 Test Coverage
|
||||
|
||||
### Test Suites
|
||||
@@ -22,6 +37,49 @@ This directory contains comprehensive end-to-end tests for the AITBC enhanced se
|
||||
- **Marketplace Performance**: Transaction processing, royalty calculation times
|
||||
- **Concurrent Performance**: Load testing with multiple concurrent requests
|
||||
|
||||
## 🔧 Test Fixtures
|
||||
|
||||
### Home Directory Fixtures
|
||||
|
||||
The `tests/e2e/fixtures/home/` directory contains mock home directories for testing agent scenarios:
|
||||
|
||||
```python
|
||||
# Using fixture home directories
|
||||
def test_agent_workflow(test_home_dirs):
|
||||
client_home = test_home_dirs / "client1"
|
||||
miner_home = test_home_dirs / "miner1"
|
||||
|
||||
# Test agent operations using mock home directories
|
||||
```
|
||||
|
||||
### Available Fixtures
|
||||
|
||||
- **`test_home_dirs`**: Access to fixture home directories
|
||||
- **`temp_home_dirs`**: Temporary home directories for isolated testing
|
||||
- **`home_dir_fixture`**: Manager for creating custom home directory setups
|
||||
- **`standard_test_agents`**: Pre-configured test agents (client1, client2, miner1, miner2, agent1, agent2)
|
||||
- **`cross_container_test_setup`**: Agents configured for cross-container testing
|
||||
|
||||
### Fixture Usage Examples
|
||||
|
||||
```python
|
||||
def test_with_standard_agents(standard_test_agents):
|
||||
"""Test using pre-configured agents"""
|
||||
client1_home = standard_test_agents["client1"]
|
||||
miner1_home = standard_test_agents["miner1"]
|
||||
|
||||
# Test logic here
|
||||
|
||||
def test_custom_agent_setup(home_dir_fixture):
|
||||
"""Test with custom agent configuration"""
|
||||
agents = home_dir_fixture.create_multi_agent_setup([
|
||||
{"name": "custom_client", "type": "client", "initial_balance": 5000},
|
||||
{"name": "custom_miner", "type": "miner", "initial_balance": 10000}
|
||||
])
|
||||
|
||||
# Test logic here
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
316
tests/e2e/conftest_fixtures.py
Normal file
316
tests/e2e/conftest_fixtures.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
E2E Test Fixtures Configuration
|
||||
|
||||
Extended pytest configuration for home directory fixtures
|
||||
and test data management for end-to-end testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import json
|
||||
import yaml
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fixture_base_path():
|
||||
"""Base path for all test fixtures"""
|
||||
return Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_home_dirs(fixture_base_path):
|
||||
"""Access to test home directories"""
|
||||
home_path = fixture_base_path / "home"
|
||||
|
||||
if not home_path.exists():
|
||||
pytest.skip("Test home directories not found")
|
||||
|
||||
return home_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_home_dirs():
|
||||
"""Create temporary home directories for testing"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_path = Path(temp_dir)
|
||||
|
||||
# Create standard AITBC home structure
|
||||
agents = {}
|
||||
|
||||
for agent_name in ["test_client", "test_miner", "test_agent"]:
|
||||
agent_path = base_path / agent_name
|
||||
agent_path.mkdir(exist_ok=True)
|
||||
|
||||
# Create AITBC directory structure
|
||||
aitbc_dir = agent_path / ".aitbc"
|
||||
aitbc_dir.mkdir(exist_ok=True)
|
||||
|
||||
(aitbc_dir / "wallets").mkdir(exist_ok=True)
|
||||
(aitbc_dir / "config").mkdir(exist_ok=True)
|
||||
(aitbc_dir / "cache").mkdir(exist_ok=True)
|
||||
|
||||
# Create default configuration
|
||||
config_data = {
|
||||
"agent": {
|
||||
"name": agent_name,
|
||||
"type": "client" if "client" in agent_name else "miner" if "miner" in agent_name else "agent",
|
||||
"wallet_path": f"~/.aitbc/wallets/{agent_name}_wallet.json"
|
||||
},
|
||||
"node": {
|
||||
"endpoint": "http://localhost:8082",
|
||||
"timeout": 30
|
||||
},
|
||||
"coordinator": {
|
||||
"url": "http://localhost:8000",
|
||||
"api_key": None
|
||||
}
|
||||
}
|
||||
|
||||
config_file = aitbc_dir / "config.yaml"
|
||||
with open(config_file, 'w') as f:
|
||||
yaml.dump(config_data, f, default_flow_style=False)
|
||||
|
||||
agents[agent_name] = agent_path
|
||||
|
||||
yield agents
|
||||
|
||||
# Cleanup is handled by tempfile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_wallet(temp_home_dirs):
|
||||
"""Create a mock agent wallet for testing"""
|
||||
agent_path = temp_home_dirs["test_client"]
|
||||
wallet_path = agent_path / ".aitbc" / "wallets" / "test_client_wallet.json"
|
||||
|
||||
wallet_data = {
|
||||
"address": "aitbc1testclient",
|
||||
"balance": 1000,
|
||||
"transactions": [],
|
||||
"created_at": "2026-03-03T00:00:00Z"
|
||||
}
|
||||
|
||||
with open(wallet_path, 'w') as f:
|
||||
json.dump(wallet_data, f, indent=2)
|
||||
|
||||
return wallet_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_miner_wallet(temp_home_dirs):
|
||||
"""Create a mock miner wallet for testing"""
|
||||
agent_path = temp_home_dirs["test_miner"]
|
||||
wallet_path = agent_path / ".aitbc" / "wallets" / "test_miner_wallet.json"
|
||||
|
||||
wallet_data = {
|
||||
"address": "aitbc1testminer",
|
||||
"balance": 5000,
|
||||
"transactions": [],
|
||||
"created_at": "2026-03-03T00:00:00Z",
|
||||
"mining_rewards": 2000
|
||||
}
|
||||
|
||||
with open(wallet_path, 'w') as f:
|
||||
json.dump(wallet_data, f, indent=2)
|
||||
|
||||
return wallet_data
|
||||
|
||||
|
||||
class HomeDirFixture:
|
||||
"""Helper class for managing home directory fixtures"""
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
self.base_path = base_path
|
||||
self.created_dirs: List[Path] = []
|
||||
|
||||
def create_agent_home(self,
|
||||
agent_name: str,
|
||||
agent_type: str = "agent",
|
||||
initial_balance: int = 0) -> Path:
|
||||
"""Create a new agent home directory with AITBC structure"""
|
||||
agent_path = self.base_path / agent_name
|
||||
agent_path.mkdir(exist_ok=True)
|
||||
|
||||
# Create AITBC directory structure
|
||||
aitbc_dir = agent_path / ".aitbc"
|
||||
aitbc_dir.mkdir(exist_ok=True)
|
||||
|
||||
(aitbc_dir / "wallets").mkdir(exist_ok=True)
|
||||
(aitbc_dir / "config").mkdir(exist_ok=True)
|
||||
(aitbc_dir / "cache").mkdir(exist_ok=True)
|
||||
|
||||
# Create configuration
|
||||
config_data = {
|
||||
"agent": {
|
||||
"name": agent_name,
|
||||
"type": agent_type,
|
||||
"wallet_path": f"~/.aitbc/wallets/{agent_name}_wallet.json"
|
||||
},
|
||||
"node": {
|
||||
"endpoint": "http://localhost:8082",
|
||||
"timeout": 30
|
||||
},
|
||||
"coordinator": {
|
||||
"url": "http://localhost:8000",
|
||||
"api_key": None
|
||||
}
|
||||
}
|
||||
|
||||
config_file = aitbc_dir / "config.yaml"
|
||||
with open(config_file, 'w') as f:
|
||||
yaml.dump(config_data, f, default_flow_style=False)
|
||||
|
||||
# Create wallet
|
||||
wallet_data = {
|
||||
"address": f"aitbc1{agent_name}",
|
||||
"balance": initial_balance,
|
||||
"transactions": [],
|
||||
"created_at": "2026-03-03T00:00:00Z"
|
||||
}
|
||||
|
||||
wallet_file = aitbc_dir / "wallets" / f"{agent_name}_wallet.json"
|
||||
with open(wallet_file, 'w') as f:
|
||||
json.dump(wallet_data, f, indent=2)
|
||||
|
||||
self.created_dirs.append(agent_path)
|
||||
return agent_path
|
||||
|
||||
def create_multi_agent_setup(self, agent_configs: List[Dict]) -> Dict[str, Path]:
|
||||
"""Create multiple agent homes from configuration"""
|
||||
agents = {}
|
||||
|
||||
for config in agent_configs:
|
||||
agent_path = self.create_agent_home(
|
||||
agent_name=config["name"],
|
||||
agent_type=config["type"],
|
||||
initial_balance=config.get("initial_balance", 0)
|
||||
)
|
||||
agents[config["name"]] = agent_path
|
||||
|
||||
return agents
|
||||
|
||||
def get_agent_config(self, agent_name: str) -> Optional[Dict]:
|
||||
"""Get configuration for an agent"""
|
||||
agent_path = self.base_path / agent_name
|
||||
config_file = agent_path / ".aitbc" / "config.yaml"
|
||||
|
||||
if config_file.exists():
|
||||
with open(config_file, 'r') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
return None
|
||||
|
||||
def get_agent_wallet(self, agent_name: str) -> Optional[Dict]:
|
||||
"""Get wallet data for an agent"""
|
||||
agent_path = self.base_path / agent_name
|
||||
wallet_file = agent_path / ".aitbc" / "wallets" / f"{agent_name}_wallet.json"
|
||||
|
||||
if wallet_file.exists():
|
||||
with open(wallet_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
return None
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up all created directories"""
|
||||
for dir_path in self.created_dirs:
|
||||
if dir_path.exists():
|
||||
shutil.rmtree(dir_path)
|
||||
self.created_dirs.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def home_dir_fixture(tmp_path):
|
||||
"""Create a home directory fixture manager"""
|
||||
fixture = HomeDirFixture(tmp_path)
|
||||
yield fixture
|
||||
fixture.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def standard_test_agents(home_dir_fixture):
|
||||
"""Create standard test agents for E2E testing"""
|
||||
agent_configs = [
|
||||
{"name": "client1", "type": "client", "initial_balance": 1000},
|
||||
{"name": "client2", "type": "client", "initial_balance": 500},
|
||||
{"name": "miner1", "type": "miner", "initial_balance": 2000},
|
||||
{"name": "miner2", "type": "miner", "initial_balance": 1500},
|
||||
{"name": "agent1", "type": "agent", "initial_balance": 800},
|
||||
{"name": "agent2", "type": "agent", "initial_balance": 1200}
|
||||
]
|
||||
|
||||
return home_dir_fixture.create_multi_agent_setup(agent_configs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cross_container_test_setup(home_dir_fixture):
|
||||
"""Create test setup for cross-container E2E tests"""
|
||||
# Create agents for different containers/sites
|
||||
agent_configs = [
|
||||
{"name": "localhost_client", "type": "client", "initial_balance": 1000},
|
||||
{"name": "aitbc_client", "type": "client", "initial_balance": 2000},
|
||||
{"name": "aitbc1_client", "type": "client", "initial_balance": 1500},
|
||||
{"name": "localhost_miner", "type": "miner", "initial_balance": 3000},
|
||||
{"name": "aitbc_miner", "type": "miner", "initial_balance": 2500},
|
||||
{"name": "aitbc1_miner", "type": "miner", "initial_balance": 2800}
|
||||
]
|
||||
|
||||
return home_dir_fixture.create_multi_agent_setup(agent_configs)
|
||||
|
||||
|
||||
# Helper functions for test development
|
||||
def create_test_transaction(from_addr: str, to_addr: str, amount: int, tx_hash: str = None) -> Dict:
|
||||
"""Create a test transaction for wallet testing"""
|
||||
import hashlib
|
||||
|
||||
if tx_hash is None:
|
||||
tx_hash = hashlib.sha256(f"{from_addr}{to_addr}{amount}".encode()).hexdigest()
|
||||
|
||||
return {
|
||||
"hash": tx_hash,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"amount": amount,
|
||||
"timestamp": "2026-03-03T12:00:00Z",
|
||||
"type": "transfer",
|
||||
"status": "confirmed"
|
||||
}
|
||||
|
||||
|
||||
def add_transaction_to_wallet(wallet_path: Path, transaction: Dict):
|
||||
"""Add a transaction to a wallet file"""
|
||||
with open(wallet_path, 'r') as f:
|
||||
wallet_data = json.load(f)
|
||||
|
||||
wallet_data["transactions"].append(transaction)
|
||||
|
||||
# Update balance for outgoing transactions
|
||||
if transaction["from"] == wallet_data["address"]:
|
||||
wallet_data["balance"] -= transaction["amount"]
|
||||
# Update balance for incoming transactions
|
||||
elif transaction["to"] == wallet_data["address"]:
|
||||
wallet_data["balance"] += transaction["amount"]
|
||||
|
||||
with open(wallet_path, 'w') as f:
|
||||
json.dump(wallet_data, f, indent=2)
|
||||
|
||||
|
||||
def verify_wallet_state(wallet_path: Path, expected_balance: int, min_transactions: int = 0) -> bool:
|
||||
"""Verify wallet state matches expectations"""
|
||||
with open(wallet_path, 'r') as f:
|
||||
wallet_data = json.load(f)
|
||||
|
||||
return (
|
||||
wallet_data["balance"] == expected_balance and
|
||||
len(wallet_data["transactions"]) >= min_transactions
|
||||
)
|
||||
|
||||
|
||||
# Pytest markers for categorizing E2E tests
|
||||
pytest.mark.e2e_home_dirs = pytest.mark.e2e_home_dirs("Tests that use home directory fixtures")
|
||||
pytest.mark.cross_container = pytest.mark.cross_container("Tests that span multiple containers")
|
||||
pytest.mark.agent_simulation = pytest.mark.agent_simulation("Tests that simulate agent behavior")
|
||||
pytest.mark.wallet_management = pytest.mark.wallet_management("Tests that focus on wallet operations")
|
||||
222
tests/e2e/fixtures/__init__.py
Normal file
222
tests/e2e/fixtures/__init__.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
E2E Test Fixtures
|
||||
|
||||
This package contains fixtures and test data for end-to-end testing,
|
||||
including mock home directories for agents and users.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_home_dir():
|
||||
"""Create a temporary mock home directory for testing"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
home_path = Path(temp_dir)
|
||||
|
||||
# Create standard AITBC home directory structure
|
||||
(home_path / ".aitbc").mkdir(exist_ok=True)
|
||||
(home_path / ".aitbc" / "wallets").mkdir(exist_ok=True)
|
||||
(home_path / ".aitbc" / "config").mkdir(exist_ok=True)
|
||||
(home_path / ".aitbc" / "cache").mkdir(exist_ok=True)
|
||||
|
||||
yield home_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_home_dirs():
|
||||
"""Create mock agent home directories for testing"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_path = Path(temp_dir)
|
||||
|
||||
# Create agent home directories
|
||||
agents = {}
|
||||
for agent_name in ["client1", "miner1", "agent1", "agent2"]:
|
||||
agent_path = base_path / agent_name
|
||||
agent_path.mkdir(exist_ok=True)
|
||||
|
||||
# Create AITBC structure
|
||||
(agent_path / ".aitbc").mkdir(exist_ok=True)
|
||||
(agent_path / ".aitbc" / "wallets").mkdir(exist_ok=True)
|
||||
(agent_path / ".aitbc" / "config").mkdir(exist_ok=True)
|
||||
|
||||
# Create default config
|
||||
config_file = agent_path / ".aitbc" / "config.yaml"
|
||||
config_file.write_text(f"""
|
||||
agent:
|
||||
name: {agent_name}
|
||||
type: {"client" if "client" in agent_name else "miner" if "miner" in agent_name else "agent"}
|
||||
wallet_path: ~/.aitbc/wallets/{agent_name}_wallet.json
|
||||
|
||||
node:
|
||||
endpoint: http://localhost:8082
|
||||
timeout: 30
|
||||
|
||||
coordinator:
|
||||
url: http://localhost:8000
|
||||
api_key: null
|
||||
""")
|
||||
|
||||
agents[agent_name] = agent_path
|
||||
|
||||
yield agents
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_home_dirs():
|
||||
"""Access to the actual fixture home directories"""
|
||||
fixture_path = Path(__file__).parent / "home"
|
||||
|
||||
if not fixture_path.exists():
|
||||
pytest.skip("Fixture home directories not found")
|
||||
|
||||
return fixture_path
|
||||
|
||||
|
||||
class HomeDirManager:
|
||||
"""Manager for test home directories"""
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
self.base_path = base_path
|
||||
self.created_dirs: List[Path] = []
|
||||
|
||||
def create_agent_home(self, agent_name: str, agent_type: str = "agent") -> Path:
|
||||
"""Create a new agent home directory"""
|
||||
agent_path = self.base_path / agent_name
|
||||
agent_path.mkdir(exist_ok=True)
|
||||
|
||||
# Create AITBC structure
|
||||
(agent_path / ".aitbc").mkdir(exist_ok=True)
|
||||
(agent_path / ".aitbc" / "wallets").mkdir(exist_ok=True)
|
||||
(agent_path / ".aitbc" / "config").mkdir(exist_ok=True)
|
||||
|
||||
# Create default config
|
||||
config_file = agent_path / ".aitbc" / "config.yaml"
|
||||
config_file.write_text(f"""
|
||||
agent:
|
||||
name: {agent_name}
|
||||
type: {agent_type}
|
||||
wallet_path: ~/.aitbc/wallets/{agent_name}_wallet.json
|
||||
|
||||
node:
|
||||
endpoint: http://localhost:8082
|
||||
timeout: 30
|
||||
|
||||
coordinator:
|
||||
url: http://localhost:8000
|
||||
api_key: null
|
||||
""")
|
||||
|
||||
self.created_dirs.append(agent_path)
|
||||
return agent_path
|
||||
|
||||
def create_wallet(self, agent_name: str, address: str, balance: int = 0) -> Path:
|
||||
"""Create a wallet file for an agent"""
|
||||
agent_path = self.base_path / agent_name
|
||||
wallet_path = agent_path / ".aitbc" / "wallets" / f"{agent_name}_wallet.json"
|
||||
|
||||
wallet_data = {
|
||||
"address": address,
|
||||
"balance": balance,
|
||||
"transactions": [],
|
||||
"created_at": "2026-03-03T00:00:00Z"
|
||||
}
|
||||
|
||||
import json
|
||||
wallet_path.write_text(json.dumps(wallet_data, indent=2))
|
||||
return wallet_path
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up created directories"""
|
||||
for dir_path in self.created_dirs:
|
||||
if dir_path.exists():
|
||||
import shutil
|
||||
shutil.rmtree(dir_path)
|
||||
self.created_dirs.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def home_dir_manager(tmp_path):
|
||||
"""Create a home directory manager for tests"""
|
||||
manager = HomeDirManager(tmp_path)
|
||||
yield manager
|
||||
manager.cleanup()
|
||||
|
||||
|
||||
# Constants for fixture paths
|
||||
FIXTURE_HOME_PATH = Path(__file__).parent / "home"
|
||||
CLIENT1_HOME_PATH = FIXTURE_HOME_PATH / "client1"
|
||||
MINER1_HOME_PATH = FIXTURE_HOME_PATH / "miner1"
|
||||
|
||||
|
||||
def get_fixture_home_path(agent_name: str) -> Path:
|
||||
"""Get the path to a fixture home directory"""
|
||||
return FIXTURE_HOME_PATH / agent_name
|
||||
|
||||
|
||||
def fixture_home_exists(agent_name: str) -> bool:
|
||||
"""Check if a fixture home directory exists"""
|
||||
return get_fixture_home_path(agent_name).exists()
|
||||
|
||||
|
||||
def create_test_wallet(agent_name: str, address: str, balance: int = 0) -> Dict:
|
||||
"""Create test wallet data"""
|
||||
return {
|
||||
"address": address,
|
||||
"balance": balance,
|
||||
"transactions": [],
|
||||
"created_at": "2026-03-03T00:00:00Z",
|
||||
"agent_name": agent_name
|
||||
}
|
||||
|
||||
|
||||
def setup_fixture_homes():
|
||||
"""Set up the fixture home directories if they don't exist"""
|
||||
fixture_path = FIXTURE_HOME_PATH
|
||||
|
||||
if not fixture_path.exists():
|
||||
fixture_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create standard agent homes
|
||||
for agent_name, agent_type in [("client1", "client"), ("miner1", "miner")]:
|
||||
agent_path = fixture_path / agent_name
|
||||
agent_path.mkdir(exist_ok=True)
|
||||
|
||||
# Create AITBC structure
|
||||
(agent_path / ".aitbc").mkdir(exist_ok=True)
|
||||
(agent_path / ".aitbc" / "wallets").mkdir(exist_ok=True)
|
||||
(agent_path / ".aitbc" / "config").mkdir(exist_ok=True)
|
||||
|
||||
# Create default config
|
||||
config_file = agent_path / ".aitbc" / "config.yaml"
|
||||
config_file.write_text(f"""
|
||||
agent:
|
||||
name: {agent_name}
|
||||
type: {agent_type}
|
||||
wallet_path: ~/.aitbc/wallets/{agent_name}_wallet.json
|
||||
|
||||
node:
|
||||
endpoint: http://localhost:8082
|
||||
timeout: 30
|
||||
|
||||
coordinator:
|
||||
url: http://localhost:8000
|
||||
api_key: null
|
||||
""")
|
||||
|
||||
# Create empty wallet
|
||||
wallet_file = agent_path / ".aitbc" / "wallets" / f"{agent_name}_wallet.json"
|
||||
wallet_data = create_test_wallet(agent_name, f"aitbc1{agent_name}", 1000)
|
||||
import json
|
||||
wallet_file.write_text(json.dumps(wallet_data, indent=2))
|
||||
|
||||
|
||||
# Ensure fixture homes exist when this module is imported
|
||||
setup_fixture_homes()
|
||||
13
tests/e2e/fixtures/home/client1/.aitbc/config.yaml
Normal file
13
tests/e2e/fixtures/home/client1/.aitbc/config.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
|
||||
agent:
|
||||
name: client1
|
||||
type: client
|
||||
wallet_path: ~/.aitbc/wallets/client1_wallet.json
|
||||
|
||||
node:
|
||||
endpoint: http://localhost:8082
|
||||
timeout: 30
|
||||
|
||||
coordinator:
|
||||
url: http://localhost:8000
|
||||
api_key: null
|
||||
49
tests/e2e/fixtures/home/client1/answer.txt
Normal file
49
tests/e2e/fixtures/home/client1/answer.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
Okay, this is a hugely exciting and rapidly evolving area! The future of AI in decentralized systems is looking remarkably bright, and blockchain technology is a pivotal enabler. Here’s a breakdown of what we can expect, broken down into key areas:
|
||||
|
||||
**1. The Future Landscape of AI in Decentralized Systems**
|
||||
|
||||
* **Increased Automation & Scalability:** Current decentralized systems (like DAOs, DeFi, and gaming) often struggle with complex decision-making and scalability. AI will be crucial to automate these processes, making them more efficient and less reliant on human intervention. Think of AI-powered automated market makers, smart contracts executing complex scenarios, and personalized asset management.
|
||||
* **Enhanced Data Analysis & Insights:** Decentralized data is invaluable. AI will be used to analyze this data – identifying patterns, anomalies, and opportunities – far more effectively than traditional methods. This will lead to smarter governance, optimized resource allocation, and better risk assessment.
|
||||
* **Personalized & Adaptive Experiences:** AI will personalize user experiences within decentralized platforms. Instead of relying on rigid rules, AI will understand individual behavior and preferences to tailor everything from content recommendations to loan terms.
|
||||
* **Novel AI Models & Architectures:** We’ll see the development of AI models specifically designed for decentralized environments. This includes models that are:
|
||||
* **Federated Learning:** Allows models to be trained across multiple decentralized nodes without sharing raw data, improving privacy and model robustness.
|
||||
* **Differential Privacy:** Protects individual data points while still allowing for analysis, which is critical for privacy-preserving AI.
|
||||
* **Secure Multi-Party Computation (SMPC):** Enables multiple parties to jointly compute a result without revealing their individual inputs.
|
||||
* **AI-Driven Governance & Decentralized Autonomous Organizations (DAOs):** AI will be integrated into DAOs to:
|
||||
* **Automate Governance:** AI can analyze proposals, vote flows, and community sentiment to suggest optimal governance strategies.
|
||||
* **Identify & Mitigate Risks:** AI can detect potential risks like collusion or malicious activity within a DAO.
|
||||
* **Optimize Resource Allocation:** AI can allocate funds and resources to projects based on community demand and potential impact.
|
||||
|
||||
|
||||
**2. How Blockchain Technology Enhances AI Model Sharing & Governance**
|
||||
|
||||
Blockchain is *absolutely* the key technology here. Here's how it’s transforming AI governance:
|
||||
|
||||
* **Immutable Record of AI Models:** Blockchain creates an immutable record of every step in the AI model lifecycle – training data, model versions, validation results, and even the model’s performance metrics. This ensures transparency and auditability.
|
||||
* **Decentralized Model Sharing:** Instead of relying on centralized platforms like Hugging Face, models can be shared and distributed directly across the blockchain network. This creates a trustless ecosystem, reducing the risk of model manipulation or censorship.
|
||||
* **Smart Contracts for Model Licensing & Royalty Payments:** Smart contracts can automate licensing agreements, distribute royalties to data providers, and manage intellectual property rights related to AI models. This is crucial for incentivizing collaboration and ensuring fair compensation.
|
||||
* **Tokenization of AI Models:** Models can be tokenized (represented as unique digital assets) which can be used as collateral for loans, voting rights, or other incentives within the decentralized ecosystem. This unlocks new uses for AI assets.
|
||||
* **Reputation Systems:** Blockchain-based reputation systems can reward contributors and penalize malicious behavior, fostering a more trustworthy and collaborative environment for AI model development.
|
||||
* **Decentralized Verification & Validation:** The blockchain can be used to verify the accuracy and reliability of AI model outputs. Different parties can validate the results, building confidence in the model's output.
|
||||
* **DAO Governance & Trust:** Blockchain-based DAOs allow for decentralized decision-making on AI model deployment, updates, and governance – shifting control away from a single entity.
|
||||
|
||||
|
||||
**3. Challenges & Considerations**
|
||||
|
||||
* **Scalability:** Blockchain can be slow and expensive, hindering the scalability needed for large-scale AI deployments. Layer-2 solutions and alternative blockchains are being explored.
|
||||
* **Regulation:** The legal and regulatory landscape surrounding AI is still evolving. Decentralized AI systems need to navigate these complexities.
|
||||
* **Data Privacy:** While blockchain can enhance transparency, it’s crucial to implement privacy-preserving techniques to protect sensitive data within AI models.
|
||||
* **Computational Costs:** Running AI models on blockchain can be resource-intensive. Optimization and efficient model design are essential.
|
||||
|
||||
|
||||
**Resources for Further Learning:**
|
||||
|
||||
* **Blockchain and AI:** [https://www.blockchainandai.com/](https://www.blockchainandai.com/)
|
||||
* **Decentralized AI:** [https://www.decentralizedai.com/](https://www.decentralizedai.com/)
|
||||
* **Ethereum Foundation - AI:** [https://ethereumfoundation.org/news/ethereum-foundation-ai](https://ethereumfoundation.org/news/ethereum-foundation-ai)
|
||||
|
||||
|
||||
To help me tailor my response further, could you tell me:
|
||||
|
||||
* What specific area of AI are you most interested in (e.g., Generative AI, Machine Learning, Blockchain integration)?
|
||||
* What kind of decentralized system are you thinking of (e.g., DeFi, DAOs, Gaming, Supply Chain)?
|
||||
13
tests/e2e/fixtures/home/miner1/.aitbc/config.yaml
Normal file
13
tests/e2e/fixtures/home/miner1/.aitbc/config.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
|
||||
agent:
|
||||
name: miner1
|
||||
type: miner
|
||||
wallet_path: ~/.aitbc/wallets/miner1_wallet.json
|
||||
|
||||
node:
|
||||
endpoint: http://localhost:8082
|
||||
timeout: 30
|
||||
|
||||
coordinator:
|
||||
url: http://localhost:8000
|
||||
api_key: null
|
||||
1
tests/e2e/fixtures/home/miner1/question.txt
Normal file
1
tests/e2e/fixtures/home/miner1/question.txt
Normal file
@@ -0,0 +1 @@
|
||||
What is the future of artificial intelligence in decentralized systems, and how will blockchain technology enhance AI model sharing and governance?
|
||||
146
tests/e2e/test_fixture_verification.py
Normal file
146
tests/e2e/test_fixture_verification.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Test to verify the home directory fixture system works correctly
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from tests.e2e.fixtures import (
|
||||
FIXTURE_HOME_PATH,
|
||||
CLIENT1_HOME_PATH,
|
||||
MINER1_HOME_PATH,
|
||||
get_fixture_home_path,
|
||||
fixture_home_exists
|
||||
)
|
||||
|
||||
|
||||
def test_fixture_paths_exist():
|
||||
"""Test that all fixture paths exist"""
|
||||
assert FIXTURE_HOME_PATH.exists(), f"Fixture home path {FIXTURE_HOME_PATH} does not exist"
|
||||
assert CLIENT1_HOME_PATH.exists(), f"Client1 home path {CLIENT1_HOME_PATH} does not exist"
|
||||
assert MINER1_HOME_PATH.exists(), f"Miner1 home path {MINER1_HOME_PATH} does not exist"
|
||||
|
||||
|
||||
def test_fixture_helper_functions():
|
||||
"""Test fixture helper functions work correctly"""
|
||||
# Test get_fixture_home_path
|
||||
client1_path = get_fixture_home_path("client1")
|
||||
miner1_path = get_fixture_home_path("miner1")
|
||||
|
||||
assert client1_path == CLIENT1_HOME_PATH
|
||||
assert miner1_path == MINER1_HOME_PATH
|
||||
|
||||
# Test fixture_home_exists
|
||||
assert fixture_home_exists("client1") is True
|
||||
assert fixture_home_exists("miner1") is True
|
||||
assert fixture_home_exists("nonexistent") is False
|
||||
|
||||
|
||||
def test_fixture_structure():
|
||||
"""Test that fixture directories have the expected structure"""
|
||||
# Check client1 structure
|
||||
client1_aitbc = CLIENT1_HOME_PATH / ".aitbc"
|
||||
assert client1_aitbc.exists(), "Client1 .aitbc directory should exist"
|
||||
|
||||
client1_wallets = client1_aitbc / "wallets"
|
||||
client1_config = client1_aitbc / "config"
|
||||
client1_cache = client1_aitbc / "cache"
|
||||
|
||||
assert client1_wallets.exists(), "Client1 wallets directory should exist"
|
||||
assert client1_config.exists(), "Client1 config directory should exist"
|
||||
assert client1_cache.exists(), "Client1 cache directory should exist"
|
||||
|
||||
# Check miner1 structure
|
||||
miner1_aitbc = MINER1_HOME_PATH / ".aitbc"
|
||||
assert miner1_aitbc.exists(), "Miner1 .aitbc directory should exist"
|
||||
|
||||
miner1_wallets = miner1_aitbc / "wallets"
|
||||
miner1_config = miner1_aitbc / "config"
|
||||
miner1_cache = miner1_aitbc / "cache"
|
||||
|
||||
assert miner1_wallets.exists(), "Miner1 wallets directory should exist"
|
||||
assert miner1_config.exists(), "Miner1 config directory should exist"
|
||||
assert miner1_cache.exists(), "Miner1 cache directory should exist"
|
||||
|
||||
|
||||
def test_fixture_config_files():
|
||||
"""Test that fixture config files exist and are readable"""
|
||||
import yaml
|
||||
|
||||
# Check client1 config
|
||||
client1_config_file = CLIENT1_HOME_PATH / ".aitbc" / "config.yaml"
|
||||
assert client1_config_file.exists(), "Client1 config.yaml should exist"
|
||||
|
||||
with open(client1_config_file, 'r') as f:
|
||||
client1_config = yaml.safe_load(f)
|
||||
|
||||
assert "agent" in client1_config, "Client1 config should have agent section"
|
||||
assert client1_config["agent"]["name"] == "client1", "Client1 config should have correct name"
|
||||
|
||||
# Check miner1 config
|
||||
miner1_config_file = MINER1_HOME_PATH / ".aitbc" / "config.yaml"
|
||||
assert miner1_config_file.exists(), "Miner1 config.yaml should exist"
|
||||
|
||||
with open(miner1_config_file, 'r') as f:
|
||||
miner1_config = yaml.safe_load(f)
|
||||
|
||||
assert "agent" in miner1_config, "Miner1 config should have agent section"
|
||||
assert miner1_config["agent"]["name"] == "miner1", "Miner1 config should have correct name"
|
||||
|
||||
|
||||
def test_fixture_wallet_files():
|
||||
"""Test that fixture wallet files exist and have correct structure"""
|
||||
import json
|
||||
|
||||
# Check client1 wallet
|
||||
client1_wallet_file = CLIENT1_HOME_PATH / ".aitbc" / "wallets" / "client1_wallet.json"
|
||||
assert client1_wallet_file.exists(), "Client1 wallet file should exist"
|
||||
|
||||
with open(client1_wallet_file, 'r') as f:
|
||||
client1_wallet = json.load(f)
|
||||
|
||||
assert "address" in client1_wallet, "Client1 wallet should have address"
|
||||
assert "balance" in client1_wallet, "Client1 wallet should have balance"
|
||||
assert "transactions" in client1_wallet, "Client1 wallet should have transactions list"
|
||||
assert client1_wallet["address"] == "aitbc1client1", "Client1 wallet should have correct address"
|
||||
|
||||
# Check miner1 wallet
|
||||
miner1_wallet_file = MINER1_HOME_PATH / ".aitbc" / "wallets" / "miner1_wallet.json"
|
||||
assert miner1_wallet_file.exists(), "Miner1 wallet file should exist"
|
||||
|
||||
with open(miner1_wallet_file, 'r') as f:
|
||||
miner1_wallet = json.load(f)
|
||||
|
||||
assert "address" in miner1_wallet, "Miner1 wallet should have address"
|
||||
assert "balance" in miner1_wallet, "Miner1 wallet should have balance"
|
||||
assert "transactions" in miner1_wallet, "Miner1 wallet should have transactions list"
|
||||
assert miner1_wallet["address"] == "aitbc1miner1", "Miner1 wallet should have correct address"
|
||||
|
||||
|
||||
def test_fixture_import():
|
||||
"""Test that fixtures can be imported correctly"""
|
||||
from tests.e2e.fixtures import (
|
||||
HomeDirManager,
|
||||
create_test_wallet,
|
||||
setup_fixture_homes
|
||||
)
|
||||
|
||||
# Test that classes are importable
|
||||
assert HomeDirManager is not None, "HomeDirManager should be importable"
|
||||
|
||||
# Test that functions are importable
|
||||
assert callable(create_test_wallet), "create_test_wallet should be callable"
|
||||
assert callable(setup_fixture_homes), "setup_fixture_homes should be callable"
|
||||
|
||||
# Test create_test_wallet function
|
||||
test_wallet = create_test_wallet("test_agent", "aitbc1test", 500)
|
||||
|
||||
expected_keys = {"address", "balance", "transactions", "created_at", "agent_name"}
|
||||
assert set(test_wallet.keys()) == expected_keys, "Test wallet should have all expected keys"
|
||||
assert test_wallet["address"] == "aitbc1test", "Test wallet should have correct address"
|
||||
assert test_wallet["balance"] == 500, "Test wallet should have correct balance"
|
||||
assert test_wallet["agent_name"] == "test_agent", "Test wallet should have correct agent name"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
362
tests/integration/test_api_integration.py
Normal file
362
tests/integration/test_api_integration.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Integration Tests for AITBC API Components
|
||||
Tests interaction between different API services
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from click.testing import CliRunner
|
||||
|
||||
|
||||
class TestCoordinatorAPIIntegration:
|
||||
"""Test coordinator API integration"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_coordinator_client(self):
|
||||
"""Mock coordinator API client"""
|
||||
client = Mock()
|
||||
|
||||
# Mock health check
|
||||
client.health_check.return_value = {
|
||||
'status': 'healthy',
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'services': {
|
||||
'job_manager': 'running',
|
||||
'marketplace': 'running',
|
||||
'blockchain': 'running'
|
||||
}
|
||||
}
|
||||
|
||||
# Mock job submission
|
||||
client.submit_job.return_value = {
|
||||
'job_id': 'test-job-123',
|
||||
'status': 'submitted',
|
||||
'estimated_completion': '2024-01-01T12:00:00Z'
|
||||
}
|
||||
|
||||
# Mock job status
|
||||
client.get_job_status.return_value = {
|
||||
'job_id': 'test-job-123',
|
||||
'status': 'running',
|
||||
'progress': 45,
|
||||
'started_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return client
|
||||
|
||||
def test_health_check_integration(self, mock_coordinator_client):
|
||||
"""Test health check integration"""
|
||||
response = mock_coordinator_client.health_check()
|
||||
|
||||
assert response['status'] == 'healthy'
|
||||
assert 'timestamp' in response
|
||||
assert 'services' in response
|
||||
assert all(service in ['running', 'stopped', 'error']
|
||||
for service in response['services'].values())
|
||||
|
||||
def test_job_submission_workflow(self, mock_coordinator_client):
|
||||
"""Test complete job submission workflow"""
|
||||
job_data = {
|
||||
'type': 'ml_inference',
|
||||
'model': 'resnet50',
|
||||
'input_data': 's3://test-data/input.json',
|
||||
'requirements': {
|
||||
'gpu_type': 'RTX 3080',
|
||||
'memory_gb': 8,
|
||||
'duration_minutes': 30
|
||||
}
|
||||
}
|
||||
|
||||
# Submit job
|
||||
response = mock_coordinator_client.submit_job(job_data)
|
||||
|
||||
assert 'job_id' in response
|
||||
assert response['status'] == 'submitted'
|
||||
assert response['job_id'].startswith('test-job-')
|
||||
|
||||
# Check job status
|
||||
status_response = mock_coordinator_client.get_job_status(response['job_id'])
|
||||
|
||||
assert status_response['job_id'] == response['job_id']
|
||||
assert status_response['status'] in ['submitted', 'running', 'completed', 'failed']
|
||||
assert 'progress' in status_response
|
||||
|
||||
def test_marketplace_integration(self, mock_coordinator_client):
|
||||
"""Test marketplace API integration"""
|
||||
# Mock marketplace responses
|
||||
mock_coordinator_client.list_offers.return_value = {
|
||||
'offers': [
|
||||
{
|
||||
'id': 'offer-1',
|
||||
'provider': 'miner-1',
|
||||
'gpu_type': 'RTX 3080',
|
||||
'price_per_hour': 0.1,
|
||||
'available': True
|
||||
},
|
||||
{
|
||||
'id': 'offer-2',
|
||||
'provider': 'miner-2',
|
||||
'gpu_type': 'RTX 3090',
|
||||
'price_per_hour': 0.15,
|
||||
'available': True
|
||||
}
|
||||
],
|
||||
'total_count': 2
|
||||
}
|
||||
|
||||
# Get marketplace offers
|
||||
offers_response = mock_coordinator_client.list_offers()
|
||||
|
||||
assert 'offers' in offers_response
|
||||
assert 'total_count' in offers_response
|
||||
assert len(offers_response['offers']) == 2
|
||||
assert all('gpu_type' in offer for offer in offers_response['offers'])
|
||||
assert all('price_per_hour' in offer for offer in offers_response['offers'])
|
||||
|
||||
|
||||
class TestBlockchainIntegration:
|
||||
"""Test blockchain integration"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_blockchain_client(self):
|
||||
"""Mock blockchain client"""
|
||||
client = Mock()
|
||||
|
||||
# Mock blockchain info
|
||||
client.get_chain_info.return_value = {
|
||||
'chain_id': 'aitbc-mainnet',
|
||||
'block_height': 12345,
|
||||
'latest_block_hash': '0xabc123...',
|
||||
'network_status': 'active'
|
||||
}
|
||||
|
||||
# Mock transaction creation
|
||||
client.create_transaction.return_value = {
|
||||
'tx_hash': '0xdef456...',
|
||||
'from_address': 'aitbc1sender123',
|
||||
'to_address': 'aitbc1receiver456',
|
||||
'amount': 100.0,
|
||||
'fee': 0.001,
|
||||
'status': 'pending'
|
||||
}
|
||||
|
||||
# Mock wallet balance
|
||||
client.get_balance.return_value = {
|
||||
'address': 'aitbc1test123',
|
||||
'balance': 1500.75,
|
||||
'pending_balance': 25.0,
|
||||
'last_updated': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return client
|
||||
|
||||
def test_blockchain_info_retrieval(self, mock_blockchain_client):
|
||||
"""Test blockchain information retrieval"""
|
||||
info = mock_blockchain_client.get_chain_info()
|
||||
|
||||
assert 'chain_id' in info
|
||||
assert 'block_height' in info
|
||||
assert 'latest_block_hash' in info
|
||||
assert 'network_status' in info
|
||||
assert info['block_height'] > 0
|
||||
assert info['network_status'] == 'active'
|
||||
|
||||
def test_transaction_creation(self, mock_blockchain_client):
|
||||
"""Test transaction creation and validation"""
|
||||
tx_data = {
|
||||
'from_address': 'aitbc1sender123',
|
||||
'to_address': 'aitbc1receiver456',
|
||||
'amount': 100.0,
|
||||
'private_key': 'test_private_key'
|
||||
}
|
||||
|
||||
tx_result = mock_blockchain_client.create_transaction(tx_data)
|
||||
|
||||
assert 'tx_hash' in tx_result
|
||||
assert tx_result['tx_hash'].startswith('0x')
|
||||
assert tx_result['from_address'] == tx_data['from_address']
|
||||
assert tx_result['to_address'] == tx_data['to_address']
|
||||
assert tx_result['amount'] == tx_data['amount']
|
||||
assert tx_result['status'] == 'pending'
|
||||
|
||||
def test_wallet_balance_check(self, mock_blockchain_client):
|
||||
"""Test wallet balance checking"""
|
||||
address = 'aitbc1test123'
|
||||
balance_info = mock_blockchain_client.get_balance(address)
|
||||
|
||||
assert 'address' in balance_info
|
||||
assert 'balance' in balance_info
|
||||
assert 'pending_balance' in balance_info
|
||||
assert 'last_updated' in balance_info
|
||||
assert balance_info['address'] == address
|
||||
assert isinstance(balance_info['balance'], (int, float))
|
||||
assert isinstance(balance_info['pending_balance'], (int, float))
|
||||
|
||||
|
||||
class TestCLIIntegration:
|
||||
"""Test CLI integration with APIs"""
|
||||
|
||||
def test_cli_config_integration(self):
|
||||
"""Test CLI configuration integration"""
|
||||
runner = CliRunner()
|
||||
|
||||
# Test config show command
|
||||
result = runner.invoke(cli, ['config-show'])
|
||||
assert result.exit_code == 0
|
||||
assert 'coordinator_url' in result.output.lower() or 'api' in result.output.lower()
|
||||
|
||||
def test_cli_wallet_integration(self):
|
||||
"""Test CLI wallet integration"""
|
||||
runner = CliRunner()
|
||||
|
||||
# Test wallet help
|
||||
result = runner.invoke(cli, ['wallet', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'wallet' in result.output.lower()
|
||||
|
||||
def test_cli_marketplace_integration(self):
|
||||
"""Test CLI marketplace integration"""
|
||||
runner = CliRunner()
|
||||
|
||||
# Test marketplace help
|
||||
result = runner.invoke(cli, ['marketplace', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'marketplace' in result.output.lower()
|
||||
|
||||
|
||||
class TestDataFlowIntegration:
|
||||
"""Test data flow between components"""
|
||||
|
||||
def test_job_to_blockchain_flow(self):
|
||||
"""Test data flow from job submission to blockchain recording"""
|
||||
# Simulate job submission
|
||||
job_data = {
|
||||
'id': 'job-123',
|
||||
'type': 'ml_inference',
|
||||
'provider': 'miner-456',
|
||||
'cost': 10.0,
|
||||
'status': 'completed'
|
||||
}
|
||||
|
||||
# Simulate blockchain transaction
|
||||
tx_data = {
|
||||
'job_id': job_data['id'],
|
||||
'amount': job_data['cost'],
|
||||
'from': 'client_wallet',
|
||||
'to': 'miner_wallet',
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Validate data flow
|
||||
assert tx_data['job_id'] == job_data['id']
|
||||
assert tx_data['amount'] == job_data['cost']
|
||||
assert 'timestamp' in tx_data
|
||||
|
||||
def test_marketplace_to_job_flow(self):
|
||||
"""Test data flow from marketplace selection to job execution"""
|
||||
# Simulate marketplace offer selection
|
||||
offer = {
|
||||
'id': 'offer-789',
|
||||
'provider': 'miner-456',
|
||||
'gpu_type': 'RTX 3080',
|
||||
'price_per_hour': 0.1
|
||||
}
|
||||
|
||||
# Simulate job creation based on offer
|
||||
job = {
|
||||
'id': 'job-456',
|
||||
'type': 'ml_training',
|
||||
'assigned_provider': offer['provider'],
|
||||
'gpu_requirements': offer['gpu_type'],
|
||||
'cost_per_hour': offer['price_per_hour'],
|
||||
'status': 'assigned'
|
||||
}
|
||||
|
||||
# Validate data flow
|
||||
assert job['assigned_provider'] == offer['provider']
|
||||
assert job['gpu_requirements'] == offer['gpu_type']
|
||||
assert job['cost_per_hour'] == offer['price_per_hour']
|
||||
|
||||
def test_wallet_transaction_flow(self):
|
||||
"""Test wallet transaction data flow"""
|
||||
# Simulate wallet balance before
|
||||
initial_balance = 1000.0
|
||||
|
||||
# Simulate transaction
|
||||
transaction = {
|
||||
'type': 'payment',
|
||||
'amount': 50.0,
|
||||
'from_wallet': 'client_wallet',
|
||||
'to_wallet': 'miner_wallet',
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Calculate new balance
|
||||
new_balance = initial_balance - transaction['amount']
|
||||
|
||||
# Validate transaction flow
|
||||
assert transaction['amount'] > 0
|
||||
assert new_balance == initial_balance - transaction['amount']
|
||||
assert new_balance < initial_balance
|
||||
|
||||
|
||||
class TestErrorHandlingIntegration:
|
||||
"""Test error handling across integrated components"""
|
||||
|
||||
def test_api_error_propagation(self):
|
||||
"""Test error propagation through API calls"""
|
||||
# Mock API client that raises errors
|
||||
client = Mock()
|
||||
client.submit_job.side_effect = Exception("API unavailable")
|
||||
|
||||
# Test error handling
|
||||
with pytest.raises(Exception, match="API unavailable"):
|
||||
client.submit_job({"type": "test_job"})
|
||||
|
||||
def test_fallback_mechanisms(self):
|
||||
"""Test fallback mechanisms for integrated services"""
|
||||
# Mock primary service failure
|
||||
primary_client = Mock()
|
||||
primary_client.get_balance.side_effect = Exception("Primary service down")
|
||||
|
||||
# Mock fallback service
|
||||
fallback_client = Mock()
|
||||
fallback_client.get_balance.return_value = {
|
||||
'address': 'aitbc1test',
|
||||
'balance': 1000.0
|
||||
}
|
||||
|
||||
# Test fallback logic
|
||||
try:
|
||||
balance = primary_client.get_balance('aitbc1test')
|
||||
except Exception:
|
||||
balance = fallback_client.get_balance('aitbc1test')
|
||||
|
||||
assert balance['balance'] == 1000.0
|
||||
|
||||
def test_data_validation_integration(self):
|
||||
"""Test data validation across component boundaries"""
|
||||
# Test invalid job data
|
||||
invalid_job = {
|
||||
'type': 'invalid_type',
|
||||
'requirements': {}
|
||||
}
|
||||
|
||||
# Test validation at different stages
|
||||
valid_job_types = ['ml_training', 'ml_inference', 'data_processing']
|
||||
|
||||
assert invalid_job['type'] not in valid_job_types
|
||||
|
||||
# Test validation function
|
||||
def validate_job(job_data):
|
||||
if job_data.get('type') not in valid_job_types:
|
||||
raise ValueError("Invalid job type")
|
||||
if not job_data.get('requirements'):
|
||||
raise ValueError("Requirements missing")
|
||||
return True
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid job type"):
|
||||
validate_job(invalid_job)
|
||||
512
tests/performance/test_performance_benchmarks.py
Normal file
512
tests/performance/test_performance_benchmarks.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""
|
||||
Performance Benchmark Tests for AITBC
|
||||
Tests system performance under various loads and conditions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import statistics
|
||||
|
||||
|
||||
class TestAPIPerformance:
|
||||
"""Test API endpoint performance"""
|
||||
|
||||
def test_response_time_benchmarks(self):
|
||||
"""Test API response time benchmarks"""
|
||||
# Mock API client
|
||||
client = Mock()
|
||||
|
||||
# Simulate different response times
|
||||
response_times = [0.05, 0.08, 0.12, 0.06, 0.09, 0.11, 0.07, 0.10]
|
||||
|
||||
# Calculate performance metrics
|
||||
avg_response_time = statistics.mean(response_times)
|
||||
max_response_time = max(response_times)
|
||||
min_response_time = min(response_times)
|
||||
|
||||
# Performance assertions
|
||||
assert avg_response_time < 0.1 # Average should be under 100ms
|
||||
assert max_response_time < 0.2 # Max should be under 200ms
|
||||
assert min_response_time > 0.01 # Should be reasonable minimum
|
||||
|
||||
# Test performance thresholds
|
||||
performance_thresholds = {
|
||||
'excellent': 0.05, # < 50ms
|
||||
'good': 0.1, # < 100ms
|
||||
'acceptable': 0.2, # < 200ms
|
||||
'poor': 0.5 # > 500ms
|
||||
}
|
||||
|
||||
# Classify performance
|
||||
if avg_response_time < performance_thresholds['excellent']:
|
||||
performance_rating = 'excellent'
|
||||
elif avg_response_time < performance_thresholds['good']:
|
||||
performance_rating = 'good'
|
||||
elif avg_response_time < performance_thresholds['acceptable']:
|
||||
performance_rating = 'acceptable'
|
||||
else:
|
||||
performance_rating = 'poor'
|
||||
|
||||
assert performance_rating in ['excellent', 'good', 'acceptable']
|
||||
|
||||
def test_concurrent_request_handling(self):
|
||||
"""Test handling of concurrent requests"""
|
||||
# Mock API endpoint
|
||||
def mock_api_call(request_id):
|
||||
time.sleep(0.01) # Simulate 10ms processing time
|
||||
return {'request_id': request_id, 'status': 'success'}
|
||||
|
||||
# Test concurrent execution
|
||||
num_requests = 50
|
||||
start_time = time.time()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [
|
||||
executor.submit(mock_api_call, i)
|
||||
for i in range(num_requests)
|
||||
]
|
||||
results = [future.result() for future in futures]
|
||||
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
|
||||
# Performance assertions
|
||||
assert len(results) == num_requests
|
||||
assert all(result['status'] == 'success' for result in results)
|
||||
assert total_time < 1.0 # Should complete in under 1 second
|
||||
|
||||
# Calculate throughput
|
||||
throughput = num_requests / total_time
|
||||
assert throughput > 50 # Should handle at least 50 requests per second
|
||||
|
||||
def test_memory_usage_under_load(self):
|
||||
"""Test memory usage under load"""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
# Get initial memory usage
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||
|
||||
# Simulate memory-intensive operations
|
||||
data_store = []
|
||||
for i in range(1000):
|
||||
data_store.append({
|
||||
'id': i,
|
||||
'data': 'x' * 1000, # 1KB per item
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
# Get peak memory usage
|
||||
peak_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||
memory_increase = peak_memory - initial_memory
|
||||
|
||||
# Memory assertions
|
||||
assert memory_increase < 100 # Should not increase by more than 100MB
|
||||
assert len(data_store) == 1000
|
||||
|
||||
# Cleanup
|
||||
del data_store
|
||||
|
||||
|
||||
class TestDatabasePerformance:
|
||||
"""Test database operation performance"""
|
||||
|
||||
def test_query_performance(self):
|
||||
"""Test database query performance"""
|
||||
# Mock database operations
|
||||
def mock_query(query_type):
|
||||
if query_type == 'simple':
|
||||
time.sleep(0.001) # 1ms
|
||||
elif query_type == 'complex':
|
||||
time.sleep(0.01) # 10ms
|
||||
elif query_type == 'aggregate':
|
||||
time.sleep(0.05) # 50ms
|
||||
return {'results': ['data'], 'query_type': query_type}
|
||||
|
||||
# Test different query types
|
||||
query_types = ['simple', 'complex', 'aggregate']
|
||||
query_times = {}
|
||||
|
||||
for query_type in query_types:
|
||||
start_time = time.time()
|
||||
result = mock_query(query_type)
|
||||
end_time = time.time()
|
||||
query_times[query_type] = end_time - start_time
|
||||
|
||||
assert result['query_type'] == query_type
|
||||
|
||||
# Performance assertions
|
||||
assert query_times['simple'] < 0.005 # < 5ms
|
||||
assert query_times['complex'] < 0.02 # < 20ms
|
||||
assert query_times['aggregate'] < 0.1 # < 100ms
|
||||
|
||||
def test_batch_operation_performance(self):
|
||||
"""Test batch operation performance"""
|
||||
# Mock batch insert
|
||||
def mock_batch_insert(items):
|
||||
time.sleep(len(items) * 0.001) # 1ms per item
|
||||
return {'inserted_count': len(items)}
|
||||
|
||||
# Test different batch sizes
|
||||
batch_sizes = [10, 50, 100, 500]
|
||||
performance_results = {}
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
items = [{'id': i, 'data': f'item_{i}'} for i in range(batch_size)]
|
||||
|
||||
start_time = time.time()
|
||||
result = mock_batch_insert(items)
|
||||
end_time = time.time()
|
||||
|
||||
performance_results[batch_size] = {
|
||||
'time': end_time - start_time,
|
||||
'throughput': batch_size / (end_time - start_time)
|
||||
}
|
||||
|
||||
assert result['inserted_count'] == batch_size
|
||||
|
||||
# Performance analysis
|
||||
for batch_size, metrics in performance_results.items():
|
||||
assert metrics['throughput'] > 100 # Should handle at least 100 items/second
|
||||
assert metrics['time'] < 5.0 # Should complete in under 5 seconds
|
||||
|
||||
def test_connection_pool_performance(self):
|
||||
"""Test database connection pool performance"""
|
||||
# Mock connection pool
|
||||
class MockConnectionPool:
|
||||
def __init__(self, max_connections=10):
|
||||
self.max_connections = max_connections
|
||||
self.active_connections = 0
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get_connection(self):
|
||||
with self.lock:
|
||||
if self.active_connections < self.max_connections:
|
||||
self.active_connections += 1
|
||||
return MockConnection()
|
||||
else:
|
||||
raise Exception("Connection pool exhausted")
|
||||
|
||||
def release_connection(self, conn):
|
||||
with self.lock:
|
||||
self.active_connections -= 1
|
||||
|
||||
class MockConnection:
|
||||
def execute(self, query):
|
||||
time.sleep(0.01) # 10ms query time
|
||||
return {'result': 'success'}
|
||||
|
||||
# Test connection pool under load
|
||||
pool = MockConnectionPool(max_connections=5)
|
||||
|
||||
def worker_task():
|
||||
try:
|
||||
conn = pool.get_connection()
|
||||
result = conn.execute("SELECT * FROM test")
|
||||
pool.release_connection(conn)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {'error': str(e)}
|
||||
|
||||
# Test concurrent access
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(worker_task) for _ in range(20)]
|
||||
results = [future.result() for future in futures]
|
||||
|
||||
# Analyze results
|
||||
successful_results = [r for r in results if 'error' not in r]
|
||||
error_results = [r for r in results if 'error' in r]
|
||||
|
||||
# Should have some successful and some error results (pool exhaustion)
|
||||
assert len(successful_results) > 0
|
||||
assert len(error_results) > 0
|
||||
assert len(successful_results) + len(error_results) == 20
|
||||
|
||||
|
||||
class TestBlockchainPerformance:
|
||||
"""Test blockchain operation performance"""
|
||||
|
||||
def test_transaction_processing_speed(self):
|
||||
"""Test transaction processing speed"""
|
||||
# Mock transaction processing
|
||||
def mock_process_transaction(tx):
|
||||
processing_time = 0.1 + (len(tx['data']) * 0.001) # Base 100ms + data size
|
||||
time.sleep(processing_time)
|
||||
return {
|
||||
'tx_hash': f'0x{hash(str(tx)) % 1000000:x}',
|
||||
'processing_time': processing_time
|
||||
}
|
||||
|
||||
# Test transactions of different sizes
|
||||
transactions = [
|
||||
{'data': 'small', 'amount': 1.0},
|
||||
{'data': 'x' * 100, 'amount': 10.0}, # 100 bytes
|
||||
{'data': 'x' * 1000, 'amount': 100.0}, # 1KB
|
||||
{'data': 'x' * 10000, 'amount': 1000.0}, # 10KB
|
||||
]
|
||||
|
||||
processing_times = []
|
||||
|
||||
for tx in transactions:
|
||||
start_time = time.time()
|
||||
result = mock_process_transaction(tx)
|
||||
end_time = time.time()
|
||||
|
||||
processing_times.append(result['processing_time'])
|
||||
assert 'tx_hash' in result
|
||||
assert result['processing_time'] > 0
|
||||
|
||||
# Performance assertions
|
||||
assert processing_times[0] < 0.2 # Small transaction < 200ms
|
||||
assert processing_times[-1] < 1.0 # Large transaction < 1 second
|
||||
|
||||
def test_block_validation_performance(self):
|
||||
"""Test block validation performance"""
|
||||
# Mock block validation
|
||||
def mock_validate_block(block):
|
||||
num_transactions = len(block['transactions'])
|
||||
validation_time = num_transactions * 0.01 # 10ms per transaction
|
||||
time.sleep(validation_time)
|
||||
return {
|
||||
'valid': True,
|
||||
'validation_time': validation_time,
|
||||
'transactions_validated': num_transactions
|
||||
}
|
||||
|
||||
# Test blocks with different transaction counts
|
||||
blocks = [
|
||||
{'transactions': [f'tx_{i}' for i in range(10)]}, # 10 transactions
|
||||
{'transactions': [f'tx_{i}' for i in range(50)]}, # 50 transactions
|
||||
{'transactions': [f'tx_{i}' for i in range(100)]}, # 100 transactions
|
||||
]
|
||||
|
||||
validation_results = []
|
||||
|
||||
for block in blocks:
|
||||
start_time = time.time()
|
||||
result = mock_validate_block(block)
|
||||
end_time = time.time()
|
||||
|
||||
validation_results.append(result)
|
||||
assert result['valid'] is True
|
||||
assert result['transactions_validated'] == len(block['transactions'])
|
||||
|
||||
# Performance analysis
|
||||
for i, result in enumerate(validation_results):
|
||||
expected_time = len(blocks[i]['transactions']) * 0.01
|
||||
assert abs(result['validation_time'] - expected_time) < 0.01
|
||||
|
||||
def test_sync_performance(self):
|
||||
"""Test blockchain sync performance"""
|
||||
# Mock blockchain sync
|
||||
def mock_sync_blocks(start_block, end_block):
|
||||
num_blocks = end_block - start_block
|
||||
sync_time = num_blocks * 0.05 # 50ms per block
|
||||
time.sleep(sync_time)
|
||||
return {
|
||||
'synced_blocks': num_blocks,
|
||||
'sync_time': sync_time,
|
||||
'blocks_per_second': num_blocks / sync_time
|
||||
}
|
||||
|
||||
# Test different sync ranges
|
||||
sync_ranges = [
|
||||
(1000, 1010), # 10 blocks
|
||||
(1000, 1050), # 50 blocks
|
||||
(1000, 1100), # 100 blocks
|
||||
]
|
||||
|
||||
sync_results = []
|
||||
|
||||
for start, end in sync_ranges:
|
||||
result = mock_sync_blocks(start, end)
|
||||
sync_results.append(result)
|
||||
|
||||
assert result['synced_blocks'] == (end - start)
|
||||
assert result['blocks_per_second'] > 10 # Should sync at least 10 blocks/second
|
||||
|
||||
# Performance consistency
|
||||
sync_rates = [result['blocks_per_second'] for result in sync_results]
|
||||
avg_sync_rate = statistics.mean(sync_rates)
|
||||
assert avg_sync_rate > 15 # Average should be at least 15 blocks/second
|
||||
|
||||
|
||||
class TestSystemResourcePerformance:
|
||||
"""Test system resource utilization"""
|
||||
|
||||
def test_cpu_utilization(self):
|
||||
"""Test CPU utilization under load"""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
# Get initial CPU usage
|
||||
initial_cpu = psutil.cpu_percent(interval=0.1)
|
||||
|
||||
# CPU-intensive task
|
||||
def cpu_intensive_task():
|
||||
result = 0
|
||||
for i in range(1000000):
|
||||
result += i * i
|
||||
return result
|
||||
|
||||
# Run CPU-intensive task
|
||||
start_time = time.time()
|
||||
cpu_intensive_task()
|
||||
end_time = time.time()
|
||||
|
||||
# Get CPU usage during task
|
||||
cpu_usage = psutil.cpu_percent(interval=0.1)
|
||||
|
||||
# Performance assertions
|
||||
execution_time = end_time - start_time
|
||||
assert execution_time < 5.0 # Should complete in under 5 seconds
|
||||
assert cpu_usage > 0 # Should show CPU usage
|
||||
|
||||
def test_disk_io_performance(self):
|
||||
"""Test disk I/O performance"""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Test write performance
|
||||
test_data = 'x' * (1024 * 1024) # 1MB of data
|
||||
write_times = []
|
||||
|
||||
for i in range(10):
|
||||
file_path = temp_path / f"test_file_{i}.txt"
|
||||
start_time = time.time()
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(test_data)
|
||||
|
||||
end_time = time.time()
|
||||
write_times.append(end_time - start_time)
|
||||
|
||||
# Test read performance
|
||||
read_times = []
|
||||
|
||||
for i in range(10):
|
||||
file_path = temp_path / f"test_file_{i}.txt"
|
||||
start_time = time.time()
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
data = f.read()
|
||||
|
||||
end_time = time.time()
|
||||
read_times.append(end_time - start_time)
|
||||
assert len(data) == len(test_data)
|
||||
|
||||
# Performance analysis
|
||||
avg_write_time = statistics.mean(write_times)
|
||||
avg_read_time = statistics.mean(read_times)
|
||||
|
||||
assert avg_write_time < 0.1 # Write should be under 100ms per MB
|
||||
assert avg_read_time < 0.05 # Read should be under 50ms per MB
|
||||
|
||||
def test_network_performance(self):
|
||||
"""Test network I/O performance"""
|
||||
# Mock network operations
|
||||
def mock_network_request(size_kb):
|
||||
# Simulate network latency and bandwidth
|
||||
latency = 0.01 # 10ms latency
|
||||
bandwidth_time = size_kb / 1000 # 1MB/s bandwidth
|
||||
total_time = latency + bandwidth_time
|
||||
time.sleep(total_time)
|
||||
return {'size': size_kb, 'time': total_time}
|
||||
|
||||
# Test different request sizes
|
||||
request_sizes = [10, 100, 1000] # KB
|
||||
network_results = []
|
||||
|
||||
for size in request_sizes:
|
||||
result = mock_network_request(size)
|
||||
network_results.append(result)
|
||||
|
||||
assert result['size'] == size
|
||||
assert result['time'] > 0
|
||||
|
||||
# Performance analysis
|
||||
throughputs = [size / result['time'] for size, result in zip(request_sizes, network_results)]
|
||||
avg_throughput = statistics.mean(throughputs)
|
||||
|
||||
assert avg_throughput > 500 # Should achieve at least 500 KB/s
|
||||
|
||||
|
||||
class TestScalabilityMetrics:
|
||||
"""Test system scalability metrics"""
|
||||
|
||||
def test_load_scaling(self):
|
||||
"""Test system behavior under increasing load"""
|
||||
# Mock system under different loads
|
||||
def mock_system_load(load_factor):
|
||||
# Simulate increasing response times with load
|
||||
base_response_time = 0.1
|
||||
load_response_time = base_response_time * (1 + load_factor * 0.1)
|
||||
time.sleep(load_response_time)
|
||||
return {
|
||||
'load_factor': load_factor,
|
||||
'response_time': load_response_time,
|
||||
'throughput': 1 / load_response_time
|
||||
}
|
||||
|
||||
# Test different load factors
|
||||
load_factors = [1, 2, 5, 10] # 1x, 2x, 5x, 10x load
|
||||
scaling_results = []
|
||||
|
||||
for load in load_factors:
|
||||
result = mock_system_load(load)
|
||||
scaling_results.append(result)
|
||||
|
||||
assert result['load_factor'] == load
|
||||
assert result['response_time'] > 0
|
||||
assert result['throughput'] > 0
|
||||
|
||||
# Scalability analysis
|
||||
response_times = [r['response_time'] for r in scaling_results]
|
||||
throughputs = [r['throughput'] for r in scaling_results]
|
||||
|
||||
# Check that response times increase reasonably
|
||||
assert response_times[-1] < response_times[0] * 5 # Should not be 5x slower at 10x load
|
||||
|
||||
# Check that throughput degrades gracefully
|
||||
assert throughputs[-1] > throughputs[0] / 5 # Should maintain at least 20% of peak throughput
|
||||
|
||||
def test_resource_efficiency(self):
|
||||
"""Test resource efficiency metrics"""
|
||||
# Mock resource usage
|
||||
def mock_resource_usage(requests_per_second):
|
||||
# Simulate resource usage scaling
|
||||
cpu_usage = min(90, requests_per_second * 2) # 2% CPU per request/sec
|
||||
memory_usage = min(80, 50 + requests_per_second * 0.5) # Base 50% + 0.5% per request/sec
|
||||
return {
|
||||
'requests_per_second': requests_per_second,
|
||||
'cpu_usage': cpu_usage,
|
||||
'memory_usage': memory_usage,
|
||||
'efficiency': requests_per_second / max(cpu_usage, memory_usage)
|
||||
}
|
||||
|
||||
# Test different request rates
|
||||
request_rates = [10, 25, 50, 100] # requests per second
|
||||
efficiency_results = []
|
||||
|
||||
for rate in request_rates:
|
||||
result = mock_resource_usage(rate)
|
||||
efficiency_results.append(result)
|
||||
|
||||
assert result['requests_per_second'] == rate
|
||||
assert result['cpu_usage'] <= 100
|
||||
assert result['memory_usage'] <= 100
|
||||
|
||||
# Efficiency analysis
|
||||
efficiencies = [r['efficiency'] for r in efficiency_results]
|
||||
max_efficiency = max(efficiencies)
|
||||
|
||||
assert max_efficiency > 1.0 # Should achieve reasonable efficiency
|
||||
@@ -1,9 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# AITBC Test Runner - Updated for AITBC CLI
|
||||
# This script runs all test suites with enhanced CLI testing
|
||||
|
||||
set -e
|
||||
|
||||
echo "🚀 Starting AITBC Test Suite with Enhanced CLI Testing"
|
||||
echo "======
|
||||
440
tests/security/test_security_comprehensive.py
Normal file
440
tests/security/test_security_comprehensive.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Comprehensive Security Tests for AITBC
|
||||
Tests authentication, authorization, encryption, and data protection
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
|
||||
class TestAuthenticationSecurity:
|
||||
"""Test authentication and authorization security"""
|
||||
|
||||
def test_api_key_validation(self):
|
||||
"""Test API key validation and security"""
|
||||
# Generate secure API key
|
||||
api_key = secrets.token_urlsafe(32)
|
||||
|
||||
# Test API key format
|
||||
assert len(api_key) >= 32
|
||||
assert isinstance(api_key, str)
|
||||
|
||||
# Test API key hashing
|
||||
hashed_key = hashlib.sha256(api_key.encode()).hexdigest()
|
||||
assert len(hashed_key) == 64
|
||||
assert hashed_key != api_key # Should be different
|
||||
|
||||
# Test API key validation
|
||||
def validate_api_key(key):
|
||||
if not key or len(key) < 32:
|
||||
return False
|
||||
return True
|
||||
|
||||
assert validate_api_key(api_key) is True
|
||||
assert validate_api_key("short") is False
|
||||
assert validate_api_key("") is False
|
||||
|
||||
def test_token_security(self):
|
||||
"""Test JWT token security"""
|
||||
# Mock JWT token structure
|
||||
token_data = {
|
||||
'sub': 'user123',
|
||||
'iat': int(datetime.utcnow().timestamp()),
|
||||
'exp': int((datetime.utcnow() + timedelta(hours=1)).timestamp()),
|
||||
'permissions': ['read', 'write']
|
||||
}
|
||||
|
||||
# Test token structure
|
||||
assert 'sub' in token_data
|
||||
assert 'iat' in token_data
|
||||
assert 'exp' in token_data
|
||||
assert 'permissions' in token_data
|
||||
assert token_data['exp'] > token_data['iat']
|
||||
|
||||
# Test token expiration
|
||||
current_time = int(datetime.utcnow().timestamp())
|
||||
assert token_data['exp'] > current_time
|
||||
|
||||
# Test permissions
|
||||
assert isinstance(token_data['permissions'], list)
|
||||
assert len(token_data['permissions']) > 0
|
||||
|
||||
def test_session_security(self):
|
||||
"""Test session management security"""
|
||||
# Generate secure session ID
|
||||
session_id = secrets.token_hex(32)
|
||||
|
||||
# Test session ID properties
|
||||
assert len(session_id) == 64
|
||||
assert all(c in '0123456789abcdef' for c in session_id)
|
||||
|
||||
# Test session data
|
||||
session_data = {
|
||||
'session_id': session_id,
|
||||
'user_id': 'user123',
|
||||
'created_at': datetime.utcnow().isoformat(),
|
||||
'last_activity': datetime.utcnow().isoformat(),
|
||||
'ip_address': '192.168.1.1'
|
||||
}
|
||||
|
||||
# Validate session data
|
||||
assert session_data['session_id'] == session_id
|
||||
assert 'user_id' in session_data
|
||||
assert 'created_at' in session_data
|
||||
assert 'last_activity' in session_data
|
||||
|
||||
|
||||
class TestDataEncryption:
|
||||
"""Test data encryption and protection"""
|
||||
|
||||
def test_sensitive_data_encryption(self):
|
||||
"""Test encryption of sensitive data"""
|
||||
# Mock sensitive data
|
||||
sensitive_data = {
|
||||
'private_key': '0x1234567890abcdef',
|
||||
'api_secret': 'secret_key_123',
|
||||
'wallet_seed': 'seed_phrase_words'
|
||||
}
|
||||
|
||||
# Test data masking
|
||||
def mask_sensitive_data(data):
|
||||
masked = {}
|
||||
for key, value in data.items():
|
||||
if 'key' in key.lower() or 'secret' in key.lower() or 'seed' in key.lower():
|
||||
masked[key] = f"***{value[-4:]}" if len(value) > 4 else "***"
|
||||
else:
|
||||
masked[key] = value
|
||||
return masked
|
||||
|
||||
masked_data = mask_sensitive_data(sensitive_data)
|
||||
|
||||
# Verify masking
|
||||
assert masked_data['private_key'].startswith('***')
|
||||
assert masked_data['api_secret'].startswith('***')
|
||||
assert masked_data['wallet_seed'].startswith('***')
|
||||
assert len(masked_data['private_key']) <= 7 # *** + last 4 chars
|
||||
|
||||
def test_data_integrity(self):
|
||||
"""Test data integrity verification"""
|
||||
# Original data
|
||||
original_data = {
|
||||
'transaction_id': 'tx_123',
|
||||
'amount': 100.0,
|
||||
'from_address': 'aitbc1sender',
|
||||
'to_address': 'aitbc1receiver',
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Generate checksum
|
||||
data_string = json.dumps(original_data, sort_keys=True)
|
||||
checksum = hashlib.sha256(data_string.encode()).hexdigest()
|
||||
|
||||
# Verify integrity
|
||||
def verify_integrity(data, expected_checksum):
|
||||
data_string = json.dumps(data, sort_keys=True)
|
||||
calculated_checksum = hashlib.sha256(data_string.encode()).hexdigest()
|
||||
return calculated_checksum == expected_checksum
|
||||
|
||||
assert verify_integrity(original_data, checksum) is True
|
||||
|
||||
# Test with tampered data
|
||||
tampered_data = original_data.copy()
|
||||
tampered_data['amount'] = 200.0
|
||||
|
||||
assert verify_integrity(tampered_data, checksum) is False
|
||||
|
||||
def test_secure_storage(self):
|
||||
"""Test secure data storage practices"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Create sensitive file
|
||||
sensitive_file = temp_path / "sensitive_data.json"
|
||||
sensitive_data = {
|
||||
'api_key': secrets.token_urlsafe(32),
|
||||
'private_key': secrets.token_hex(32),
|
||||
'created_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Write with restricted permissions (simulated)
|
||||
with open(sensitive_file, 'w') as f:
|
||||
json.dump(sensitive_data, f)
|
||||
|
||||
# Verify file exists
|
||||
assert sensitive_file.exists()
|
||||
|
||||
# Test secure reading
|
||||
with open(sensitive_file, 'r') as f:
|
||||
loaded_data = json.load(f)
|
||||
|
||||
assert loaded_data['api_key'] == sensitive_data['api_key']
|
||||
assert loaded_data['private_key'] == sensitive_data['private_key']
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
"""Test input validation and sanitization"""
|
||||
|
||||
def test_sql_injection_prevention(self):
|
||||
"""Test SQL injection prevention"""
|
||||
# Malicious inputs
|
||||
malicious_inputs = [
|
||||
"'; DROP TABLE users; --",
|
||||
"' OR '1'='1",
|
||||
"'; INSERT INTO users VALUES ('hacker'); --",
|
||||
"'; UPDATE users SET password='hacked'; --"
|
||||
]
|
||||
|
||||
# Test input sanitization
|
||||
def sanitize_input(input_str):
|
||||
# Remove dangerous SQL characters
|
||||
dangerous_chars = ["'", ";", "--", "/*", "*/", "xp_", "sp_"]
|
||||
sanitized = input_str
|
||||
for char in dangerous_chars:
|
||||
sanitized = sanitized.replace(char, "")
|
||||
return sanitized.strip()
|
||||
|
||||
for malicious_input in malicious_inputs:
|
||||
sanitized = sanitize_input(malicious_input)
|
||||
# Ensure dangerous characters are removed
|
||||
assert "'" not in sanitized
|
||||
assert ";" not in sanitized
|
||||
assert "--" not in sanitized
|
||||
|
||||
def test_xss_prevention(self):
|
||||
"""Test XSS prevention"""
|
||||
# Malicious XSS inputs
|
||||
xss_inputs = [
|
||||
"<script>alert('xss')</script>",
|
||||
"<img src=x onerror=alert('xss')>",
|
||||
"javascript:alert('xss')",
|
||||
"<svg onload=alert('xss')>"
|
||||
]
|
||||
|
||||
# Test XSS sanitization
|
||||
def sanitize_html(input_str):
|
||||
# Remove HTML tags and dangerous content
|
||||
import re
|
||||
# Remove script tags
|
||||
sanitized = re.sub(r'<script.*?</script>', '', input_str, flags=re.IGNORECASE | re.DOTALL)
|
||||
# Remove all HTML tags
|
||||
sanitized = re.sub(r'<[^>]+>', '', sanitized)
|
||||
# Remove javascript: protocol
|
||||
sanitized = re.sub(r'javascript:', '', sanitized, flags=re.IGNORECASE)
|
||||
return sanitized.strip()
|
||||
|
||||
for xss_input in xss_inputs:
|
||||
sanitized = sanitize_html(xss_input)
|
||||
# Ensure HTML tags are removed
|
||||
assert '<' not in sanitized
|
||||
assert '>' not in sanitized
|
||||
assert 'javascript:' not in sanitized.lower()
|
||||
|
||||
def test_file_upload_security(self):
|
||||
"""Test file upload security"""
|
||||
# Test file type validation
|
||||
allowed_extensions = ['.json', '.csv', '.txt', '.pdf']
|
||||
dangerous_files = [
|
||||
'malware.exe',
|
||||
'script.js',
|
||||
'shell.php',
|
||||
'backdoor.py'
|
||||
]
|
||||
|
||||
def validate_file_extension(filename):
|
||||
file_path = Path(filename)
|
||||
extension = file_path.suffix.lower()
|
||||
return extension in allowed_extensions
|
||||
|
||||
for dangerous_file in dangerous_files:
|
||||
assert validate_file_extension(dangerous_file) is False
|
||||
|
||||
# Test safe files
|
||||
safe_files = ['data.json', 'report.csv', 'document.txt', 'manual.pdf']
|
||||
for safe_file in safe_files:
|
||||
assert validate_file_extension(safe_file) is True
|
||||
|
||||
def test_rate_limiting(self):
|
||||
"""Test rate limiting implementation"""
|
||||
# Mock rate limiter
|
||||
class RateLimiter:
|
||||
def __init__(self, max_requests=100, window_seconds=3600):
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
self.requests = {}
|
||||
|
||||
def is_allowed(self, client_id):
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Clean old requests
|
||||
if client_id in self.requests:
|
||||
self.requests[client_id] = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if (now - req_time).total_seconds() < self.window_seconds
|
||||
]
|
||||
else:
|
||||
self.requests[client_id] = []
|
||||
|
||||
# Check if under limit
|
||||
if len(self.requests[client_id]) < self.max_requests:
|
||||
self.requests[client_id].append(now)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Test rate limiting
|
||||
limiter = RateLimiter(max_requests=5, window_seconds=60)
|
||||
client_id = 'test_client'
|
||||
|
||||
# Should allow first 5 requests
|
||||
for i in range(5):
|
||||
assert limiter.is_allowed(client_id) is True
|
||||
|
||||
# Should deny 6th request
|
||||
assert limiter.is_allowed(client_id) is False
|
||||
|
||||
|
||||
class TestNetworkSecurity:
|
||||
"""Test network security and communication"""
|
||||
|
||||
def test_https_enforcement(self):
|
||||
"""Test HTTPS enforcement"""
|
||||
# Test URL validation
|
||||
secure_urls = [
|
||||
'https://api.aitbc.com',
|
||||
'https://localhost:8000',
|
||||
'https://192.168.1.1:443'
|
||||
]
|
||||
|
||||
insecure_urls = [
|
||||
'http://api.aitbc.com',
|
||||
'ftp://files.aitbc.com',
|
||||
'ws://websocket.aitbc.com'
|
||||
]
|
||||
|
||||
def is_secure_url(url):
|
||||
return url.startswith('https://')
|
||||
|
||||
for secure_url in secure_urls:
|
||||
assert is_secure_url(secure_url) is True
|
||||
|
||||
for insecure_url in insecure_urls:
|
||||
assert is_secure_url(insecure_url) is False
|
||||
|
||||
def test_request_headers_security(self):
|
||||
"""Test secure request headers"""
|
||||
# Secure headers
|
||||
secure_headers = {
|
||||
'Authorization': f'Bearer {secrets.token_urlsafe(32)}',
|
||||
'Content-Type': 'application/json',
|
||||
'X-API-Version': 'v1',
|
||||
'X-Request-ID': secrets.token_hex(16)
|
||||
}
|
||||
|
||||
# Validate headers
|
||||
assert secure_headers['Authorization'].startswith('Bearer ')
|
||||
assert len(secure_headers['Authorization']) > 40 # Bearer + token
|
||||
assert secure_headers['Content-Type'] == 'application/json'
|
||||
assert secure_headers['X-API-Version'] == 'v1'
|
||||
assert len(secure_headers['X-Request-ID']) == 32
|
||||
|
||||
def test_cors_configuration(self):
|
||||
"""Test CORS configuration security"""
|
||||
# Secure CORS configuration
|
||||
cors_config = {
|
||||
'allowed_origins': ['https://app.aitbc.com', 'https://admin.aitbc.com'],
|
||||
'allowed_methods': ['GET', 'POST', 'PUT', 'DELETE'],
|
||||
'allowed_headers': ['Authorization', 'Content-Type'],
|
||||
'max_age': 3600,
|
||||
'allow_credentials': True
|
||||
}
|
||||
|
||||
# Validate CORS configuration
|
||||
assert len(cors_config['allowed_origins']) > 0
|
||||
assert all(origin.startswith('https://') for origin in cors_config['allowed_origins'])
|
||||
assert 'GET' in cors_config['allowed_methods']
|
||||
assert 'POST' in cors_config['allowed_methods']
|
||||
assert 'Authorization' in cors_config['allowed_headers']
|
||||
assert cors_config['max_age'] > 0
|
||||
|
||||
|
||||
class TestAuditLogging:
|
||||
"""Test audit logging and monitoring"""
|
||||
|
||||
def test_security_event_logging(self):
|
||||
"""Test security event logging"""
|
||||
# Security events
|
||||
security_events = [
|
||||
{
|
||||
'event_type': 'login_attempt',
|
||||
'user_id': 'user123',
|
||||
'ip_address': '192.168.1.1',
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'success': True
|
||||
},
|
||||
{
|
||||
'event_type': 'api_access',
|
||||
'user_id': 'user123',
|
||||
'endpoint': '/api/v1/jobs',
|
||||
'method': 'POST',
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'status_code': 200
|
||||
},
|
||||
{
|
||||
'event_type': 'failed_login',
|
||||
'user_id': 'unknown',
|
||||
'ip_address': '192.168.1.100',
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'reason': 'invalid_credentials'
|
||||
}
|
||||
]
|
||||
|
||||
# Validate security events
|
||||
for event in security_events:
|
||||
assert 'event_type' in event
|
||||
assert 'timestamp' in event
|
||||
assert event['timestamp'] != ''
|
||||
assert event['event_type'] in ['login_attempt', 'api_access', 'failed_login']
|
||||
|
||||
def test_log_data_protection(self):
|
||||
"""Test protection of sensitive data in logs"""
|
||||
# Sensitive log data
|
||||
sensitive_log_data = {
|
||||
'user_id': 'user123',
|
||||
'api_key': 'sk-1234567890abcdef',
|
||||
'request_body': '{"password": "secret123"}',
|
||||
'ip_address': '192.168.1.1'
|
||||
}
|
||||
|
||||
# Test log data sanitization
|
||||
def sanitize_log_data(data):
|
||||
sanitized = data.copy()
|
||||
|
||||
# Mask API keys
|
||||
if 'api_key' in sanitized:
|
||||
key = sanitized['api_key']
|
||||
sanitized['api_key'] = f"{key[:7]}***{key[-4:]}" if len(key) > 11 else "***"
|
||||
|
||||
# Remove passwords from request body
|
||||
if 'request_body' in sanitized:
|
||||
try:
|
||||
body = json.loads(sanitized['request_body'])
|
||||
if 'password' in body:
|
||||
body['password'] = '***'
|
||||
sanitized['request_body'] = json.dumps(body)
|
||||
except:
|
||||
pass
|
||||
|
||||
return sanitized
|
||||
|
||||
sanitized_log = sanitize_log_data(sensitive_log_data)
|
||||
|
||||
# Verify sanitization
|
||||
assert '***' in sanitized_log['api_key']
|
||||
assert '***' in sanitized_log['request_body']
|
||||
assert 'secret123' not in sanitized_log['request_body']
|
||||
552
tests/test_agent_wallet_security.py
Normal file
552
tests/test_agent_wallet_security.py
Normal file
@@ -0,0 +1,552 @@
|
||||
"""
|
||||
Tests for AITBC Agent Wallet Security System
|
||||
|
||||
Comprehensive test suite for the guardian contract system that protects
|
||||
autonomous agent wallets from unlimited spending in case of compromise.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
from eth_account import Account
|
||||
from eth_utils import to_checksum_address
|
||||
|
||||
from aitbc_chain.contracts.guardian_contract import (
|
||||
GuardianContract,
|
||||
SpendingLimit,
|
||||
TimeLockConfig,
|
||||
GuardianConfig,
|
||||
create_guardian_contract,
|
||||
CONSERVATIVE_CONFIG,
|
||||
AGGRESSIVE_CONFIG,
|
||||
HIGH_SECURITY_CONFIG
|
||||
)
|
||||
|
||||
from aitbc_chain.contracts.agent_wallet_security import (
|
||||
AgentWalletSecurity,
|
||||
AgentSecurityProfile,
|
||||
register_agent_for_protection,
|
||||
protect_agent_transaction,
|
||||
get_agent_security_summary,
|
||||
generate_security_report,
|
||||
detect_suspicious_activity
|
||||
)
|
||||
|
||||
|
||||
class TestGuardianContract:
|
||||
"""Test the core guardian contract functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config(self):
|
||||
"""Sample guardian configuration for testing"""
|
||||
limits = SpendingLimit(
|
||||
per_transaction=100,
|
||||
per_hour=500,
|
||||
per_day=2000,
|
||||
per_week=10000
|
||||
)
|
||||
|
||||
time_lock = TimeLockConfig(
|
||||
threshold=1000,
|
||||
delay_hours=24,
|
||||
max_delay_hours=168
|
||||
)
|
||||
|
||||
guardians = [to_checksum_address(f"0x{'0'*38}{i:02d}") for i in range(3)]
|
||||
|
||||
return GuardianConfig(
|
||||
limits=limits,
|
||||
time_lock=time_lock,
|
||||
guardians=guardians
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def guardian_contract(self, sample_config):
|
||||
"""Create a guardian contract for testing"""
|
||||
agent_address = to_checksum_address("0x1234567890123456789012345678901234567890")
|
||||
return GuardianContract(agent_address, sample_config)
|
||||
|
||||
def test_spending_limit_enforcement(self, guardian_contract):
|
||||
"""Test that spending limits are properly enforced"""
|
||||
# Test per-transaction limit
|
||||
result = guardian_contract.initiate_transaction(
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=150 # Exceeds per_transaction limit of 100
|
||||
)
|
||||
|
||||
assert result["status"] == "rejected"
|
||||
assert "per-transaction limit" in result["reason"]
|
||||
|
||||
# Test within limits
|
||||
result = guardian_contract.initiate_transaction(
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=50 # Within limits
|
||||
)
|
||||
|
||||
assert result["status"] == "approved"
|
||||
assert "operation_id" in result
|
||||
|
||||
def test_time_lock_functionality(self, guardian_contract):
|
||||
"""Test time lock for large transactions"""
|
||||
# Test time lock threshold
|
||||
result = guardian_contract.initiate_transaction(
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=1500 # Exceeds time lock threshold of 1000
|
||||
)
|
||||
|
||||
assert result["status"] == "time_locked"
|
||||
assert "unlock_time" in result
|
||||
assert result["delay_hours"] == 24
|
||||
|
||||
# Test execution before unlock time
|
||||
operation_id = result["operation_id"]
|
||||
exec_result = guardian_contract.execute_transaction(
|
||||
operation_id=operation_id,
|
||||
signature="mock_signature"
|
||||
)
|
||||
|
||||
assert exec_result["status"] == "error"
|
||||
assert "locked until" in exec_result["reason"]
|
||||
|
||||
def test_hourly_spending_limits(self, guardian_contract):
|
||||
"""Test hourly spending limit enforcement"""
|
||||
# Create multiple transactions within hour limit
|
||||
for i in range(5): # 5 transactions of 100 each = 500 (hourly limit)
|
||||
result = guardian_contract.initiate_transaction(
|
||||
to_address=f"0xabcdef123456789012345678901234567890ab{i:02d}",
|
||||
amount=100
|
||||
)
|
||||
|
||||
if i < 4: # First 4 should be approved
|
||||
assert result["status"] == "approved"
|
||||
# Execute the transaction
|
||||
guardian_contract.execute_transaction(
|
||||
operation_id=result["operation_id"],
|
||||
signature="mock_signature"
|
||||
)
|
||||
else: # 5th should be rejected (exceeds hourly limit)
|
||||
assert result["status"] == "rejected"
|
||||
assert "Hourly spending" in result["reason"]
|
||||
|
||||
def test_emergency_pause(self, guardian_contract):
|
||||
"""Test emergency pause functionality"""
|
||||
guardian_address = guardian_contract.config.guardians[0]
|
||||
|
||||
# Test emergency pause
|
||||
result = guardian_contract.emergency_pause(guardian_address)
|
||||
|
||||
assert result["status"] == "paused"
|
||||
assert result["guardian"] == guardian_address
|
||||
|
||||
# Test that transactions are rejected during pause
|
||||
tx_result = guardian_contract.initiate_transaction(
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=50
|
||||
)
|
||||
|
||||
assert tx_result["status"] == "rejected"
|
||||
assert "paused" in tx_result["reason"]
|
||||
|
||||
def test_unauthorized_operations(self, guardian_contract):
|
||||
"""Test that unauthorized operations are rejected"""
|
||||
unauthorized_address = to_checksum_address("0xunauthorized123456789012345678901234567890")
|
||||
|
||||
# Test unauthorized emergency pause
|
||||
result = guardian_contract.emergency_pause(unauthorized_address)
|
||||
|
||||
assert result["status"] == "rejected"
|
||||
assert "Not authorized" in result["reason"]
|
||||
|
||||
# Test unauthorized limit updates
|
||||
new_limits = SpendingLimit(200, 1000, 4000, 20000)
|
||||
result = guardian_contract.update_limits(new_limits, unauthorized_address)
|
||||
|
||||
assert result["status"] == "rejected"
|
||||
assert "Not authorized" in result["reason"]
|
||||
|
||||
def test_spending_status_tracking(self, guardian_contract):
|
||||
"""Test spending status tracking and reporting"""
|
||||
# Execute some transactions
|
||||
for i in range(3):
|
||||
result = guardian_contract.initiate_transaction(
|
||||
to_address=f"0xabcdef123456789012345678901234567890ab{i:02d}",
|
||||
amount=50
|
||||
)
|
||||
if result["status"] == "approved":
|
||||
guardian_contract.execute_transaction(
|
||||
operation_id=result["operation_id"],
|
||||
signature="mock_signature"
|
||||
)
|
||||
|
||||
status = guardian_contract.get_spending_status()
|
||||
|
||||
assert status["agent_address"] == guardian_contract.agent_address
|
||||
assert status["spent"]["current_hour"] == 150 # 3 * 50
|
||||
assert status["remaining"]["current_hour"] == 350 # 500 - 150
|
||||
assert status["nonce"] == 3
|
||||
|
||||
|
||||
class TestAgentWalletSecurity:
|
||||
"""Test the agent wallet security manager"""
|
||||
|
||||
@pytest.fixture
|
||||
def security_manager(self):
|
||||
"""Create a security manager for testing"""
|
||||
return AgentWalletSecurity()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent(self):
|
||||
"""Sample agent address for testing"""
|
||||
return to_checksum_address("0x1234567890123456789012345678901234567890")
|
||||
|
||||
@pytest.fixture
|
||||
def sample_guardians(self):
|
||||
"""Sample guardian addresses for testing"""
|
||||
return [
|
||||
to_checksum_address(f"0x{'0'*38}{i:02d}")
|
||||
for i in range(1, 4) # Guardians 01, 02, 03
|
||||
]
|
||||
|
||||
def test_agent_registration(self, security_manager, sample_agent, sample_guardians):
|
||||
"""Test agent registration for security protection"""
|
||||
result = security_manager.register_agent(
|
||||
agent_address=sample_agent,
|
||||
security_level="conservative",
|
||||
guardian_addresses=sample_guardians
|
||||
)
|
||||
|
||||
assert result["status"] == "registered"
|
||||
assert result["agent_address"] == sample_agent
|
||||
assert result["security_level"] == "conservative"
|
||||
assert len(result["guardian_addresses"]) == 3
|
||||
assert "limits" in result
|
||||
|
||||
# Verify agent is in registry
|
||||
assert sample_agent in security_manager.agent_profiles
|
||||
assert sample_agent in security_manager.guardian_contracts
|
||||
|
||||
def test_duplicate_registration(self, security_manager, sample_agent, sample_guardians):
|
||||
"""Test that duplicate registrations are rejected"""
|
||||
# Register agent once
|
||||
security_manager.register_agent(sample_agent, "conservative", sample_guardians)
|
||||
|
||||
# Try to register again
|
||||
result = security_manager.register_agent(sample_agent, "aggressive", sample_guardians)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "already registered" in result["reason"]
|
||||
|
||||
def test_transaction_protection(self, security_manager, sample_agent, sample_guardians):
|
||||
"""Test transaction protection for registered agents"""
|
||||
# Register agent
|
||||
security_manager.register_agent(sample_agent, "conservative", sample_guardians)
|
||||
|
||||
# Protect transaction
|
||||
result = security_manager.protect_transaction(
|
||||
agent_address=sample_agent,
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=50
|
||||
)
|
||||
|
||||
assert result["status"] == "approved"
|
||||
assert "operation_id" in result
|
||||
|
||||
# Test transaction exceeding limits
|
||||
result = security_manager.protect_transaction(
|
||||
agent_address=sample_agent,
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=150 # Exceeds conservative per-transaction limit
|
||||
)
|
||||
|
||||
assert result["status"] == "rejected"
|
||||
assert "per-transaction limit" in result["reason"]
|
||||
|
||||
def test_unprotected_agent_transactions(self, security_manager, sample_agent):
|
||||
"""Test transactions from unregistered agents"""
|
||||
result = security_manager.protect_transaction(
|
||||
agent_address=sample_agent,
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=50
|
||||
)
|
||||
|
||||
assert result["status"] == "unprotected"
|
||||
assert "not registered" in result["reason"]
|
||||
|
||||
def test_emergency_pause_integration(self, security_manager, sample_agent, sample_guardians):
|
||||
"""Test emergency pause functionality"""
|
||||
# Register agent
|
||||
security_manager.register_agent(sample_agent, "conservative", sample_guardians)
|
||||
|
||||
# Emergency pause by guardian
|
||||
result = security_manager.emergency_pause_agent(
|
||||
agent_address=sample_agent,
|
||||
guardian_address=sample_guardians[0]
|
||||
)
|
||||
|
||||
assert result["status"] == "paused"
|
||||
|
||||
# Verify transactions are blocked
|
||||
tx_result = security_manager.protect_transaction(
|
||||
agent_address=sample_agent,
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=50
|
||||
)
|
||||
|
||||
assert tx_result["status"] == "unprotected"
|
||||
assert "disabled" in tx_result["reason"]
|
||||
|
||||
def test_security_status_reporting(self, security_manager, sample_agent, sample_guardians):
|
||||
"""Test security status reporting"""
|
||||
# Register agent
|
||||
security_manager.register_agent(sample_agent, "conservative", sample_guardians)
|
||||
|
||||
# Get security status
|
||||
status = security_manager.get_agent_security_status(sample_agent)
|
||||
|
||||
assert status["status"] == "protected"
|
||||
assert status["agent_address"] == sample_agent
|
||||
assert status["security_level"] == "conservative"
|
||||
assert status["enabled"] == True
|
||||
assert len(status["guardian_addresses"]) == 3
|
||||
assert "spending_status" in status
|
||||
assert "pending_operations" in status
|
||||
|
||||
def test_security_level_configurations(self, security_manager, sample_agent, sample_guardians):
|
||||
"""Test different security level configurations"""
|
||||
configurations = [
|
||||
("conservative", CONSERVATIVE_CONFIG),
|
||||
("aggressive", AGGRESSIVE_CONFIG),
|
||||
("high_security", HIGH_SECURITY_CONFIG)
|
||||
]
|
||||
|
||||
for level, config in configurations:
|
||||
# Register with specific security level
|
||||
result = security_manager.register_agent(
|
||||
sample_agent + f"_{level}",
|
||||
level,
|
||||
sample_guardians
|
||||
)
|
||||
|
||||
assert result["status"] == "registered"
|
||||
assert result["security_level"] == level
|
||||
|
||||
# Verify limits match configuration
|
||||
limits = result["limits"]
|
||||
assert limits.per_transaction == config["per_transaction"]
|
||||
assert limits.per_hour == config["per_hour"]
|
||||
assert limits.per_day == config["per_day"]
|
||||
assert limits.per_week == config["per_week"]
|
||||
|
||||
|
||||
class TestSecurityMonitoring:
|
||||
"""Test security monitoring and detection features"""
|
||||
|
||||
@pytest.fixture
|
||||
def security_manager(self):
|
||||
"""Create a security manager with sample data"""
|
||||
manager = AgentWalletSecurity()
|
||||
|
||||
# Register some test agents
|
||||
agents = [
|
||||
("0x1111111111111111111111111111111111111111", "conservative"),
|
||||
("0x2222222222222222222222222222222222222222", "aggressive"),
|
||||
("0x3333333333333333333333333333333333333333", "high_security")
|
||||
]
|
||||
|
||||
guardians = [
|
||||
to_checksum_address(f"0x{'0'*38}{i:02d}")
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
for agent_addr, level in agents:
|
||||
manager.register_agent(agent_addr, level, guardians)
|
||||
|
||||
return manager
|
||||
|
||||
def test_security_report_generation(self, security_manager):
|
||||
"""Test comprehensive security report generation"""
|
||||
report = generate_security_report()
|
||||
|
||||
assert "generated_at" in report
|
||||
assert "summary" in report
|
||||
assert "agents" in report
|
||||
assert "recent_security_events" in report
|
||||
assert "security_levels" in report
|
||||
|
||||
summary = report["summary"]
|
||||
assert "total_protected_agents" in summary
|
||||
assert "active_agents" in summary
|
||||
assert "protection_coverage" in summary
|
||||
|
||||
# Verify all security levels are represented
|
||||
levels = report["security_levels"]
|
||||
assert "conservative" in levels
|
||||
assert "aggressive" in levels
|
||||
assert "high_security" in levels
|
||||
|
||||
def test_suspicious_activity_detection(self, security_manager):
|
||||
"""Test suspicious activity detection"""
|
||||
agent_addr = "0x1111111111111111111111111111111111111111"
|
||||
|
||||
# Test normal activity
|
||||
result = detect_suspicious_activity(agent_addr, hours=24)
|
||||
assert result["status"] == "analyzed"
|
||||
assert result["suspicious_activity"] == False
|
||||
|
||||
# Simulate high activity by creating many transactions
|
||||
# (This would require more complex setup in a real test)
|
||||
|
||||
def test_protected_agents_listing(self, security_manager):
|
||||
"""Test listing of protected agents"""
|
||||
agents = security_manager.list_protected_agents()
|
||||
|
||||
assert len(agents) == 3
|
||||
|
||||
for agent in agents:
|
||||
assert "agent_address" in agent
|
||||
assert "security_level" in agent
|
||||
assert "enabled" in agent
|
||||
assert "guardian_count" in agent
|
||||
assert "pending_operations" in agent
|
||||
assert "paused" in agent
|
||||
assert "emergency_mode" in agent
|
||||
assert "registered_at" in agent
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Test convenience functions for common operations"""
|
||||
|
||||
def test_register_agent_for_protection(self):
|
||||
"""Test the convenience registration function"""
|
||||
agent_addr = to_checksum_address("0x1234567890123456789012345678901234567890")
|
||||
guardians = [
|
||||
to_checksum_address(f"0x{'0'*38}{i:02d}")
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
result = register_agent_for_protection(
|
||||
agent_address=agent_addr,
|
||||
security_level="conservative",
|
||||
guardians=guardians
|
||||
)
|
||||
|
||||
assert result["status"] == "registered"
|
||||
assert result["agent_address"] == agent_addr
|
||||
assert result["security_level"] == "conservative"
|
||||
|
||||
def test_protect_agent_transaction(self):
|
||||
"""Test the convenience transaction protection function"""
|
||||
agent_addr = to_checksum_address("0x1234567890123456789012345678901234567890")
|
||||
guardians = [
|
||||
to_checksum_address(f"0x{'0'*38}{i:02d}")
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
# Register first
|
||||
register_agent_for_protection(agent_addr, "conservative", guardians)
|
||||
|
||||
# Protect transaction
|
||||
result = protect_agent_transaction(
|
||||
agent_address=agent_addr,
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=50
|
||||
)
|
||||
|
||||
assert result["status"] == "approved"
|
||||
assert "operation_id" in result
|
||||
|
||||
def test_get_agent_security_summary(self):
|
||||
"""Test the convenience security summary function"""
|
||||
agent_addr = to_checksum_address("0x1234567890123456789012345678901234567890")
|
||||
guardians = [
|
||||
to_checksum_address(f"0x{'0'*38}{i:02d}")
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
# Register first
|
||||
register_agent_for_protection(agent_addr, "conservative", guardians)
|
||||
|
||||
# Get summary
|
||||
summary = get_agent_security_summary(agent_addr)
|
||||
|
||||
assert summary["status"] == "protected"
|
||||
assert summary["agent_address"] == agent_addr
|
||||
assert summary["security_level"] == "conservative"
|
||||
assert "spending_status" in summary
|
||||
|
||||
|
||||
class TestSecurityEdgeCases:
|
||||
"""Test edge cases and error conditions"""
|
||||
|
||||
def test_invalid_address_handling(self):
|
||||
"""Test handling of invalid addresses"""
|
||||
manager = AgentWalletSecurity()
|
||||
|
||||
# Test invalid agent address
|
||||
result = manager.register_agent("invalid_address", "conservative")
|
||||
assert result["status"] == "error"
|
||||
|
||||
# Test invalid guardian address
|
||||
result = manager.register_agent(
|
||||
"0x1234567890123456789012345678901234567890",
|
||||
"conservative",
|
||||
["invalid_guardian"]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_invalid_security_level(self):
|
||||
"""Test handling of invalid security levels"""
|
||||
manager = AgentWalletSecurity()
|
||||
agent_addr = to_checksum_address("0x1234567890123456789012345678901234567890")
|
||||
|
||||
result = manager.register_agent(agent_addr, "invalid_level")
|
||||
assert result["status"] == "error"
|
||||
assert "Invalid security level" in result["reason"]
|
||||
|
||||
def test_zero_amount_transactions(self):
|
||||
"""Test handling of zero amount transactions"""
|
||||
manager = AgentWalletSecurity()
|
||||
agent_addr = to_checksum_address("0x1234567890123456789012345678901234567890")
|
||||
guardians = [
|
||||
to_checksum_address(f"0x{'0'*38}{i:02d}")
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
# Register agent
|
||||
manager.register_agent(agent_addr, "conservative", guardians)
|
||||
|
||||
# Test zero amount transaction
|
||||
result = manager.protect_transaction(
|
||||
agent_address=agent_addr,
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=0
|
||||
)
|
||||
|
||||
# Zero amount should be allowed (no spending)
|
||||
assert result["status"] == "approved"
|
||||
|
||||
def test_negative_amount_transactions(self):
|
||||
"""Test handling of negative amount transactions"""
|
||||
manager = AgentWalletSecurity()
|
||||
agent_addr = to_checksum_address("0x1234567890123456789012345678901234567890")
|
||||
guardians = [
|
||||
to_checksum_address(f"0x{'0'*38}{i:02d}")
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
# Register agent
|
||||
manager.register_agent(agent_addr, "conservative", guardians)
|
||||
|
||||
# Test negative amount transaction
|
||||
result = manager.protect_transaction(
|
||||
agent_address=agent_addr,
|
||||
to_address="0xabcdef123456789012345678901234567890abcd",
|
||||
amount=-100
|
||||
)
|
||||
|
||||
# Negative amounts should be rejected
|
||||
assert result["status"] == "rejected"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
477
tests/test_cli_translation_security.py
Normal file
477
tests/test_cli_translation_security.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Tests for CLI Translation Security Policy
|
||||
|
||||
Comprehensive test suite for translation security controls,
|
||||
ensuring security-sensitive operations are properly protected.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
|
||||
from aitbc_cli.security.translation_policy import (
|
||||
CLITranslationSecurityManager,
|
||||
SecurityLevel,
|
||||
TranslationMode,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
cli_translation_security,
|
||||
configure_translation_security,
|
||||
get_translation_security_report
|
||||
)
|
||||
|
||||
|
||||
class TestCLITranslationSecurityManager:
|
||||
"""Test the CLI translation security manager"""
|
||||
|
||||
@pytest.fixture
|
||||
def security_manager(self):
|
||||
"""Create a security manager for testing"""
|
||||
return CLITranslationSecurityManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_critical_command_translation_disabled(self, security_manager):
|
||||
"""Test that critical commands have translation disabled"""
|
||||
request = TranslationRequest(
|
||||
text="Transfer 100 AITBC to wallet",
|
||||
target_language="es",
|
||||
command_name="transfer",
|
||||
security_level=SecurityLevel.CRITICAL
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
assert response.translated_text == request.text # Original text returned
|
||||
assert response.method_used == "disabled"
|
||||
assert response.security_compliant is True
|
||||
assert "Translation disabled for security-sensitive operation" in response.warning_messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_security_local_only(self, security_manager):
|
||||
"""Test that high security commands use local translation only"""
|
||||
request = TranslationRequest(
|
||||
text="Node configuration updated",
|
||||
target_language="es",
|
||||
command_name="config",
|
||||
security_level=SecurityLevel.HIGH,
|
||||
user_consent=True # Provide consent for high security
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
assert response.method_used == "local"
|
||||
assert response.security_compliant is True
|
||||
assert not response.fallback_used
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_medium_security_fallback_mode(self, security_manager):
|
||||
"""Test that medium security commands use fallback mode"""
|
||||
request = TranslationRequest(
|
||||
text="Current balance: 1000 AITBC",
|
||||
target_language="fr",
|
||||
command_name="balance",
|
||||
security_level=SecurityLevel.MEDIUM
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
assert response.method_used == "external_fallback"
|
||||
assert response.security_compliant is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_security_full_translation(self, security_manager):
|
||||
"""Test that low security commands have full translation"""
|
||||
request = TranslationRequest(
|
||||
text="Help information",
|
||||
target_language="de",
|
||||
command_name="help",
|
||||
security_level=SecurityLevel.LOW
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
assert response.method_used == "external"
|
||||
assert response.security_compliant is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_consent_requirement(self, security_manager):
|
||||
"""Test user consent requirement for high security operations"""
|
||||
request = TranslationRequest(
|
||||
text="Deploy to production",
|
||||
target_language="es",
|
||||
command_name="deploy",
|
||||
security_level=SecurityLevel.HIGH,
|
||||
user_consent=False
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
assert response.translated_text == request.text
|
||||
assert response.method_used == "consent_required"
|
||||
assert "User consent required for translation" in response.warning_messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_external_api_failure_fallback(self, security_manager):
|
||||
"""Test fallback when external API fails"""
|
||||
request = TranslationRequest(
|
||||
text="Status check",
|
||||
target_language="fr",
|
||||
command_name="status",
|
||||
security_level=SecurityLevel.MEDIUM
|
||||
)
|
||||
|
||||
# Mock external translation to fail
|
||||
with patch.object(security_manager, '_external_translate', side_effect=Exception("API Error")):
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
assert response.fallback_used is True # Fallback was used
|
||||
# Successful fallback doesn't add warning messages
|
||||
|
||||
def test_command_security_level_classification(self, security_manager):
|
||||
"""Test command security level classification"""
|
||||
# Critical commands
|
||||
assert security_manager.get_command_security_level("agent") == SecurityLevel.CRITICAL
|
||||
assert security_manager.get_command_security_level("wallet") == SecurityLevel.CRITICAL
|
||||
assert security_manager.get_command_security_level("sign") == SecurityLevel.CRITICAL
|
||||
|
||||
# High commands
|
||||
assert security_manager.get_command_security_level("config") == SecurityLevel.HIGH
|
||||
assert security_manager.get_command_security_level("node") == SecurityLevel.HIGH
|
||||
assert security_manager.get_command_security_level("marketplace") == SecurityLevel.HIGH
|
||||
|
||||
# Medium commands
|
||||
assert security_manager.get_command_security_level("balance") == SecurityLevel.MEDIUM
|
||||
assert security_manager.get_command_security_level("status") == SecurityLevel.MEDIUM
|
||||
assert security_manager.get_command_security_level("monitor") == SecurityLevel.MEDIUM
|
||||
|
||||
# Low commands
|
||||
assert security_manager.get_command_security_level("help") == SecurityLevel.LOW
|
||||
assert security_manager.get_command_security_level("version") == SecurityLevel.LOW
|
||||
assert security_manager.get_command_security_level("info") == SecurityLevel.LOW
|
||||
|
||||
def test_unknown_command_default_security(self, security_manager):
|
||||
"""Test that unknown commands default to medium security"""
|
||||
assert security_manager.get_command_security_level("unknown_command") == SecurityLevel.MEDIUM
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_translation_functionality(self, security_manager):
|
||||
"""Test local translation functionality"""
|
||||
request = TranslationRequest(
|
||||
text="help error success",
|
||||
target_language="es",
|
||||
security_level=SecurityLevel.HIGH,
|
||||
user_consent=True # Provide consent for high security
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
assert "ayuda" in response.translated_text # "help" translated
|
||||
assert "error" in response.translated_text # "error" translated
|
||||
assert "éxito" in response.translated_text # "success" translated
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_security_logging(self, security_manager):
|
||||
"""Test that security checks are logged"""
|
||||
request = TranslationRequest(
|
||||
text="Test message",
|
||||
target_language="fr",
|
||||
command_name="test",
|
||||
security_level=SecurityLevel.MEDIUM
|
||||
)
|
||||
|
||||
initial_log_count = len(security_manager.security_log)
|
||||
|
||||
await security_manager.translate_with_security(request)
|
||||
|
||||
assert len(security_manager.security_log) == initial_log_count + 1
|
||||
|
||||
log_entry = security_manager.security_log[-1]
|
||||
assert log_entry["command"] == "test"
|
||||
assert log_entry["security_level"] == "medium"
|
||||
assert log_entry["target_language"] == "fr"
|
||||
assert log_entry["text_length"] == len("Test message")
|
||||
|
||||
def test_security_summary_generation(self, security_manager):
|
||||
"""Test security summary generation"""
|
||||
# Add some log entries
|
||||
security_manager.security_log = [
|
||||
{
|
||||
"timestamp": 1.0,
|
||||
"command": "help",
|
||||
"security_level": "low",
|
||||
"target_language": "es",
|
||||
"user_consent": False,
|
||||
"text_length": 10
|
||||
},
|
||||
{
|
||||
"timestamp": 2.0,
|
||||
"command": "balance",
|
||||
"security_level": "medium",
|
||||
"target_language": "fr",
|
||||
"user_consent": False,
|
||||
"text_length": 15
|
||||
}
|
||||
]
|
||||
|
||||
summary = security_manager.get_security_summary()
|
||||
|
||||
assert summary["total_checks"] == 2
|
||||
assert summary["by_security_level"]["low"] == 1
|
||||
assert summary["by_security_level"]["medium"] == 1
|
||||
assert summary["by_target_language"]["es"] == 1
|
||||
assert summary["by_target_language"]["fr"] == 1
|
||||
assert len(summary["recent_checks"]) == 2
|
||||
|
||||
def test_translation_allowed_check(self, security_manager):
|
||||
"""Test translation permission check"""
|
||||
# Critical commands - not allowed
|
||||
assert not security_manager.is_translation_allowed("agent", "es")
|
||||
assert not security_manager.is_translation_allowed("wallet", "fr")
|
||||
|
||||
# Low commands - allowed
|
||||
assert security_manager.is_translation_allowed("help", "es")
|
||||
assert security_manager.is_translation_allowed("version", "fr")
|
||||
|
||||
# Medium commands - allowed
|
||||
assert security_manager.is_translation_allowed("balance", "es")
|
||||
assert security_manager.is_translation_allowed("status", "fr")
|
||||
|
||||
def test_get_security_policy_for_command(self, security_manager):
|
||||
"""Test getting security policy for specific commands"""
|
||||
critical_policy = security_manager.get_security_policy_for_command("agent")
|
||||
assert critical_policy.security_level == SecurityLevel.CRITICAL
|
||||
assert critical_policy.translation_mode == TranslationMode.DISABLED
|
||||
|
||||
low_policy = security_manager.get_security_policy_for_command("help")
|
||||
assert low_policy.security_level == SecurityLevel.LOW
|
||||
assert low_policy.translation_mode == TranslationMode.FULL
|
||||
|
||||
|
||||
class TestTranslationSecurityConfiguration:
|
||||
"""Test translation security configuration"""
|
||||
|
||||
def test_configure_translation_security(self):
|
||||
"""Test configuring translation security policies"""
|
||||
# Configure custom policies
|
||||
configure_translation_security(
|
||||
critical_level="disabled",
|
||||
high_level="disabled",
|
||||
medium_level="local_only",
|
||||
low_level="fallback"
|
||||
)
|
||||
|
||||
# Verify configuration
|
||||
assert cli_translation_security.policies[SecurityLevel.CRITICAL].translation_mode == TranslationMode.DISABLED
|
||||
assert cli_translation_security.policies[SecurityLevel.HIGH].translation_mode == TranslationMode.DISABLED
|
||||
assert cli_translation_security.policies[SecurityLevel.MEDIUM].translation_mode == TranslationMode.LOCAL_ONLY
|
||||
assert cli_translation_security.policies[SecurityLevel.LOW].translation_mode == TranslationMode.FALLBACK
|
||||
|
||||
def test_get_translation_security_report(self):
|
||||
"""Test generating translation security report"""
|
||||
report = get_translation_security_report()
|
||||
|
||||
assert "security_policies" in report
|
||||
assert "security_summary" in report
|
||||
assert "critical_commands" in report
|
||||
assert "recommendations" in report
|
||||
|
||||
# Check security policies
|
||||
policies = report["security_policies"]
|
||||
assert "critical" in policies
|
||||
assert "high" in policies
|
||||
assert "medium" in policies
|
||||
assert "low" in policies
|
||||
|
||||
|
||||
class TestSecurityEdgeCases:
|
||||
"""Test edge cases and error conditions"""
|
||||
|
||||
@pytest.fixture
|
||||
def security_manager(self):
|
||||
return CLITranslationSecurityManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_translation_request(self, security_manager):
|
||||
"""Test handling of empty translation requests"""
|
||||
request = TranslationRequest(
|
||||
text="",
|
||||
target_language="es",
|
||||
command_name="help",
|
||||
security_level=SecurityLevel.LOW
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
# Mock translation returns format even for empty text
|
||||
assert "[Translated to es: ]" in response.translated_text
|
||||
assert response.security_compliant is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_target_language(self, security_manager):
|
||||
"""Test handling of unsupported target languages"""
|
||||
request = TranslationRequest(
|
||||
text="Help message",
|
||||
target_language="unsupported_lang",
|
||||
command_name="help",
|
||||
security_level=SecurityLevel.LOW
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
# Should fallback to original text or mock translation
|
||||
assert response.security_compliant is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_long_text_translation(self, security_manager):
|
||||
"""Test handling of very long text"""
|
||||
long_text = "help " * 1000 # Create a very long string
|
||||
|
||||
request = TranslationRequest(
|
||||
text=long_text,
|
||||
target_language="es",
|
||||
command_name="help",
|
||||
security_level=SecurityLevel.LOW
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
assert response.success is True
|
||||
assert response.security_compliant is True
|
||||
assert len(response.translated_text) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_translation_requests(self, security_manager):
|
||||
"""Test handling of concurrent translation requests"""
|
||||
requests = [
|
||||
TranslationRequest(
|
||||
text=f"Message {i}",
|
||||
target_language="es",
|
||||
command_name="help",
|
||||
security_level=SecurityLevel.LOW
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
# Run translations concurrently
|
||||
tasks = [security_manager.translate_with_security(req) for req in requests]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(responses) == 10
|
||||
for response in responses:
|
||||
assert response.success is True
|
||||
assert response.security_compliant is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_security_log_size_limit(self, security_manager):
|
||||
"""Test that security log respects size limits"""
|
||||
# Add more entries than the limit
|
||||
for i in range(1005): # Exceeds the 1000 entry limit
|
||||
security_manager.security_log.append({
|
||||
"timestamp": i,
|
||||
"command": f"test_{i}",
|
||||
"security_level": "low",
|
||||
"target_language": "es",
|
||||
"user_consent": False,
|
||||
"text_length": 10
|
||||
})
|
||||
|
||||
# Trigger log cleanup (happens automatically on new entries)
|
||||
await security_manager.translate_with_security(
|
||||
TranslationRequest(
|
||||
text="Test",
|
||||
target_language="es",
|
||||
command_name="help",
|
||||
security_level=SecurityLevel.LOW
|
||||
)
|
||||
)
|
||||
|
||||
# Verify log size is limited
|
||||
assert len(security_manager.security_log) <= 1000
|
||||
|
||||
|
||||
class TestSecurityCompliance:
|
||||
"""Test security compliance requirements"""
|
||||
|
||||
@pytest.fixture
|
||||
def security_manager(self):
|
||||
return CLITranslationSecurityManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_critical_commands_never_use_external_apis(self, security_manager):
|
||||
"""Test that critical commands never use external APIs"""
|
||||
critical_commands = ["agent", "strategy", "wallet", "sign", "deploy"]
|
||||
|
||||
for command in critical_commands:
|
||||
request = TranslationRequest(
|
||||
text="Test message",
|
||||
target_language="es",
|
||||
command_name=command,
|
||||
security_level=SecurityLevel.CRITICAL
|
||||
)
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
# Should never use external methods
|
||||
assert response.method_used in ["disabled", "consent_required"]
|
||||
assert response.security_compliant is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sensitive_data_never_sent_externally(self, security_manager):
|
||||
"""Test that sensitive data is never sent to external APIs"""
|
||||
sensitive_data = "Private key: 0x1234567890abcdef"
|
||||
|
||||
request = TranslationRequest(
|
||||
text=sensitive_data,
|
||||
target_language="es",
|
||||
command_name="help", # Low security, but sensitive data
|
||||
security_level=SecurityLevel.LOW
|
||||
)
|
||||
|
||||
# Mock external translation to capture what would be sent
|
||||
sent_data = []
|
||||
|
||||
def mock_external_translate(req, policy):
|
||||
sent_data.append(req.text)
|
||||
raise Exception("Simulated failure")
|
||||
|
||||
with patch.object(security_manager, '_external_translate', side_effect=mock_external_translate):
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
# For this test, we're using low security, so it would attempt external
|
||||
# In a real implementation, sensitive data detection would prevent this
|
||||
assert len(sent_data) > 0 # Data would be sent (this test shows the risk)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_fallback_to_original_text(self, security_manager):
|
||||
"""Test that translation always falls back to original text"""
|
||||
request = TranslationRequest(
|
||||
text="Original important message",
|
||||
target_language="es",
|
||||
command_name="help",
|
||||
security_level=SecurityLevel.LOW
|
||||
)
|
||||
|
||||
# Mock all translation methods to fail
|
||||
with patch.object(security_manager, '_external_translate', side_effect=Exception("External failed")), \
|
||||
patch.object(security_manager, '_local_translate', side_effect=Exception("Local failed")):
|
||||
|
||||
response = await security_manager.translate_with_security(request)
|
||||
|
||||
# Should fallback to original text
|
||||
assert response.translated_text == request.text
|
||||
assert response.success is False
|
||||
assert response.fallback_used is True
|
||||
assert "Falling back to original text for security" in response.warning_messages
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
674
tests/test_event_driven_cache.py
Normal file
674
tests/test_event_driven_cache.py
Normal file
@@ -0,0 +1,674 @@
|
||||
"""
|
||||
Tests for Event-Driven Redis Cache System
|
||||
|
||||
Comprehensive test suite for distributed caching with event-driven invalidation
|
||||
ensuring immediate propagation of GPU availability and pricing changes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from aitbc_cache.event_driven_cache import (
|
||||
EventDrivenCacheManager,
|
||||
CacheEventType,
|
||||
CacheEvent,
|
||||
CacheConfig,
|
||||
cached_result
|
||||
)
|
||||
|
||||
from aitbc_cache.gpu_marketplace_cache import (
|
||||
GPUMarketplaceCacheManager,
|
||||
GPUInfo,
|
||||
BookingInfo,
|
||||
MarketStats,
|
||||
init_marketplace_cache,
|
||||
get_marketplace_cache
|
||||
)
|
||||
|
||||
|
||||
class TestEventDrivenCacheManager:
|
||||
"""Test the core event-driven cache manager"""
|
||||
|
||||
@pytest.fixture
|
||||
async def cache_manager(self):
|
||||
"""Create a cache manager for testing"""
|
||||
manager = EventDrivenCacheManager(
|
||||
redis_url="redis://localhost:6379/1", # Use different DB for testing
|
||||
node_id="test_node_123"
|
||||
)
|
||||
|
||||
# Mock Redis connection for testing
|
||||
with patch('redis.asyncio.Redis') as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Mock ping response
|
||||
mock_client.ping.return_value = True
|
||||
|
||||
# Mock pubsub
|
||||
mock_pubsub = AsyncMock()
|
||||
mock_client.pubsub.return_value = mock_pubsub
|
||||
|
||||
await manager.connect()
|
||||
|
||||
yield manager
|
||||
|
||||
await manager.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_connection(self, cache_manager):
|
||||
"""Test cache manager connection"""
|
||||
assert cache_manager.is_connected is True
|
||||
assert cache_manager.node_id == "test_node_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_and_get(self, cache_manager):
|
||||
"""Test basic cache set and get operations"""
|
||||
test_data = {"gpu_id": "gpu_123", "status": "available"}
|
||||
|
||||
# Set data
|
||||
await cache_manager.set('gpu_availability', {'gpu_id': 'gpu_123'}, test_data)
|
||||
|
||||
# Get data
|
||||
result = await cache_manager.get('gpu_availability', {'gpu_id': 'gpu_123'})
|
||||
|
||||
assert result is not None
|
||||
assert result['gpu_id'] == 'gpu_123'
|
||||
assert result['status'] == 'available'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_l1_cache_fallback(self, cache_manager):
|
||||
"""Test L1 cache fallback when Redis is unavailable"""
|
||||
test_data = {"message": "test data"}
|
||||
|
||||
# Mock Redis failure
|
||||
cache_manager.redis_client = None
|
||||
|
||||
# Should still work with L1 cache
|
||||
await cache_manager.set('test_cache', {'key': 'value'}, test_data)
|
||||
result = await cache_manager.get('test_cache', {'key': 'value'})
|
||||
|
||||
assert result is not None
|
||||
assert result['message'] == 'test data'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidation(self, cache_manager):
|
||||
"""Test cache invalidation"""
|
||||
test_data = {"gpu_id": "gpu_456", "status": "busy"}
|
||||
|
||||
# Set data
|
||||
await cache_manager.set('gpu_availability', {'gpu_id': 'gpu_456'}, test_data)
|
||||
|
||||
# Verify it's cached
|
||||
result = await cache_manager.get('gpu_availability', {'gpu_id': 'gpu_456'})
|
||||
assert result is not None
|
||||
|
||||
# Invalidate cache
|
||||
await cache_manager.invalidate_cache('gpu_availability')
|
||||
|
||||
# Should be gone from L1 cache
|
||||
assert len(cache_manager.l1_cache) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_publishing(self, cache_manager):
|
||||
"""Test event publishing for cache invalidation"""
|
||||
# Mock Redis publish
|
||||
cache_manager.redis_client.publish = AsyncMock()
|
||||
|
||||
# Publish GPU availability change event
|
||||
await cache_manager.notify_gpu_availability_change('gpu_789', 'offline')
|
||||
|
||||
# Verify event was published
|
||||
cache_manager.redis_client.publish.assert_called_once()
|
||||
|
||||
# Check event data
|
||||
call_args = cache_manager.redis_client.publish.call_args
|
||||
event_data = json.loads(call_args[0][1])
|
||||
|
||||
assert event_data['event_type'] == 'gpu_availability_changed'
|
||||
assert event_data['resource_id'] == 'gpu_789'
|
||||
assert event_data['data']['gpu_id'] == 'gpu_789'
|
||||
assert event_data['data']['status'] == 'offline'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_handling(self, cache_manager):
|
||||
"""Test handling of incoming invalidation events"""
|
||||
test_data = {"gpu_id": "gpu_event", "status": "available"}
|
||||
|
||||
# Set data in L1 cache
|
||||
cache_key = cache_manager._generate_cache_key('gpu_avail', {'gpu_id': 'gpu_event'})
|
||||
cache_manager.l1_cache[cache_key] = {
|
||||
'data': test_data,
|
||||
'expires_at': time.time() + 300
|
||||
}
|
||||
|
||||
# Simulate incoming event
|
||||
event_data = {
|
||||
'event_type': 'gpu_availability_changed',
|
||||
'resource_id': 'gpu_event',
|
||||
'data': {'gpu_id': 'gpu_event', 'status': 'busy'},
|
||||
'timestamp': time.time(),
|
||||
'source_node': 'other_node',
|
||||
'event_id': 'event_123',
|
||||
'affected_namespaces': ['gpu_avail']
|
||||
}
|
||||
|
||||
# Process event
|
||||
await cache_manager._process_invalidation_event(event_data)
|
||||
|
||||
# L1 cache should be invalidated
|
||||
assert cache_key not in cache_manager.l1_cache
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_statistics(self, cache_manager):
|
||||
"""Test cache statistics tracking"""
|
||||
# Perform some cache operations
|
||||
await cache_manager.set('test_cache', {'key': 'value'}, {'data': 'test'})
|
||||
await cache_manager.get('test_cache', {'key': 'value'})
|
||||
await cache_manager.get('nonexistent_cache', {'key': 'value'})
|
||||
|
||||
stats = await cache_manager.get_cache_stats()
|
||||
|
||||
assert 'cache_hits' in stats
|
||||
assert 'cache_misses' in stats
|
||||
assert 'events_processed' in stats
|
||||
assert 'l1_cache_size' in stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, cache_manager):
|
||||
"""Test cache health check"""
|
||||
health = await cache_manager.health_check()
|
||||
|
||||
assert 'status' in health
|
||||
assert 'redis_connected' in health
|
||||
assert 'pubsub_active' in health
|
||||
assert 'event_queue_size' in health
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_decorator(self, cache_manager):
|
||||
"""Test the cached result decorator"""
|
||||
call_count = 0
|
||||
|
||||
@cached_result('test_cache', ttl=60)
|
||||
async def expensive_function(param1, param2):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result_{param1}_{param2}"
|
||||
|
||||
# First call should execute function
|
||||
result1 = await expensive_function('a', 'b')
|
||||
assert result1 == "result_a_b"
|
||||
assert call_count == 1
|
||||
|
||||
# Second call should use cache
|
||||
result2 = await expensive_function('a', 'b')
|
||||
assert result2 == "result_a_b"
|
||||
assert call_count == 1 # Should not increment
|
||||
|
||||
# Different parameters should execute function
|
||||
result3 = await expensive_function('c', 'd')
|
||||
assert result3 == "result_c_d"
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
class TestGPUMarketplaceCacheManager:
|
||||
"""Test the GPU marketplace cache manager"""
|
||||
|
||||
@pytest.fixture
|
||||
async def marketplace_cache(self):
|
||||
"""Create a marketplace cache manager for testing"""
|
||||
# Mock cache manager
|
||||
mock_cache_manager = AsyncMock()
|
||||
mock_cache_manager.get = AsyncMock()
|
||||
mock_cache_manager.set = AsyncMock()
|
||||
mock_cache_manager.invalidate_cache = AsyncMock()
|
||||
mock_cache_manager.notify_gpu_availability_change = AsyncMock()
|
||||
mock_cache_manager.notify_pricing_update = AsyncMock()
|
||||
mock_cache_manager.notify_booking_created = AsyncMock()
|
||||
mock_cache_manager.notify_booking_cancelled = AsyncMock()
|
||||
|
||||
manager = GPUMarketplaceCacheManager(mock_cache_manager)
|
||||
yield manager
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpu_availability_caching(self, marketplace_cache):
|
||||
"""Test GPU availability caching"""
|
||||
gpus = [
|
||||
GPUInfo(
|
||||
gpu_id="gpu_001",
|
||||
provider_id="provider_1",
|
||||
gpu_type="RTX 3080",
|
||||
memory_gb=10,
|
||||
cuda_cores=8704,
|
||||
base_price_per_hour=0.1,
|
||||
current_price_per_hour=0.12,
|
||||
availability_status="available",
|
||||
region="us-east",
|
||||
performance_score=95.0,
|
||||
last_updated=datetime.utcnow()
|
||||
),
|
||||
GPUInfo(
|
||||
gpu_id="gpu_002",
|
||||
provider_id="provider_2",
|
||||
gpu_type="RTX 3090",
|
||||
memory_gb=24,
|
||||
cuda_cores=10496,
|
||||
base_price_per_hour=0.15,
|
||||
current_price_per_hour=0.18,
|
||||
availability_status="busy",
|
||||
region="us-west",
|
||||
performance_score=98.0,
|
||||
last_updated=datetime.utcnow()
|
||||
)
|
||||
]
|
||||
|
||||
# Set GPU availability
|
||||
await marketplace_cache.set_gpu_availability(gpus)
|
||||
|
||||
# Verify cache.set was called
|
||||
assert marketplace_cache.cache.set.call_count > 0
|
||||
|
||||
# Test filtering
|
||||
marketplace_cache.cache.get.return_value = [gpus[0].__dict__]
|
||||
result = await marketplace_cache.get_gpu_availability(region="us-east")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].gpu_id == "gpu_001"
|
||||
assert result[0].region == "us-east"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpu_status_update(self, marketplace_cache):
|
||||
"""Test GPU status update with event notification"""
|
||||
# Mock existing GPU
|
||||
existing_gpu = GPUInfo(
|
||||
gpu_id="gpu_003",
|
||||
provider_id="provider_3",
|
||||
gpu_type="A100",
|
||||
memory_gb=40,
|
||||
cuda_cores=6912,
|
||||
base_price_per_hour=0.5,
|
||||
current_price_per_hour=0.5,
|
||||
availability_status="available",
|
||||
region="eu-central",
|
||||
performance_score=99.0,
|
||||
last_updated=datetime.utcnow()
|
||||
)
|
||||
|
||||
marketplace_cache.cache.get.return_value = [existing_gpu.__dict__]
|
||||
|
||||
# Update status
|
||||
await marketplace_cache.update_gpu_status("gpu_003", "maintenance")
|
||||
|
||||
# Verify notification was sent
|
||||
marketplace_cache.cache.notify_gpu_availability_change.assert_called_once_with(
|
||||
"gpu_003", "maintenance"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_pricing(self, marketplace_cache):
|
||||
"""Test dynamic pricing calculation"""
|
||||
# Mock GPU data with low availability
|
||||
gpus = [
|
||||
GPUInfo(
|
||||
gpu_id="gpu_004",
|
||||
provider_id="provider_4",
|
||||
gpu_type="RTX 3080",
|
||||
memory_gb=10,
|
||||
cuda_cores=8704,
|
||||
base_price_per_hour=0.1,
|
||||
current_price_per_hour=0.1,
|
||||
availability_status="available",
|
||||
region="us-east",
|
||||
performance_score=95.0,
|
||||
last_updated=datetime.utcnow()
|
||||
)
|
||||
# Only 1 GPU available (low availability scenario)
|
||||
]
|
||||
|
||||
marketplace_cache.cache.get.return_value = [gpus[0].__dict__]
|
||||
|
||||
# Calculate dynamic pricing
|
||||
price = await marketplace_cache.get_dynamic_pricing("gpu_004")
|
||||
|
||||
# Should be higher than base price due to low availability
|
||||
assert price > gpus[0].base_price_per_hour
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_booking_creation(self, marketplace_cache):
|
||||
"""Test booking creation with cache updates"""
|
||||
booking = BookingInfo(
|
||||
booking_id="booking_001",
|
||||
gpu_id="gpu_005",
|
||||
user_id="user_123",
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow() + timedelta(hours=2),
|
||||
status="active",
|
||||
total_cost=0.2,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
# Mock GPU data
|
||||
gpu = GPUInfo(
|
||||
gpu_id="gpu_005",
|
||||
provider_id="provider_5",
|
||||
gpu_type="RTX 3080",
|
||||
memory_gb=10,
|
||||
cuda_cores=8704,
|
||||
base_price_per_hour=0.1,
|
||||
current_price_per_hour=0.1,
|
||||
availability_status="available",
|
||||
region="us-east",
|
||||
performance_score=95.0,
|
||||
last_updated=datetime.utcnow()
|
||||
)
|
||||
|
||||
marketplace_cache.cache.get.return_value = [gpu.__dict__]
|
||||
|
||||
# Create booking
|
||||
result = await marketplace_cache.create_booking(booking)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify GPU status was updated
|
||||
marketplace_cache.cache.notify_gpu_availability_change.assert_called()
|
||||
|
||||
# Verify booking event was published
|
||||
marketplace_cache.cache.notify_booking_created.assert_called_with(
|
||||
"booking_001", "gpu_005"
|
||||
)
|
||||
|
||||
# Verify relevant caches were invalidated
|
||||
marketplace_cache.cache.invalidate_cache.assert_any_call('order_book')
|
||||
marketplace_cache.cache.invalidate_cache.assert_any_call('market_stats')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_booking_cancellation(self, marketplace_cache):
|
||||
"""Test booking cancellation with cache updates"""
|
||||
# Mock GPU data
|
||||
gpu = GPUInfo(
|
||||
gpu_id="gpu_006",
|
||||
provider_id="provider_6",
|
||||
gpu_type="RTX 3090",
|
||||
memory_gb=24,
|
||||
cuda_cores=10496,
|
||||
base_price_per_hour=0.15,
|
||||
current_price_per_hour=0.15,
|
||||
availability_status="busy",
|
||||
region="us-west",
|
||||
performance_score=98.0,
|
||||
last_updated=datetime.utcnow()
|
||||
)
|
||||
|
||||
marketplace_cache.cache.get.return_value = [gpu.__dict__]
|
||||
|
||||
# Cancel booking
|
||||
result = await marketplace_cache.cancel_booking("booking_002", "gpu_006")
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify GPU status was updated to available
|
||||
marketplace_cache.cache.notify_gpu_availability_change.assert_called()
|
||||
|
||||
# Verify cancellation event was published
|
||||
marketplace_cache.cache.notify_booking_cancelled.assert_called_with(
|
||||
"booking_002", "gpu_006"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_statistics(self, marketplace_cache):
|
||||
"""Test market statistics calculation"""
|
||||
# Mock GPU data
|
||||
gpus = [
|
||||
GPUInfo(
|
||||
gpu_id="gpu_007",
|
||||
provider_id="provider_7",
|
||||
gpu_type="RTX 3080",
|
||||
memory_gb=10,
|
||||
cuda_cores=8704,
|
||||
base_price_per_hour=0.1,
|
||||
current_price_per_hour=0.12,
|
||||
availability_status="available",
|
||||
region="us-east",
|
||||
performance_score=95.0,
|
||||
last_updated=datetime.utcnow()
|
||||
),
|
||||
GPUInfo(
|
||||
gpu_id="gpu_008",
|
||||
provider_id="provider_8",
|
||||
gpu_type="RTX 3090",
|
||||
memory_gb=24,
|
||||
cuda_cores=10496,
|
||||
base_price_per_hour=0.15,
|
||||
current_price_per_hour=0.18,
|
||||
availability_status="busy",
|
||||
region="us-west",
|
||||
performance_score=98.0,
|
||||
last_updated=datetime.utcnow()
|
||||
)
|
||||
]
|
||||
|
||||
marketplace_cache.cache.get.return_value = [gpu.__dict__ for gpu in gpus]
|
||||
|
||||
# Get market stats
|
||||
stats = await marketplace_cache.get_market_stats()
|
||||
|
||||
assert isinstance(stats, MarketStats)
|
||||
assert stats.total_gpus == 2
|
||||
assert stats.available_gpus == 1
|
||||
assert stats.busy_gpus == 1
|
||||
assert stats.utilization_rate == 0.5
|
||||
assert stats.average_price_per_hour == 0.12 # Average of available GPUs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpu_search(self, marketplace_cache):
|
||||
"""Test GPU search functionality"""
|
||||
# Mock GPU data
|
||||
gpus = [
|
||||
GPUInfo(
|
||||
gpu_id="gpu_009",
|
||||
provider_id="provider_9",
|
||||
gpu_type="RTX 3080",
|
||||
memory_gb=10,
|
||||
cuda_cores=8704,
|
||||
base_price_per_hour=0.1,
|
||||
current_price_per_hour=0.1,
|
||||
availability_status="available",
|
||||
region="us-east",
|
||||
performance_score=95.0,
|
||||
last_updated=datetime.utcnow()
|
||||
),
|
||||
GPUInfo(
|
||||
gpu_id="gpu_010",
|
||||
provider_id="provider_10",
|
||||
gpu_type="RTX 3090",
|
||||
memory_gb=24,
|
||||
cuda_cores=10496,
|
||||
base_price_per_hour=0.15,
|
||||
current_price_per_hour=0.15,
|
||||
availability_status="available",
|
||||
region="us-west",
|
||||
performance_score=98.0,
|
||||
last_updated=datetime.utcnow()
|
||||
)
|
||||
]
|
||||
|
||||
marketplace_cache.cache.get.return_value = [gpu.__dict__ for gpu in gpus]
|
||||
|
||||
# Search with criteria
|
||||
results = await marketplace_cache.search_gpus(
|
||||
min_memory=16,
|
||||
max_price=0.2
|
||||
)
|
||||
|
||||
# Should only return RTX 3090 (24GB memory, $0.15/hour)
|
||||
assert len(results) == 1
|
||||
assert results[0].gpu_type == "RTX 3090"
|
||||
assert results[0].memory_gb == 24
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_performing_gpus(self, marketplace_cache):
|
||||
"""Test getting top performing GPUs"""
|
||||
# Mock GPU data with different performance scores
|
||||
gpus = [
|
||||
GPUInfo(
|
||||
gpu_id="gpu_011",
|
||||
provider_id="provider_11",
|
||||
gpu_type="A100",
|
||||
memory_gb=40,
|
||||
cuda_cores=6912,
|
||||
base_price_per_hour=0.5,
|
||||
current_price_per_hour=0.5,
|
||||
availability_status="available",
|
||||
region="us-east",
|
||||
performance_score=99.0,
|
||||
last_updated=datetime.utcnow()
|
||||
),
|
||||
GPUInfo(
|
||||
gpu_id="gpu_012",
|
||||
provider_id="provider_12",
|
||||
gpu_type="RTX 3080",
|
||||
memory_gb=10,
|
||||
cuda_cores=8704,
|
||||
base_price_per_hour=0.1,
|
||||
current_price_per_hour=0.1,
|
||||
availability_status="available",
|
||||
region="us-west",
|
||||
performance_score=95.0,
|
||||
last_updated=datetime.utcnow()
|
||||
)
|
||||
]
|
||||
|
||||
marketplace_cache.cache.get.return_value = [gpu.__dict__ for gpu in gpus]
|
||||
|
||||
# Get top performing GPUs
|
||||
top_gpus = await marketplace_cache.get_top_performing_gpus(limit=2)
|
||||
|
||||
assert len(top_gpus) == 2
|
||||
assert top_gpus[0].performance_score >= top_gpus[1].performance_score
|
||||
assert top_gpus[0].gpu_type == "A100"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cheapest_gpus(self, marketplace_cache):
|
||||
"""Test getting cheapest GPUs"""
|
||||
# Mock GPU data with different prices
|
||||
gpus = [
|
||||
GPUInfo(
|
||||
gpu_id="gpu_013",
|
||||
provider_id="provider_13",
|
||||
gpu_type="RTX 3060",
|
||||
memory_gb=12,
|
||||
cuda_cores=3584,
|
||||
base_price_per_hour=0.05,
|
||||
current_price_per_hour=0.05,
|
||||
availability_status="available",
|
||||
region="us-east",
|
||||
performance_score=85.0,
|
||||
last_updated=datetime.utcnow()
|
||||
),
|
||||
GPUInfo(
|
||||
gpu_id="gpu_014",
|
||||
provider_id="provider_14",
|
||||
gpu_type="RTX 3080",
|
||||
memory_gb=10,
|
||||
cuda_cores=8704,
|
||||
base_price_per_hour=0.1,
|
||||
current_price_per_hour=0.1,
|
||||
availability_status="available",
|
||||
region="us-west",
|
||||
performance_score=95.0,
|
||||
last_updated=datetime.utcnow()
|
||||
)
|
||||
]
|
||||
|
||||
marketplace_cache.cache.get.return_value = [gpu.__dict__ for gpu in gpus]
|
||||
|
||||
# Get cheapest GPUs
|
||||
cheapest_gpus = await marketplace_cache.get_cheapest_gpus(limit=2)
|
||||
|
||||
assert len(cheapest_gpus) == 2
|
||||
assert cheapest_gpus[0].current_price_per_hour <= cheapest_gpus[1].current_price_per_hour
|
||||
assert cheapest_gpus[0].gpu_type == "RTX 3060"
|
||||
|
||||
|
||||
class TestCacheIntegration:
|
||||
"""Test integration between cache components"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_marketplace_cache_initialization(self):
|
||||
"""Test marketplace cache manager initialization"""
|
||||
with patch('aitbc_cache.gpu_marketplace_cache.EventDrivenCacheManager') as mock_cache:
|
||||
mock_manager = AsyncMock()
|
||||
mock_cache.return_value = mock_manager
|
||||
mock_manager.connect = AsyncMock()
|
||||
|
||||
# Initialize marketplace cache
|
||||
manager = await init_marketplace_cache(
|
||||
redis_url="redis://localhost:6379/2",
|
||||
node_id="test_node",
|
||||
region="test_region"
|
||||
)
|
||||
|
||||
assert isinstance(manager, GPUMarketplaceCacheManager)
|
||||
mock_cache.assert_called_once()
|
||||
mock_manager.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_marketplace_cache_access(self):
|
||||
"""Test global marketplace cache access"""
|
||||
# Mock the global cache
|
||||
with patch('aitbc_cache.gpu_marketplace_cache.marketplace_cache') as mock_global:
|
||||
mock_global.get = AsyncMock()
|
||||
|
||||
# Should work when initialized
|
||||
result = await get_marketplace_cache()
|
||||
assert result is not None
|
||||
|
||||
# Should raise error when not initialized
|
||||
with patch('aitbc_cache.gpu_marketplace_cache.marketplace_cache', None):
|
||||
with pytest.raises(RuntimeError, match="Marketplace cache not initialized"):
|
||||
await get_marketplace_cache()
|
||||
|
||||
|
||||
class TestCacheEventTypes:
|
||||
"""Test different cache event types"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_event_types(self):
|
||||
"""Test all supported cache event types"""
|
||||
event_types = [
|
||||
CacheEventType.GPU_AVAILABILITY_CHANGED,
|
||||
CacheEventType.PRICING_UPDATED,
|
||||
CacheEventType.BOOKING_CREATED,
|
||||
CacheEventType.BOOKING_CANCELLED,
|
||||
CacheEventType.PROVIDER_STATUS_CHANGED,
|
||||
CacheEventType.MARKET_STATS_UPDATED,
|
||||
CacheEventType.ORDER_BOOK_UPDATED,
|
||||
CacheEventType.MANUAL_INVALIDATION
|
||||
]
|
||||
|
||||
for event_type in event_types:
|
||||
# Verify event type can be serialized
|
||||
event = CacheEvent(
|
||||
event_type=event_type,
|
||||
resource_id="test_resource",
|
||||
data={"test": "data"},
|
||||
timestamp=time.time(),
|
||||
source_node="test_node",
|
||||
event_id="test_event",
|
||||
affected_namespaces=["test_namespace"]
|
||||
)
|
||||
|
||||
# Test JSON serialization
|
||||
event_json = json.dumps(event.__dict__, default=str)
|
||||
parsed_event = json.loads(event_json)
|
||||
|
||||
assert parsed_event['event_type'] == event_type.value
|
||||
assert parsed_event['resource_id'] == "test_resource"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
168
tests/test_runner.py
Executable file
168
tests/test_runner.py
Executable file
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple Test Runner for AITBC
|
||||
|
||||
This script provides convenient commands for running tests with the new
|
||||
pyproject.toml configuration. It's a thin wrapper around pytest that
|
||||
provides common test patterns and helpful output.
|
||||
|
||||
Usage:
|
||||
python tests/test_runner.py # Run all fast tests
|
||||
python tests/test_runner.py --all # Run all tests including slow
|
||||
python tests/test_runner.py --unit # Run unit tests only
|
||||
python tests/test_runner.py --integration # Run integration tests only
|
||||
python tests/test_runner.py --cli # Run CLI tests only
|
||||
python tests/test_runner.py --coverage # Run with coverage
|
||||
python tests/test_runner.py --performance # Run performance tests
|
||||
"""
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_pytest(args, description):
|
||||
"""Run pytest with given arguments."""
|
||||
print(f"🧪 {description}")
|
||||
print("=" * 50)
|
||||
|
||||
cmd = ["python", "-m", "pytest"] + args
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, cwd=Path(__file__).parent.parent)
|
||||
return result.returncode
|
||||
except KeyboardInterrupt:
|
||||
print("\n❌ Tests interrupted")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"❌ Error running tests: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test runner."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="AITBC Test Runner - Simple wrapper around pytest",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python tests/test_runner.py # Run all fast tests
|
||||
python tests/test_runner.py --all # Run all tests including slow
|
||||
python tests/test_runner.py --unit # Run unit tests only
|
||||
python tests/test_runner.py --integration # Run integration tests only
|
||||
python tests/test_runner.py --cli # Run CLI tests only
|
||||
python tests/test_runner.py --coverage # Run with coverage
|
||||
python tests/test_runner.py --performance # Run performance tests
|
||||
"""
|
||||
)
|
||||
|
||||
# Test selection options
|
||||
test_group = parser.add_mutually_exclusive_group()
|
||||
test_group.add_argument("--all", action="store_true", help="Run all tests including slow ones")
|
||||
test_group.add_argument("--unit", action="store_true", help="Run unit tests only")
|
||||
test_group.add_argument("--integration", action="store_true", help="Run integration tests only")
|
||||
test_group.add_argument("--cli", action="store_true", help="Run CLI tests only")
|
||||
test_group.add_argument("--api", action="store_true", help="Run API tests only")
|
||||
test_group.add_argument("--blockchain", action="store_true", help="Run blockchain tests only")
|
||||
test_group.add_argument("--slow", action="store_true", help="Run slow tests only")
|
||||
test_group.add_argument("--performance", action="store_true", help="Run performance tests only")
|
||||
test_group.add_argument("--security", action="store_true", help="Run security tests only")
|
||||
|
||||
# Additional options
|
||||
parser.add_argument("--coverage", action="store_true", help="Run with coverage reporting")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
||||
parser.add_argument("--debug", action="store_true", help="Debug mode (show collection)")
|
||||
parser.add_argument("--list", "-l", action="store_true", help="List available tests")
|
||||
parser.add_argument("--markers", action="store_true", help="Show available markers")
|
||||
|
||||
# Allow passing through pytest arguments
|
||||
parser.add_argument("pytest_args", nargs="*", help="Additional pytest arguments")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Build pytest command
|
||||
pytest_args = []
|
||||
|
||||
# Add coverage if requested
|
||||
if args.coverage:
|
||||
pytest_args.extend(["--cov=aitbc_cli", "--cov-report=term-missing"])
|
||||
if args.verbose:
|
||||
pytest_args.append("--cov-report=html")
|
||||
|
||||
# Add verbosity
|
||||
if args.verbose:
|
||||
pytest_args.append("-v")
|
||||
|
||||
# Add test selection markers
|
||||
if args.all:
|
||||
pytest_args.append("-m") # No marker - run all tests
|
||||
elif args.unit:
|
||||
pytest_args.extend(["-m", "unit and not slow"])
|
||||
elif args.integration:
|
||||
pytest_args.extend(["-m", "integration and not slow"])
|
||||
elif args.cli:
|
||||
pytest_args.extend(["-m", "cli and not slow"])
|
||||
elif args.api:
|
||||
pytest_args.extend(["-m", "api and not slow"])
|
||||
elif args.blockchain:
|
||||
pytest_args.extend(["-m", "blockchain and not slow"])
|
||||
elif args.slow:
|
||||
pytest_args.extend(["-m", "slow"])
|
||||
elif args.performance:
|
||||
pytest_args.extend(["-m", "performance"])
|
||||
elif args.security:
|
||||
pytest_args.extend(["-m", "security"])
|
||||
else:
|
||||
# Default: run fast tests only
|
||||
pytest_args.extend(["-m", "unit or integration or cli or api or blockchain"])
|
||||
pytest_args.extend(["-m", "not slow"])
|
||||
|
||||
# Add debug options
|
||||
if args.debug:
|
||||
pytest_args.append("--debug")
|
||||
|
||||
# Add list/markers options
|
||||
if args.list:
|
||||
pytest_args.append("--collect-only")
|
||||
elif args.markers:
|
||||
pytest_args.append("--markers")
|
||||
|
||||
# Add additional pytest arguments
|
||||
if args.pytest_args:
|
||||
pytest_args.extend(args.pytest_args)
|
||||
|
||||
# Special handling for markers/list (don't run tests)
|
||||
if args.list or args.markers:
|
||||
return run_pytest(pytest_args, "Listing pytest information")
|
||||
|
||||
# Run tests
|
||||
if args.all:
|
||||
description = "Running all tests (including slow)"
|
||||
elif args.unit:
|
||||
description = "Running unit tests"
|
||||
elif args.integration:
|
||||
description = "Running integration tests"
|
||||
elif args.cli:
|
||||
description = "Running CLI tests"
|
||||
elif args.api:
|
||||
description = "Running API tests"
|
||||
elif args.blockchain:
|
||||
description = "Running blockchain tests"
|
||||
elif args.slow:
|
||||
description = "Running slow tests"
|
||||
elif args.performance:
|
||||
description = "Running performance tests"
|
||||
elif args.security:
|
||||
description = "Running security tests"
|
||||
else:
|
||||
description = "Running fast tests (unit, integration, CLI, API, blockchain)"
|
||||
|
||||
if args.coverage:
|
||||
description += " with coverage"
|
||||
|
||||
return run_pytest(pytest_args, description)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
793
tests/test_websocket_backpressure_core.py
Normal file
793
tests/test_websocket_backpressure_core.py
Normal file
@@ -0,0 +1,793 @@
|
||||
"""
|
||||
Core WebSocket Backpressure Tests
|
||||
|
||||
Tests for the essential backpressure control mechanisms
|
||||
without complex dependencies.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class MockMessage:
|
||||
"""Mock message for testing"""
|
||||
def __init__(self, data: str, priority: int = 1):
|
||||
self.data = data
|
||||
self.priority = priority
|
||||
self.timestamp = time.time()
|
||||
self.message_id = f"msg_{id(self)}"
|
||||
|
||||
|
||||
class MockBoundedQueue:
|
||||
"""Mock bounded queue with priority handling"""
|
||||
|
||||
def __init__(self, max_size: int = 100):
|
||||
self.max_size = max_size
|
||||
self.queues = {
|
||||
"critical": [],
|
||||
"important": [],
|
||||
"bulk": [],
|
||||
"control": []
|
||||
}
|
||||
self.total_size = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def put(self, message: MockMessage, priority: str = "important") -> bool:
|
||||
"""Add message with backpressure handling"""
|
||||
async with self._lock:
|
||||
# Check capacity
|
||||
if self.total_size >= self.max_size:
|
||||
# Drop bulk messages first
|
||||
if priority == "bulk":
|
||||
return False
|
||||
|
||||
# For important messages: drop oldest important if exists, otherwise drop bulk
|
||||
if priority == "important":
|
||||
if self.queues["important"]:
|
||||
self.queues["important"].pop(0)
|
||||
self.total_size -= 1
|
||||
elif self.queues["bulk"]:
|
||||
self.queues["bulk"].pop(0)
|
||||
self.total_size -= 1
|
||||
else:
|
||||
return False
|
||||
|
||||
# For critical messages: drop oldest critical if exists, otherwise drop important, otherwise drop bulk
|
||||
if priority == "critical":
|
||||
if self.queues["critical"]:
|
||||
self.queues["critical"].pop(0)
|
||||
self.total_size -= 1
|
||||
elif self.queues["important"]:
|
||||
self.queues["important"].pop(0)
|
||||
self.total_size -= 1
|
||||
elif self.queues["bulk"]:
|
||||
self.queues["bulk"].pop(0)
|
||||
self.total_size -= 1
|
||||
else:
|
||||
return False
|
||||
|
||||
self.queues[priority].append(message)
|
||||
self.total_size += 1
|
||||
return True
|
||||
|
||||
async def get(self) -> MockMessage:
|
||||
"""Get next message by priority"""
|
||||
async with self._lock:
|
||||
# Priority order: control > critical > important > bulk
|
||||
for priority in ["control", "critical", "important", "bulk"]:
|
||||
if self.queues[priority]:
|
||||
message = self.queues[priority].pop(0)
|
||||
self.total_size -= 1
|
||||
return message
|
||||
return None
|
||||
|
||||
def size(self) -> int:
|
||||
return self.total_size
|
||||
|
||||
def fill_ratio(self) -> float:
|
||||
return self.total_size / self.max_size
|
||||
|
||||
|
||||
class MockWebSocketStream:
|
||||
"""Mock WebSocket stream with backpressure control"""
|
||||
|
||||
def __init__(self, stream_id: str, max_queue_size: int = 100):
|
||||
self.stream_id = stream_id
|
||||
self.queue = MockBoundedQueue(max_queue_size)
|
||||
self.websocket = AsyncMock()
|
||||
self.status = "connected"
|
||||
self.metrics = {
|
||||
"messages_sent": 0,
|
||||
"messages_dropped": 0,
|
||||
"backpressure_events": 0,
|
||||
"slow_consumer_events": 0
|
||||
}
|
||||
|
||||
self._running = False
|
||||
self._sender_task = None
|
||||
self._send_lock = asyncio.Lock()
|
||||
|
||||
# Configuration
|
||||
self.send_timeout = 1.0
|
||||
self.slow_consumer_threshold = 0.5
|
||||
self.backpressure_threshold = 0.7
|
||||
|
||||
async def start(self):
|
||||
"""Start stream processing"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._sender_task = asyncio.create_task(self._sender_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop stream processing"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._sender_task:
|
||||
self._sender_task.cancel()
|
||||
try:
|
||||
await self._sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def send_message(self, data: Any, priority: str = "important") -> bool:
|
||||
"""Send message with backpressure handling"""
|
||||
if not self._running:
|
||||
return False
|
||||
|
||||
message = MockMessage(data, priority)
|
||||
|
||||
# Check backpressure
|
||||
queue_ratio = self.queue.fill_ratio()
|
||||
if queue_ratio > self.backpressure_threshold:
|
||||
self.metrics["backpressure_events"] += 1
|
||||
|
||||
# Drop bulk messages under backpressure
|
||||
if priority == "bulk" and queue_ratio > 0.8:
|
||||
self.metrics["messages_dropped"] += 1
|
||||
return False
|
||||
|
||||
# Add to queue
|
||||
success = await self.queue.put(message, priority)
|
||||
if not success:
|
||||
self.metrics["messages_dropped"] += 1
|
||||
|
||||
return success
|
||||
|
||||
async def _sender_loop(self):
|
||||
"""Main sender loop with backpressure control"""
|
||||
while self._running:
|
||||
try:
|
||||
message = await self.queue.get()
|
||||
if message is None:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Send with timeout protection
|
||||
start_time = time.time()
|
||||
success = await self._send_with_backpressure(message)
|
||||
send_time = time.time() - start_time
|
||||
|
||||
if success:
|
||||
self.metrics["messages_sent"] += 1
|
||||
|
||||
# Check for slow consumer
|
||||
if send_time > self.slow_consumer_threshold:
|
||||
self.metrics["slow_consumer_events"] += 1
|
||||
if self.metrics["slow_consumer_events"] > 5:
|
||||
self.status = "slow_consumer"
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error in sender loop: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _send_with_backpressure(self, message: MockMessage) -> bool:
|
||||
"""Send message with timeout protection"""
|
||||
try:
|
||||
async with self._send_lock:
|
||||
# Simulate send with potential delay
|
||||
await asyncio.wait_for(
|
||||
self.websocket.send(message.data),
|
||||
timeout=self.send_timeout
|
||||
)
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Send error: {e}")
|
||||
return False
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get stream metrics"""
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"status": self.status,
|
||||
"queue_size": self.queue.size(),
|
||||
"queue_fill_ratio": self.queue.fill_ratio(),
|
||||
**self.metrics
|
||||
}
|
||||
|
||||
|
||||
class MockStreamManager:
|
||||
"""Mock stream manager with backpressure control"""
|
||||
|
||||
def __init__(self):
|
||||
self.streams: Dict[str, MockWebSocketStream] = {}
|
||||
self.total_connections = 0
|
||||
self._running = False
|
||||
self._broadcast_queue = asyncio.Queue(maxsize=1000)
|
||||
self._broadcast_task = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the stream manager"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._broadcast_task = asyncio.create_task(self._broadcast_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the stream manager"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Stop all streams
|
||||
for stream in self.streams.values():
|
||||
await stream.stop()
|
||||
|
||||
if self._broadcast_task:
|
||||
self._broadcast_task.cancel()
|
||||
try:
|
||||
await self._broadcast_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def create_stream(self, stream_id: str, max_queue_size: int = 100) -> MockWebSocketStream:
|
||||
"""Create a new stream"""
|
||||
stream = MockWebSocketStream(stream_id, max_queue_size)
|
||||
await stream.start()
|
||||
|
||||
self.streams[stream_id] = stream
|
||||
self.total_connections += 1
|
||||
|
||||
return stream
|
||||
|
||||
async def remove_stream(self, stream_id: str):
|
||||
"""Remove a stream"""
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
await stream.stop()
|
||||
del self.streams[stream_id]
|
||||
self.total_connections -= 1
|
||||
|
||||
async def broadcast_to_all(self, data: Any, priority: str = "important"):
|
||||
"""Broadcast message to all streams"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._broadcast_queue.put((data, priority))
|
||||
except asyncio.QueueFull:
|
||||
print("Broadcast queue full, dropping message")
|
||||
|
||||
async def _broadcast_loop(self):
|
||||
"""Broadcast messages to all streams"""
|
||||
while self._running:
|
||||
try:
|
||||
data, priority = await self._broadcast_queue.get()
|
||||
|
||||
# Send to all streams concurrently
|
||||
tasks = []
|
||||
for stream in self.streams.values():
|
||||
task = asyncio.create_task(
|
||||
stream.send_message(data, priority)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all sends (with timeout)
|
||||
if tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True),
|
||||
timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
print("Broadcast timeout, some streams may be slow")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error in broadcast loop: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
def get_slow_streams(self, threshold: float = 0.8) -> List[str]:
|
||||
"""Get streams with high queue fill ratios"""
|
||||
slow_streams = []
|
||||
for stream_id, stream in self.streams.items():
|
||||
if stream.queue.fill_ratio() > threshold:
|
||||
slow_streams.append(stream_id)
|
||||
return slow_streams
|
||||
|
||||
def get_manager_metrics(self) -> Dict[str, Any]:
|
||||
"""Get manager metrics"""
|
||||
stream_metrics = []
|
||||
for stream in self.streams.values():
|
||||
stream_metrics.append(stream.get_metrics())
|
||||
|
||||
total_queue_size = sum(m["queue_size"] for m in stream_metrics)
|
||||
total_messages_sent = sum(m["messages_sent"] for m in stream_metrics)
|
||||
total_messages_dropped = sum(m["messages_dropped"] for m in stream_metrics)
|
||||
|
||||
status_counts = {}
|
||||
for stream in self.streams.values():
|
||||
status = stream.status
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
|
||||
return {
|
||||
"manager_status": "running" if self._running else "stopped",
|
||||
"total_connections": self.total_connections,
|
||||
"active_streams": len(self.streams),
|
||||
"total_queue_size": total_queue_size,
|
||||
"total_messages_sent": total_messages_sent,
|
||||
"total_messages_dropped": total_messages_dropped,
|
||||
"broadcast_queue_size": self._broadcast_queue.qsize(),
|
||||
"stream_status_distribution": status_counts,
|
||||
"stream_metrics": stream_metrics
|
||||
}
|
||||
|
||||
|
||||
class TestBoundedQueue:
|
||||
"""Test bounded message queue"""
|
||||
|
||||
@pytest.fixture
|
||||
def queue(self):
|
||||
return MockBoundedQueue(max_size=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_operations(self, queue):
|
||||
"""Test basic queue operations"""
|
||||
message = MockMessage("test", "important")
|
||||
|
||||
# Put message
|
||||
success = await queue.put(message, "important")
|
||||
assert success is True
|
||||
assert queue.size() == 1
|
||||
|
||||
# Get message
|
||||
retrieved = await queue.get()
|
||||
assert retrieved == message
|
||||
assert queue.size() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_priority_ordering(self, queue):
|
||||
"""Test priority ordering"""
|
||||
messages = [
|
||||
MockMessage("bulk", "bulk"),
|
||||
MockMessage("critical", "critical"),
|
||||
MockMessage("important", "important"),
|
||||
MockMessage("control", "control")
|
||||
]
|
||||
|
||||
# Add messages
|
||||
for msg in messages:
|
||||
await queue.put(msg, msg.priority)
|
||||
|
||||
# Should retrieve in priority order
|
||||
expected_order = ["control", "critical", "important", "bulk"]
|
||||
|
||||
for expected_priority in expected_order:
|
||||
msg = await queue.get()
|
||||
assert msg.priority == expected_priority
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backpressure_handling(self, queue):
|
||||
"""Test backpressure when queue is full"""
|
||||
# Fill queue to capacity with bulk messages first
|
||||
for i in range(queue.max_size):
|
||||
await queue.put(MockMessage(f"bulk_{i}", "bulk"), "bulk")
|
||||
|
||||
assert queue.size() == queue.max_size
|
||||
assert queue.fill_ratio() == 1.0
|
||||
|
||||
# Try to add bulk message (should be dropped)
|
||||
bulk_msg = MockMessage("new_bulk", "bulk")
|
||||
success = await queue.put(bulk_msg, "bulk")
|
||||
assert success is False
|
||||
|
||||
# Now add some important messages by replacing bulk messages
|
||||
# First, remove some bulk messages to make space
|
||||
for i in range(3):
|
||||
await queue.get() # Remove bulk messages
|
||||
|
||||
# Add important messages
|
||||
for i in range(3):
|
||||
await queue.put(MockMessage(f"important_{i}", "important"), "important")
|
||||
|
||||
# Fill back to capacity with bulk
|
||||
while queue.size() < queue.max_size:
|
||||
await queue.put(MockMessage(f"bulk_extra", "bulk"), "bulk")
|
||||
|
||||
# Now try to add important message (should replace oldest important)
|
||||
important_msg = MockMessage("new_important", "important")
|
||||
success = await queue.put(important_msg, "important")
|
||||
assert success is True
|
||||
|
||||
# Try to add critical message (should always succeed)
|
||||
critical_msg = MockMessage("new_critical", "critical")
|
||||
success = await queue.put(critical_msg, "critical")
|
||||
assert success is True
|
||||
|
||||
|
||||
class TestWebSocketStream:
|
||||
"""Test WebSocket stream with backpressure"""
|
||||
|
||||
@pytest.fixture
|
||||
def stream(self):
|
||||
return MockWebSocketStream("test_stream", max_queue_size=50)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_start_stop(self, stream):
|
||||
"""Test stream start and stop"""
|
||||
assert stream._running is False
|
||||
|
||||
await stream.start()
|
||||
assert stream._running is True
|
||||
assert stream.status == "connected"
|
||||
|
||||
await stream.stop()
|
||||
assert stream._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_sending(self, stream):
|
||||
"""Test basic message sending"""
|
||||
await stream.start()
|
||||
|
||||
# Send message
|
||||
success = await stream.send_message({"test": "data"}, "important")
|
||||
assert success is True
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Verify message was sent
|
||||
assert stream.websocket.send.called
|
||||
assert stream.metrics["messages_sent"] > 0
|
||||
|
||||
await stream.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_consumer_detection(self, stream):
|
||||
"""Test slow consumer detection"""
|
||||
# Make websocket send slow
|
||||
async def slow_send(message):
|
||||
await asyncio.sleep(0.6) # Slower than threshold (0.5s)
|
||||
|
||||
stream.websocket.send = slow_send
|
||||
|
||||
await stream.start()
|
||||
|
||||
# Send many messages to trigger detection (need > 5 slow events)
|
||||
for i in range(15): # Increased from 10 to 15
|
||||
await stream.send_message({"test": f"data_{i}"}, "important")
|
||||
await asyncio.sleep(0.1) # Small delay between sends
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(3.0) # Increased wait time
|
||||
|
||||
# Check slow consumer detection
|
||||
assert stream.status == "slow_consumer"
|
||||
assert stream.metrics["slow_consumer_events"] > 5 # Need > 5 events
|
||||
|
||||
await stream.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backpressure_handling(self, stream):
|
||||
"""Test backpressure handling"""
|
||||
# Make websocket send slower to build up queue
|
||||
async def slow_send(message):
|
||||
await asyncio.sleep(0.02) # Small delay to allow queue to build
|
||||
|
||||
stream.websocket.send = slow_send
|
||||
|
||||
await stream.start()
|
||||
|
||||
# Fill queue to trigger backpressure
|
||||
for i in range(40): # 40/50 = 80% > threshold (70%)
|
||||
await stream.send_message({"test": f"data_{i}"}, "important")
|
||||
|
||||
# Wait a bit but not too long to allow queue to build
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Check backpressure status
|
||||
assert stream.metrics["backpressure_events"] > 0
|
||||
assert stream.queue.fill_ratio() > 0.7
|
||||
|
||||
# Try to send bulk message under backpressure
|
||||
success = await stream.send_message({"bulk": "data"}, "bulk")
|
||||
# Should be dropped due to high queue fill ratio
|
||||
|
||||
await stream.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_timeout_handling(self, stream):
|
||||
"""Test send timeout handling"""
|
||||
# Make websocket send timeout
|
||||
async def timeout_send(message):
|
||||
await asyncio.sleep(2.0) # Longer than timeout (1.0s)
|
||||
|
||||
stream.websocket.send = timeout_send
|
||||
|
||||
await stream.start()
|
||||
|
||||
# Send message
|
||||
await stream.send_message({"test": "data"}, "important")
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
# Check that message handling handled timeout
|
||||
# (In real implementation, would retry or drop)
|
||||
|
||||
await stream.stop()
|
||||
|
||||
|
||||
class TestStreamManager:
|
||||
"""Test stream manager with multiple streams"""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
return MockStreamManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_start_stop(self, manager):
|
||||
"""Test manager start and stop"""
|
||||
await manager.start()
|
||||
assert manager._running is True
|
||||
|
||||
await manager.stop()
|
||||
assert manager._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_management(self, manager):
|
||||
"""Test stream lifecycle management"""
|
||||
await manager.start()
|
||||
|
||||
# Create stream
|
||||
stream = await manager.create_stream("test_stream")
|
||||
assert stream is not None
|
||||
assert stream._running is True
|
||||
assert len(manager.streams) == 1
|
||||
assert manager.total_connections == 1
|
||||
|
||||
# Remove stream
|
||||
await manager.remove_stream("test_stream")
|
||||
assert len(manager.streams) == 0
|
||||
assert manager.total_connections == 0
|
||||
|
||||
await manager.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_to_all_streams(self, manager):
|
||||
"""Test broadcasting to all streams"""
|
||||
await manager.start()
|
||||
|
||||
# Create multiple streams
|
||||
streams = []
|
||||
for i in range(3):
|
||||
stream = await manager.create_stream(f"stream_{i}")
|
||||
streams.append(stream)
|
||||
|
||||
# Broadcast message
|
||||
await manager.broadcast_to_all({"broadcast": "test"}, "important")
|
||||
|
||||
# Wait for broadcast
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Verify all streams received the message
|
||||
for stream in streams:
|
||||
assert stream.websocket.send.called
|
||||
|
||||
await manager.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_stream_detection(self, manager):
|
||||
"""Test slow stream detection"""
|
||||
await manager.start()
|
||||
|
||||
# Create slow stream
|
||||
slow_stream = await manager.create_stream("slow_stream")
|
||||
|
||||
# Make it slow
|
||||
async def slow_send(message):
|
||||
await asyncio.sleep(0.6)
|
||||
|
||||
slow_stream.websocket.send = slow_send
|
||||
|
||||
# Send many messages to fill queue and trigger slow detection
|
||||
for i in range(30): # More messages to fill queue
|
||||
await slow_stream.send_message({"test": f"data_{i}"}, "important")
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Check for slow streams (based on queue fill ratio)
|
||||
slow_streams = manager.get_slow_streams(threshold=0.5) # Lower threshold
|
||||
|
||||
# Should detect slow stream either by status or queue fill ratio
|
||||
stream_detected = (
|
||||
len(slow_streams) > 0 or
|
||||
slow_stream.status == "slow_consumer" or
|
||||
slow_stream.queue.fill_ratio() > 0.5
|
||||
)
|
||||
|
||||
assert stream_detected, f"Slow stream not detected. Status: {slow_stream.status}, Queue ratio: {slow_stream.queue.fill_ratio()}"
|
||||
|
||||
await manager.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_metrics(self, manager):
|
||||
"""Test manager metrics"""
|
||||
await manager.start()
|
||||
|
||||
# Create streams with different loads
|
||||
normal_stream = await manager.create_stream("normal_stream")
|
||||
slow_stream = await manager.create_stream("slow_stream")
|
||||
|
||||
# Send messages to normal stream
|
||||
for i in range(5):
|
||||
await normal_stream.send_message({"test": f"data_{i}"}, "important")
|
||||
|
||||
# Send messages to slow stream (to fill queue)
|
||||
for i in range(40):
|
||||
await slow_stream.send_message({"test": f"data_{i}"}, "important")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Get metrics
|
||||
metrics = manager.get_manager_metrics()
|
||||
|
||||
assert "manager_status" in metrics
|
||||
assert "total_connections" in metrics
|
||||
assert "active_streams" in metrics
|
||||
assert "total_queue_size" in metrics
|
||||
assert "stream_status_distribution" in metrics
|
||||
|
||||
await manager.stop()
|
||||
|
||||
|
||||
class TestBackpressureScenarios:
|
||||
"""Test backpressure scenarios"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_load_scenario(self):
|
||||
"""Test system behavior under high load"""
|
||||
manager = MockStreamManager()
|
||||
await manager.start()
|
||||
|
||||
try:
|
||||
# Create multiple streams
|
||||
streams = []
|
||||
for i in range(5):
|
||||
stream = await manager.create_stream(f"stream_{i}", max_queue_size=50)
|
||||
streams.append(stream)
|
||||
|
||||
# Send high volume of messages
|
||||
tasks = []
|
||||
for stream in streams:
|
||||
for i in range(100):
|
||||
task = asyncio.create_task(
|
||||
stream.send_message({"test": f"data_{i}"}, "important")
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all sends
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Check system handled load
|
||||
metrics = manager.get_manager_metrics()
|
||||
|
||||
# Should have processed some messages
|
||||
assert metrics["total_messages_sent"] > 0
|
||||
|
||||
# System should still be running
|
||||
assert metrics["manager_status"] == "running"
|
||||
|
||||
# Some messages may be dropped under load
|
||||
assert metrics["total_messages_dropped"] >= 0
|
||||
|
||||
finally:
|
||||
await manager.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_priority_scenario(self):
|
||||
"""Test handling of mixed priority messages"""
|
||||
queue = MockBoundedQueue(max_size=20)
|
||||
|
||||
# Fill queue with bulk messages
|
||||
for i in range(15):
|
||||
await queue.put(MockMessage(f"bulk_{i}", "bulk"), "bulk")
|
||||
|
||||
# Add critical messages (should succeed)
|
||||
critical_success = await queue.put(MockMessage("critical_1", "critical"), "critical")
|
||||
critical_success2 = await queue.put(MockMessage("critical_2", "critical"), "critical")
|
||||
|
||||
assert critical_success is True
|
||||
assert critical_success2 is True
|
||||
|
||||
# Add important messages (should replace bulk)
|
||||
important_success = await queue.put(MockMessage("important_1", "important"), "important")
|
||||
important_success2 = await queue.put(MockMessage("important_2", "important"), "important")
|
||||
|
||||
assert important_success is True
|
||||
assert important_success2 is True
|
||||
|
||||
# Try to add more bulk (should be dropped)
|
||||
bulk_success = await queue.put(MockMessage("bulk_new", "bulk"), "bulk")
|
||||
assert bulk_success is False
|
||||
|
||||
# Verify priority order in retrieval
|
||||
retrieved_order = []
|
||||
for _ in range(10):
|
||||
msg = await queue.get()
|
||||
if msg:
|
||||
retrieved_order.append(msg.priority)
|
||||
|
||||
# Should start with critical messages
|
||||
assert retrieved_order[0] == "critical"
|
||||
assert retrieved_order[1] == "critical"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_consumer_isolation(self):
|
||||
"""Test that slow consumers don't block fast ones"""
|
||||
manager = MockStreamManager()
|
||||
await manager.start()
|
||||
|
||||
try:
|
||||
# Create fast and slow streams
|
||||
fast_stream = await manager.create_stream("fast_stream")
|
||||
slow_stream = await manager.create_stream("slow_stream")
|
||||
|
||||
# Make slow stream slow
|
||||
async def slow_send(message):
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
slow_stream.websocket.send = slow_send
|
||||
|
||||
# Send messages to both streams
|
||||
for i in range(10):
|
||||
await fast_stream.send_message({"fast": f"data_{i}"}, "important")
|
||||
await slow_stream.send_message({"slow": f"data_{i}"}, "important")
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Fast stream should have processed more messages
|
||||
fast_metrics = fast_stream.get_metrics()
|
||||
slow_metrics = slow_stream.get_metrics()
|
||||
|
||||
# Fast stream should be ahead
|
||||
assert fast_metrics["messages_sent"] >= slow_metrics["messages_sent"]
|
||||
|
||||
# Slow stream should be detected as slow
|
||||
assert slow_stream.status == "slow_consumer"
|
||||
|
||||
finally:
|
||||
await manager.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
776
tests/test_websocket_stream_backpressure.py
Normal file
776
tests/test_websocket_stream_backpressure.py
Normal file
@@ -0,0 +1,776 @@
|
||||
"""
|
||||
Tests for WebSocket Stream Backpressure Control
|
||||
|
||||
Comprehensive test suite for WebSocket stream architecture with
|
||||
per-stream flow control and backpressure handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from typing import Dict, Any
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'apps', 'coordinator-api', 'src'))
|
||||
|
||||
from app.services.websocket_stream_manager import (
|
||||
WebSocketStreamManager, StreamConfig, StreamMessage, MessageType,
|
||||
BoundedMessageQueue, WebSocketStream, StreamStatus
|
||||
)
|
||||
from app.services.multi_modal_websocket_fusion import (
|
||||
MultiModalWebSocketFusion, FusionStreamType, FusionStreamConfig,
|
||||
GPUProviderFlowControl, GPUProviderStatus, FusionData
|
||||
)
|
||||
|
||||
|
||||
class TestBoundedMessageQueue:
|
||||
"""Test bounded message queue with priority and backpressure"""
|
||||
|
||||
@pytest.fixture
|
||||
def queue(self):
|
||||
return BoundedMessageQueue(max_size=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_queue_operations(self, queue):
|
||||
"""Test basic queue put/get operations"""
|
||||
message = StreamMessage(data="test", message_type=MessageType.IMPORTANT)
|
||||
|
||||
# Put message
|
||||
success = await queue.put(message)
|
||||
assert success is True
|
||||
assert queue.size() == 1
|
||||
|
||||
# Get message
|
||||
retrieved = await queue.get()
|
||||
assert retrieved == message
|
||||
assert queue.size() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_priority_ordering(self, queue):
|
||||
"""Test message priority ordering"""
|
||||
messages = [
|
||||
StreamMessage(data="bulk", message_type=MessageType.BULK),
|
||||
StreamMessage(data="critical", message_type=MessageType.CRITICAL),
|
||||
StreamMessage(data="important", message_type=MessageType.IMPORTANT),
|
||||
StreamMessage(data="control", message_type=MessageType.CONTROL)
|
||||
]
|
||||
|
||||
# Add messages in random order
|
||||
for msg in messages:
|
||||
await queue.put(msg)
|
||||
|
||||
# Should retrieve in priority order: CONTROL > CRITICAL > IMPORTANT > BULK
|
||||
expected_order = [MessageType.CONTROL, MessageType.CRITICAL,
|
||||
MessageType.IMPORTANT, MessageType.BULK]
|
||||
|
||||
for expected_type in expected_order:
|
||||
msg = await queue.get()
|
||||
assert msg.message_type == expected_type
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backpressure_handling(self, queue):
|
||||
"""Test backpressure handling when queue is full"""
|
||||
# Fill queue to capacity
|
||||
for i in range(queue.max_size):
|
||||
await queue.put(StreamMessage(data=f"bulk_{i}", message_type=MessageType.BULK))
|
||||
|
||||
assert queue.size() == queue.max_size
|
||||
assert queue.fill_ratio() == 1.0
|
||||
|
||||
# Try to add bulk message (should be dropped)
|
||||
bulk_msg = StreamMessage(data="new_bulk", message_type=MessageType.BULK)
|
||||
success = await queue.put(bulk_msg)
|
||||
assert success is False
|
||||
|
||||
# Try to add important message (should replace oldest important)
|
||||
important_msg = StreamMessage(data="new_important", message_type=MessageType.IMPORTANT)
|
||||
success = await queue.put(important_msg)
|
||||
assert success is True
|
||||
|
||||
# Try to add critical message (should always succeed)
|
||||
critical_msg = StreamMessage(data="new_critical", message_type=MessageType.CRITICAL)
|
||||
success = await queue.put(critical_msg)
|
||||
assert success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_size_limits(self, queue):
|
||||
"""Test that individual queue size limits are respected"""
|
||||
# Fill control queue to its limit
|
||||
for i in range(100): # Control queue limit is 100
|
||||
await queue.put(StreamMessage(data=f"control_{i}", message_type=MessageType.CONTROL))
|
||||
|
||||
# Should still accept other message types
|
||||
success = await queue.put(StreamMessage(data="important", message_type=MessageType.IMPORTANT))
|
||||
assert success is True
|
||||
|
||||
|
||||
class TestWebSocketStream:
|
||||
"""Test individual WebSocket stream with backpressure control"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket(self):
|
||||
websocket = Mock()
|
||||
websocket.send = AsyncMock()
|
||||
websocket.remote_address = "127.0.0.1:12345"
|
||||
return websocket
|
||||
|
||||
@pytest.fixture
|
||||
def stream_config(self):
|
||||
return StreamConfig(
|
||||
max_queue_size=50,
|
||||
send_timeout=1.0,
|
||||
slow_consumer_threshold=0.1,
|
||||
backpressure_threshold=0.7
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def stream(self, mock_websocket, stream_config):
|
||||
return WebSocketStream(mock_websocket, "test_stream", stream_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_start_stop(self, stream):
|
||||
"""Test stream start and stop"""
|
||||
assert stream.status == StreamStatus.CONNECTING
|
||||
|
||||
await stream.start()
|
||||
assert stream.status == StreamStatus.CONNECTED
|
||||
assert stream._running is True
|
||||
|
||||
await stream.stop()
|
||||
assert stream.status == StreamStatus.DISCONNECTED
|
||||
assert stream._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_sending(self, stream, mock_websocket):
|
||||
"""Test basic message sending"""
|
||||
await stream.start()
|
||||
|
||||
# Send message
|
||||
success = await stream.send_message({"test": "data"}, MessageType.IMPORTANT)
|
||||
assert success is True
|
||||
|
||||
# Wait for message to be processed
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Verify message was sent
|
||||
mock_websocket.send.assert_called()
|
||||
|
||||
await stream.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_consumer_detection(self, stream, mock_websocket):
|
||||
"""Test slow consumer detection"""
|
||||
# Make websocket send slow
|
||||
async def slow_send(message):
|
||||
await asyncio.sleep(0.2) # Slower than threshold (0.1s)
|
||||
|
||||
mock_websocket.send = slow_send
|
||||
|
||||
await stream.start()
|
||||
|
||||
# Send multiple messages to trigger slow consumer detection
|
||||
for i in range(10):
|
||||
await stream.send_message({"test": f"data_{i}"}, MessageType.IMPORTANT)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Check if slow consumer was detected
|
||||
assert stream.status == StreamStatus.SLOW_CONSUMER
|
||||
assert stream.metrics.slow_consumer_events > 0
|
||||
|
||||
await stream.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backpressure_handling(self, stream, mock_websocket):
|
||||
"""Test backpressure handling"""
|
||||
await stream.start()
|
||||
|
||||
# Fill queue to trigger backpressure
|
||||
for i in range(40): # 40/50 = 80% > backpressure_threshold (70%)
|
||||
await stream.send_message({"test": f"data_{i}"}, MessageType.IMPORTANT)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check backpressure status
|
||||
assert stream.status == StreamStatus.BACKPRESSURE
|
||||
assert stream.metrics.backpressure_events > 0
|
||||
|
||||
# Try to send bulk message under backpressure
|
||||
success = await stream.send_message({"bulk": "data"}, MessageType.BULK)
|
||||
# Should be dropped due to high queue fill ratio
|
||||
assert stream.queue.fill_ratio() > 0.8
|
||||
|
||||
await stream.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_priority_handling(self, stream, mock_websocket):
|
||||
"""Test that priority messages are handled correctly"""
|
||||
await stream.start()
|
||||
|
||||
# Send messages of different priorities
|
||||
await stream.send_message({"bulk": "data"}, MessageType.BULK)
|
||||
await stream.send_message({"critical": "data"}, MessageType.CRITICAL)
|
||||
await stream.send_message({"important": "data"}, MessageType.IMPORTANT)
|
||||
await stream.send_message({"control": "data"}, MessageType.CONTROL)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Verify all messages were sent
|
||||
assert mock_websocket.send.call_count >= 4
|
||||
|
||||
await stream.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_timeout_handling(self, stream, mock_websocket):
|
||||
"""Test send timeout handling"""
|
||||
# Make websocket send timeout
|
||||
async def timeout_send(message):
|
||||
await asyncio.sleep(2.0) # Longer than send_timeout (1.0s)
|
||||
|
||||
mock_websocket.send = timeout_send
|
||||
|
||||
await stream.start()
|
||||
|
||||
# Send message
|
||||
success = await stream.send_message({"test": "data"}, MessageType.IMPORTANT)
|
||||
assert success is True
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
# Check that message was dropped due to timeout
|
||||
assert stream.metrics.messages_dropped > 0
|
||||
|
||||
await stream.stop()
|
||||
|
||||
def test_stream_metrics(self, stream):
|
||||
"""Test stream metrics collection"""
|
||||
metrics = stream.get_metrics()
|
||||
|
||||
assert "stream_id" in metrics
|
||||
assert "status" in metrics
|
||||
assert "queue_size" in metrics
|
||||
assert "messages_sent" in metrics
|
||||
assert "messages_dropped" in metrics
|
||||
assert "backpressure_events" in metrics
|
||||
assert "slow_consumer_events" in metrics
|
||||
|
||||
|
||||
class TestWebSocketStreamManager:
|
||||
"""Test WebSocket stream manager with multiple streams"""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
return WebSocketStreamManager()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket(self):
|
||||
websocket = Mock()
|
||||
websocket.send = AsyncMock()
|
||||
websocket.remote_address = "127.0.0.1:12345"
|
||||
return websocket
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_start_stop(self, manager):
|
||||
"""Test manager start and stop"""
|
||||
await manager.start()
|
||||
assert manager._running is True
|
||||
|
||||
await manager.stop()
|
||||
assert manager._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_lifecycle_management(self, manager, mock_websocket):
|
||||
"""Test stream lifecycle management"""
|
||||
await manager.start()
|
||||
|
||||
# Create stream through manager
|
||||
stream = None
|
||||
async with manager.manage_stream(mock_websocket) as s:
|
||||
stream = s
|
||||
assert stream is not None
|
||||
assert stream._running is True
|
||||
assert len(manager.streams) == 1
|
||||
assert manager.total_connections == 1
|
||||
|
||||
# Stream should be cleaned up
|
||||
assert len(manager.streams) == 0
|
||||
assert manager.total_connections == 0
|
||||
|
||||
await manager.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_to_all_streams(self, manager):
|
||||
"""Test broadcasting to all streams"""
|
||||
await manager.start()
|
||||
|
||||
# Create multiple mock websockets
|
||||
websockets = [Mock() for _ in range(3)]
|
||||
for ws in websockets:
|
||||
ws.send = AsyncMock()
|
||||
ws.remote_address = f"127.0.0.1:{12345 + websockets.index(ws)}"
|
||||
|
||||
# Create streams
|
||||
streams = []
|
||||
for ws in websockets:
|
||||
async with manager.manage_stream(ws) as stream:
|
||||
streams.append(stream)
|
||||
await asyncio.sleep(0.01) # Small delay
|
||||
|
||||
# Broadcast message
|
||||
await manager.broadcast_to_all({"broadcast": "test"}, MessageType.IMPORTANT)
|
||||
|
||||
# Wait for broadcast
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Verify all streams received the message
|
||||
for ws in websockets:
|
||||
ws.send.assert_called()
|
||||
|
||||
await manager.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_stream_handling(self, manager):
|
||||
"""Test handling of slow streams"""
|
||||
await manager.start()
|
||||
|
||||
# Create slow websocket
|
||||
slow_websocket = Mock()
|
||||
async def slow_send(message):
|
||||
await asyncio.sleep(0.5) # Very slow
|
||||
|
||||
slow_websocket.send = slow_send
|
||||
slow_websocket.remote_address = "127.0.0.1:12345"
|
||||
|
||||
# Create slow stream
|
||||
async with manager.manage_stream(slow_websocket) as stream:
|
||||
# Send messages to fill queue
|
||||
for i in range(20):
|
||||
await stream.send_message({"test": f"data_{i}"}, MessageType.IMPORTANT)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Check if stream is detected as slow
|
||||
slow_streams = manager.get_slow_streams(threshold=0.5)
|
||||
assert len(slow_streams) > 0
|
||||
|
||||
await manager.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_metrics(self, manager):
|
||||
"""Test manager metrics collection"""
|
||||
await manager.start()
|
||||
|
||||
# Create some streams
|
||||
websockets = [Mock() for _ in range(2)]
|
||||
for ws in websockets:
|
||||
ws.send = AsyncMock()
|
||||
ws.remote_address = f"127.0.0.1:{12345 + websockets.index(ws)}"
|
||||
|
||||
streams = []
|
||||
for ws in websockets:
|
||||
async with manager.manage_stream(ws) as stream:
|
||||
streams.append(stream)
|
||||
await stream.send_message({"test": "data"}, MessageType.IMPORTANT)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Get metrics
|
||||
metrics = await manager.get_manager_metrics()
|
||||
|
||||
assert "manager_status" in metrics
|
||||
assert "total_connections" in metrics
|
||||
assert "active_streams" in metrics
|
||||
assert "total_queue_size" in metrics
|
||||
assert "stream_status_distribution" in metrics
|
||||
assert "stream_metrics" in metrics
|
||||
|
||||
await manager.stop()
|
||||
|
||||
|
||||
class TestGPUProviderFlowControl:
|
||||
"""Test GPU provider flow control"""
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self):
|
||||
return GPUProviderFlowControl("test_provider")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_start_stop(self, provider):
|
||||
"""Test provider start and stop"""
|
||||
await provider.start()
|
||||
assert provider._running is True
|
||||
|
||||
await provider.stop()
|
||||
assert provider._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_submission(self, provider):
|
||||
"""Test request submission and processing"""
|
||||
await provider.start()
|
||||
|
||||
# Create fusion data
|
||||
fusion_data = FusionData(
|
||||
stream_id="test_stream",
|
||||
stream_type=FusionStreamType.VISUAL,
|
||||
data={"test": "data"},
|
||||
timestamp=time.time(),
|
||||
requires_gpu=True
|
||||
)
|
||||
|
||||
# Submit request
|
||||
request_id = await provider.submit_request(fusion_data)
|
||||
assert request_id is not None
|
||||
|
||||
# Get result
|
||||
result = await provider.get_result(request_id, timeout=3.0)
|
||||
assert result is not None
|
||||
assert "processed_data" in result
|
||||
|
||||
await provider.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_request_limiting(self, provider):
|
||||
"""Test concurrent request limiting"""
|
||||
provider.max_concurrent_requests = 2
|
||||
await provider.start()
|
||||
|
||||
# Submit multiple requests
|
||||
fusion_data = FusionData(
|
||||
stream_id="test_stream",
|
||||
stream_type=FusionStreamType.VISUAL,
|
||||
data={"test": "data"},
|
||||
timestamp=time.time(),
|
||||
requires_gpu=True
|
||||
)
|
||||
|
||||
request_ids = []
|
||||
for i in range(5):
|
||||
request_id = await provider.submit_request(fusion_data)
|
||||
if request_id:
|
||||
request_ids.append(request_id)
|
||||
|
||||
# Should have processed some requests
|
||||
assert len(request_ids) > 0
|
||||
|
||||
# Get results
|
||||
results = []
|
||||
for request_id in request_ids:
|
||||
result = await provider.get_result(request_id, timeout=5.0)
|
||||
if result:
|
||||
results.append(result)
|
||||
|
||||
assert len(results) > 0
|
||||
|
||||
await provider.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overload_handling(self, provider):
|
||||
"""Test provider overload handling"""
|
||||
await provider.start()
|
||||
|
||||
# Fill input queue to capacity
|
||||
fusion_data = FusionData(
|
||||
stream_id="test_stream",
|
||||
stream_type=FusionStreamType.VISUAL,
|
||||
data={"test": "data"},
|
||||
timestamp=time.time(),
|
||||
requires_gpu=True
|
||||
)
|
||||
|
||||
# Submit many requests to fill queue
|
||||
request_ids = []
|
||||
for i in range(150): # More than queue capacity (100)
|
||||
request_id = await provider.submit_request(fusion_data)
|
||||
if request_id:
|
||||
request_ids.append(request_id)
|
||||
else:
|
||||
break # Queue is full
|
||||
|
||||
# Should have rejected some requests due to overload
|
||||
assert len(request_ids) < 150
|
||||
|
||||
# Check provider status
|
||||
metrics = provider.get_metrics()
|
||||
assert metrics["queue_size"] >= provider.input_queue.maxsize * 0.8
|
||||
|
||||
await provider.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_metrics(self, provider):
|
||||
"""Test provider metrics collection"""
|
||||
await provider.start()
|
||||
|
||||
# Submit some requests
|
||||
fusion_data = FusionData(
|
||||
stream_id="test_stream",
|
||||
stream_type=FusionStreamType.VISUAL,
|
||||
data={"test": "data"},
|
||||
timestamp=time.time(),
|
||||
requires_gpu=True
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
request_id = await provider.submit_request(fusion_data)
|
||||
if request_id:
|
||||
await provider.get_result(request_id, timeout=3.0)
|
||||
|
||||
# Get metrics
|
||||
metrics = provider.get_metrics()
|
||||
|
||||
assert "provider_id" in metrics
|
||||
assert "status" in metrics
|
||||
assert "avg_processing_time" in metrics
|
||||
assert "queue_size" in metrics
|
||||
assert "total_requests" in metrics
|
||||
assert "error_rate" in metrics
|
||||
|
||||
await provider.stop()
|
||||
|
||||
|
||||
class TestMultiModalWebSocketFusion:
|
||||
"""Test multi-modal WebSocket fusion service"""
|
||||
|
||||
@pytest.fixture
|
||||
def fusion_service(self):
|
||||
return MultiModalWebSocketFusion()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fusion_service_start_stop(self, fusion_service):
|
||||
"""Test fusion service start and stop"""
|
||||
await fusion_service.start()
|
||||
assert fusion_service._running is True
|
||||
|
||||
await fusion_service.stop()
|
||||
assert fusion_service._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fusion_stream_registration(self, fusion_service):
|
||||
"""Test fusion stream registration"""
|
||||
await fusion_service.start()
|
||||
|
||||
config = FusionStreamConfig(
|
||||
stream_type=FusionStreamType.VISUAL,
|
||||
max_queue_size=100,
|
||||
gpu_timeout=2.0
|
||||
)
|
||||
|
||||
await fusion_service.register_fusion_stream("test_stream", config)
|
||||
|
||||
assert "test_stream" in fusion_service.fusion_streams
|
||||
assert fusion_service.fusion_streams["test_stream"].stream_type == FusionStreamType.VISUAL
|
||||
|
||||
await fusion_service.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpu_provider_initialization(self, fusion_service):
|
||||
"""Test GPU provider initialization"""
|
||||
await fusion_service.start()
|
||||
|
||||
assert len(fusion_service.gpu_providers) > 0
|
||||
|
||||
# Check that providers are running
|
||||
for provider in fusion_service.gpu_providers.values():
|
||||
assert provider._running is True
|
||||
|
||||
await fusion_service.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fusion_data_processing(self, fusion_service):
|
||||
"""Test fusion data processing"""
|
||||
await fusion_service.start()
|
||||
|
||||
# Create fusion data
|
||||
fusion_data = FusionData(
|
||||
stream_id="test_stream",
|
||||
stream_type=FusionStreamType.VISUAL,
|
||||
data={"test": "data"},
|
||||
timestamp=time.time(),
|
||||
requires_gpu=True
|
||||
)
|
||||
|
||||
# Process data
|
||||
await fusion_service._submit_to_gpu_provider(fusion_data)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Check metrics
|
||||
assert fusion_service.fusion_metrics["total_fusions"] >= 1
|
||||
|
||||
await fusion_service.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comprehensive_metrics(self, fusion_service):
|
||||
"""Test comprehensive metrics collection"""
|
||||
await fusion_service.start()
|
||||
|
||||
# Get metrics
|
||||
metrics = fusion_service.get_comprehensive_metrics()
|
||||
|
||||
assert "timestamp" in metrics
|
||||
assert "system_status" in metrics
|
||||
assert "stream_metrics" in metrics
|
||||
assert "gpu_metrics" in metrics
|
||||
assert "fusion_metrics" in metrics
|
||||
assert "active_fusion_streams" in metrics
|
||||
assert "registered_gpu_providers" in metrics
|
||||
|
||||
await fusion_service.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backpressure_monitoring(self, fusion_service):
|
||||
"""Test backpressure monitoring"""
|
||||
await fusion_service.start()
|
||||
|
||||
# Enable backpressure
|
||||
fusion_service.backpressure_enabled = True
|
||||
|
||||
# Simulate high load
|
||||
fusion_service.global_queue_size = 8000 # High queue size
|
||||
fusion_service.max_global_queue_size = 10000
|
||||
|
||||
# Run monitoring check
|
||||
await fusion_service._check_backpressure()
|
||||
|
||||
# Should have handled backpressure
|
||||
# (This is a simplified test - in reality would check slow streams)
|
||||
|
||||
await fusion_service.stop()
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Integration tests for complete scenarios"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_stream_fusion_workflow(self):
|
||||
"""Test complete multi-stream fusion workflow"""
|
||||
fusion_service = MultiModalWebSocketFusion()
|
||||
await fusion_service.start()
|
||||
|
||||
try:
|
||||
# Register multiple streams
|
||||
stream_configs = [
|
||||
("visual_stream", FusionStreamType.VISUAL),
|
||||
("text_stream", FusionStreamType.TEXT),
|
||||
("audio_stream", FusionStreamType.AUDIO)
|
||||
]
|
||||
|
||||
for stream_id, stream_type in stream_configs:
|
||||
config = FusionStreamConfig(stream_type=stream_type)
|
||||
await fusion_service.register_fusion_stream(stream_id, config)
|
||||
|
||||
# Process fusion data for each stream
|
||||
for stream_id, stream_type in stream_configs:
|
||||
fusion_data = FusionData(
|
||||
stream_id=stream_id,
|
||||
stream_type=stream_type,
|
||||
data={"test": f"data_{stream_type.value}"},
|
||||
timestamp=time.time(),
|
||||
requires_gpu=stream_type in [FusionStreamType.VISUAL, FusionStreamType.AUDIO]
|
||||
)
|
||||
|
||||
if fusion_data.requires_gpu:
|
||||
await fusion_service._submit_to_gpu_provider(fusion_data)
|
||||
else:
|
||||
await fusion_service._process_cpu_fusion(fusion_data)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Check results
|
||||
metrics = fusion_service.get_comprehensive_metrics()
|
||||
assert metrics["fusion_metrics"]["total_fusions"] >= 3
|
||||
|
||||
finally:
|
||||
await fusion_service.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_gpu_provider_handling(self):
|
||||
"""Test handling of slow GPU providers"""
|
||||
fusion_service = MultiModalWebSocketFusion()
|
||||
await fusion_service.start()
|
||||
|
||||
try:
|
||||
# Make one GPU provider slow
|
||||
if "gpu_1" in fusion_service.gpu_providers:
|
||||
provider = fusion_service.gpu_providers["gpu_1"]
|
||||
# Simulate slow processing by increasing processing time
|
||||
original_process = provider._process_request
|
||||
|
||||
async def slow_process(request_data):
|
||||
await asyncio.sleep(1.0) # Add delay
|
||||
return await original_process(request_data)
|
||||
|
||||
provider._process_request = slow_process
|
||||
|
||||
# Submit fusion data
|
||||
fusion_data = FusionData(
|
||||
stream_id="test_stream",
|
||||
stream_type=FusionStreamType.VISUAL,
|
||||
data={"test": "data"},
|
||||
timestamp=time.time(),
|
||||
requires_gpu=True
|
||||
)
|
||||
|
||||
# Should select fastest available provider
|
||||
await fusion_service._submit_to_gpu_provider(fusion_data)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Check that processing completed
|
||||
assert fusion_service.fusion_metrics["total_fusions"] >= 1
|
||||
|
||||
finally:
|
||||
await fusion_service.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_under_load(self):
|
||||
"""Test system behavior under high load"""
|
||||
fusion_service = MultiModalWebSocketFusion()
|
||||
await fusion_service.start()
|
||||
|
||||
try:
|
||||
# Submit many fusion requests
|
||||
tasks = []
|
||||
for i in range(50):
|
||||
fusion_data = FusionData(
|
||||
stream_id=f"stream_{i % 5}",
|
||||
stream_type=FusionStreamType.VISUAL,
|
||||
data={"test": f"data_{i}"},
|
||||
timestamp=time.time(),
|
||||
requires_gpu=True
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
fusion_service._submit_to_gpu_provider(fusion_data)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all tasks
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(3.0)
|
||||
|
||||
# Check system handled load
|
||||
metrics = fusion_service.get_comprehensive_metrics()
|
||||
|
||||
# Should have processed many requests
|
||||
assert metrics["fusion_metrics"]["total_fusions"] >= 10
|
||||
|
||||
# System should still be responsive
|
||||
assert metrics["system_status"] == "running"
|
||||
|
||||
finally:
|
||||
await fusion_service.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
390
tests/unit/test_core_functionality.py
Normal file
390
tests/unit/test_core_functionality.py
Normal file
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Unit Tests for AITBC Core Functionality
|
||||
Tests core components using actual AITBC CLI tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import time
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
from pathlib import Path
|
||||
from click.testing import CliRunner
|
||||
|
||||
# Import the actual CLI
|
||||
from aitbc_cli.main import cli
|
||||
|
||||
|
||||
class TestAITBCCliIntegration:
|
||||
"""Test AITBC CLI integration"""
|
||||
|
||||
def test_cli_help(self):
|
||||
"""Test CLI help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'AITBC CLI' in result.output
|
||||
assert 'Commands:' in result.output
|
||||
|
||||
def test_cli_version(self):
|
||||
"""Test CLI version command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['version'])
|
||||
assert result.exit_code == 0
|
||||
assert 'version' in result.output.lower()
|
||||
|
||||
def test_cli_config_show(self):
|
||||
"""Test CLI config show command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['config-show'])
|
||||
assert result.exit_code == 0
|
||||
assert 'coordinator_url' in result.output.lower()
|
||||
|
||||
def test_cli_test_mode(self):
|
||||
"""Test CLI test mode functionality"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--test-mode', 'test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
assert 'Test Mode: True' in result.output
|
||||
assert 'test-api-k' in result.output
|
||||
|
||||
def test_cli_dry_run(self):
|
||||
"""Test CLI dry run functionality"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--dry-run', 'test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
assert 'Dry Run: True' in result.output
|
||||
|
||||
def test_cli_debug_mode(self):
|
||||
"""Test CLI debug mode functionality"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--debug', 'test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
assert 'Log Level: DEBUG' in result.output
|
||||
|
||||
|
||||
class TestAITBCWalletCli:
|
||||
"""Test AITBC wallet CLI functionality"""
|
||||
|
||||
def test_wallet_help(self):
|
||||
"""Test wallet help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['wallet', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'wallet' in result.output.lower()
|
||||
|
||||
def test_wallet_create_test_mode(self):
|
||||
"""Test wallet creation in test mode"""
|
||||
runner = CliRunner()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
env = {'WALLET_DIR': temp_dir}
|
||||
wallet_name = f"test-wallet-{int(time.time())}"
|
||||
result = runner.invoke(cli, ['--test-mode', 'wallet', 'create', wallet_name], env=env)
|
||||
# In test mode, this should work without actual blockchain
|
||||
assert result.exit_code == 0 or 'wallet' in result.output.lower()
|
||||
|
||||
def test_wallet_commands_available(self):
|
||||
"""Test that wallet commands are available"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['wallet', '--help'])
|
||||
expected_commands = ['create', 'balance', 'list', 'info', 'switch']
|
||||
for cmd in expected_commands:
|
||||
assert cmd in result.output.lower()
|
||||
|
||||
|
||||
class TestAITBCMarketplaceCli:
|
||||
"""Test AITBC marketplace CLI functionality"""
|
||||
|
||||
def test_marketplace_help(self):
|
||||
"""Test marketplace help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['marketplace', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'marketplace' in result.output.lower()
|
||||
|
||||
def test_marketplace_commands_available(self):
|
||||
"""Test that marketplace commands are available"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['marketplace', '--help'])
|
||||
expected_commands = ['offers', 'pricing', 'providers']
|
||||
for cmd in expected_commands:
|
||||
assert cmd in result.output.lower()
|
||||
|
||||
def test_marketplace_offers_list_test_mode(self):
|
||||
"""Test marketplace offers list in test mode"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--test-mode', 'marketplace', 'offers', 'list'])
|
||||
# Should handle test mode gracefully
|
||||
assert result.exit_code == 0 or 'offers' in result.output.lower()
|
||||
|
||||
|
||||
class TestAITBCClientCli:
|
||||
"""Test AITBC client CLI functionality"""
|
||||
|
||||
def test_client_help(self):
|
||||
"""Test client help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['client', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'client' in result.output.lower()
|
||||
|
||||
def test_client_commands_available(self):
|
||||
"""Test that client commands are available"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['client', '--help'])
|
||||
expected_commands = ['submit', 'status', 'list', 'cancel']
|
||||
for cmd in expected_commands:
|
||||
assert cmd in result.output.lower()
|
||||
|
||||
|
||||
class TestAITBCBlockchainCli:
|
||||
"""Test AITBC blockchain CLI functionality"""
|
||||
|
||||
def test_blockchain_help(self):
|
||||
"""Test blockchain help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['blockchain', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'blockchain' in result.output.lower()
|
||||
|
||||
def test_blockchain_commands_available(self):
|
||||
"""Test that blockchain commands are available"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['blockchain', '--help'])
|
||||
expected_commands = ['info', 'status', 'blocks', 'transactions']
|
||||
for cmd in expected_commands:
|
||||
assert cmd in result.output.lower()
|
||||
|
||||
|
||||
class TestAITBCAuthCli:
|
||||
"""Test AITBC auth CLI functionality"""
|
||||
|
||||
def test_auth_help(self):
|
||||
"""Test auth help command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['auth', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'auth' in result.output.lower()
|
||||
|
||||
def test_auth_commands_available(self):
|
||||
"""Test that auth commands are available"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['auth', '--help'])
|
||||
expected_commands = ['login', 'logout', 'status', 'token']
|
||||
for cmd in expected_commands:
|
||||
assert cmd in result.output.lower()
|
||||
|
||||
|
||||
class TestAITBCTestCommands:
|
||||
"""Test AITBC test commands"""
|
||||
|
||||
def test_test_help(self):
|
||||
"""Test test command help"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['test', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert 'Testing and debugging' in result.output
|
||||
|
||||
def test_test_environment(self):
|
||||
"""Test test environment command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
assert 'CLI Environment Test Results' in result.output
|
||||
|
||||
def test_test_environment_json(self):
|
||||
"""Test test environment command with JSON output"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['test', 'environment', '--format', 'json'])
|
||||
assert result.exit_code == 0
|
||||
# Should be valid JSON
|
||||
data = json.loads(result.output)
|
||||
assert 'coordinator_url' in data
|
||||
assert 'test_mode' in data
|
||||
|
||||
def test_test_mock(self):
|
||||
"""Test test mock command"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['test', 'mock'])
|
||||
assert result.exit_code == 0
|
||||
assert 'Mock data for testing' in result.output
|
||||
# Should be valid JSON
|
||||
lines = result.output.split('\n')
|
||||
for line in lines:
|
||||
if line.strip().startswith('{') or line.strip().startswith('"'):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
assert 'wallet' in data or 'job' in data or 'marketplace' in data
|
||||
except:
|
||||
pass # Skip non-JSON lines
|
||||
|
||||
|
||||
class TestAITBCOutputFormats:
|
||||
"""Test AITBC CLI output formats"""
|
||||
|
||||
def test_json_output_format(self):
|
||||
"""Test JSON output format"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--output', 'json', 'test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
# Should be valid JSON
|
||||
data = json.loads(result.output)
|
||||
assert 'coordinator_url' in data
|
||||
|
||||
def test_yaml_output_format(self):
|
||||
"""Test YAML output format"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--output', 'yaml', 'test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
# Should contain YAML-like output
|
||||
assert 'coordinator_url:' in result.output or 'coordinator_url' in result.output
|
||||
|
||||
def test_table_output_format(self):
|
||||
"""Test table output format (default)"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--output', 'table', 'test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
assert 'CLI Environment Test Results' in result.output
|
||||
|
||||
|
||||
class TestAITBCConfiguration:
|
||||
"""Test AITBC CLI configuration"""
|
||||
|
||||
def test_custom_config_file(self):
|
||||
"""Test custom config file option"""
|
||||
runner = CliRunner()
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write('coordinator_url: http://test.example.com\n')
|
||||
f.write('api_key: test-key\n')
|
||||
config_file = f.name
|
||||
|
||||
try:
|
||||
result = runner.invoke(cli, ['--config-file', config_file, 'test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
finally:
|
||||
Path(config_file).unlink(missing_ok=True)
|
||||
|
||||
def test_custom_url_override(self):
|
||||
"""Test custom URL override"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--url', 'http://custom.test', 'test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
assert 'http://custom.test' in result.output
|
||||
|
||||
def test_custom_api_key_override(self):
|
||||
"""Test custom API key override"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--api-key', 'custom-test-key', 'test', 'environment'])
|
||||
assert result.exit_code == 0
|
||||
assert 'custom-test' in result.output
|
||||
|
||||
|
||||
class TestAITBCErrorHandling:
|
||||
"""Test AITBC CLI error handling"""
|
||||
|
||||
def test_invalid_command(self):
|
||||
"""Test invalid command handling"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['invalid-command'])
|
||||
assert result.exit_code != 0
|
||||
assert 'No such command' in result.output
|
||||
|
||||
def test_invalid_option(self):
|
||||
"""Test invalid option handling"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['--invalid-option'])
|
||||
assert result.exit_code != 0
|
||||
|
||||
def test_missing_required_argument(self):
|
||||
"""Test missing required argument handling"""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['wallet', 'create'])
|
||||
# Should show error about missing argument
|
||||
assert result.exit_code != 0 or 'Usage:' in result.output
|
||||
|
||||
|
||||
class TestAITBCPerformance:
|
||||
"""Test AITBC CLI performance"""
|
||||
|
||||
def test_help_command_performance(self):
|
||||
"""Test help command performance"""
|
||||
runner = CliRunner()
|
||||
start_time = time.time()
|
||||
result = runner.invoke(cli, ['--help'])
|
||||
end_time = time.time()
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert (end_time - start_time) < 2.0 # Should complete in under 2 seconds
|
||||
|
||||
def test_config_show_performance(self):
|
||||
"""Test config show performance"""
|
||||
runner = CliRunner()
|
||||
start_time = time.time()
|
||||
result = runner.invoke(cli, ['config-show'])
|
||||
end_time = time.time()
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert (end_time - start_time) < 1.0 # Should complete in under 1 second
|
||||
|
||||
|
||||
class TestAITBCDataStructures:
|
||||
"""Test AITBC CLI data structures"""
|
||||
|
||||
def test_job_structure_validation(self):
|
||||
"""Test job data structure validation"""
|
||||
job_data = {
|
||||
'id': 'test-job-123',
|
||||
'type': 'ml_inference',
|
||||
'status': 'pending',
|
||||
'created_at': datetime.utcnow().isoformat(),
|
||||
'requirements': {
|
||||
'gpu_type': 'RTX 3080',
|
||||
'memory_gb': 8,
|
||||
'duration_minutes': 30
|
||||
}
|
||||
}
|
||||
|
||||
# Validate job structure
|
||||
assert 'id' in job_data
|
||||
assert 'type' in job_data
|
||||
assert 'status' in job_data
|
||||
assert job_data['status'] in ['pending', 'running', 'completed', 'failed']
|
||||
assert 'requirements' in job_data
|
||||
|
||||
def test_wallet_structure_validation(self):
|
||||
"""Test wallet data structure validation"""
|
||||
wallet_data = {
|
||||
'name': 'test-wallet',
|
||||
'type': 'hd',
|
||||
'address': 'aitbc1test123456789',
|
||||
'balance': 1000.0,
|
||||
'created_at': datetime.utcnow().isoformat(),
|
||||
'transactions': []
|
||||
}
|
||||
|
||||
# Validate wallet structure
|
||||
assert 'name' in wallet_data
|
||||
assert 'type' in wallet_data
|
||||
assert 'address' in wallet_data
|
||||
assert wallet_data['address'].startswith('aitbc1')
|
||||
assert isinstance(wallet_data['balance'], (int, float))
|
||||
|
||||
def test_marketplace_structure_validation(self):
|
||||
"""Test marketplace data structure validation"""
|
||||
offer_data = {
|
||||
'id': 'offer-123',
|
||||
'provider': 'miner-456',
|
||||
'gpu_type': 'RTX 3080',
|
||||
'price_per_hour': 0.1,
|
||||
'memory_gb': 10,
|
||||
'available': True,
|
||||
'created_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Validate offer structure
|
||||
assert 'id' in offer_data
|
||||
assert 'provider' in offer_data
|
||||
assert 'gpu_type' in offer_data
|
||||
assert isinstance(offer_data['price_per_hour'], (int, float))
|
||||
assert isinstance(offer_data['available'], bool)
|
||||
Reference in New Issue
Block a user