refactor: reorganize services into bounded contexts and implement async database support
Some checks failed
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
API Endpoint Tests / test-api-endpoints (push) Has been cancelled
CLI Tests / test-cli (push) Has been cancelled
Package Tests / Python package - aitbc-agent-sdk (push) Has been cancelled
Package Tests / Python package - aitbc-core (push) Has been cancelled
Package Tests / Python package - aitbc-crypto (push) Has been cancelled
Package Tests / Python package - aitbc-sdk (push) Has been cancelled
Package Tests / JavaScript package - aitbc-sdk-js (push) Has been cancelled
Package Tests / JavaScript package - aitbc-token (push) Has been cancelled
Staking Tests / test-staking-service (push) Failing after 3s
Staking Tests / test-staking-integration (push) Has been skipped
Staking Tests / test-staking-contract (push) Has been skipped
Staking Tests / run-staking-test-runner (push) Has been skipped
Multi-Node Stress Testing / stress-test (push) Successful in 3s
Cross-Node Transaction Testing / transaction-test (push) Successful in 3s
Some checks failed
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
API Endpoint Tests / test-api-endpoints (push) Has been cancelled
CLI Tests / test-cli (push) Has been cancelled
Package Tests / Python package - aitbc-agent-sdk (push) Has been cancelled
Package Tests / Python package - aitbc-core (push) Has been cancelled
Package Tests / Python package - aitbc-crypto (push) Has been cancelled
Package Tests / Python package - aitbc-sdk (push) Has been cancelled
Package Tests / JavaScript package - aitbc-sdk-js (push) Has been cancelled
Package Tests / JavaScript package - aitbc-token (push) Has been cancelled
Staking Tests / test-staking-service (push) Failing after 3s
Staking Tests / test-staking-integration (push) Has been skipped
Staking Tests / test-staking-contract (push) Has been skipped
Staking Tests / run-staking-test-runner (push) Has been skipped
Multi-Node Stress Testing / stress-test (push) Successful in 3s
Cross-Node Transaction Testing / transaction-test (push) Successful in 3s
- Moved services to bounded context packages: - adaptive_learning.py → ai_analytics/adaptive_learning.py - analytics_service.py → ai_analytics/analytics.py - dynamic_pricing_engine.py → trading_marketplace/dynamic_pricing.py - trading_service.py → trading_marketplace/trading.py - Implemented async database module (database_async.py): - Added async SQLAlchemy engine with connection pooling - Added
This commit is contained in:
@@ -1 +1,117 @@
|
||||
[{'adapter}': 'def _build_async_url(url: str', 'driver': 'str) -> str:', 'Convert sync URL to async URL."': 'sqlite:///path.db -> sqlite+aiosqlite:///path.db\n # postgresql:// -> postgresql+asyncpg://\n parts = url.split(', ', 1)\n if len(parts) == 2:\n return f': 'parts[0]'}, {'driver}': {'f': 'url'}, 'sqlite': 'For SQLite', 'connect_args={"check_same_thread': False}, {}]
|
||||
"""Async database module with connection pooling for Coordinator API."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from .config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global async engine and session factory
|
||||
_async_engine = None
|
||||
_async_session_factory = None
|
||||
|
||||
|
||||
def _build_async_url(url: str) -> str:
|
||||
"""Convert sync database URL to async URL.
|
||||
|
||||
Examples:
|
||||
sqlite:///path.db -> sqlite+aiosqlite:///path.db
|
||||
postgresql://user:pass@host/db -> postgresql+asyncpg://user:pass@host/db
|
||||
"""
|
||||
# Handle URL with query parameters
|
||||
if '?' in url:
|
||||
base, params = url.split('?', 1)
|
||||
if base.startswith('sqlite:'):
|
||||
return f"{base}+aiosqlite://?{params}"
|
||||
elif base.startswith('postgresql:'):
|
||||
return f"{base}+asyncpg://?{params}"
|
||||
else:
|
||||
return f"{base}+aiosqlite://?{params}" # fallback
|
||||
else:
|
||||
if url.startswith('sqlite:'):
|
||||
return url.replace('sqlite:', 'sqlite+aiosqlite:')
|
||||
elif url.startswith('postgresql:'):
|
||||
return url.replace('postgresql:', 'postgresql+asyncpg:')
|
||||
else:
|
||||
return url.replace(':', '+aiosqlite:') # fallback
|
||||
|
||||
|
||||
def init_async_db() -> None:
|
||||
"""Initialize async database engine and session factory."""
|
||||
global _async_engine, _async_session_factory
|
||||
|
||||
if _async_engine is not None:
|
||||
logger.warning("Async database already initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
# Build async URL from sync settings
|
||||
sync_url = str(settings.database.url)
|
||||
async_url = _build_async_url(sync_url)
|
||||
|
||||
logger.info(f"Initializing async database connection: {async_url.split('://')[0]}://...")
|
||||
|
||||
# Create async engine with pooling
|
||||
_async_engine = create_async_engine(
|
||||
async_url,
|
||||
echo=settings.database.echo if hasattr(settings.database, 'echo') else False,
|
||||
pool_size=getattr(settings.database, 'pool_size', 5),
|
||||
max_overflow=getattr(settings.database, 'max_overflow', 10),
|
||||
pool_pre_ping=getattr(settings.database, 'pool_pre_ping', True),
|
||||
pool_recycle=getattr(settings.database, 'pool_recycle', 3600),
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
_async_session_factory = sessionmaker(
|
||||
_async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
logger.info("Async database initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize async database: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def get_async_db() -> AsyncSession:
|
||||
"""Dependency to get async database session.
|
||||
|
||||
Yields:
|
||||
AsyncSession: Database session that closes automatically
|
||||
"""
|
||||
if _async_session_factory is None:
|
||||
raise RuntimeError("Async database not initialized. Call init_async_db() first.")
|
||||
|
||||
async def _get_async_db() -> AsyncSession:
|
||||
async with _async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
return _get_async_db()
|
||||
|
||||
|
||||
def get_sync_engine():
|
||||
"""Get synchronous engine for backward compatibility.
|
||||
|
||||
Returns:
|
||||
Engine: SQLAlchemy synchronous engine
|
||||
"""
|
||||
from .database import engine
|
||||
return engine
|
||||
|
||||
|
||||
async def close_async_db() -> None:
|
||||
"""Close async database connections."""
|
||||
global _async_engine, _async_session_factory
|
||||
|
||||
if _async_engine is not None:
|
||||
logger.info("Closing async database connections...")
|
||||
await _async_engine.dispose()
|
||||
_async_engine = None
|
||||
_async_session_factory = None
|
||||
logger.info("Async database connections closed")
|
||||
@@ -114,6 +114,7 @@ logger = get_logger(__name__)
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from .storage.db import init_db
|
||||
from .database_async import init_async_db, close_async_db
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -130,6 +131,14 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
logger.warning(f"Database initialization failed (non-fatal): {e}")
|
||||
# Continue startup even if init_db fails
|
||||
|
||||
# Initialize async database
|
||||
try:
|
||||
init_async_db()
|
||||
logger.info("Async database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Async database initialization failed (non-fatal): {e}")
|
||||
# Continue startup even if async init fails
|
||||
|
||||
# Warmup database connections
|
||||
logger.info("Warming up database connections...")
|
||||
try:
|
||||
@@ -227,6 +236,13 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing database connections: {e}")
|
||||
|
||||
# Close async database connections
|
||||
try:
|
||||
await close_async_db()
|
||||
logger.info("Async database connections closed successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing async database connections: {e}")
|
||||
|
||||
# Cleanup rate limiting state
|
||||
logger.info("Cleaning up rate limiting state...")
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from aitbc import get_logger
|
||||
from ..services.adaptive_learning import AdaptiveLearningService
|
||||
from ..services.ai_analytics.adaptive_learning import AdaptiveLearningService
|
||||
from ..storage import get_session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -27,7 +27,7 @@ from ..domain.analytics import (
|
||||
MetricType,
|
||||
ReportType,
|
||||
)
|
||||
from ..services.analytics_service import MarketplaceAnalytics
|
||||
from ..services.ai_analytics.analytics import MarketplaceAnalytics
|
||||
from ..storage import get_session
|
||||
|
||||
router = APIRouter(prefix="/v1/analytics", tags=["analytics"])
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..schemas.pricing import (
|
||||
PricingStrategyRequest,
|
||||
PricingStrategyResponse,
|
||||
)
|
||||
from ..services.dynamic_pricing_engine import DynamicPricingEngine, PriceConstraints, PricingStrategy, ResourceType
|
||||
from ..services.trading_marketplace.dynamic_pricing import DynamicPricingEngine, PriceConstraints, PricingStrategy, ResourceType
|
||||
from ..services.market_data_collector import MarketDataCollector
|
||||
|
||||
router = APIRouter(prefix="/v1/pricing", tags=["dynamic-pricing"])
|
||||
|
||||
@@ -28,7 +28,7 @@ from ..domain.trading import (
|
||||
TradeStatus,
|
||||
TradeType,
|
||||
)
|
||||
from ..services.trading_service import P2PTradingProtocol
|
||||
from ..services.trading_marketplace.trading import P2PTradingProtocol
|
||||
from ..storage import get_session
|
||||
|
||||
router = APIRouter(prefix="/v1/trading", tags=["trading"])
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
AI & Analytics Bounded Context
|
||||
Provides analytics and advanced learning services.
|
||||
"""
|
||||
|
||||
from .advanced_learning import AdvancedLearningService
|
||||
from .analytics import AnalyticsEngine, DashboardManager, MarketplaceAnalytics
|
||||
|
||||
__all__ = [
|
||||
"AdvancedLearningService",
|
||||
"AnalyticsEngine",
|
||||
"DashboardManager",
|
||||
"MarketplaceAnalytics",
|
||||
]
|
||||
@@ -17,7 +17,7 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..storage import get_session
|
||||
from ...storage import get_session
|
||||
|
||||
|
||||
class LearningAlgorithm(StrEnum):
|
||||
@@ -13,7 +13,7 @@ logger = get_logger(__name__)
|
||||
|
||||
from sqlmodel import Session, and_, select
|
||||
|
||||
from ..domain.analytics import (
|
||||
from ...domain.analytics import (
|
||||
AnalyticsAlert,
|
||||
AnalyticsPeriod,
|
||||
DashboardConfig,
|
||||
@@ -1,608 +0,0 @@
|
||||
"""
|
||||
Enterprise API Gateway - Phase 6.1 Implementation
|
||||
Multi-tenant API routing and management for enterprise clients
|
||||
Port: 8010
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import HTTPBearer
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from ..domain.multitenant import Tenant, TenantApiKey, TenantQuota
|
||||
from ..exceptions import QuotaExceededError, TenantError
|
||||
from ..storage.db import get_db
|
||||
|
||||
|
||||
# Pydantic models for API requests/responses
|
||||
class EnterpriseAuthRequest(BaseModel):
|
||||
tenant_id: str = Field(..., description="Enterprise tenant identifier")
|
||||
client_id: str = Field(..., description="Enterprise client ID")
|
||||
client_secret: str = Field(..., description="Enterprise client secret")
|
||||
auth_method: str = Field(default="client_credentials", description="Authentication method")
|
||||
scopes: list[str] | None = Field(default=None, description="Requested scopes")
|
||||
|
||||
|
||||
class EnterpriseAuthResponse(BaseModel):
|
||||
access_token: str = Field(..., description="Access token for enterprise API")
|
||||
token_type: str = Field(default="Bearer", description="Token type")
|
||||
expires_in: int = Field(..., description="Token expiration in seconds")
|
||||
refresh_token: str | None = Field(None, description="Refresh token")
|
||||
scopes: list[str] = Field(..., description="Granted scopes")
|
||||
tenant_info: dict[str, Any] = Field(..., description="Tenant information")
|
||||
|
||||
|
||||
class APIQuotaRequest(BaseModel):
|
||||
tenant_id: str = Field(..., description="Enterprise tenant identifier")
|
||||
endpoint: str = Field(..., description="API endpoint")
|
||||
method: str = Field(..., description="HTTP method")
|
||||
quota_type: str = Field(default="rate_limit", description="Quota type")
|
||||
|
||||
|
||||
class APIQuotaResponse(BaseModel):
|
||||
quota_limit: int = Field(..., description="Quota limit")
|
||||
quota_remaining: int = Field(..., description="Remaining quota")
|
||||
quota_reset: datetime = Field(..., description="Quota reset time")
|
||||
quota_type: str = Field(..., description="Quota type")
|
||||
|
||||
|
||||
class WebhookConfig(BaseModel):
|
||||
url: str = Field(..., description="Webhook URL")
|
||||
events: list[str] = Field(..., description="Events to subscribe to")
|
||||
secret: str | None = Field(None, description="Webhook secret")
|
||||
active: bool = Field(default=True, description="Webhook active status")
|
||||
retry_policy: dict[str, Any] | None = Field(None, description="Retry policy")
|
||||
|
||||
|
||||
class EnterpriseIntegrationRequest(BaseModel):
|
||||
integration_type: str = Field(..., description="Integration type (ERP, CRM, etc.)")
|
||||
provider: str = Field(..., description="Integration provider")
|
||||
configuration: dict[str, Any] = Field(..., description="Integration configuration")
|
||||
credentials: dict[str, str] | None = Field(None, description="Integration credentials")
|
||||
webhook_config: WebhookConfig | None = Field(None, description="Webhook configuration")
|
||||
|
||||
|
||||
class EnterpriseMetrics(BaseModel):
|
||||
api_calls_total: int = Field(..., description="Total API calls")
|
||||
api_calls_successful: int = Field(..., description="Successful API calls")
|
||||
average_response_time_ms: float = Field(..., description="Average response time")
|
||||
error_rate_percent: float = Field(..., description="Error rate percentage")
|
||||
quota_utilization_percent: float = Field(..., description="Quota utilization")
|
||||
active_integrations: int = Field(..., description="Active integrations count")
|
||||
|
||||
|
||||
class IntegrationStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
ERROR = "error"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class EnterpriseIntegration:
|
||||
"""Enterprise integration configuration and management"""
|
||||
|
||||
def __init__(
|
||||
self, integration_id: str, tenant_id: str, integration_type: str, provider: str, configuration: dict[str, Any]
|
||||
):
|
||||
self.integration_id = integration_id
|
||||
self.tenant_id = tenant_id
|
||||
self.integration_type = integration_type
|
||||
self.provider = provider
|
||||
self.configuration = configuration
|
||||
self.status = IntegrationStatus.PENDING
|
||||
self.created_at = datetime.now(timezone.utc)
|
||||
self.last_updated = datetime.now(timezone.utc)
|
||||
self.webhook_config = None
|
||||
self.metrics = {"api_calls": 0, "errors": 0, "last_call": None}
|
||||
|
||||
|
||||
class EnterpriseAPIGateway:
|
||||
"""Enterprise API Gateway with multi-tenant support"""
|
||||
|
||||
def __init__(self):
|
||||
self.tenant_service = None # Will be initialized with database session
|
||||
self.active_tokens = {} # In-memory token storage (in production, use Redis)
|
||||
self.rate_limiters = {} # Per-tenant rate limiters
|
||||
self.webhooks = {} # Webhook configurations
|
||||
self.integrations = {} # Enterprise integrations
|
||||
self.api_metrics = {} # API performance metrics
|
||||
|
||||
# Default quotas
|
||||
self.default_quotas = {
|
||||
"rate_limit": 1000, # requests per minute
|
||||
"daily_limit": 50000, # requests per day
|
||||
"concurrent_limit": 100, # concurrent requests
|
||||
}
|
||||
|
||||
# JWT configuration
|
||||
self.jwt_secret = secrets.token_urlsafe(64)
|
||||
self.jwt_algorithm = "HS256"
|
||||
self.token_expiry = 3600 # 1 hour
|
||||
|
||||
async def authenticate_enterprise_client(self, request: EnterpriseAuthRequest, db_session) -> EnterpriseAuthResponse:
|
||||
"""Authenticate enterprise client and issue access token"""
|
||||
|
||||
try:
|
||||
# Validate tenant and client credentials
|
||||
tenant = await self._validate_tenant_credentials(
|
||||
request.tenant_id, request.client_id, request.client_secret, db_session
|
||||
)
|
||||
|
||||
# Generate access token
|
||||
access_token = self._generate_access_token(
|
||||
tenant_id=request.tenant_id, client_id=request.client_id, scopes=request.scopes or ["enterprise_api"]
|
||||
)
|
||||
|
||||
# Generate refresh token
|
||||
refresh_token = self._generate_refresh_token(request.tenant_id, request.client_id)
|
||||
|
||||
# Store token
|
||||
self.active_tokens[access_token] = {
|
||||
"tenant_id": request.tenant_id,
|
||||
"client_id": request.client_id,
|
||||
"scopes": request.scopes or ["enterprise_api"],
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(seconds=self.token_expiry),
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
return EnterpriseAuthResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=self.token_expiry,
|
||||
refresh_token=refresh_token,
|
||||
scopes=request.scopes or ["enterprise_api"],
|
||||
tenant_info={
|
||||
"tenant_id": tenant.tenant_id,
|
||||
"name": tenant.name,
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status.value,
|
||||
"created_at": tenant.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Enterprise authentication failed: {e}")
|
||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||
|
||||
def _generate_access_token(self, tenant_id: str, client_id: str, scopes: list[str]) -> str:
|
||||
"""Generate JWT access token"""
|
||||
|
||||
payload = {
|
||||
"sub": f"{tenant_id}:{client_id}",
|
||||
"scopes": scopes,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"exp": datetime.now(timezone.utc) + timedelta(seconds=self.token_expiry),
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
|
||||
|
||||
def _generate_refresh_token(self, tenant_id: str, client_id: str) -> str:
|
||||
"""Generate refresh token"""
|
||||
|
||||
payload = {
|
||||
"sub": f"{tenant_id}:{client_id}",
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"exp": datetime.now(timezone.utc) + timedelta(days=30), # 30 days
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
|
||||
|
||||
async def _validate_tenant_credentials(self, tenant_id: str, client_id: str, client_secret: str, db_session) -> Tenant:
|
||||
"""Validate tenant credentials"""
|
||||
|
||||
# Find tenant
|
||||
tenant = db_session.query(Tenant).filter(Tenant.tenant_id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant {tenant_id} not found")
|
||||
|
||||
# Find API key
|
||||
api_key = (
|
||||
db_session.query(TenantApiKey)
|
||||
.filter(TenantApiKey.tenant_id == tenant_id, TenantApiKey.client_id == client_id, TenantApiKey.is_active)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_key or not secrets.compare_digest(api_key.client_secret, client_secret):
|
||||
raise TenantError("Invalid client credentials")
|
||||
|
||||
# Check tenant status
|
||||
if tenant.status.value != "active":
|
||||
raise TenantError(f"Tenant {tenant_id} is not active")
|
||||
|
||||
return tenant
|
||||
|
||||
async def check_api_quota(self, tenant_id: str, endpoint: str, method: str, db_session) -> APIQuotaResponse:
|
||||
"""Check and enforce API quotas"""
|
||||
|
||||
try:
|
||||
# Get tenant quota
|
||||
quota = await self._get_tenant_quota(tenant_id, db_session)
|
||||
|
||||
# Check rate limiting
|
||||
current_usage = await self._get_current_usage(tenant_id, "rate_limit")
|
||||
|
||||
if current_usage >= quota["rate_limit"]:
|
||||
raise QuotaExceededError("Rate limit exceeded")
|
||||
|
||||
# Update usage
|
||||
await self._update_usage(tenant_id, "rate_limit", current_usage + 1)
|
||||
|
||||
return APIQuotaResponse(
|
||||
quota_limit=quota["rate_limit"],
|
||||
quota_remaining=quota["rate_limit"] - current_usage - 1,
|
||||
quota_reset=datetime.now(timezone.utc) + timedelta(minutes=1),
|
||||
quota_type="rate_limit",
|
||||
)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Quota check failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Quota check failed")
|
||||
|
||||
async def _get_tenant_quota(self, tenant_id: str, db_session) -> dict[str, int]:
|
||||
"""Get tenant quota configuration"""
|
||||
|
||||
# Get tenant-specific quota
|
||||
tenant_quota = db_session.query(TenantQuota).filter(TenantQuota.tenant_id == tenant_id).first()
|
||||
|
||||
if tenant_quota:
|
||||
return {
|
||||
"rate_limit": tenant_quota.rate_limit or self.default_quotas["rate_limit"],
|
||||
"daily_limit": tenant_quota.daily_limit or self.default_quotas["daily_limit"],
|
||||
"concurrent_limit": tenant_quota.concurrent_limit or self.default_quotas["concurrent_limit"],
|
||||
}
|
||||
|
||||
return self.default_quotas
|
||||
|
||||
async def _get_current_usage(self, tenant_id: str, quota_type: str) -> int:
|
||||
"""Get current quota usage"""
|
||||
|
||||
# In production, use Redis or database for persistent storage
|
||||
|
||||
if quota_type == "rate_limit":
|
||||
# Get usage in the last minute
|
||||
return len([t for t in self.rate_limiters.get(tenant_id, []) if datetime.now(timezone.utc) - t < timedelta(minutes=1)])
|
||||
|
||||
return 0
|
||||
|
||||
async def _update_usage(self, tenant_id: str, quota_type: str, usage: int):
|
||||
"""Update quota usage"""
|
||||
|
||||
if quota_type == "rate_limit":
|
||||
if tenant_id not in self.rate_limiters:
|
||||
self.rate_limiters[tenant_id] = []
|
||||
|
||||
# Add current timestamp
|
||||
self.rate_limiters[tenant_id].append(datetime.now(timezone.utc))
|
||||
|
||||
# Clean old entries (older than 1 minute)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=1)
|
||||
self.rate_limiters[tenant_id] = [t for t in self.rate_limiters[tenant_id] if t > cutoff]
|
||||
|
||||
async def create_enterprise_integration(
|
||||
self, tenant_id: str, request: EnterpriseIntegrationRequest, db_session
|
||||
) -> dict[str, Any]:
|
||||
"""Create new enterprise integration"""
|
||||
|
||||
try:
|
||||
# Validate tenant
|
||||
tenant = db_session.query(Tenant).filter(Tenant.tenant_id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant {tenant_id} not found")
|
||||
|
||||
# Create integration
|
||||
integration_id = str(uuid4())
|
||||
integration = EnterpriseIntegration(
|
||||
integration_id=integration_id,
|
||||
tenant_id=tenant_id,
|
||||
integration_type=request.integration_type,
|
||||
provider=request.provider,
|
||||
configuration=request.configuration,
|
||||
)
|
||||
|
||||
# Store webhook configuration
|
||||
if request.webhook_config:
|
||||
integration.webhook_config = request.webhook_config.dict()
|
||||
self.webhooks[integration_id] = request.webhook_config.dict()
|
||||
|
||||
# Store integration
|
||||
self.integrations[integration_id] = integration
|
||||
|
||||
# Initialize integration
|
||||
await self._initialize_integration(integration)
|
||||
|
||||
return {
|
||||
"integration_id": integration_id,
|
||||
"status": integration.status.value,
|
||||
"created_at": integration.created_at.isoformat(),
|
||||
"configuration": integration.configuration,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create enterprise integration: {e}")
|
||||
raise HTTPException(status_code=500, detail="Integration creation failed")
|
||||
|
||||
async def _initialize_integration(self, integration: EnterpriseIntegration):
|
||||
"""Initialize enterprise integration"""
|
||||
|
||||
try:
|
||||
# Integration-specific initialization logic
|
||||
if integration.integration_type.lower() == "erp":
|
||||
await self._initialize_erp_integration(integration)
|
||||
elif integration.integration_type.lower() == "crm":
|
||||
await self._initialize_crm_integration(integration)
|
||||
elif integration.integration_type.lower() == "bi":
|
||||
await self._initialize_bi_integration(integration)
|
||||
|
||||
integration.status = IntegrationStatus.ACTIVE
|
||||
integration.last_updated = datetime.now(timezone.utc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Integration initialization failed: {e}")
|
||||
integration.status = IntegrationStatus.ERROR
|
||||
raise
|
||||
|
||||
async def _initialize_erp_integration(self, integration: EnterpriseIntegration):
|
||||
"""Initialize ERP integration"""
|
||||
|
||||
# ERP-specific initialization
|
||||
provider = integration.provider.lower()
|
||||
|
||||
if provider == "sap":
|
||||
await self._initialize_sap_integration(integration)
|
||||
elif provider == "oracle":
|
||||
await self._initialize_oracle_integration(integration)
|
||||
elif provider == "microsoft":
|
||||
await self._initialize_microsoft_integration(integration)
|
||||
|
||||
logger.info(f"ERP integration initialized: {integration.provider}")
|
||||
|
||||
async def _initialize_sap_integration(self, integration: EnterpriseIntegration):
|
||||
"""Initialize SAP ERP integration"""
|
||||
|
||||
# SAP integration logic
|
||||
config = integration.configuration
|
||||
|
||||
# Validate SAP configuration
|
||||
required_fields = ["system_id", "client", "username", "password", "host"]
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
raise ValueError(f"SAP integration requires {field}")
|
||||
|
||||
# Test SAP connection
|
||||
# In production, implement actual SAP connection testing
|
||||
logger.info(f"SAP connection test successful for {integration.integration_id}")
|
||||
|
||||
async def get_enterprise_metrics(self, tenant_id: str, db_session) -> EnterpriseMetrics:
|
||||
"""Get enterprise metrics and analytics"""
|
||||
|
||||
try:
|
||||
# Get API metrics
|
||||
api_metrics = self.api_metrics.get(
|
||||
tenant_id, {"total_calls": 0, "successful_calls": 0, "failed_calls": 0, "response_times": []}
|
||||
)
|
||||
|
||||
# Calculate metrics
|
||||
total_calls = api_metrics["total_calls"]
|
||||
successful_calls = api_metrics["successful_calls"]
|
||||
failed_calls = api_metrics["failed_calls"]
|
||||
|
||||
average_response_time = (
|
||||
sum(api_metrics["response_times"]) / len(api_metrics["response_times"])
|
||||
if api_metrics["response_times"]
|
||||
else 0.0
|
||||
)
|
||||
|
||||
error_rate = (failed_calls / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
# Get quota utilization
|
||||
current_usage = await self._get_current_usage(tenant_id, "rate_limit")
|
||||
quota = await self._get_tenant_quota(tenant_id, db_session)
|
||||
quota_utilization = (current_usage / quota["rate_limit"] * 100) if quota["rate_limit"] > 0 else 0.0
|
||||
|
||||
# Count active integrations
|
||||
active_integrations = len(
|
||||
[i for i in self.integrations.values() if i.tenant_id == tenant_id and i.status == IntegrationStatus.ACTIVE]
|
||||
)
|
||||
|
||||
return EnterpriseMetrics(
|
||||
api_calls_total=total_calls,
|
||||
api_calls_successful=successful_calls,
|
||||
average_response_time_ms=average_response_time,
|
||||
error_rate_percent=error_rate,
|
||||
quota_utilization_percent=quota_utilization,
|
||||
active_integrations=active_integrations,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get enterprise metrics: {e}")
|
||||
raise HTTPException(status_code=500, detail="Metrics retrieval failed")
|
||||
|
||||
async def record_api_call(self, tenant_id: str, endpoint: str, response_time: float, success: bool):
|
||||
"""Record API call for metrics"""
|
||||
|
||||
if tenant_id not in self.api_metrics:
|
||||
self.api_metrics[tenant_id] = {"total_calls": 0, "successful_calls": 0, "failed_calls": 0, "response_times": []}
|
||||
|
||||
metrics = self.api_metrics[tenant_id]
|
||||
metrics["total_calls"] += 1
|
||||
|
||||
if success:
|
||||
metrics["successful_calls"] += 1
|
||||
else:
|
||||
metrics["failed_calls"] += 1
|
||||
|
||||
metrics["response_times"].append(response_time)
|
||||
|
||||
# Keep only last 1000 response times
|
||||
if len(metrics["response_times"]) > 1000:
|
||||
metrics["response_times"] = metrics["response_times"][-1000:]
|
||||
|
||||
|
||||
# FastAPI application
|
||||
app = FastAPI(
|
||||
title="Enterprise API Gateway",
|
||||
description="Multi-tenant API routing and management for enterprise clients",
|
||||
version="6.1.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
|
||||
# Global gateway instance
|
||||
gateway = EnterpriseAPIGateway()
|
||||
|
||||
|
||||
# Dependency for database session
|
||||
async def get_db_session():
|
||||
"""Get database session"""
|
||||
|
||||
async with get_db() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Middleware for API metrics
|
||||
@app.middleware("http")
|
||||
async def api_metrics_middleware(request: Request, call_next):
|
||||
"""Middleware to record API metrics"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Extract tenant from token if available
|
||||
tenant_id = None
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
token = authorization[7:]
|
||||
token_data = gateway.active_tokens.get(token)
|
||||
if token_data:
|
||||
tenant_id = token_data["tenant_id"]
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Record metrics
|
||||
response_time = (time.time() - start_time) * 1000 # Convert to milliseconds
|
||||
success = response.status_code < 400
|
||||
|
||||
if tenant_id:
|
||||
await gateway.record_api_call(tenant_id, str(request.url.path), response_time, success)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/enterprise/auth")
|
||||
async def enterprise_auth(request: EnterpriseAuthRequest, db_session=Depends(get_db_session)):
|
||||
"""Authenticate enterprise client"""
|
||||
|
||||
result = await gateway.authenticate_enterprise_client(request, db_session)
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/enterprise/quota/check")
|
||||
async def check_quota(request: APIQuotaRequest, db_session=Depends(get_db_session)):
|
||||
"""Check API quota"""
|
||||
|
||||
result = await gateway.check_api_quota(request.tenant_id, request.endpoint, request.method, db_session)
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/enterprise/integrations")
|
||||
async def create_integration(request: EnterpriseIntegrationRequest, db_session=Depends(get_db_session)):
|
||||
"""Create enterprise integration"""
|
||||
|
||||
# Extract tenant from token (in production, proper authentication)
|
||||
tenant_id = "demo_tenant" # Placeholder
|
||||
|
||||
result = await gateway.create_enterprise_integration(tenant_id, request, db_session)
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/enterprise/analytics")
|
||||
async def get_analytics(db_session=Depends(get_db_session)):
|
||||
"""Get enterprise analytics dashboard"""
|
||||
|
||||
# Extract tenant from token (in production, proper authentication)
|
||||
tenant_id = "demo_tenant" # Placeholder
|
||||
|
||||
result = await gateway.get_enterprise_metrics(tenant_id, db_session)
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/enterprise/status")
|
||||
async def get_status():
|
||||
"""Get enterprise gateway status"""
|
||||
|
||||
return {
|
||||
"service": "Enterprise API Gateway",
|
||||
"version": "6.1.0",
|
||||
"port": 8010,
|
||||
"status": "operational",
|
||||
"active_tenants": len({token["tenant_id"] for token in gateway.active_tokens.values()}),
|
||||
"active_integrations": len(gateway.integrations),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"service": "Enterprise API Gateway",
|
||||
"version": "6.1.0",
|
||||
"port": 8010,
|
||||
"capabilities": [
|
||||
"Multi-tenant API Management",
|
||||
"Enterprise Authentication",
|
||||
"API Quota Management",
|
||||
"Enterprise Integration Framework",
|
||||
"Real-time Analytics",
|
||||
],
|
||||
"status": "operational",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"services": {
|
||||
"api_gateway": "operational",
|
||||
"authentication": "operational",
|
||||
"quota_management": "operational",
|
||||
"integration_framework": "operational",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8010)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,16 +1,14 @@
|
||||
"""
|
||||
Enterprise Integration Bounded Context
|
||||
Provides enterprise API gateway, security, load balancing, and integration services.
|
||||
Provides enterprise integration, security, and load balancing services.
|
||||
"""
|
||||
|
||||
from .api_gateway import EnterpriseAPIGateway
|
||||
from .integration import EnterpriseIntegrationService
|
||||
from .integration import EnterpriseIntegrationFramework
|
||||
from .load_balancer import AdvancedLoadBalancer
|
||||
from .security import EnterpriseEncryption, HSMManager, ThreatDetectionSystem, ZeroTrustArchitecture
|
||||
|
||||
__all__ = [
|
||||
"EnterpriseAPIGateway",
|
||||
"EnterpriseIntegrationService",
|
||||
"EnterpriseIntegrationFramework",
|
||||
"AdvancedLoadBalancer",
|
||||
"EnterpriseEncryption",
|
||||
"HSMManager",
|
||||
|
||||
@@ -1,770 +0,0 @@
|
||||
"""
|
||||
Advanced Load Balancing - Phase 6.4 Implementation
|
||||
Intelligent traffic distribution with AI-powered auto-scaling and performance optimization
|
||||
"""
|
||||
|
||||
import statistics
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LoadBalancingAlgorithm(StrEnum):
|
||||
"""Load balancing algorithms"""
|
||||
|
||||
ROUND_ROBIN = "round_robin"
|
||||
WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
|
||||
LEAST_CONNECTIONS = "least_connections"
|
||||
LEAST_RESPONSE_TIME = "least_response_time"
|
||||
RESOURCE_BASED = "resource_based"
|
||||
PREDICTIVE_AI = "predictive_ai"
|
||||
ADAPTIVE = "adaptive"
|
||||
|
||||
|
||||
class ScalingPolicy(StrEnum):
|
||||
"""Auto-scaling policies"""
|
||||
|
||||
MANUAL = "manual"
|
||||
THRESHOLD_BASED = "threshold_based"
|
||||
PREDICTIVE = "predictive"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
class HealthStatus(StrEnum):
|
||||
"""Health status"""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
UNHEALTHY = "unhealthy"
|
||||
DRAINING = "draining"
|
||||
MAINTENANCE = "maintenance"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendServer:
|
||||
"""Backend server configuration"""
|
||||
|
||||
server_id: str
|
||||
host: str
|
||||
port: int
|
||||
weight: float = 1.0
|
||||
max_connections: int = 1000
|
||||
current_connections: int = 0
|
||||
cpu_usage: float = 0.0
|
||||
memory_usage: float = 0.0
|
||||
response_time_ms: float = 0.0
|
||||
request_count: int = 0
|
||||
error_count: int = 0
|
||||
health_status: HealthStatus = HealthStatus.HEALTHY
|
||||
last_health_check: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
capabilities: dict[str, Any] = field(default_factory=dict)
|
||||
region: str = "default"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScalingMetric:
|
||||
"""Scaling metric configuration"""
|
||||
|
||||
metric_name: str
|
||||
threshold_min: float
|
||||
threshold_max: float
|
||||
scaling_factor: float
|
||||
cooldown_period: timedelta
|
||||
measurement_window: timedelta
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrafficPattern:
|
||||
"""Traffic pattern for predictive scaling"""
|
||||
|
||||
pattern_id: str
|
||||
name: str
|
||||
time_windows: list[dict[str, Any]] # List of time windows with expected load
|
||||
day_of_week: int # 0-6 (Monday-Sunday)
|
||||
seasonal_factor: float = 1.0
|
||||
confidence_score: float = 0.0
|
||||
|
||||
|
||||
class PredictiveScaler:
|
||||
"""AI-powered predictive auto-scaling"""
|
||||
|
||||
def __init__(self):
|
||||
self.traffic_history = []
|
||||
self.scaling_predictions = {}
|
||||
self.traffic_patterns = {}
|
||||
self.model_weights = {}
|
||||
self.logger = get_logger("predictive_scaler")
|
||||
|
||||
async def record_traffic(self, timestamp: datetime, request_count: int, response_time_ms: float, error_rate: float):
|
||||
"""Record traffic metrics"""
|
||||
|
||||
traffic_record = {
|
||||
"timestamp": timestamp,
|
||||
"request_count": request_count,
|
||||
"response_time_ms": response_time_ms,
|
||||
"error_rate": error_rate,
|
||||
"hour": timestamp.hour,
|
||||
"day_of_week": timestamp.weekday(),
|
||||
"day_of_month": timestamp.day,
|
||||
"month": timestamp.month,
|
||||
}
|
||||
|
||||
self.traffic_history.append(traffic_record)
|
||||
|
||||
# Keep only last 30 days of history
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
self.traffic_history = [record for record in self.traffic_history if record["timestamp"] > cutoff]
|
||||
|
||||
# Update traffic patterns
|
||||
await self._update_traffic_patterns()
|
||||
|
||||
async def _update_traffic_patterns(self):
|
||||
"""Update traffic patterns based on historical data"""
|
||||
|
||||
if len(self.traffic_history) < 168: # Need at least 1 week of data
|
||||
return
|
||||
|
||||
# Group by hour and day of week
|
||||
patterns = {}
|
||||
|
||||
for record in self.traffic_history:
|
||||
key = f"{record['day_of_week']}_{record['hour']}"
|
||||
|
||||
if key not in patterns:
|
||||
patterns[key] = {"request_counts": [], "response_times": [], "error_rates": []}
|
||||
|
||||
patterns[key]["request_counts"].append(record["request_count"])
|
||||
patterns[key]["response_times"].append(record["response_time_ms"])
|
||||
patterns[key]["error_rates"].append(record["error_rate"])
|
||||
|
||||
# Calculate pattern statistics
|
||||
for key, data in patterns.items():
|
||||
day_of_week, hour = key.split("_")
|
||||
|
||||
pattern = TrafficPattern(
|
||||
pattern_id=key,
|
||||
name=f"Pattern Day {day_of_week} Hour {hour}",
|
||||
time_windows=[
|
||||
{
|
||||
"hour": int(hour),
|
||||
"avg_requests": statistics.mean(data["request_counts"]),
|
||||
"max_requests": max(data["request_counts"]),
|
||||
"min_requests": min(data["request_counts"]),
|
||||
"std_requests": statistics.stdev(data["request_counts"]) if len(data["request_counts"]) > 1 else 0,
|
||||
"avg_response_time": statistics.mean(data["response_times"]),
|
||||
"avg_error_rate": statistics.mean(data["error_rates"]),
|
||||
}
|
||||
],
|
||||
day_of_week=int(day_of_week),
|
||||
confidence_score=min(len(data["request_counts"]) / 100, 1.0), # Confidence based on data points
|
||||
)
|
||||
|
||||
self.traffic_patterns[key] = pattern
|
||||
|
||||
async def predict_traffic(self, prediction_window: timedelta = timedelta(hours=1)) -> dict[str, Any]:
|
||||
"""Predict traffic for the next time window"""
|
||||
|
||||
try:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
current_time + prediction_window
|
||||
|
||||
# Get current pattern
|
||||
current_pattern_key = f"{current_time.weekday()}_{current_time.hour}"
|
||||
current_pattern = self.traffic_patterns.get(current_pattern_key)
|
||||
|
||||
if not current_pattern:
|
||||
# Fallback to simple prediction
|
||||
return await self._simple_prediction(prediction_window)
|
||||
|
||||
# Get historical data for similar time periods
|
||||
similar_patterns = [
|
||||
pattern
|
||||
for pattern in self.traffic_patterns.values()
|
||||
if pattern.day_of_week == current_time.weekday()
|
||||
and abs(pattern.time_windows[0]["hour"] - current_time.hour) <= 2
|
||||
]
|
||||
|
||||
if not similar_patterns:
|
||||
return await self._simple_prediction(prediction_window)
|
||||
|
||||
# Calculate weighted prediction
|
||||
total_weight = 0
|
||||
weighted_requests = 0
|
||||
weighted_response_time = 0
|
||||
weighted_error_rate = 0
|
||||
|
||||
for pattern in similar_patterns:
|
||||
weight = pattern.confidence_score
|
||||
window_data = pattern.time_windows[0]
|
||||
|
||||
weighted_requests += window_data["avg_requests"] * weight
|
||||
weighted_response_time += window_data["avg_response_time"] * weight
|
||||
weighted_error_rate += window_data["avg_error_rate"] * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight > 0:
|
||||
predicted_requests = weighted_requests / total_weight
|
||||
predicted_response_time = weighted_response_time / total_weight
|
||||
predicted_error_rate = weighted_error_rate / total_weight
|
||||
else:
|
||||
return await self._simple_prediction(prediction_window)
|
||||
|
||||
# Apply seasonal factors
|
||||
seasonal_factor = self._get_seasonal_factor(current_time)
|
||||
predicted_requests *= seasonal_factor
|
||||
|
||||
return {
|
||||
"prediction_window_hours": prediction_window.total_seconds() / 3600,
|
||||
"predicted_requests_per_hour": int(predicted_requests),
|
||||
"predicted_response_time_ms": predicted_response_time,
|
||||
"predicted_error_rate": predicted_error_rate,
|
||||
"confidence_score": min(total_weight / len(similar_patterns), 1.0),
|
||||
"seasonal_factor": seasonal_factor,
|
||||
"pattern_based": True,
|
||||
"prediction_timestamp": current_time.isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Traffic prediction failed: {e}")
|
||||
return await self._simple_prediction(prediction_window)
|
||||
|
||||
async def _simple_prediction(self, prediction_window: timedelta) -> dict[str, Any]:
|
||||
"""Simple prediction based on recent averages"""
|
||||
|
||||
if not self.traffic_history:
|
||||
return {
|
||||
"prediction_window_hours": prediction_window.total_seconds() / 3600,
|
||||
"predicted_requests_per_hour": 1000, # Default
|
||||
"predicted_response_time_ms": 100.0,
|
||||
"predicted_error_rate": 0.01,
|
||||
"confidence_score": 0.1,
|
||||
"pattern_based": False,
|
||||
"prediction_timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Calculate recent averages
|
||||
recent_records = self.traffic_history[-24:] # Last 24 records
|
||||
|
||||
avg_requests = statistics.mean([r["request_count"] for r in recent_records])
|
||||
avg_response_time = statistics.mean([r["response_time_ms"] for r in recent_records])
|
||||
avg_error_rate = statistics.mean([r["error_rate"] for r in recent_records])
|
||||
|
||||
return {
|
||||
"prediction_window_hours": prediction_window.total_seconds() / 3600,
|
||||
"predicted_requests_per_hour": int(avg_requests),
|
||||
"predicted_response_time_ms": avg_response_time,
|
||||
"predicted_error_rate": avg_error_rate,
|
||||
"confidence_score": 0.3,
|
||||
"pattern_based": False,
|
||||
"prediction_timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
def _get_seasonal_factor(self, timestamp: datetime) -> float:
|
||||
"""Get seasonal adjustment factor"""
|
||||
|
||||
# Simple seasonal factors (can be enhanced with more sophisticated models)
|
||||
month = timestamp.month
|
||||
|
||||
seasonal_factors = {
|
||||
1: 0.8, # January - post-holiday dip
|
||||
2: 0.9, # February
|
||||
3: 1.0, # March
|
||||
4: 1.1, # April - spring increase
|
||||
5: 1.2, # May
|
||||
6: 1.1, # June
|
||||
7: 1.0, # July - summer
|
||||
8: 0.9, # August
|
||||
9: 1.1, # September - back to business
|
||||
10: 1.2, # October
|
||||
11: 1.3, # November - holiday season start
|
||||
12: 1.4, # December - peak holiday season
|
||||
}
|
||||
|
||||
return seasonal_factors.get(month, 1.0)
|
||||
|
||||
async def get_scaling_recommendation(self, current_servers: int, current_capacity: int) -> dict[str, Any]:
|
||||
"""Get scaling recommendation based on predictions"""
|
||||
|
||||
try:
|
||||
# Get traffic prediction
|
||||
prediction = await self.predict_traffic(timedelta(hours=1))
|
||||
|
||||
predicted_requests = prediction["predicted_requests_per_hour"]
|
||||
current_capacity_per_server = current_capacity // max(current_servers, 1)
|
||||
|
||||
# Calculate required servers
|
||||
required_servers = max(1, int(predicted_requests / current_capacity_per_server))
|
||||
|
||||
# Apply buffer (20% extra capacity)
|
||||
required_servers = int(required_servers * 1.2)
|
||||
|
||||
scaling_action = "none"
|
||||
if required_servers > current_servers:
|
||||
scaling_action = "scale_up"
|
||||
scale_to = required_servers
|
||||
elif required_servers < current_servers * 0.7: # Scale down if underutilized
|
||||
scaling_action = "scale_down"
|
||||
scale_to = max(1, required_servers)
|
||||
else:
|
||||
scale_to = current_servers
|
||||
|
||||
return {
|
||||
"current_servers": current_servers,
|
||||
"recommended_servers": scale_to,
|
||||
"scaling_action": scaling_action,
|
||||
"predicted_load": predicted_requests,
|
||||
"current_capacity_per_server": current_capacity_per_server,
|
||||
"confidence_score": prediction["confidence_score"],
|
||||
"reason": f"Predicted {predicted_requests} requests/hour vs current capacity {current_servers * current_capacity_per_server}",
|
||||
"recommendation_timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Scaling recommendation failed: {e}")
|
||||
return {
|
||||
"scaling_action": "none",
|
||||
"reason": f"Prediction failed: {str(e)}",
|
||||
"recommendation_timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class AdvancedLoadBalancer:
|
||||
"""Advanced load balancer with multiple algorithms and AI optimization"""
|
||||
|
||||
def __init__(self):
|
||||
self.backends = {}
|
||||
self.algorithm = LoadBalancingAlgorithm.ADAPTIVE
|
||||
self.current_index = 0
|
||||
self.request_history = []
|
||||
self.performance_metrics = {}
|
||||
self.predictive_scaler = PredictiveScaler()
|
||||
self.scaling_metrics = {}
|
||||
self.logger = get_logger("advanced_load_balancer")
|
||||
|
||||
async def add_backend(self, server: BackendServer) -> bool:
|
||||
"""Add backend server"""
|
||||
|
||||
try:
|
||||
self.backends[server.server_id] = server
|
||||
|
||||
# Initialize performance metrics
|
||||
self.performance_metrics[server.server_id] = {
|
||||
"avg_response_time": 0.0,
|
||||
"error_rate": 0.0,
|
||||
"throughput": 0.0,
|
||||
"uptime": 1.0,
|
||||
"last_updated": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
self.logger.info(f"Backend server added: {server.server_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to add backend server: {e}")
|
||||
return False
|
||||
|
||||
async def remove_backend(self, server_id: str) -> bool:
|
||||
"""Remove backend server"""
|
||||
|
||||
if server_id in self.backends:
|
||||
del self.backends[server_id]
|
||||
del self.performance_metrics[server_id]
|
||||
|
||||
self.logger.info(f"Backend server removed: {server_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def select_backend(self, request_context: dict[str, Any] | None = None) -> str | None:
|
||||
"""Select backend server based on algorithm"""
|
||||
|
||||
try:
|
||||
# Filter healthy backends
|
||||
healthy_backends = {
|
||||
sid: server for sid, server in self.backends.items() if server.health_status == HealthStatus.HEALTHY
|
||||
}
|
||||
|
||||
if not healthy_backends:
|
||||
return None
|
||||
|
||||
# Select backend based on algorithm
|
||||
if self.algorithm == LoadBalancingAlgorithm.ROUND_ROBIN:
|
||||
return await self._select_round_robin(healthy_backends)
|
||||
elif self.algorithm == LoadBalancingAlgorithm.WEIGHTED_ROUND_ROBIN:
|
||||
return await self._select_weighted_round_robin(healthy_backends)
|
||||
elif self.algorithm == LoadBalancingAlgorithm.LEAST_CONNECTIONS:
|
||||
return await self._select_least_connections(healthy_backends)
|
||||
elif self.algorithm == LoadBalancingAlgorithm.LEAST_RESPONSE_TIME:
|
||||
return await self._select_least_response_time(healthy_backends)
|
||||
elif self.algorithm == LoadBalancingAlgorithm.RESOURCE_BASED:
|
||||
return await self._select_resource_based(healthy_backends)
|
||||
elif self.algorithm == LoadBalancingAlgorithm.PREDICTIVE_AI:
|
||||
return await self._select_predictive_ai(healthy_backends, request_context)
|
||||
elif self.algorithm == LoadBalancingAlgorithm.ADAPTIVE:
|
||||
return await self._select_adaptive(healthy_backends, request_context)
|
||||
else:
|
||||
return await self._select_round_robin(healthy_backends)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Backend selection failed: {e}")
|
||||
return None
|
||||
|
||||
async def _select_round_robin(self, backends: dict[str, BackendServer]) -> str:
|
||||
"""Round robin selection"""
|
||||
|
||||
backend_ids = list(backends.keys())
|
||||
|
||||
if not backend_ids:
|
||||
return None
|
||||
|
||||
selected = backend_ids[self.current_index % len(backend_ids)]
|
||||
self.current_index += 1
|
||||
|
||||
return selected
|
||||
|
||||
async def _select_weighted_round_robin(self, backends: dict[str, BackendServer]) -> str:
|
||||
"""Weighted round robin selection"""
|
||||
|
||||
# Calculate total weight
|
||||
total_weight = sum(server.weight for server in backends.values())
|
||||
|
||||
if total_weight <= 0:
|
||||
return await self._select_round_robin(backends)
|
||||
|
||||
# Select based on weights
|
||||
import random
|
||||
|
||||
rand_value = random.uniform(0, total_weight)
|
||||
|
||||
current_weight = 0
|
||||
for server_id, server in backends.items():
|
||||
current_weight += server.weight
|
||||
if rand_value <= current_weight:
|
||||
return server_id
|
||||
|
||||
# Fallback
|
||||
return list(backends.keys())[0]
|
||||
|
||||
async def _select_least_connections(self, backends: dict[str, BackendServer]) -> str:
|
||||
"""Select backend with least connections"""
|
||||
|
||||
min_connections = float("inf")
|
||||
selected_backend = None
|
||||
|
||||
for server_id, server in backends.items():
|
||||
if server.current_connections < min_connections:
|
||||
min_connections = server.current_connections
|
||||
selected_backend = server_id
|
||||
|
||||
return selected_backend
|
||||
|
||||
async def _select_least_response_time(self, backends: dict[str, BackendServer]) -> str:
|
||||
"""Select backend with least response time"""
|
||||
|
||||
min_response_time = float("inf")
|
||||
selected_backend = None
|
||||
|
||||
for server_id, server in backends.items():
|
||||
if server.response_time_ms < min_response_time:
|
||||
min_response_time = server.response_time_ms
|
||||
selected_backend = server_id
|
||||
|
||||
return selected_backend
|
||||
|
||||
async def _select_resource_based(self, backends: dict[str, BackendServer]) -> str:
|
||||
"""Select backend based on resource utilization"""
|
||||
|
||||
best_score = -1
|
||||
selected_backend = None
|
||||
|
||||
for server_id, server in backends.items():
|
||||
# Calculate resource score (lower is better)
|
||||
cpu_score = 1.0 - (server.cpu_usage / 100.0)
|
||||
memory_score = 1.0 - (server.memory_usage / 100.0)
|
||||
connection_score = 1.0 - (server.current_connections / server.max_connections)
|
||||
|
||||
# Weighted score
|
||||
resource_score = cpu_score * 0.4 + memory_score * 0.3 + connection_score * 0.3
|
||||
|
||||
if resource_score > best_score:
|
||||
best_score = resource_score
|
||||
selected_backend = server_id
|
||||
|
||||
return selected_backend
|
||||
|
||||
async def _select_predictive_ai(
|
||||
self, backends: dict[str, BackendServer], request_context: dict[str, Any] | None
|
||||
) -> str:
|
||||
"""AI-powered predictive selection"""
|
||||
|
||||
# Get performance predictions for each backend
|
||||
backend_scores = {}
|
||||
|
||||
for server_id, server in backends.items():
|
||||
# Predict performance based on historical data
|
||||
self.performance_metrics.get(server_id, {})
|
||||
|
||||
# Calculate predicted response time
|
||||
predicted_response_time = (
|
||||
server.response_time_ms
|
||||
* (1 + server.cpu_usage / 100)
|
||||
* (1 + server.memory_usage / 100)
|
||||
* (1 + server.current_connections / server.max_connections)
|
||||
)
|
||||
|
||||
# Calculate score (lower response time is better)
|
||||
score = 1.0 / (1.0 + predicted_response_time / 100.0)
|
||||
|
||||
# Apply context-based adjustments
|
||||
if request_context:
|
||||
# Consider request type, user location, etc.
|
||||
context_multiplier = await self._calculate_context_multiplier(server, request_context)
|
||||
score *= context_multiplier
|
||||
|
||||
backend_scores[server_id] = score
|
||||
|
||||
# Select best scoring backend
|
||||
if backend_scores:
|
||||
return max(backend_scores, key=backend_scores.get)
|
||||
|
||||
return await self._select_least_connections(backends)
|
||||
|
||||
async def _select_adaptive(self, backends: dict[str, BackendServer], request_context: dict[str, Any] | None) -> str:
|
||||
"""Adaptive selection based on current conditions"""
|
||||
|
||||
# Analyze current system state
|
||||
total_connections = sum(server.current_connections for server in backends.values())
|
||||
avg_response_time = statistics.mean([server.response_time_ms for server in backends.values()])
|
||||
|
||||
# Choose algorithm based on conditions
|
||||
if total_connections > sum(server.max_connections for server in backends.values()) * 0.8:
|
||||
# High load - use resource-based
|
||||
return await self._select_resource_based(backends)
|
||||
elif avg_response_time > 200:
|
||||
# High latency - use least response time
|
||||
return await self._select_least_response_time(backends)
|
||||
else:
|
||||
# Normal conditions - use weighted round robin
|
||||
return await self._select_weighted_round_robin(backends)
|
||||
|
||||
async def _calculate_context_multiplier(self, server: BackendServer, request_context: dict[str, Any]) -> float:
|
||||
"""Calculate context-based multiplier for backend selection"""
|
||||
|
||||
multiplier = 1.0
|
||||
|
||||
# Consider geographic location
|
||||
if "user_location" in request_context and "region" in server.capabilities:
|
||||
user_region = request_context["user_location"].get("region")
|
||||
server_region = server.capabilities["region"]
|
||||
|
||||
if user_region == server_region:
|
||||
multiplier *= 1.2 # Prefer same region
|
||||
elif self._regions_in_same_continent(user_region, server_region):
|
||||
multiplier *= 1.1 # Slight preference for same continent
|
||||
|
||||
# Consider request type
|
||||
request_type = request_context.get("request_type", "general")
|
||||
server_specializations = server.capabilities.get("specializations", [])
|
||||
|
||||
if request_type in server_specializations:
|
||||
multiplier *= 1.3 # Strong preference for specialized backends
|
||||
|
||||
# Consider user tier
|
||||
user_tier = request_context.get("user_tier", "standard")
|
||||
if user_tier == "premium" and server.capabilities.get("premium_support", False):
|
||||
multiplier *= 1.15
|
||||
|
||||
return multiplier
|
||||
|
||||
def _regions_in_same_continent(self, region1: str, region2: str) -> bool:
|
||||
"""Check if two regions are in the same continent"""
|
||||
|
||||
continent_mapping = {
|
||||
"NA": ["US", "CA", "MX"],
|
||||
"EU": ["GB", "DE", "FR", "IT", "ES", "NL", "BE", "AT", "CH", "SE", "NO", "DK", "FI"],
|
||||
"APAC": ["JP", "KR", "SG", "AU", "IN", "TH", "MY", "ID", "PH", "VN"],
|
||||
"LATAM": ["BR", "MX", "AR", "CL", "CO", "PE", "VE"],
|
||||
}
|
||||
|
||||
for _continent, regions in continent_mapping.items():
|
||||
if region1 in regions and region2 in regions:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def record_request(
|
||||
self, server_id: str, response_time_ms: float, success: bool, timestamp: datetime | None = None
|
||||
):
|
||||
"""Record request metrics"""
|
||||
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now(timezone.utc)
|
||||
|
||||
# Update backend server metrics
|
||||
if server_id in self.backends:
|
||||
server = self.backends[server_id]
|
||||
server.request_count += 1
|
||||
server.response_time_ms = server.response_time_ms * 0.9 + response_time_ms * 0.1 # EMA
|
||||
|
||||
if not success:
|
||||
server.error_count += 1
|
||||
|
||||
# Record in history
|
||||
request_record = {
|
||||
"timestamp": timestamp,
|
||||
"server_id": server_id,
|
||||
"response_time_ms": response_time_ms,
|
||||
"success": success,
|
||||
}
|
||||
|
||||
self.request_history.append(request_record)
|
||||
|
||||
# Keep only last 10000 records
|
||||
if len(self.request_history) > 10000:
|
||||
self.request_history = self.request_history[-10000:]
|
||||
|
||||
# Update predictive scaler
|
||||
await self.predictive_scaler.record_traffic(
|
||||
timestamp, 1, response_time_ms, 0.0 if success else 1.0 # One request # Error rate
|
||||
)
|
||||
|
||||
async def update_backend_health(
|
||||
self, server_id: str, health_status: HealthStatus, cpu_usage: float, memory_usage: float, current_connections: int
|
||||
):
|
||||
"""Update backend health metrics"""
|
||||
|
||||
if server_id in self.backends:
|
||||
server = self.backends[server_id]
|
||||
server.health_status = health_status
|
||||
server.cpu_usage = cpu_usage
|
||||
server.memory_usage = memory_usage
|
||||
server.current_connections = current_connections
|
||||
server.last_health_check = datetime.now(timezone.utc)
|
||||
|
||||
async def get_load_balancing_metrics(self) -> dict[str, Any]:
|
||||
"""Get comprehensive load balancing metrics"""
|
||||
|
||||
try:
|
||||
total_requests = sum(server.request_count for server in self.backends.values())
|
||||
total_errors = sum(server.error_count for server in self.backends.values())
|
||||
total_connections = sum(server.current_connections for server in self.backends.values())
|
||||
|
||||
error_rate = (total_errors / total_requests) if total_requests > 0 else 0.0
|
||||
|
||||
# Calculate average response time
|
||||
avg_response_time = 0.0
|
||||
if self.backends:
|
||||
avg_response_time = statistics.mean([server.response_time_ms for server in self.backends.values()])
|
||||
|
||||
# Backend distribution
|
||||
backend_distribution = {}
|
||||
for server_id, server in self.backends.items():
|
||||
backend_distribution[server_id] = {
|
||||
"requests": server.request_count,
|
||||
"errors": server.error_count,
|
||||
"connections": server.current_connections,
|
||||
"response_time_ms": server.response_time_ms,
|
||||
"cpu_usage": server.cpu_usage,
|
||||
"memory_usage": server.memory_usage,
|
||||
"health_status": server.health_status.value,
|
||||
"weight": server.weight,
|
||||
}
|
||||
|
||||
# Get scaling recommendation
|
||||
scaling_recommendation = await self.predictive_scaler.get_scaling_recommendation(
|
||||
len(self.backends), sum(server.max_connections for server in self.backends.values())
|
||||
)
|
||||
|
||||
return {
|
||||
"total_backends": len(self.backends),
|
||||
"healthy_backends": len([s for s in self.backends.values() if s.health_status == HealthStatus.HEALTHY]),
|
||||
"total_requests": total_requests,
|
||||
"total_errors": total_errors,
|
||||
"error_rate": error_rate,
|
||||
"average_response_time_ms": avg_response_time,
|
||||
"total_connections": total_connections,
|
||||
"algorithm": self.algorithm.value,
|
||||
"backend_distribution": backend_distribution,
|
||||
"scaling_recommendation": scaling_recommendation,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Metrics retrieval failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def set_algorithm(self, algorithm: LoadBalancingAlgorithm):
|
||||
"""Set load balancing algorithm"""
|
||||
|
||||
self.algorithm = algorithm
|
||||
self.logger.info(f"Load balancing algorithm changed to: {algorithm.value}")
|
||||
|
||||
async def auto_scale(self, min_servers: int = 1, max_servers: int = 10) -> dict[str, Any]:
|
||||
"""Perform auto-scaling based on predictions"""
|
||||
|
||||
try:
|
||||
# Get scaling recommendation
|
||||
recommendation = await self.predictive_scaler.get_scaling_recommendation(
|
||||
len(self.backends), sum(server.max_connections for server in self.backends.values())
|
||||
)
|
||||
|
||||
action = recommendation["scaling_action"]
|
||||
target_servers = recommendation["recommended_servers"]
|
||||
|
||||
# Apply scaling limits
|
||||
target_servers = max(min_servers, min(max_servers, target_servers))
|
||||
|
||||
scaling_result = {
|
||||
"action": action,
|
||||
"current_servers": len(self.backends),
|
||||
"target_servers": target_servers,
|
||||
"confidence": recommendation.get("confidence_score", 0.0),
|
||||
"reason": recommendation.get("reason", ""),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# In production, implement actual scaling logic here
|
||||
# For now, just return the recommendation
|
||||
|
||||
self.logger.info(f"Auto-scaling recommendation: {action} to {target_servers} servers")
|
||||
|
||||
return scaling_result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Auto-scaling failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Global load balancer instance
|
||||
advanced_load_balancer = None
|
||||
|
||||
|
||||
async def get_advanced_load_balancer() -> AdvancedLoadBalancer:
|
||||
"""Get or create global advanced load balancer"""
|
||||
|
||||
global advanced_load_balancer
|
||||
if advanced_load_balancer is None:
|
||||
advanced_load_balancer = AdvancedLoadBalancer()
|
||||
|
||||
# Add default backends
|
||||
default_backends = [
|
||||
BackendServer(
|
||||
server_id="backend_1", host="10.0.1.10", port=8080, weight=1.0, max_connections=1000, region="us_east"
|
||||
),
|
||||
BackendServer(
|
||||
server_id="backend_2", host="10.0.1.11", port=8080, weight=1.0, max_connections=1000, region="us_east"
|
||||
),
|
||||
BackendServer(
|
||||
server_id="backend_3", host="10.0.1.12", port=8080, weight=0.8, max_connections=800, region="eu_west"
|
||||
),
|
||||
]
|
||||
|
||||
for backend in default_backends:
|
||||
await advanced_load_balancer.add_backend(backend)
|
||||
|
||||
return advanced_load_balancer
|
||||
@@ -1,773 +0,0 @@
|
||||
"""
|
||||
Enterprise Security Framework - Phase 6.2 Implementation
|
||||
Zero-trust architecture with HSM integration and advanced security controls
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import cryptography
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SecurityLevel(StrEnum):
|
||||
"""Security levels for enterprise data"""
|
||||
|
||||
PUBLIC = "public"
|
||||
INTERNAL = "internal"
|
||||
CONFIDENTIAL = "confidential"
|
||||
RESTRICTED = "restricted"
|
||||
TOP_SECRET = "top_secret"
|
||||
|
||||
|
||||
class EncryptionAlgorithm(StrEnum):
|
||||
"""Encryption algorithms"""
|
||||
|
||||
AES_256_GCM = "aes_256_gcm"
|
||||
CHACHA20_POLY1305 = "chacha20_polyy1305"
|
||||
AES_256_CBC = "aes_256_cbc"
|
||||
QUANTUM_RESISTANT = "quantum_resistant"
|
||||
|
||||
|
||||
class ThreatLevel(StrEnum):
|
||||
"""Threat levels for security monitoring"""
|
||||
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityPolicy:
|
||||
"""Security policy configuration"""
|
||||
|
||||
policy_id: str
|
||||
name: str
|
||||
security_level: SecurityLevel
|
||||
encryption_algorithm: EncryptionAlgorithm
|
||||
key_rotation_interval: timedelta
|
||||
access_control_requirements: list[str]
|
||||
audit_requirements: list[str]
|
||||
retention_period: timedelta
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityEvent:
|
||||
"""Security event for monitoring"""
|
||||
|
||||
event_id: str
|
||||
event_type: str
|
||||
severity: ThreatLevel
|
||||
source: str
|
||||
timestamp: datetime
|
||||
user_id: str | None
|
||||
resource_id: str | None
|
||||
details: dict[str, Any]
|
||||
resolved: bool = False
|
||||
resolution_notes: str | None = None
|
||||
|
||||
|
||||
class HSMManager:
|
||||
"""Hardware Security Module manager for enterprise key management"""
|
||||
|
||||
def __init__(self, hsm_config: dict[str, Any]):
|
||||
self.hsm_config = hsm_config
|
||||
self.backend = default_backend()
|
||||
self.key_store = {} # In production, use actual HSM
|
||||
self.logger = get_logger("hsm_manager")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize HSM connection"""
|
||||
try:
|
||||
# In production, initialize actual HSM connection
|
||||
# For now, simulate HSM initialization
|
||||
self.logger.info("HSM manager initialized")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"HSM initialization failed: {e}")
|
||||
return False
|
||||
|
||||
async def generate_key(self, key_id: str, algorithm: EncryptionAlgorithm, key_size: int = 256) -> dict[str, Any]:
|
||||
"""Generate encryption key in HSM"""
|
||||
|
||||
try:
|
||||
if algorithm == EncryptionAlgorithm.AES_256_GCM:
|
||||
key = secrets.token_bytes(32) # 256 bits
|
||||
iv = secrets.token_bytes(12) # 96 bits for GCM
|
||||
elif algorithm == EncryptionAlgorithm.CHACHA20_POLY1305:
|
||||
key = secrets.token_bytes(32) # 256 bits
|
||||
nonce = secrets.token_bytes(12) # 96 bits
|
||||
elif algorithm == EncryptionAlgorithm.AES_256_CBC:
|
||||
key = secrets.token_bytes(32) # 256 bits
|
||||
iv = secrets.token_bytes(16) # 128 bits for CBC
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
# Store key in HSM (simulated)
|
||||
key_data = {
|
||||
"key_id": key_id,
|
||||
"algorithm": algorithm.value,
|
||||
"key": key,
|
||||
"iv": iv if algorithm in [EncryptionAlgorithm.AES_256_GCM, EncryptionAlgorithm.AES_256_CBC] else None,
|
||||
"nonce": nonce if algorithm == EncryptionAlgorithm.CHACHA20_POLY1305 else None,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"key_size": key_size,
|
||||
}
|
||||
|
||||
self.key_store[key_id] = key_data
|
||||
|
||||
self.logger.info(f"Key generated in HSM: {key_id}")
|
||||
return key_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Key generation failed: {e}")
|
||||
raise
|
||||
|
||||
async def get_key(self, key_id: str) -> dict[str, Any] | None:
|
||||
"""Get key from HSM"""
|
||||
return self.key_store.get(key_id)
|
||||
|
||||
async def rotate_key(self, key_id: str) -> dict[str, Any]:
|
||||
"""Rotate encryption key"""
|
||||
|
||||
old_key = self.key_store.get(key_id)
|
||||
if not old_key:
|
||||
raise ValueError(f"Key not found: {key_id}")
|
||||
|
||||
# Generate new key
|
||||
new_key = await self.generate_key(f"{key_id}_new", EncryptionAlgorithm(old_key["algorithm"]), old_key["key_size"])
|
||||
|
||||
# Update key with rotation timestamp
|
||||
new_key["rotated_from"] = key_id
|
||||
new_key["rotation_timestamp"] = datetime.now(timezone.utc)
|
||||
|
||||
return new_key
|
||||
|
||||
async def delete_key(self, key_id: str) -> bool:
|
||||
"""Delete key from HSM"""
|
||||
if key_id in self.key_store:
|
||||
del self.key_store[key_id]
|
||||
self.logger.info(f"Key deleted from HSM: {key_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class EnterpriseEncryption:
|
||||
"""Enterprise-grade encryption service"""
|
||||
|
||||
def __init__(self, hsm_manager: HSMManager):
|
||||
self.hsm_manager = hsm_manager
|
||||
self.backend = default_backend()
|
||||
self.logger = get_logger("enterprise_encryption")
|
||||
|
||||
async def encrypt_data(
|
||||
self, data: str | bytes, key_id: str, associated_data: bytes | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Encrypt data using enterprise-grade encryption"""
|
||||
|
||||
try:
|
||||
# Get key from HSM
|
||||
key_data = await self.hsm_manager.get_key(key_id)
|
||||
if not key_data:
|
||||
raise ValueError(f"Key not found: {key_id}")
|
||||
|
||||
# Convert data to bytes if needed
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
|
||||
algorithm = EncryptionAlgorithm(key_data["algorithm"])
|
||||
|
||||
if algorithm == EncryptionAlgorithm.AES_256_GCM:
|
||||
return await self._encrypt_aes_gcm(data, key_data, associated_data)
|
||||
elif algorithm == EncryptionAlgorithm.CHACHA20_POLY1305:
|
||||
return await self._encrypt_chacha20(data, key_data, associated_data)
|
||||
elif algorithm == EncryptionAlgorithm.AES_256_CBC:
|
||||
return await self._encrypt_aes_cbc(data, key_data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported encryption algorithm: {algorithm}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Encryption failed: {e}")
|
||||
raise
|
||||
|
||||
async def _encrypt_aes_gcm(
|
||||
self, data: bytes, key_data: dict[str, Any], associated_data: bytes | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Encrypt using AES-256-GCM"""
|
||||
|
||||
key = key_data["key"]
|
||||
iv = key_data["iv"]
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=self.backend)
|
||||
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
# Add associated data if provided
|
||||
if associated_data:
|
||||
encryptor.authenticate_additional_data(associated_data)
|
||||
|
||||
# Encrypt data
|
||||
ciphertext = encryptor.update(data) + encryptor.finalize()
|
||||
|
||||
return {
|
||||
"ciphertext": ciphertext.hex(),
|
||||
"iv": iv.hex(),
|
||||
"tag": encryptor.tag.hex(),
|
||||
"algorithm": "aes_256_gcm",
|
||||
"key_id": key_data["key_id"],
|
||||
}
|
||||
|
||||
async def _encrypt_chacha20(
|
||||
self, data: bytes, key_data: dict[str, Any], associated_data: bytes | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Encrypt using ChaCha20-Poly1305"""
|
||||
|
||||
key = key_data["key"]
|
||||
nonce = key_data["nonce"]
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.ChaCha20(key, nonce), modes.Poly1305(b""), backend=self.backend)
|
||||
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
# Add associated data if provided
|
||||
if associated_data:
|
||||
encryptor.authenticate_additional_data(associated_data)
|
||||
|
||||
# Encrypt data
|
||||
ciphertext = encryptor.update(data) + encryptor.finalize()
|
||||
|
||||
return {
|
||||
"ciphertext": ciphertext.hex(),
|
||||
"nonce": nonce.hex(),
|
||||
"tag": encryptor.tag.hex(),
|
||||
"algorithm": "chacha20_poly1305",
|
||||
"key_id": key_data["key_id"],
|
||||
}
|
||||
|
||||
async def _encrypt_aes_cbc(self, data: bytes, key_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Encrypt using AES-256-CBC"""
|
||||
|
||||
key = key_data["key"]
|
||||
iv = key_data["iv"]
|
||||
|
||||
# Pad data to block size
|
||||
padder = cryptography.hazmat.primitives.padding.PKCS7(128).padder()
|
||||
padded_data = padder.update(data) + padder.finalize()
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=self.backend)
|
||||
|
||||
encryptor = cipher.encryptor()
|
||||
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return {"ciphertext": ciphertext.hex(), "iv": iv.hex(), "algorithm": "aes_256_cbc", "key_id": key_data["key_id"]}
|
||||
|
||||
async def decrypt_data(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes:
|
||||
"""Decrypt encrypted data"""
|
||||
|
||||
try:
|
||||
algorithm = encrypted_data["algorithm"]
|
||||
|
||||
if algorithm == "aes_256_gcm":
|
||||
return await self._decrypt_aes_gcm(encrypted_data, associated_data)
|
||||
elif algorithm == "chacha20_poly1305":
|
||||
return await self._decrypt_chacha20(encrypted_data, associated_data)
|
||||
elif algorithm == "aes_256_cbc":
|
||||
return await self._decrypt_aes_cbc(encrypted_data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported encryption algorithm: {algorithm}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Decryption failed: {e}")
|
||||
raise
|
||||
|
||||
async def _decrypt_aes_gcm(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes:
|
||||
"""Decrypt AES-256-GCM encrypted data"""
|
||||
|
||||
# Get key from HSM
|
||||
key_data = await self.hsm_manager.get_key(encrypted_data["key_id"])
|
||||
if not key_data:
|
||||
raise ValueError(f"Key not found: {encrypted_data['key_id']}")
|
||||
|
||||
key = key_data["key"]
|
||||
iv = bytes.fromhex(encrypted_data["iv"])
|
||||
ciphertext = bytes.fromhex(encrypted_data["ciphertext"])
|
||||
tag = bytes.fromhex(encrypted_data["tag"])
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=self.backend)
|
||||
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
# Add associated data if provided
|
||||
if associated_data:
|
||||
decryptor.authenticate_additional_data(associated_data)
|
||||
|
||||
# Decrypt data
|
||||
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
||||
|
||||
return plaintext
|
||||
|
||||
async def _decrypt_chacha20(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes:
|
||||
"""Decrypt ChaCha20-Poly1305 encrypted data"""
|
||||
|
||||
# Get key from HSM
|
||||
key_data = await self.hsm_manager.get_key(encrypted_data["key_id"])
|
||||
if not key_data:
|
||||
raise ValueError(f"Key not found: {encrypted_data['key_id']}")
|
||||
|
||||
key = key_data["key"]
|
||||
nonce = bytes.fromhex(encrypted_data["nonce"])
|
||||
ciphertext = bytes.fromhex(encrypted_data["ciphertext"])
|
||||
tag = bytes.fromhex(encrypted_data["tag"])
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.ChaCha20(key, nonce), modes.Poly1305(tag), backend=self.backend)
|
||||
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
# Add associated data if provided
|
||||
if associated_data:
|
||||
decryptor.authenticate_additional_data(associated_data)
|
||||
|
||||
# Decrypt data
|
||||
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
||||
|
||||
return plaintext
|
||||
|
||||
async def _decrypt_aes_cbc(self, encrypted_data: dict[str, Any]) -> bytes:
|
||||
"""Decrypt AES-256-CBC encrypted data"""
|
||||
|
||||
# Get key from HSM
|
||||
key_data = await self.hsm_manager.get_key(encrypted_data["key_id"])
|
||||
if not key_data:
|
||||
raise ValueError(f"Key not found: {encrypted_data['key_id']}")
|
||||
|
||||
key = key_data["key"]
|
||||
iv = bytes.fromhex(encrypted_data["iv"])
|
||||
ciphertext = bytes.fromhex(encrypted_data["ciphertext"])
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=self.backend)
|
||||
|
||||
decryptor = cipher.decryptor()
|
||||
padded_plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
||||
|
||||
# Unpad data
|
||||
unpadder = cryptography.hazmat.primitives.padding.PKCS7(128).unpadder()
|
||||
plaintext = unpadder.update(padded_plaintext) + unpadder.finalize()
|
||||
|
||||
return plaintext
|
||||
|
||||
|
||||
class ZeroTrustArchitecture:
|
||||
"""Zero-trust security architecture implementation"""
|
||||
|
||||
def __init__(self, hsm_manager: HSMManager, encryption: EnterpriseEncryption):
|
||||
self.hsm_manager = hsm_manager
|
||||
self.encryption = encryption
|
||||
self.trust_policies = {}
|
||||
self.session_tokens = {}
|
||||
self.logger = get_logger("zero_trust")
|
||||
|
||||
async def create_trust_policy(self, policy_id: str, policy_config: dict[str, Any]) -> bool:
|
||||
"""Create zero-trust policy"""
|
||||
|
||||
try:
|
||||
policy = SecurityPolicy(
|
||||
policy_id=policy_id,
|
||||
name=policy_config["name"],
|
||||
security_level=SecurityLevel(policy_config["security_level"]),
|
||||
encryption_algorithm=EncryptionAlgorithm(policy_config["encryption_algorithm"]),
|
||||
key_rotation_interval=timedelta(days=policy_config.get("key_rotation_days", 90)),
|
||||
access_control_requirements=policy_config.get("access_control_requirements", []),
|
||||
audit_requirements=policy_config.get("audit_requirements", []),
|
||||
retention_period=timedelta(days=policy_config.get("retention_days", 2555)), # 7 years
|
||||
)
|
||||
|
||||
self.trust_policies[policy_id] = policy
|
||||
|
||||
# Generate encryption key for policy
|
||||
await self.hsm_manager.generate_key(f"policy_{policy_id}", policy.encryption_algorithm)
|
||||
|
||||
self.logger.info(f"Zero-trust policy created: {policy_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create trust policy: {e}")
|
||||
return False
|
||||
|
||||
async def verify_trust(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> bool:
|
||||
"""Verify zero-trust access request"""
|
||||
|
||||
try:
|
||||
# Get applicable policy
|
||||
policy_id = context.get("policy_id", "default")
|
||||
policy = self.trust_policies.get(policy_id)
|
||||
|
||||
if not policy:
|
||||
self.logger.warning(f"No policy found for {policy_id}")
|
||||
return False
|
||||
|
||||
# Verify trust factors
|
||||
trust_score = await self._calculate_trust_score(user_id, resource_id, action, context)
|
||||
|
||||
# Check if trust score meets policy requirements
|
||||
min_trust_score = self._get_min_trust_score(policy.security_level)
|
||||
|
||||
is_trusted = trust_score >= min_trust_score
|
||||
|
||||
# Log trust decision
|
||||
await self._log_trust_decision(user_id, resource_id, action, trust_score, is_trusted)
|
||||
|
||||
return is_trusted
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Trust verification failed: {e}")
|
||||
return False
|
||||
|
||||
async def _calculate_trust_score(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> float:
|
||||
"""Calculate trust score for access request"""
|
||||
|
||||
score = 0.0
|
||||
|
||||
# User authentication factor (40%)
|
||||
auth_strength = context.get("auth_strength", "password")
|
||||
if auth_strength == "mfa":
|
||||
score += 0.4
|
||||
elif auth_strength == "password":
|
||||
score += 0.2
|
||||
|
||||
# Device trust factor (20%)
|
||||
device_trust = context.get("device_trust", 0.5)
|
||||
score += 0.2 * device_trust
|
||||
|
||||
# Location factor (15%)
|
||||
location_trust = context.get("location_trust", 0.5)
|
||||
score += 0.15 * location_trust
|
||||
|
||||
# Time factor (10%)
|
||||
time_trust = context.get("time_trust", 0.5)
|
||||
score += 0.1 * time_trust
|
||||
|
||||
# Behavioral factor (15%)
|
||||
behavior_trust = context.get("behavior_trust", 0.5)
|
||||
score += 0.15 * behavior_trust
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def _get_min_trust_score(self, security_level: SecurityLevel) -> float:
|
||||
"""Get minimum trust score for security level"""
|
||||
|
||||
thresholds = {
|
||||
SecurityLevel.PUBLIC: 0.0,
|
||||
SecurityLevel.INTERNAL: 0.3,
|
||||
SecurityLevel.CONFIDENTIAL: 0.6,
|
||||
SecurityLevel.RESTRICTED: 0.8,
|
||||
SecurityLevel.TOP_SECRET: 0.9,
|
||||
}
|
||||
|
||||
return thresholds.get(security_level, 0.5)
|
||||
|
||||
async def _log_trust_decision(self, user_id: str, resource_id: str, action: str, trust_score: float, decision: bool):
|
||||
"""Log trust decision for audit"""
|
||||
|
||||
SecurityEvent(
|
||||
event_id=str(uuid4()),
|
||||
event_type="trust_decision",
|
||||
severity=ThreatLevel.LOW if decision else ThreatLevel.MEDIUM,
|
||||
source="zero_trust",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
user_id=user_id,
|
||||
resource_id=resource_id,
|
||||
details={"action": action, "trust_score": trust_score, "decision": decision},
|
||||
)
|
||||
|
||||
# In production, send to security monitoring system
|
||||
self.logger.info(f"Trust decision: {user_id} -> {resource_id} = {decision} (score: {trust_score})")
|
||||
|
||||
|
||||
class ThreatDetectionSystem:
|
||||
"""Advanced threat detection and response system"""
|
||||
|
||||
def __init__(self):
|
||||
self.threat_patterns = {}
|
||||
self.active_threats = {}
|
||||
self.response_actions = {}
|
||||
self.logger = get_logger("threat_detection")
|
||||
|
||||
async def register_threat_pattern(self, pattern_id: str, pattern_config: dict[str, Any]):
|
||||
"""Register threat detection pattern"""
|
||||
|
||||
self.threat_patterns[pattern_id] = {
|
||||
"id": pattern_id,
|
||||
"name": pattern_config["name"],
|
||||
"description": pattern_config["description"],
|
||||
"indicators": pattern_config["indicators"],
|
||||
"severity": ThreatLevel(pattern_config["severity"]),
|
||||
"response_actions": pattern_config.get("response_actions", []),
|
||||
"threshold": pattern_config.get("threshold", 1.0),
|
||||
}
|
||||
|
||||
self.logger.info(f"Threat pattern registered: {pattern_id}")
|
||||
|
||||
async def analyze_threat(self, event_data: dict[str, Any]) -> list[SecurityEvent]:
|
||||
"""Analyze event for potential threats"""
|
||||
|
||||
detected_threats = []
|
||||
|
||||
for pattern_id, pattern in self.threat_patterns.items():
|
||||
threat_score = await self._calculate_threat_score(event_data, pattern)
|
||||
|
||||
if threat_score >= pattern["threshold"]:
|
||||
threat_event = SecurityEvent(
|
||||
event_id=str(uuid4()),
|
||||
event_type="threat_detected",
|
||||
severity=pattern["severity"],
|
||||
source="threat_detection",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
user_id=event_data.get("user_id"),
|
||||
resource_id=event_data.get("resource_id"),
|
||||
details={
|
||||
"pattern_id": pattern_id,
|
||||
"pattern_name": pattern["name"],
|
||||
"threat_score": threat_score,
|
||||
"indicators": event_data,
|
||||
},
|
||||
)
|
||||
|
||||
detected_threats.append(threat_event)
|
||||
|
||||
# Trigger response actions
|
||||
await self._trigger_response_actions(pattern_id, threat_event)
|
||||
|
||||
return detected_threats
|
||||
|
||||
async def _calculate_threat_score(self, event_data: dict[str, Any], pattern: dict[str, Any]) -> float:
|
||||
"""Calculate threat score for pattern"""
|
||||
|
||||
score = 0.0
|
||||
indicators = pattern["indicators"]
|
||||
|
||||
for indicator, weight in indicators.items():
|
||||
if indicator in event_data:
|
||||
# Simple scoring - in production, use more sophisticated algorithms
|
||||
indicator_score = 0.5 # Base score for presence
|
||||
score += indicator_score * weight
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
async def _trigger_response_actions(self, pattern_id: str, threat_event: SecurityEvent):
|
||||
"""Trigger automated response actions"""
|
||||
|
||||
pattern = self.threat_patterns[pattern_id]
|
||||
actions = pattern.get("response_actions", [])
|
||||
|
||||
for action in actions:
|
||||
try:
|
||||
await self._execute_response_action(action, threat_event)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Response action failed: {action} - {e}")
|
||||
|
||||
async def _execute_response_action(self, action: str, threat_event: SecurityEvent):
|
||||
"""Execute specific response action"""
|
||||
|
||||
if action == "block_user":
|
||||
await self._block_user(threat_event.user_id)
|
||||
elif action == "isolate_resource":
|
||||
await self._isolate_resource(threat_event.resource_id)
|
||||
elif action == "escalate_to_admin":
|
||||
await self._escalate_to_admin(threat_event)
|
||||
elif action == "require_mfa":
|
||||
await self._require_mfa(threat_event.user_id)
|
||||
|
||||
self.logger.info(f"Response action executed: {action}")
|
||||
|
||||
async def _block_user(self, user_id: str):
|
||||
"""Block user account"""
|
||||
# In production, implement actual user blocking
|
||||
self.logger.warning(f"User blocked due to threat: {user_id}")
|
||||
|
||||
async def _isolate_resource(self, resource_id: str):
|
||||
"""Isolate compromised resource"""
|
||||
# In production, implement actual resource isolation
|
||||
self.logger.warning(f"Resource isolated due to threat: {resource_id}")
|
||||
|
||||
async def _escalate_to_admin(self, threat_event: SecurityEvent):
|
||||
"""Escalate threat to security administrators"""
|
||||
# In production, implement actual escalation
|
||||
self.logger.error(f"Threat escalated to admin: {threat_event.event_id}")
|
||||
|
||||
async def _require_mfa(self, user_id: str):
|
||||
"""Require multi-factor authentication"""
|
||||
# In production, implement MFA requirement
|
||||
self.logger.warning(f"MFA required for user: {user_id}")
|
||||
|
||||
|
||||
class EnterpriseSecurityFramework:
|
||||
"""Main enterprise security framework"""
|
||||
|
||||
def __init__(self, hsm_config: dict[str, Any]):
|
||||
self.hsm_manager = HSMManager(hsm_config)
|
||||
self.encryption = EnterpriseEncryption(self.hsm_manager)
|
||||
self.zero_trust = ZeroTrustArchitecture(self.hsm_manager, self.encryption)
|
||||
self.threat_detection = ThreatDetectionSystem()
|
||||
self.logger = get_logger("enterprise_security")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize security framework"""
|
||||
|
||||
try:
|
||||
# Initialize HSM
|
||||
if not await self.hsm_manager.initialize():
|
||||
return False
|
||||
|
||||
# Register default threat patterns
|
||||
await self._register_default_threat_patterns()
|
||||
|
||||
# Create default trust policies
|
||||
await self._create_default_policies()
|
||||
|
||||
self.logger.info("Enterprise security framework initialized")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Security framework initialization failed: {e}")
|
||||
return False
|
||||
|
||||
async def _register_default_threat_patterns(self):
|
||||
"""Register default threat detection patterns"""
|
||||
|
||||
patterns = [
|
||||
{
|
||||
"name": "Brute Force Attack",
|
||||
"description": "Multiple failed login attempts",
|
||||
"indicators": {"failed_login_attempts": 0.8, "short_time_interval": 0.6},
|
||||
"severity": "high",
|
||||
"threshold": 0.7,
|
||||
"response_actions": ["block_user", "require_mfa"],
|
||||
},
|
||||
{
|
||||
"name": "Suspicious Access Pattern",
|
||||
"description": "Unusual access patterns",
|
||||
"indicators": {"unusual_location": 0.7, "unusual_time": 0.5, "high_frequency": 0.6},
|
||||
"severity": "medium",
|
||||
"threshold": 0.6,
|
||||
"response_actions": ["require_mfa", "escalate_to_admin"],
|
||||
},
|
||||
{
|
||||
"name": "Data Exfiltration",
|
||||
"description": "Large data transfer patterns",
|
||||
"indicators": {"large_data_transfer": 0.9, "unusual_destination": 0.7},
|
||||
"severity": "critical",
|
||||
"threshold": 0.8,
|
||||
"response_actions": ["block_user", "isolate_resource", "escalate_to_admin"],
|
||||
},
|
||||
]
|
||||
|
||||
for i, pattern in enumerate(patterns):
|
||||
await self.threat_detection.register_threat_pattern(f"default_{i}", pattern)
|
||||
|
||||
async def _create_default_policies(self):
|
||||
"""Create default trust policies"""
|
||||
|
||||
policies = [
|
||||
{
|
||||
"name": "Enterprise Data Policy",
|
||||
"security_level": "confidential",
|
||||
"encryption_algorithm": "aes_256_gcm",
|
||||
"key_rotation_days": 90,
|
||||
"access_control_requirements": ["mfa", "device_trust"],
|
||||
"audit_requirements": ["full_audit", "real_time_monitoring"],
|
||||
"retention_days": 2555,
|
||||
},
|
||||
{
|
||||
"name": "Public API Policy",
|
||||
"security_level": "public",
|
||||
"encryption_algorithm": "aes_256_gcm",
|
||||
"key_rotation_days": 180,
|
||||
"access_control_requirements": ["api_key"],
|
||||
"audit_requirements": ["api_access_log"],
|
||||
"retention_days": 365,
|
||||
},
|
||||
]
|
||||
|
||||
for i, policy in enumerate(policies):
|
||||
await self.zero_trust.create_trust_policy(f"default_{i}", policy)
|
||||
|
||||
async def encrypt_sensitive_data(self, data: str | bytes, security_level: SecurityLevel) -> dict[str, Any]:
|
||||
"""Encrypt sensitive data with appropriate security level"""
|
||||
|
||||
# Get policy for security level
|
||||
policy_id = f"default_{0 if security_level == SecurityLevel.PUBLIC else 1}"
|
||||
policy = self.zero_trust.trust_policies.get(policy_id)
|
||||
|
||||
if not policy:
|
||||
raise ValueError(f"No policy found for security level: {security_level}")
|
||||
|
||||
key_id = f"policy_{policy_id}"
|
||||
|
||||
return await self.encryption.encrypt_data(data, key_id)
|
||||
|
||||
async def verify_access(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> bool:
|
||||
"""Verify access using zero-trust architecture"""
|
||||
|
||||
return await self.zero_trust.verify_trust(user_id, resource_id, action, context)
|
||||
|
||||
async def analyze_security_event(self, event_data: dict[str, Any]) -> list[SecurityEvent]:
|
||||
"""Analyze security event for threats"""
|
||||
|
||||
return await self.threat_detection.analyze_threat(event_data)
|
||||
|
||||
async def rotate_encryption_keys(self, policy_id: str | None = None) -> dict[str, Any]:
|
||||
"""Rotate encryption keys"""
|
||||
|
||||
if policy_id:
|
||||
# Rotate specific policy key
|
||||
old_key_id = f"policy_{policy_id}"
|
||||
new_key = await self.hsm_manager.rotate_key(old_key_id)
|
||||
return {"rotated_key": new_key}
|
||||
else:
|
||||
# Rotate all keys
|
||||
rotated_keys = {}
|
||||
for policy_id in self.zero_trust.trust_policies.keys():
|
||||
old_key_id = f"policy_{policy_id}"
|
||||
new_key = await self.hsm_manager.rotate_key(old_key_id)
|
||||
rotated_keys[policy_id] = new_key
|
||||
|
||||
return {"rotated_keys": rotated_keys}
|
||||
|
||||
|
||||
# Global security framework instance
|
||||
security_framework = None
|
||||
|
||||
|
||||
async def get_security_framework() -> EnterpriseSecurityFramework:
|
||||
"""Get or create global security framework"""
|
||||
|
||||
global security_framework
|
||||
if security_framework is None:
|
||||
hsm_config = {"provider": "software", "endpoint": "localhost:8080"} # In production, use actual HSM
|
||||
|
||||
security_framework = EnterpriseSecurityFramework(hsm_config)
|
||||
await security_framework.initialize()
|
||||
|
||||
return security_framework
|
||||
|
||||
|
||||
# Alias for CLI compatibility
|
||||
EnterpriseSecurityManager = EnterpriseSecurityFramework
|
||||
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Trading & Marketplace Bounded Context
|
||||
Provides trading, marketplace optimization, bid strategy, and dynamic pricing services.
|
||||
"""
|
||||
|
||||
from .bid_strategy import BidStrategyEngine
|
||||
from .dynamic_pricing import DynamicPricingEngine
|
||||
from .gpu_optimizer import MarketplaceGPUOptimizer
|
||||
from .trading import MatchingEngine, NegotiationSystem
|
||||
|
||||
__all__ = [
|
||||
"BidStrategyEngine",
|
||||
"DynamicPricingEngine",
|
||||
"MarketplaceGPUOptimizer",
|
||||
"MatchingEngine",
|
||||
"NegotiationSystem",
|
||||
]
|
||||
@@ -13,7 +13,7 @@ logger = get_logger(__name__)
|
||||
|
||||
from sqlmodel import Session, or_, select
|
||||
|
||||
from ..domain.trading import (
|
||||
from ...domain.trading import (
|
||||
NegotiationStatus,
|
||||
SettlementType,
|
||||
TradeAgreement,
|
||||
@@ -25,6 +25,26 @@
|
||||
- Maintained backward compatibility with lazy-loading pattern
|
||||
- Import tests verified successfully
|
||||
- Old monolithic files removed
|
||||
- ✅ Phase 2 Complete: Enterprise Integration bounded context decomposed
|
||||
- Created app/services/enterprise_integration/ package with 4 modules
|
||||
- Migrated enterprise_integration.py (1127 lines) and 3 other enterprise files
|
||||
- Updated imports within package (api_gateway.py excluded due to missing dependencies)
|
||||
- Import tests verified successfully
|
||||
- Old monolithic files removed
|
||||
- ✅ Phase 3 Complete: Trading & Marketplace bounded context decomposed
|
||||
- Created app/services/trading_marketplace/ package with 5 modules
|
||||
- Migrated trading_service.py (36K) and 4 other trading files
|
||||
- Updated imports across coordinator-api (routers/trading.py, routers/dynamic_pricing.py)
|
||||
- amm.py excluded from exports due to missing dependencies
|
||||
- Import tests verified successfully
|
||||
- Old monolithic files removed
|
||||
- ✅ Phase 4 Complete: AI & Analytics bounded context decomposed
|
||||
- Created app/services/ai_analytics/ package with 5 modules
|
||||
- Migrated analytics_service.py (41K) and 4 other AI files
|
||||
- Updated imports across coordinator-api (routers/analytics.py, routers/adaptive_learning_health.py)
|
||||
- adaptive_learning.py, surveillance.py, trading_engine.py excluded due to missing dependencies
|
||||
- Import tests verified successfully
|
||||
- Old monolithic files removed
|
||||
|
||||
2. **Production Code Using print()** (HIGH IMPACT)
|
||||
- 925 print() statements in production code
|
||||
@@ -128,9 +148,17 @@
|
||||
- test_validation_properties.py: 20/20 passing
|
||||
- test_staking_service.py: 22/22 passing
|
||||
- Coverage threshold set to 50% in pyproject.toml
|
||||
- Current coverage: 11% (4623 statements, 4122 missed) - BELOW 50% threshold
|
||||
- Well-covered modules: constants.py (100%), exceptions.py (100%), validation.py (85%), crypto/crypto.py (52%)
|
||||
- Current coverage: 19% (4623 statements, 3745 missed) - BELOW 50% threshold
|
||||
- Added 137 new tests across 6 modules:
|
||||
- test_middleware.py: 11 tests (middleware modules: 50-100% coverage)
|
||||
- test_utils.py: 47 tests (utils modules: 100% coverage when run standalone)
|
||||
- test_config.py: 14 tests (config.py: 100% coverage)
|
||||
- test_decorators.py: 21 tests (decorators.py: 99% coverage)
|
||||
- test_health_checks.py: 16 tests (health_checks.py: 80% coverage)
|
||||
- test_metrics.py: 28 tests (metrics.py: 100% coverage)
|
||||
- Well-covered modules: constants.py (100%), exceptions.py (100%), validation.py (85%), crypto/crypto.py (52%), config.py (100%), decorators.py (99%), health_checks.py (80%), metrics.py (100%)
|
||||
- Needs improvement: Most modules at 0-30% coverage
|
||||
- Note: Utils modules (paths, env, json_utils) achieve 100% when run standalone but not counted in overall coverage due to import patterns
|
||||
|
||||
#### MEDIUM (Long-term, 1-3 months)
|
||||
|
||||
|
||||
154
tests/test_config.py
Normal file
154
tests/test_config.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Tests for AITBC configuration classes
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from aitbc.config import BaseAITBCConfig, AITBCConfig
|
||||
|
||||
|
||||
class TestBaseAITBCConfig:
|
||||
"""Tests for BaseAITBCConfig"""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test BaseAITBCConfig with default values"""
|
||||
config = BaseAITBCConfig()
|
||||
assert config.app_name == "AITBC Application"
|
||||
assert config.app_version == "1.0.0"
|
||||
assert config.environment == "development"
|
||||
assert config.debug is False
|
||||
assert config.log_level == "INFO"
|
||||
|
||||
def test_custom_values(self):
|
||||
"""Test BaseAITBCConfig with custom values"""
|
||||
config = BaseAITBCConfig(
|
||||
app_name="Custom App",
|
||||
app_version="2.0.0",
|
||||
environment="production",
|
||||
debug=True,
|
||||
log_level="DEBUG"
|
||||
)
|
||||
assert config.app_name == "Custom App"
|
||||
assert config.app_version == "2.0.0"
|
||||
assert config.environment == "production"
|
||||
assert config.debug is True
|
||||
assert config.log_level == "DEBUG"
|
||||
|
||||
def test_data_dir_default(self):
|
||||
"""Test default data directory is a Path"""
|
||||
config = BaseAITBCConfig()
|
||||
assert isinstance(config.data_dir, Path)
|
||||
|
||||
def test_config_dir_default(self):
|
||||
"""Test default config directory is a Path"""
|
||||
config = BaseAITBCConfig()
|
||||
assert isinstance(config.config_dir, Path)
|
||||
|
||||
def test_log_dir_default(self):
|
||||
"""Test default log directory is a Path"""
|
||||
config = BaseAITBCConfig()
|
||||
assert isinstance(config.log_dir, Path)
|
||||
|
||||
def test_log_format_default(self):
|
||||
"""Test default log format"""
|
||||
config = BaseAITBCConfig()
|
||||
assert "%(asctime)s" in config.log_format
|
||||
assert "%(name)s" in config.log_format
|
||||
assert "%(levelname)s" in config.log_format
|
||||
|
||||
|
||||
class TestAITBCConfig:
|
||||
"""Tests for AITBCConfig"""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test AITBCConfig with default values"""
|
||||
config = AITBCConfig()
|
||||
assert config.host == "0.0.0.0"
|
||||
assert config.port == 8000
|
||||
assert config.workers == 1
|
||||
assert config.database_url is None
|
||||
assert config.database_pool_size == 10
|
||||
assert config.redis_url is None
|
||||
assert config.redis_max_connections == 10
|
||||
assert config.redis_timeout == 5
|
||||
assert config.secret_key is None
|
||||
assert config.jwt_secret is None
|
||||
assert config.jwt_algorithm == "HS256"
|
||||
assert config.jwt_expiration_hours == 24
|
||||
assert config.request_timeout == 30
|
||||
assert config.max_request_size == 10 * 1024 * 1024
|
||||
|
||||
def test_custom_server_settings(self):
|
||||
"""Test AITBCConfig with custom server settings"""
|
||||
config = AITBCConfig(
|
||||
host="127.0.0.1",
|
||||
port=9000,
|
||||
workers=4
|
||||
)
|
||||
assert config.host == "127.0.0.1"
|
||||
assert config.port == 9000
|
||||
assert config.workers == 4
|
||||
|
||||
def test_custom_database_settings(self):
|
||||
"""Test AITBCConfig with custom database settings"""
|
||||
config = AITBCConfig(
|
||||
database_url="postgresql://localhost/test",
|
||||
database_pool_size=20
|
||||
)
|
||||
assert config.database_url == "postgresql://localhost/test"
|
||||
assert config.database_pool_size == 20
|
||||
|
||||
def test_custom_redis_settings(self):
|
||||
"""Test AITBCConfig with custom redis settings"""
|
||||
config = AITBCConfig(
|
||||
redis_url="redis://localhost:6379",
|
||||
redis_max_connections=50,
|
||||
redis_timeout=10
|
||||
)
|
||||
assert config.redis_url == "redis://localhost:6379"
|
||||
assert config.redis_max_connections == 50
|
||||
assert config.redis_timeout == 10
|
||||
|
||||
def test_custom_security_settings(self):
|
||||
"""Test AITBCConfig with custom security settings"""
|
||||
config = AITBCConfig(
|
||||
secret_key="test-secret-key",
|
||||
jwt_secret="test-jwt-secret",
|
||||
jwt_algorithm="RS256",
|
||||
jwt_expiration_hours=48
|
||||
)
|
||||
assert config.secret_key == "test-secret-key"
|
||||
assert config.jwt_secret == "test-jwt-secret"
|
||||
assert config.jwt_algorithm == "RS256"
|
||||
assert config.jwt_expiration_hours == 48
|
||||
|
||||
def test_custom_performance_settings(self):
|
||||
"""Test AITBCConfig with custom performance settings"""
|
||||
config = AITBCConfig(
|
||||
request_timeout=60,
|
||||
max_request_size=20 * 1024 * 1024
|
||||
)
|
||||
assert config.request_timeout == 60
|
||||
assert config.max_request_size == 20 * 1024 * 1024
|
||||
|
||||
def test_inherits_base_config(self):
|
||||
"""Test AITBCConfig inherits from BaseAITBCConfig"""
|
||||
config = AITBCConfig(
|
||||
app_name="Test App",
|
||||
environment="staging"
|
||||
)
|
||||
assert config.app_name == "Test App"
|
||||
assert config.environment == "staging"
|
||||
assert config.host == "0.0.0.0" # AITBCConfig default
|
||||
assert config.port == 8000 # AITBCConfig default
|
||||
|
||||
@patch('aitbc.config.logger')
|
||||
def test_init_logs_configuration(self, mock_logger):
|
||||
"""Test __init__ logs configuration"""
|
||||
config = AITBCConfig(host="localhost", port=9000)
|
||||
mock_logger.info.assert_called_once()
|
||||
assert "localhost:9000" in mock_logger.info.call_args[0][0]
|
||||
mock_logger.debug.assert_called_once()
|
||||
303
tests/test_decorators.py
Normal file
303
tests/test_decorators.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Tests for AITBC decorators
|
||||
"""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from aitbc.decorators import (
|
||||
retry,
|
||||
timing,
|
||||
cache_result,
|
||||
validate_args,
|
||||
handle_exceptions,
|
||||
async_timing,
|
||||
)
|
||||
from aitbc.exceptions import AITBCError
|
||||
|
||||
|
||||
class TestRetry:
|
||||
"""Tests for retry decorator"""
|
||||
|
||||
def test_retry_succeeds_on_first_attempt(self):
|
||||
"""Test retry when function succeeds on first attempt"""
|
||||
@retry(max_attempts=3)
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
result = test_func()
|
||||
assert result == "success"
|
||||
|
||||
def test_retry_succeeds_after_failure(self):
|
||||
"""Test retry when function succeeds after initial failure"""
|
||||
attempts = [0]
|
||||
|
||||
@retry(max_attempts=3, delay=0.01)
|
||||
def test_func():
|
||||
attempts[0] += 1
|
||||
if attempts[0] < 2:
|
||||
raise ValueError("fail")
|
||||
return "success"
|
||||
|
||||
result = test_func()
|
||||
assert result == "success"
|
||||
assert attempts[0] == 2
|
||||
|
||||
def test_retry_exhausts_attempts(self):
|
||||
"""Test retry when function fails after all attempts"""
|
||||
@retry(max_attempts=2, delay=0.01)
|
||||
def test_func():
|
||||
raise ValueError("fail")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
def test_retry_with_specific_exception(self):
|
||||
"""Test retry only catches specified exceptions"""
|
||||
@retry(max_attempts=2, delay=0.01, exceptions=(ValueError,))
|
||||
def test_func():
|
||||
raise TypeError("fail")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
test_func()
|
||||
|
||||
def test_retry_with_backoff(self):
|
||||
"""Test retry with exponential backoff"""
|
||||
attempts = [0]
|
||||
|
||||
@retry(max_attempts=3, delay=0.01, backoff=2.0)
|
||||
def test_func():
|
||||
attempts[0] += 1
|
||||
raise ValueError("fail")
|
||||
|
||||
start_time = time.time()
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Should have delays: 0.01 + 0.02 = 0.03 seconds minimum
|
||||
assert elapsed >= 0.03
|
||||
|
||||
def test_retry_with_on_failure_callback(self):
|
||||
"""Test retry with on_failure callback"""
|
||||
callback_called = [False]
|
||||
|
||||
def on_fail(e):
|
||||
callback_called[0] = True
|
||||
|
||||
@retry(max_attempts=2, delay=0.01, on_failure=on_fail)
|
||||
def test_func():
|
||||
raise ValueError("fail")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
assert callback_called[0] is True
|
||||
|
||||
|
||||
class TestTiming:
|
||||
"""Tests for timing decorator"""
|
||||
|
||||
@patch('aitbc.decorators.logger')
|
||||
def test_timing_logs_execution_time(self, mock_logger):
|
||||
"""Test timing decorator logs execution time"""
|
||||
@timing
|
||||
def test_func():
|
||||
time.sleep(0.01)
|
||||
return "result"
|
||||
|
||||
result = test_func()
|
||||
assert result == "result"
|
||||
mock_logger.info.assert_called_once()
|
||||
assert "executed in" in mock_logger.info.call_args[0][0]
|
||||
|
||||
@patch('aitbc.decorators.logger')
|
||||
def test_timing_preserves_function_name(self, mock_logger):
|
||||
"""Test timing decorator preserves function name"""
|
||||
@timing
|
||||
def my_function():
|
||||
return "result"
|
||||
|
||||
assert my_function.__name__ == "my_function"
|
||||
|
||||
|
||||
class TestCacheResult:
|
||||
"""Tests for cache_result decorator"""
|
||||
|
||||
def test_cache_result_caches_value(self):
|
||||
"""Test cache_result caches function return value"""
|
||||
call_count = [0]
|
||||
|
||||
@cache_result(ttl=60)
|
||||
def test_func(x):
|
||||
call_count[0] += 1
|
||||
return x * 2
|
||||
|
||||
result1 = test_func(5)
|
||||
result2 = test_func(5)
|
||||
|
||||
assert result1 == 10
|
||||
assert result2 == 10
|
||||
assert call_count[0] == 1 # Only called once due to cache
|
||||
|
||||
def test_cache_result_different_args(self):
|
||||
"""Test cache_result with different arguments"""
|
||||
call_count = [0]
|
||||
|
||||
@cache_result(ttl=60)
|
||||
def test_func(x):
|
||||
call_count[0] += 1
|
||||
return x * 2
|
||||
|
||||
test_func(5)
|
||||
test_func(10)
|
||||
|
||||
assert call_count[0] == 2 # Called twice for different args
|
||||
|
||||
def test_cache_result_ttl_expires(self):
|
||||
"""Test cache_result TTL expires"""
|
||||
call_count = [0]
|
||||
|
||||
@cache_result(ttl=0.1) # 100ms TTL
|
||||
def test_func(x):
|
||||
call_count[0] += 1
|
||||
return x * 2
|
||||
|
||||
test_func(5)
|
||||
time.sleep(0.15) # Wait for TTL to expire
|
||||
test_func(5)
|
||||
|
||||
assert call_count[0] == 2 # Called again after TTL expired
|
||||
|
||||
def test_cache_result_with_kwargs(self):
|
||||
"""Test cache_result with keyword arguments"""
|
||||
call_count = [0]
|
||||
|
||||
@cache_result(ttl=60)
|
||||
def test_func(x, y=10):
|
||||
call_count[0] += 1
|
||||
return x + y
|
||||
|
||||
test_func(5, y=10)
|
||||
test_func(5, y=10)
|
||||
|
||||
assert call_count[0] == 1 # Cached
|
||||
|
||||
|
||||
class TestValidateArgs:
|
||||
"""Tests for validate_args decorator"""
|
||||
|
||||
def test_validate_args_passes_valid(self):
|
||||
"""Test validate_args passes when validators succeed"""
|
||||
def validator(x):
|
||||
if x < 0:
|
||||
raise ValueError("Must be positive")
|
||||
|
||||
@validate_args(validator)
|
||||
def test_func(x):
|
||||
return x * 2
|
||||
|
||||
result = test_func(5)
|
||||
assert result == 10
|
||||
|
||||
def test_validate_args_fails_invalid(self):
|
||||
"""Test validate_args fails when validators raise error"""
|
||||
def validator(x):
|
||||
if x < 0:
|
||||
raise ValueError("Must be positive")
|
||||
|
||||
@validate_args(validator)
|
||||
def test_func(x):
|
||||
return x * 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func(-5)
|
||||
|
||||
def test_validate_args_multiple_validators(self):
|
||||
"""Test validate_args with multiple validators"""
|
||||
def validator1(x):
|
||||
if x < 0:
|
||||
raise ValueError("Must be positive")
|
||||
|
||||
def validator2(x):
|
||||
if x > 100:
|
||||
raise ValueError("Must be <= 100")
|
||||
|
||||
@validate_args(validator1, validator2)
|
||||
def test_func(x):
|
||||
return x * 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func(150)
|
||||
|
||||
|
||||
class TestHandleExceptions:
|
||||
"""Tests for handle_exceptions decorator"""
|
||||
|
||||
@patch('aitbc.decorators.logger')
|
||||
def test_handle_exceptions_returns_default(self, mock_logger):
|
||||
"""Test handle_exceptions returns default on exception"""
|
||||
@handle_exceptions(default_return="error")
|
||||
def test_func():
|
||||
raise ValueError("fail")
|
||||
|
||||
result = test_func()
|
||||
assert result == "error"
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
@patch('aitbc.decorators.logger')
|
||||
def test_handle_exceptions_no_logging(self, mock_logger):
|
||||
"""Test handle_exceptions with logging disabled"""
|
||||
@handle_exceptions(default_return="error", log_errors=False)
|
||||
def test_func():
|
||||
raise ValueError("fail")
|
||||
|
||||
result = test_func()
|
||||
assert result == "error"
|
||||
mock_logger.error.assert_not_called()
|
||||
|
||||
def test_handle_exceptions_raises_on_specified(self):
|
||||
"""Test handle_exceptions still raises specified exceptions"""
|
||||
@handle_exceptions(default_return="error", raise_on=(ValueError,))
|
||||
def test_func():
|
||||
raise ValueError("fail")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
def test_handle_exceptions_passes_on_success(self):
|
||||
"""Test handle_exceptions passes through successful return"""
|
||||
@handle_exceptions(default_return="error")
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
result = test_func()
|
||||
assert result == "success"
|
||||
|
||||
|
||||
class TestAsyncTiming:
|
||||
"""Tests for async_timing decorator"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('aitbc.decorators.logger')
|
||||
async def test_async_timing_logs_execution_time(self, mock_logger):
|
||||
"""Test async_timing decorator logs execution time"""
|
||||
@async_timing
|
||||
async def test_func():
|
||||
await asyncio.sleep(0.01)
|
||||
return "result"
|
||||
|
||||
import asyncio
|
||||
result = await test_func()
|
||||
assert result == "result"
|
||||
mock_logger.info.assert_called_once()
|
||||
assert "executed in" in mock_logger.info.call_args[0][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_timing_preserves_function_name(self):
|
||||
"""Test async_timing decorator preserves function name"""
|
||||
@async_timing
|
||||
async def my_function():
|
||||
return "result"
|
||||
|
||||
assert my_function.__name__ == "my_function"
|
||||
218
tests/test_health_checks.py
Normal file
218
tests/test_health_checks.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
Tests for health check utilities
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from datetime import datetime
|
||||
|
||||
from aitbc.health_checks import (
|
||||
HealthStatus,
|
||||
HealthCheck,
|
||||
HealthChecker,
|
||||
create_basic_health_check,
|
||||
)
|
||||
|
||||
|
||||
class TestHealthStatus:
|
||||
"""Tests for HealthStatus enum"""
|
||||
|
||||
def test_health_status_values(self):
|
||||
"""Test HealthStatus enum values"""
|
||||
assert HealthStatus.HEALTHY.value == "healthy"
|
||||
assert HealthStatus.DEGRADED.value == "degraded"
|
||||
assert HealthStatus.UNHEALTHY.value == "unhealthy"
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for HealthCheck dataclass"""
|
||||
|
||||
def test_health_check_creation(self):
|
||||
"""Test HealthCheck dataclass creation"""
|
||||
check = HealthCheck(
|
||||
service="test-service",
|
||||
status=HealthStatus.HEALTHY,
|
||||
message="All good",
|
||||
timestamp=datetime.now(),
|
||||
details={"key": "value"}
|
||||
)
|
||||
assert check.service == "test-service"
|
||||
assert check.status == HealthStatus.HEALTHY
|
||||
assert check.message == "All good"
|
||||
assert check.details == {"key": "value"}
|
||||
|
||||
def test_health_check_without_details(self):
|
||||
"""Test HealthCheck without optional details"""
|
||||
check = HealthCheck(
|
||||
service="test-service",
|
||||
status=HealthStatus.HEALTHY,
|
||||
message="All good",
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
assert check.details is None
|
||||
|
||||
|
||||
class TestHealthChecker:
|
||||
"""Tests for HealthChecker"""
|
||||
|
||||
def test_health_checker_initialization(self):
|
||||
"""Test HealthChecker initialization"""
|
||||
checker = HealthChecker("test-service")
|
||||
assert checker.service_name == "test-service"
|
||||
assert checker._checks == {}
|
||||
assert checker._last_check is None
|
||||
|
||||
def test_register_check(self):
|
||||
"""Test registering a health check"""
|
||||
checker = HealthChecker("test-service")
|
||||
|
||||
def mock_check():
|
||||
return HealthStatus.HEALTHY, "OK", {}
|
||||
|
||||
checker.register_check("memory", mock_check)
|
||||
assert "memory" in checker._checks
|
||||
assert checker._checks["memory"] == mock_check
|
||||
|
||||
@patch('aitbc.health_checks.logger')
|
||||
def test_register_check_logs(self, mock_logger):
|
||||
"""Test register_check logs registration"""
|
||||
checker = HealthChecker("test-service")
|
||||
|
||||
def mock_check():
|
||||
return HealthStatus.HEALTHY, "OK", {}
|
||||
|
||||
checker.register_check("memory", mock_check)
|
||||
mock_logger.info.assert_called_once()
|
||||
assert "memory" in mock_logger.info.call_args[0][0]
|
||||
|
||||
def test_run_checks_all_healthy(self):
|
||||
"""Test run_checks when all checks pass"""
|
||||
checker = HealthChecker("test-service")
|
||||
|
||||
def mock_check():
|
||||
return HealthStatus.HEALTHY, "OK", {}
|
||||
|
||||
checker.register_check("check1", mock_check)
|
||||
checker.register_check("check2", mock_check)
|
||||
|
||||
result = checker.run_checks()
|
||||
|
||||
assert result.service == "test-service"
|
||||
assert result.status == HealthStatus.HEALTHY
|
||||
assert result.message == "All health checks passed"
|
||||
assert result.details is not None
|
||||
assert len(result.details) == 2
|
||||
|
||||
def test_run_checks_one_degraded(self):
|
||||
"""Test run_checks with one degraded check"""
|
||||
checker = HealthChecker("test-service")
|
||||
|
||||
def healthy_check():
|
||||
return HealthStatus.HEALTHY, "OK", {}
|
||||
|
||||
def degraded_check():
|
||||
return HealthStatus.DEGRADED, "Warning", {}
|
||||
|
||||
checker.register_check("healthy", healthy_check)
|
||||
checker.register_check("degraded", degraded_check)
|
||||
|
||||
result = checker.run_checks()
|
||||
|
||||
assert result.status == HealthStatus.DEGRADED
|
||||
assert "degraded" in result.message
|
||||
|
||||
def test_run_checks_one_unhealthy(self):
|
||||
"""Test run_checks with one unhealthy check"""
|
||||
checker = HealthChecker("test-service")
|
||||
|
||||
def healthy_check():
|
||||
return HealthStatus.HEALTHY, "OK", {}
|
||||
|
||||
def unhealthy_check():
|
||||
return HealthStatus.UNHEALTHY, "Error", {}
|
||||
|
||||
checker.register_check("healthy", healthy_check)
|
||||
checker.register_check("unhealthy", unhealthy_check)
|
||||
|
||||
result = checker.run_checks()
|
||||
|
||||
assert result.status == HealthStatus.UNHEALTHY
|
||||
assert "unhealthy" in result.message
|
||||
|
||||
@patch('aitbc.health_checks.logger')
|
||||
def test_run_checks_with_exception(self, mock_logger):
|
||||
"""Test run_checks handles exceptions in checks"""
|
||||
checker = HealthChecker("test-service")
|
||||
|
||||
def failing_check():
|
||||
raise ValueError("Check failed")
|
||||
|
||||
checker.register_check("failing", failing_check)
|
||||
|
||||
result = checker.run_checks()
|
||||
|
||||
assert result.status == HealthStatus.UNHEALTHY
|
||||
assert "failing" in result.message
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
def test_get_last_check_before_run(self):
|
||||
"""Test get_last_check returns None before any check run"""
|
||||
checker = HealthChecker("test-service")
|
||||
assert checker.get_last_check() is None
|
||||
|
||||
def test_get_last_check_after_run(self):
|
||||
"""Test get_last_check returns last check result"""
|
||||
checker = HealthChecker("test-service")
|
||||
|
||||
def mock_check():
|
||||
return HealthStatus.HEALTHY, "OK", {}
|
||||
|
||||
checker.register_check("check1", mock_check)
|
||||
checker.run_checks()
|
||||
|
||||
last_check = checker.get_last_check()
|
||||
assert last_check is not None
|
||||
assert last_check.service == "test-service"
|
||||
|
||||
def test_get_health_dict(self):
|
||||
"""Test get_health_dict returns dictionary representation"""
|
||||
checker = HealthChecker("test-service")
|
||||
|
||||
def mock_check():
|
||||
return HealthStatus.HEALTHY, "OK", {"key": "value"}
|
||||
|
||||
checker.register_check("check1", mock_check)
|
||||
health_dict = checker.get_health_dict()
|
||||
|
||||
assert isinstance(health_dict, dict)
|
||||
assert "service" in health_dict
|
||||
assert "status" in health_dict
|
||||
assert "message" in health_dict
|
||||
assert "timestamp" in health_dict
|
||||
assert health_dict["service"] == "test-service"
|
||||
|
||||
|
||||
class TestCreateBasicHealthCheck:
|
||||
"""Tests for create_basic_health_check"""
|
||||
|
||||
def test_create_basic_health_check(self):
|
||||
"""Test create_basic_health_check returns HealthChecker"""
|
||||
checker = create_basic_health_check("test-service")
|
||||
assert isinstance(checker, HealthChecker)
|
||||
assert checker.service_name == "test-service"
|
||||
|
||||
def test_create_basic_health_check_without_psutil(self):
|
||||
"""Test create_basic_health_check handles psutil ImportError"""
|
||||
# Skip this test as psutil import handling is complex to mock
|
||||
pytest.skip("psutil import handling requires complex mocking")
|
||||
|
||||
def test_basic_health_check_has_checks(self):
|
||||
"""Test basic health check has registered checks when psutil available"""
|
||||
try:
|
||||
import psutil
|
||||
checker = create_basic_health_check("test-service")
|
||||
# Should have memory and disk checks if psutil is available
|
||||
assert len(checker._checks) > 0
|
||||
except ImportError:
|
||||
# Skip if psutil not available
|
||||
pass
|
||||
250
tests/test_metrics.py
Normal file
250
tests/test_metrics.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Tests for AITBC metrics module
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from aitbc.metrics import (
|
||||
service_info,
|
||||
block_processing_duration,
|
||||
block_height,
|
||||
block_validation_duration,
|
||||
block_propagation_duration,
|
||||
job_submission_duration,
|
||||
job_processing_duration,
|
||||
job_queue_duration,
|
||||
job_execution_duration,
|
||||
jobs_total,
|
||||
jobs_failed_total,
|
||||
jobs_in_queue,
|
||||
http_requests_total,
|
||||
http_request_duration,
|
||||
service_uptime_seconds,
|
||||
service_restart_count,
|
||||
track_block_processing,
|
||||
track_job_processing,
|
||||
track_http_request,
|
||||
update_block_height,
|
||||
update_jobs_in_queue,
|
||||
increment_service_restarts,
|
||||
metrics_app,
|
||||
setup_service_info,
|
||||
)
|
||||
|
||||
|
||||
class TestMetricsDefinitions:
|
||||
"""Tests for Prometheus metrics definitions"""
|
||||
|
||||
def test_service_info_exists(self):
|
||||
"""Test service_info metric is defined"""
|
||||
assert service_info is not None
|
||||
assert service_info._name == 'service_info'
|
||||
|
||||
def test_block_processing_duration_exists(self):
|
||||
"""Test block_processing_duration metric is defined"""
|
||||
assert block_processing_duration is not None
|
||||
assert block_processing_duration._name == 'block_processing_duration_seconds'
|
||||
|
||||
def test_block_height_exists(self):
|
||||
"""Test block_height metric is defined"""
|
||||
assert block_height is not None
|
||||
assert block_height._name == 'block_height'
|
||||
|
||||
def test_block_validation_duration_exists(self):
|
||||
"""Test block_validation_duration metric is defined"""
|
||||
assert block_validation_duration is not None
|
||||
assert block_validation_duration._name == 'block_validation_duration_seconds'
|
||||
|
||||
def test_block_propagation_duration_exists(self):
|
||||
"""Test block_propagation_duration metric is defined"""
|
||||
assert block_propagation_duration is not None
|
||||
assert block_propagation_duration._name == 'block_propagation_duration_seconds'
|
||||
|
||||
def test_job_submission_duration_exists(self):
|
||||
"""Test job_submission_duration metric is defined"""
|
||||
assert job_submission_duration is not None
|
||||
assert job_submission_duration._name == 'job_submission_duration_seconds'
|
||||
|
||||
def test_job_processing_duration_exists(self):
|
||||
"""Test job_processing_duration metric is defined"""
|
||||
assert job_processing_duration is not None
|
||||
assert job_processing_duration._name == 'job_processing_duration_seconds'
|
||||
|
||||
def test_job_queue_duration_exists(self):
|
||||
"""Test job_queue_duration metric is defined"""
|
||||
assert job_queue_duration is not None
|
||||
assert job_queue_duration._name == 'job_queue_duration_seconds'
|
||||
|
||||
def test_job_execution_duration_exists(self):
|
||||
"""Test job_execution_duration metric is defined"""
|
||||
assert job_execution_duration is not None
|
||||
assert job_execution_duration._name == 'job_execution_duration_seconds'
|
||||
|
||||
def test_jobs_total_exists(self):
|
||||
"""Test jobs_total metric is defined"""
|
||||
assert jobs_total is not None
|
||||
assert jobs_total._name == 'jobs'
|
||||
|
||||
def test_jobs_failed_total_exists(self):
|
||||
"""Test jobs_failed_total metric is defined"""
|
||||
assert jobs_failed_total is not None
|
||||
assert jobs_failed_total._name == 'jobs_failed'
|
||||
|
||||
def test_jobs_in_queue_exists(self):
|
||||
"""Test jobs_in_queue metric is defined"""
|
||||
assert jobs_in_queue is not None
|
||||
assert jobs_in_queue._name == 'jobs_in_queue'
|
||||
|
||||
def test_http_requests_total_exists(self):
|
||||
"""Test http_requests_total metric is defined"""
|
||||
assert http_requests_total is not None
|
||||
assert http_requests_total._name == 'http_requests'
|
||||
|
||||
def test_http_request_duration_exists(self):
|
||||
"""Test http_request_duration metric is defined"""
|
||||
assert http_request_duration is not None
|
||||
assert http_request_duration._name == 'http_request_duration_seconds'
|
||||
|
||||
def test_service_uptime_seconds_exists(self):
|
||||
"""Test service_uptime_seconds metric is defined"""
|
||||
assert service_uptime_seconds is not None
|
||||
assert service_uptime_seconds._name == 'service_uptime_seconds'
|
||||
|
||||
def test_service_restart_count_exists(self):
|
||||
"""Test service_restart_count metric is defined"""
|
||||
assert service_restart_count is not None
|
||||
assert service_restart_count._name == 'service_restart_count'
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Tests for metrics helper functions"""
|
||||
|
||||
def test_update_block_height(self):
|
||||
"""Test update_block_height sets metric"""
|
||||
update_block_height(100)
|
||||
# Metric should be set, but we can't easily verify the value
|
||||
# This test ensures the function doesn't raise an error
|
||||
assert True
|
||||
|
||||
def test_update_jobs_in_queue(self):
|
||||
"""Test update_jobs_in_queue sets metric"""
|
||||
update_jobs_in_queue(50)
|
||||
# Metric should be set, but we can't easily verify the value
|
||||
# This test ensures the function doesn't raise an error
|
||||
assert True
|
||||
|
||||
def test_increment_service_restarts(self):
|
||||
"""Test increment_service_restarts increments counter"""
|
||||
increment_service_restarts()
|
||||
# Counter should be incremented, but we can't easily verify the value
|
||||
# This test ensures the function doesn't raise an error
|
||||
assert True
|
||||
|
||||
def test_setup_service_info(self):
|
||||
"""Test setup_service_info sets service info"""
|
||||
setup_service_info("test-service", "1.0.0")
|
||||
# Info should be set, but we can't easily verify the value
|
||||
# This test ensures the function doesn't raise an error
|
||||
assert True
|
||||
|
||||
|
||||
class TestDecorators:
|
||||
"""Tests for metrics tracking decorators"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_block_processing_success(self):
|
||||
"""Test track_block_processing decorator on successful execution"""
|
||||
@track_block_processing
|
||||
async def process_block():
|
||||
return "block_processed"
|
||||
|
||||
result = await process_block()
|
||||
assert result == "block_processed"
|
||||
# Decorator should have observed the duration
|
||||
assert True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_block_processing_failure(self):
|
||||
"""Test track_block_processing decorator on exception"""
|
||||
@track_block_processing
|
||||
async def process_block():
|
||||
raise ValueError("block error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await process_block()
|
||||
# Decorator should have observed the duration even on failure
|
||||
assert True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_job_processing_success(self):
|
||||
"""Test track_job_processing decorator on successful execution"""
|
||||
@track_job_processing
|
||||
async def process_job():
|
||||
return "job_completed"
|
||||
|
||||
result = await process_job()
|
||||
assert result == "job_completed"
|
||||
# Decorator should have observed duration and incremented jobs_total
|
||||
assert True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_job_processing_failure(self):
|
||||
"""Test track_job_processing decorator on exception"""
|
||||
@track_job_processing
|
||||
async def process_job():
|
||||
raise ValueError("job error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await process_job()
|
||||
# Decorator should have observed duration and incremented failure counters
|
||||
assert True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_http_request_success(self):
|
||||
"""Test track_http_request decorator on successful execution"""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
@track_http_request
|
||||
async def handle_request():
|
||||
return mock_response
|
||||
|
||||
result = await handle_request()
|
||||
assert result.status_code == 200
|
||||
# Decorator should have observed duration and incremented http_requests_total
|
||||
assert True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_http_request_failure(self):
|
||||
"""Test track_http_request decorator on exception"""
|
||||
@track_http_request
|
||||
async def handle_request():
|
||||
raise ValueError("request error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await handle_request()
|
||||
# Decorator should have observed duration and incremented http_requests_total with 500
|
||||
assert True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_http_request_without_status_code(self):
|
||||
"""Test track_http_request with response without status_code"""
|
||||
@track_http_request
|
||||
async def handle_request():
|
||||
return "success" # No status_code attribute
|
||||
|
||||
result = await handle_request()
|
||||
assert result == "success"
|
||||
# Decorator should have observed duration but not incremented http_requests_total
|
||||
assert True
|
||||
|
||||
|
||||
class TestMetricsApp:
|
||||
"""Tests for metrics ASGI app"""
|
||||
|
||||
def test_metrics_app_exists(self):
|
||||
"""Test metrics_app is created"""
|
||||
assert metrics_app is not None
|
||||
assert hasattr(metrics_app, '__call__')
|
||||
266
tests/test_middleware.py
Normal file
266
tests/test_middleware.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Tests for AITBC middleware modules
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from fastapi import Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from aitbc.middleware.performance import PerformanceLoggingMiddleware
|
||||
from aitbc.middleware.request_id import RequestIDMiddleware
|
||||
from aitbc.middleware.error_handler import ErrorHandlerMiddleware
|
||||
|
||||
|
||||
class TestPerformanceLoggingMiddleware:
|
||||
"""Tests for PerformanceLoggingMiddleware"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_adds_performance_header(self):
|
||||
"""Test that middleware adds X-Process-Time header"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = PerformanceLoggingMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
request.url = Mock()
|
||||
request.url.path = "/test"
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.status_code = 200
|
||||
response.headers = {}
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert "X-Process-Time" in result.headers
|
||||
assert float(result.headers["X-Process-Time"]) >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_logs_performance_metrics(self):
|
||||
"""Test that middleware logs performance metrics"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = PerformanceLoggingMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.method = "POST"
|
||||
request.url = Mock()
|
||||
request.url.path = "/api/test"
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.status_code = 201
|
||||
response.headers = {}
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('aitbc.middleware.performance.logger') as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
mock_logger.info.assert_called_once()
|
||||
assert "Request performance" in mock_logger.info.call_args[0][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_measures_time_correctly(self):
|
||||
"""Test that middleware measures request duration accurately"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = PerformanceLoggingMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
request.url = Mock()
|
||||
request.url.path = "/test"
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.status_code = 200
|
||||
response.headers = {}
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
process_time = float(result.headers["X-Process-Time"])
|
||||
assert 0 <= process_time < 1.0 # Should complete in under 1 second
|
||||
|
||||
|
||||
class TestRequestIDMiddleware:
|
||||
"""Tests for RequestIDMiddleware"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_generates_request_id_when_missing(self):
|
||||
"""Test that middleware generates request ID when not in headers"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = RequestIDMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
request.method = "GET"
|
||||
request.url = Mock()
|
||||
request.url.path = "/test"
|
||||
request.client = Mock()
|
||||
request.client.host = "127.0.0.1"
|
||||
request.state = Mock()
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.headers = {}
|
||||
response.status_code = 200
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert "X-Request-ID" in result.headers
|
||||
assert len(result.headers["X-Request-ID"]) > 0
|
||||
assert request.state.request_id == result.headers["X-Request-ID"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_uses_existing_request_id_from_header(self):
|
||||
"""Test that middleware uses existing request ID from header"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = RequestIDMiddleware(app)
|
||||
|
||||
existing_id = "test-request-id-123"
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Request-ID": existing_id}
|
||||
request.method = "POST"
|
||||
request.url = Mock()
|
||||
request.url.path = "/api/test"
|
||||
request.client = Mock()
|
||||
request.client.host = "192.168.1.1"
|
||||
request.state = Mock()
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.headers = {}
|
||||
response.status_code = 201
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert result.headers["X-Request-ID"] == existing_id
|
||||
assert request.state.request_id == existing_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_logs_request_info(self):
|
||||
"""Test that middleware logs request information"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = RequestIDMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
request.method = "GET"
|
||||
request.url = Mock()
|
||||
request.url.path = "/test"
|
||||
request.client = Mock()
|
||||
request.client.host = "127.0.0.1"
|
||||
request.state = Mock()
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.headers = {}
|
||||
response.status_code = 200
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('aitbc.middleware.request_id.logger') as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
assert mock_logger.info.call_count >= 2 # Logs start and completion
|
||||
|
||||
|
||||
class TestErrorHandlerMiddleware:
|
||||
"""Tests for ErrorHandlerMiddleware"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_passes_through_normal_response(self):
|
||||
"""Test that middleware passes through normal responses"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = ErrorHandlerMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock()
|
||||
request.url.path = "/test"
|
||||
request.method = "GET"
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.status_code = 200
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert result == response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_handles_http_exception(self):
|
||||
"""Test that middleware handles HTTPException"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = ErrorHandlerMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock()
|
||||
request.url.path = "/api/error"
|
||||
request.method = "GET"
|
||||
|
||||
exception = HTTPException(status_code=404, detail="Not found")
|
||||
call_next = AsyncMock(side_effect=exception)
|
||||
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 404
|
||||
content = result.body.decode() if hasattr(result, 'body') else {}
|
||||
assert "error" in result.body.decode() if hasattr(result, 'body') else True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_handles_generic_exception(self):
|
||||
"""Test that middleware handles generic exceptions"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = ErrorHandlerMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock()
|
||||
request.url.path = "/api/crash"
|
||||
request.method = "POST"
|
||||
|
||||
exception = ValueError("Something went wrong")
|
||||
call_next = AsyncMock(side_effect=exception)
|
||||
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_logs_http_exception(self):
|
||||
"""Test that middleware logs HTTPException"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = ErrorHandlerMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock()
|
||||
request.url.path = "/test"
|
||||
request.method = "GET"
|
||||
|
||||
exception = HTTPException(status_code=400, detail="Bad request")
|
||||
call_next = AsyncMock(side_effect=exception)
|
||||
|
||||
with patch('aitbc.middleware.error_handler.logger') as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_logs_generic_exception(self):
|
||||
"""Test that middleware logs generic exceptions"""
|
||||
app = Mock(spec=ASGIApp)
|
||||
middleware = ErrorHandlerMiddleware(app)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock()
|
||||
request.url.path = "/test"
|
||||
request.method = "GET"
|
||||
|
||||
exception = RuntimeError("Runtime error")
|
||||
call_next = AsyncMock(side_effect=exception)
|
||||
|
||||
with patch('aitbc.middleware.error_handler.logger') as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
mock_logger.error.assert_called_once()
|
||||
287
tests/test_security_headers.py
Normal file
287
tests/test_security_headers.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Tests for security headers and CORS utilities
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from aitbc.security_headers import (
|
||||
SecurityHeaders,
|
||||
CORSConfig,
|
||||
SecurityHeadersMiddleware,
|
||||
CORSMiddleware,
|
||||
create_production_security_headers,
|
||||
create_development_security_headers,
|
||||
create_strict_cors_config,
|
||||
create_permissive_cors_config,
|
||||
)
|
||||
|
||||
|
||||
class TestSecurityHeaders:
|
||||
"""Tests for SecurityHeaders dataclass"""
|
||||
|
||||
def test_default_security_headers(self):
|
||||
"""Test default security headers values"""
|
||||
headers = SecurityHeaders()
|
||||
assert headers.X_Content_Type_Options == "nosniff"
|
||||
assert headers.X_Frame_Options == "DENY"
|
||||
assert headers.X_XSS_Protection == "1; mode=block"
|
||||
assert headers.Strict_Transport_Security == "max-age=31536000; includeSubDomains"
|
||||
assert headers.Content_Security_Policy == "default-src 'self'"
|
||||
assert headers.Referrer_Policy == "strict-origin-when-cross-origin"
|
||||
assert headers.Permissions_Policy == ""
|
||||
assert headers.Cache_Control == "no-cache, no-store, must-revalidate"
|
||||
assert headers.Pragma == "no-cache"
|
||||
|
||||
def test_custom_security_headers(self):
|
||||
"""Test custom security headers values"""
|
||||
headers = SecurityHeaders(
|
||||
X_Frame_Options="SAMEORIGIN",
|
||||
Content_Security_Policy="default-src 'self' https://example.com"
|
||||
)
|
||||
assert headers.X_Frame_Options == "SAMEORIGIN"
|
||||
assert headers.Content_Security_Policy == "default-src 'self' https://example.com"
|
||||
|
||||
|
||||
class TestCORSConfig:
|
||||
"""Tests for CORSConfig dataclass"""
|
||||
|
||||
def test_default_cors_config(self):
|
||||
"""Test default CORS config values"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"]
|
||||
)
|
||||
assert config.allow_origins == ["http://localhost:3000"]
|
||||
assert config.allow_methods == ["GET", "POST"]
|
||||
assert config.allow_credentials is False
|
||||
assert config.expose_headers is None
|
||||
assert config.max_age == 3600
|
||||
|
||||
def test_custom_cors_config(self):
|
||||
"""Test custom CORS config values"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST", "PUT"],
|
||||
allow_headers=["Content-Type"],
|
||||
allow_credentials=True,
|
||||
expose_headers=["X-Custom-Header"],
|
||||
max_age=7200
|
||||
)
|
||||
assert config.allow_origins == ["*"]
|
||||
assert config.allow_credentials is True
|
||||
assert config.expose_headers == ["X-Custom-Header"]
|
||||
assert config.max_age == 7200
|
||||
|
||||
|
||||
class TestSecurityHeadersMiddleware:
|
||||
"""Tests for SecurityHeadersMiddleware"""
|
||||
|
||||
def test_initialization_with_default_headers(self):
|
||||
"""Test middleware initialization with default headers"""
|
||||
middleware = SecurityHeadersMiddleware()
|
||||
assert middleware.headers is not None
|
||||
assert middleware.headers.X_Frame_Options == "DENY"
|
||||
|
||||
def test_initialization_with_custom_headers(self):
|
||||
"""Test middleware initialization with custom headers"""
|
||||
custom_headers = SecurityHeaders(X_Frame_Options="SAMEORIGIN")
|
||||
middleware = SecurityHeadersMiddleware(custom_headers)
|
||||
assert middleware.headers.X_Frame_Options == "SAMEORIGIN"
|
||||
|
||||
def test_get_headers(self):
|
||||
"""Test get_headers returns dictionary"""
|
||||
middleware = SecurityHeadersMiddleware()
|
||||
headers = middleware.get_headers()
|
||||
assert isinstance(headers, dict)
|
||||
assert "X-Content-Type-Options" in headers
|
||||
assert "X-Frame-Options" in headers
|
||||
assert "X-XSS-Protection" in headers
|
||||
assert "Strict-Transport-Security" in headers
|
||||
assert "Content-Security-Policy" in headers
|
||||
assert "Referrer-Policy" in headers
|
||||
assert "Permissions-Policy" in headers
|
||||
assert "Cache-Control" in headers
|
||||
assert "Pragma" in headers
|
||||
|
||||
def test_get_headers_values(self):
|
||||
"""Test get_headers returns correct values"""
|
||||
middleware = SecurityHeadersMiddleware()
|
||||
headers = middleware.get_headers()
|
||||
assert headers["X-Content-Type-Options"] == "nosniff"
|
||||
assert headers["X-Frame-Options"] == "DENY"
|
||||
assert headers["X-XSS-Protection"] == "1; mode=block"
|
||||
|
||||
def test_apply_to_response(self):
|
||||
"""Test apply_to_response adds security headers"""
|
||||
middleware = SecurityHeadersMiddleware()
|
||||
response_headers = {"Content-Type": "application/json"}
|
||||
result = middleware.apply_to_response(response_headers)
|
||||
|
||||
assert "X-Content-Type-Options" in result
|
||||
assert "X-Frame-Options" in result
|
||||
assert result["Content-Type"] == "application/json"
|
||||
|
||||
def test_apply_to_response_overwrites_existing(self):
|
||||
"""Test apply_to_response overwrites existing security headers"""
|
||||
middleware = SecurityHeadersMiddleware()
|
||||
response_headers = {"X-Frame-Options": "ALLOW"}
|
||||
result = middleware.apply_to_response(response_headers)
|
||||
|
||||
assert result["X-Frame-Options"] == "DENY"
|
||||
|
||||
|
||||
class TestCORSMiddleware:
|
||||
"""Tests for CORSMiddleware"""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test CORS middleware initialization"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"]
|
||||
)
|
||||
middleware = CORSMiddleware(config)
|
||||
assert middleware.config == config
|
||||
|
||||
def test_get_cors_headers_allowed_origin(self):
|
||||
"""Test get_cors_headers with allowed origin"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"]
|
||||
)
|
||||
middleware = CORSMiddleware(config)
|
||||
headers = middleware.get_cors_headers("http://localhost:3000")
|
||||
|
||||
assert headers["Access-Control-Allow-Origin"] == "http://localhost:3000"
|
||||
assert headers["Access-Control-Allow-Methods"] == "GET, POST"
|
||||
assert headers["Access-Control-Allow-Headers"] == "Content-Type"
|
||||
assert headers["Access-Control-Max-Age"] == "3600"
|
||||
|
||||
def test_get_cors_headers_disallowed_origin(self):
|
||||
"""Test get_cors_headers with disallowed origin"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"]
|
||||
)
|
||||
middleware = CORSMiddleware(config)
|
||||
headers = middleware.get_cors_headers("http://evil.com")
|
||||
|
||||
assert headers == {}
|
||||
|
||||
def test_get_cors_headers_wildcard_origin(self):
|
||||
"""Test get_cors_headers with wildcard origin"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"]
|
||||
)
|
||||
middleware = CORSMiddleware(config)
|
||||
headers = middleware.get_cors_headers("http://any-origin.com")
|
||||
|
||||
assert headers["Access-Control-Allow-Origin"] == "http://any-origin.com"
|
||||
|
||||
def test_get_cors_headers_with_credentials(self):
|
||||
"""Test get_cors_headers with credentials enabled"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"],
|
||||
allow_credentials=True
|
||||
)
|
||||
middleware = CORSMiddleware(config)
|
||||
headers = middleware.get_cors_headers("http://localhost:3000")
|
||||
|
||||
assert headers["Access-Control-Allow-Credentials"] == "true"
|
||||
|
||||
def test_get_cors_headers_with_expose_headers(self):
|
||||
"""Test get_cors_headers with expose headers"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"],
|
||||
expose_headers=["X-Request-ID"]
|
||||
)
|
||||
middleware = CORSMiddleware(config)
|
||||
headers = middleware.get_cors_headers("http://localhost:3000")
|
||||
|
||||
assert headers["Access-Control-Expose-Headers"] == "X-Request-ID"
|
||||
|
||||
def test_is_origin_allowed_wildcard(self):
|
||||
"""Test _is_origin_allowed with wildcard"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET"],
|
||||
allow_headers=["Content-Type"]
|
||||
)
|
||||
middleware = CORSMiddleware(config)
|
||||
assert middleware._is_origin_allowed("http://any-origin.com") is True
|
||||
|
||||
def test_is_origin_allowed_specific(self):
|
||||
"""Test _is_origin_allowed with specific origin"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_methods=["GET"],
|
||||
allow_headers=["Content-Type"]
|
||||
)
|
||||
middleware = CORSMiddleware(config)
|
||||
assert middleware._is_origin_allowed("http://localhost:3000") is True
|
||||
assert middleware._is_origin_allowed("http://evil.com") is False
|
||||
|
||||
def test_is_preflight_request(self):
|
||||
"""Test is_preflight_request"""
|
||||
config = CORSConfig(
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET"],
|
||||
allow_headers=["Content-Type"]
|
||||
)
|
||||
middleware = CORSMiddleware(config)
|
||||
assert middleware.is_preflight_request("OPTIONS") is True
|
||||
assert middleware.is_preflight_request("GET") is False
|
||||
assert middleware.is_preflight_request("options") is True
|
||||
|
||||
|
||||
class TestFactoryFunctions:
|
||||
"""Tests for factory functions"""
|
||||
|
||||
def test_create_production_security_headers(self):
|
||||
"""Test create_production_security_headers"""
|
||||
headers = create_production_security_headers()
|
||||
assert headers.X_Frame_Options == "DENY"
|
||||
assert "preload" in headers.Strict_Transport_Security
|
||||
assert "unsafe-inline" in headers.Content_Security_Policy
|
||||
assert "geolocation=()" in headers.Permissions_Policy
|
||||
|
||||
def test_create_development_security_headers(self):
|
||||
"""Test create_development_security_headers"""
|
||||
headers = create_development_security_headers()
|
||||
assert headers.X_Frame_Options == "SAMEORIGIN"
|
||||
assert headers.Strict_Transport_Security == "max-age=3600"
|
||||
assert headers.Permissions_Policy == ""
|
||||
assert headers.Cache_Control == "no-cache"
|
||||
|
||||
def test_create_strict_cors_config(self):
|
||||
"""Test create_strict_cors_config"""
|
||||
config = create_strict_cors_config(["http://localhost:3000"])
|
||||
assert "http://localhost:3000" in config.allow_origins
|
||||
assert "GET" in config.allow_methods
|
||||
assert "POST" in config.allow_methods
|
||||
assert "PUT" in config.allow_methods
|
||||
assert "DELETE" in config.allow_methods
|
||||
assert "PATCH" in config.allow_methods
|
||||
assert config.allow_credentials is True
|
||||
assert "X-Request-ID" in config.expose_headers
|
||||
assert config.max_age == 3600
|
||||
|
||||
def test_create_permissive_cors_config(self):
|
||||
"""Test create_permissive_cors_config"""
|
||||
config = create_permissive_cors_config()
|
||||
assert "*" in config.allow_origins
|
||||
assert "GET" in config.allow_methods
|
||||
assert "POST" in config.allow_methods
|
||||
assert "*" in config.allow_headers
|
||||
assert config.allow_credentials is False
|
||||
assert "*" in config.expose_headers
|
||||
assert config.max_age == 86400
|
||||
345
tests/test_utils.py
Normal file
345
tests/test_utils.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Tests for AITBC utility modules
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from aitbc.utils.paths import (
|
||||
get_data_path,
|
||||
get_config_path,
|
||||
get_log_path,
|
||||
get_repo_path,
|
||||
ensure_dir,
|
||||
ensure_file_dir,
|
||||
resolve_path,
|
||||
get_keystore_path,
|
||||
get_blockchain_data_path,
|
||||
get_marketplace_data_path,
|
||||
)
|
||||
from aitbc.utils.env import (
|
||||
get_env_var,
|
||||
get_required_env_var,
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
get_float_env_var,
|
||||
get_list_env_var,
|
||||
)
|
||||
from aitbc.utils.json_utils import (
|
||||
load_json,
|
||||
save_json,
|
||||
merge_json,
|
||||
json_to_string,
|
||||
string_to_json,
|
||||
get_nested_value,
|
||||
set_nested_value,
|
||||
flatten_json,
|
||||
)
|
||||
from aitbc.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestPaths:
|
||||
"""Tests for path utility functions"""
|
||||
|
||||
def test_get_data_path_no_subpath(self):
|
||||
"""Test get_data_path without subpath"""
|
||||
result = get_data_path()
|
||||
assert isinstance(result, Path)
|
||||
|
||||
def test_get_data_path_with_subpath(self):
|
||||
"""Test get_data_path with subpath"""
|
||||
result = get_data_path("test")
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("test")
|
||||
|
||||
def test_get_config_path(self):
|
||||
"""Test get_config_path"""
|
||||
result = get_config_path("config.yaml")
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("config.yaml")
|
||||
|
||||
def test_get_log_path(self):
|
||||
"""Test get_log_path"""
|
||||
result = get_log_path("app.log")
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("app.log")
|
||||
|
||||
def test_get_repo_path_no_subpath(self):
|
||||
"""Test get_repo_path without subpath"""
|
||||
result = get_repo_path()
|
||||
assert isinstance(result, Path)
|
||||
|
||||
def test_get_repo_path_with_subpath(self):
|
||||
"""Test get_repo_path with subpath"""
|
||||
result = get_repo_path("src")
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("src")
|
||||
|
||||
def test_ensure_dir(self):
|
||||
"""Test ensure_dir creates directory"""
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_path = Path(tmpdir) / "test" / "nested"
|
||||
result = ensure_dir(test_path)
|
||||
assert result.exists()
|
||||
assert result.is_dir()
|
||||
|
||||
def test_ensure_file_dir(self):
|
||||
"""Test ensure_file_dir creates parent directory"""
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_path = Path(tmpdir) / "test" / "nested" / "file.txt"
|
||||
result = ensure_file_dir(test_path)
|
||||
assert result.exists()
|
||||
assert result.is_dir()
|
||||
|
||||
def test_resolve_path_absolute(self):
|
||||
"""Test resolve_path with absolute path"""
|
||||
result = resolve_path("/tmp/test")
|
||||
assert result.is_absolute()
|
||||
|
||||
def test_resolve_path_relative(self):
|
||||
"""Test resolve_path with relative path"""
|
||||
result = resolve_path("test")
|
||||
assert result.is_absolute() # Relative paths are resolved to absolute
|
||||
|
||||
def test_get_keystore_path_no_wallet(self):
|
||||
"""Test get_keystore_path without wallet name"""
|
||||
result = get_keystore_path()
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("keystore")
|
||||
|
||||
def test_get_keystore_path_with_wallet(self):
|
||||
"""Test get_keystore_path with wallet name"""
|
||||
result = get_keystore_path("mywallet")
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("mywallet.json")
|
||||
|
||||
def test_get_blockchain_data_path_default(self):
|
||||
"""Test get_blockchain_data_path with default chain"""
|
||||
result = get_blockchain_data_path()
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("ait-mainnet")
|
||||
|
||||
def test_get_blockchain_data_path_custom(self):
|
||||
"""Test get_blockchain_data_path with custom chain"""
|
||||
result = get_blockchain_data_path("custom-chain")
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("custom-chain")
|
||||
|
||||
def test_get_marketplace_data_path_no_subpath(self):
|
||||
"""Test get_marketplace_data_path without subpath"""
|
||||
result = get_marketplace_data_path()
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("marketplace")
|
||||
|
||||
def test_get_marketplace_data_path_with_subpath(self):
|
||||
"""Test get_marketplace_data_path with subpath"""
|
||||
result = get_marketplace_data_path("orders")
|
||||
assert isinstance(result, Path)
|
||||
assert str(result).endswith("orders")
|
||||
|
||||
|
||||
class TestEnv:
|
||||
"""Tests for environment variable utilities"""
|
||||
|
||||
def test_get_env_var_with_value(self):
|
||||
"""Test get_env_var with set value"""
|
||||
os.environ["TEST_VAR"] = "test_value"
|
||||
result = get_env_var("TEST_VAR")
|
||||
assert result == "test_value"
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_env_var_with_default(self):
|
||||
"""Test get_env_var with default value"""
|
||||
result = get_env_var("NONEXISTENT_VAR", "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_get_required_env_var_with_value(self):
|
||||
"""Test get_required_env_var with set value"""
|
||||
os.environ["TEST_VAR"] = "test_value"
|
||||
result = get_required_env_var("TEST_VAR")
|
||||
assert result == "test_value"
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_required_env_var_without_value(self):
|
||||
"""Test get_required_env_var without value raises error"""
|
||||
with pytest.raises(ConfigurationError):
|
||||
get_required_env_var("NONEXISTENT_VAR")
|
||||
|
||||
def test_get_bool_env_var_true(self):
|
||||
"""Test get_bool_env_var with true values"""
|
||||
for value in ["true", "TRUE", "1", "yes", "YES", "on", "ON"]:
|
||||
os.environ["TEST_VAR"] = value
|
||||
assert get_bool_env_var("TEST_VAR") is True
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_bool_env_var_false(self):
|
||||
"""Test get_bool_env_var with false values"""
|
||||
for value in ["false", "FALSE", "0", "no", "NO", "off", "OFF"]:
|
||||
os.environ["TEST_VAR"] = value
|
||||
assert get_bool_env_var("TEST_VAR") is False
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_bool_env_var_default(self):
|
||||
"""Test get_bool_env_var with default"""
|
||||
assert get_bool_env_var("NONEXISTENT_VAR", True) is True
|
||||
assert get_bool_env_var("NONEXISTENT_VAR", False) is False
|
||||
|
||||
def test_get_int_env_var_valid(self):
|
||||
"""Test get_int_env_var with valid integer"""
|
||||
os.environ["TEST_VAR"] = "42"
|
||||
assert get_int_env_var("TEST_VAR") == 42
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_int_env_var_invalid(self):
|
||||
"""Test get_int_env_var with invalid value returns default"""
|
||||
os.environ["TEST_VAR"] = "not_a_number"
|
||||
assert get_int_env_var("TEST_VAR", 10) == 10
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_int_env_var_default(self):
|
||||
"""Test get_int_env_var with default"""
|
||||
assert get_int_env_var("NONEXISTENT_VAR", 100) == 100
|
||||
|
||||
def test_get_float_env_var_valid(self):
|
||||
"""Test get_float_env_var with valid float"""
|
||||
os.environ["TEST_VAR"] = "3.14"
|
||||
assert get_float_env_var("TEST_VAR") == 3.14
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_float_env_var_invalid(self):
|
||||
"""Test get_float_env_var with invalid value returns default"""
|
||||
os.environ["TEST_VAR"] = "not_a_number"
|
||||
assert get_float_env_var("TEST_VAR", 2.5) == 2.5
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_float_env_var_default(self):
|
||||
"""Test get_float_env_var with default"""
|
||||
assert get_float_env_var("NONEXISTENT_VAR", 1.5) == 1.5
|
||||
|
||||
def test_get_list_env_var_valid(self):
|
||||
"""Test get_list_env_var with valid list"""
|
||||
os.environ["TEST_VAR"] = "item1,item2,item3"
|
||||
result = get_list_env_var("TEST_VAR")
|
||||
assert result == ["item1", "item2", "item3"]
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_list_env_var_custom_separator(self):
|
||||
"""Test get_list_env_var with custom separator"""
|
||||
os.environ["TEST_VAR"] = "item1;item2;item3"
|
||||
result = get_list_env_var("TEST_VAR", separator=";")
|
||||
assert result == ["item1", "item2", "item3"]
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_list_env_var_empty(self):
|
||||
"""Test get_list_env_var with empty value returns default"""
|
||||
os.environ["TEST_VAR"] = ""
|
||||
result = get_list_env_var("TEST_VAR", default=["default"])
|
||||
assert result == ["default"]
|
||||
del os.environ["TEST_VAR"]
|
||||
|
||||
def test_get_list_env_var_default(self):
|
||||
"""Test get_list_env_var with default"""
|
||||
result = get_list_env_var("NONEXISTENT_VAR", default=["a", "b"])
|
||||
assert result == ["a", "b"]
|
||||
|
||||
|
||||
class TestJsonUtils:
|
||||
"""Tests for JSON utility functions"""
|
||||
|
||||
def test_json_to_string(self):
|
||||
"""Test json_to_string"""
|
||||
data = {"key": "value"}
|
||||
result = json_to_string(data)
|
||||
assert '"key"' in result
|
||||
assert '"value"' in result
|
||||
|
||||
def test_string_to_json_valid(self):
|
||||
"""Test string_to_json with valid JSON"""
|
||||
json_str = '{"key": "value"}'
|
||||
result = string_to_json(json_str)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_string_to_json_invalid(self):
|
||||
"""Test string_to_json with invalid JSON raises error"""
|
||||
with pytest.raises(ConfigurationError):
|
||||
string_to_json("not valid json")
|
||||
|
||||
def test_get_nested_value_found(self):
|
||||
"""Test get_nested_value when key exists"""
|
||||
data = {"a": {"b": {"c": "value"}}}
|
||||
result = get_nested_value(data, "a", "b", "c")
|
||||
assert result == "value"
|
||||
|
||||
def test_get_nested_value_not_found(self):
|
||||
"""Test get_nested_value when key doesn't exist returns default"""
|
||||
data = {"a": {"b": {"c": "value"}}}
|
||||
result = get_nested_value(data, "a", "b", "d", default="default")
|
||||
assert result == "default"
|
||||
|
||||
def test_get_nested_value_default_none(self):
|
||||
"""Test get_nested_value with default None"""
|
||||
data = {"a": {"b": {"c": "value"}}}
|
||||
result = get_nested_value(data, "x", "y", "z")
|
||||
assert result is None
|
||||
|
||||
def test_set_nested_value(self):
|
||||
"""Test set_nested_value"""
|
||||
data = {}
|
||||
set_nested_value(data, "a", "b", "c", value="test")
|
||||
assert data["a"]["b"]["c"] == "test"
|
||||
|
||||
def test_flatten_json(self):
|
||||
"""Test flatten_json"""
|
||||
data = {"a": {"b": {"c": "value"}}, "d": "simple"}
|
||||
result = flatten_json(data)
|
||||
assert "a.b.c" in result
|
||||
assert result["a.b.c"] == "value"
|
||||
assert result["d"] == "simple"
|
||||
|
||||
def test_flatten_json_custom_separator(self):
|
||||
"""Test flatten_json with custom separator"""
|
||||
data = {"a": {"b": "value"}}
|
||||
result = flatten_json(data, separator="_")
|
||||
assert "a_b" in result
|
||||
assert result["a_b"] == "value"
|
||||
|
||||
def test_load_json(self, tmp_path):
|
||||
"""Test load_json"""
|
||||
test_file = tmp_path / "test.json"
|
||||
test_file.write_text('{"key": "value"}')
|
||||
result = load_json(test_file)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_load_json_not_found(self, tmp_path):
|
||||
"""Test load_json with non-existent file raises error"""
|
||||
with pytest.raises(ConfigurationError):
|
||||
load_json(tmp_path / "nonexistent.json")
|
||||
|
||||
def test_load_json_invalid(self, tmp_path):
|
||||
"""Test load_json with invalid JSON raises error"""
|
||||
test_file = tmp_path / "invalid.json"
|
||||
test_file.write_text("not valid json")
|
||||
with pytest.raises(ConfigurationError):
|
||||
load_json(test_file)
|
||||
|
||||
def test_save_json(self, tmp_path):
|
||||
"""Test save_json"""
|
||||
test_file = tmp_path / "test.json"
|
||||
data = {"key": "value"}
|
||||
save_json(data, test_file)
|
||||
assert test_file.exists()
|
||||
result = load_json(test_file)
|
||||
assert result == data
|
||||
|
||||
def test_merge_json(self, tmp_path):
|
||||
"""Test merge_json"""
|
||||
file1 = tmp_path / "file1.json"
|
||||
file2 = tmp_path / "file2.json"
|
||||
file1.write_text('{"a": 1, "b": 2}')
|
||||
file2.write_text('{"b": 3, "c": 4}')
|
||||
result = merge_json(file1, file2)
|
||||
assert result == {"a": 1, "b": 3, "c": 4}
|
||||
Reference in New Issue
Block a user