refactor: reorganize services into bounded contexts and implement async database support
Some checks failed
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
API Endpoint Tests / test-api-endpoints (push) Has been cancelled
CLI Tests / test-cli (push) Has been cancelled
Package Tests / Python package - aitbc-agent-sdk (push) Has been cancelled
Package Tests / Python package - aitbc-core (push) Has been cancelled
Package Tests / Python package - aitbc-crypto (push) Has been cancelled
Package Tests / Python package - aitbc-sdk (push) Has been cancelled
Package Tests / JavaScript package - aitbc-sdk-js (push) Has been cancelled
Package Tests / JavaScript package - aitbc-token (push) Has been cancelled
Staking Tests / test-staking-service (push) Failing after 3s
Staking Tests / test-staking-integration (push) Has been skipped
Staking Tests / test-staking-contract (push) Has been skipped
Staking Tests / run-staking-test-runner (push) Has been skipped
Multi-Node Stress Testing / stress-test (push) Successful in 3s
Cross-Node Transaction Testing / transaction-test (push) Successful in 3s
Some checks failed
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
API Endpoint Tests / test-api-endpoints (push) Has been cancelled
CLI Tests / test-cli (push) Has been cancelled
Package Tests / Python package - aitbc-agent-sdk (push) Has been cancelled
Package Tests / Python package - aitbc-core (push) Has been cancelled
Package Tests / Python package - aitbc-crypto (push) Has been cancelled
Package Tests / Python package - aitbc-sdk (push) Has been cancelled
Package Tests / JavaScript package - aitbc-sdk-js (push) Has been cancelled
Package Tests / JavaScript package - aitbc-token (push) Has been cancelled
Staking Tests / test-staking-service (push) Failing after 3s
Staking Tests / test-staking-integration (push) Has been skipped
Staking Tests / test-staking-contract (push) Has been skipped
Staking Tests / run-staking-test-runner (push) Has been skipped
Multi-Node Stress Testing / stress-test (push) Successful in 3s
Cross-Node Transaction Testing / transaction-test (push) Successful in 3s
- Moved services to bounded context packages: - adaptive_learning.py → ai_analytics/adaptive_learning.py - analytics_service.py → ai_analytics/analytics.py - dynamic_pricing_engine.py → trading_marketplace/dynamic_pricing.py - trading_service.py → trading_marketplace/trading.py - Implemented async database module (database_async.py): - Added async SQLAlchemy engine with connection pooling - Added
This commit is contained in:
@@ -1 +1,117 @@
|
|||||||
[{'adapter}': 'def _build_async_url(url: str', 'driver': 'str) -> str:', 'Convert sync URL to async URL."': 'sqlite:///path.db -> sqlite+aiosqlite:///path.db\n # postgresql:// -> postgresql+asyncpg://\n parts = url.split(', ', 1)\n if len(parts) == 2:\n return f': 'parts[0]'}, {'driver}': {'f': 'url'}, 'sqlite': 'For SQLite', 'connect_args={"check_same_thread': False}, {}]
|
"""Async database module with connection pooling for Coordinator API."""
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from .config import settings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global async engine and session factory
|
||||||
|
_async_engine = None
|
||||||
|
_async_session_factory = None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_async_url(url: str) -> str:
|
||||||
|
"""Convert sync database URL to async URL.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
sqlite:///path.db -> sqlite+aiosqlite:///path.db
|
||||||
|
postgresql://user:pass@host/db -> postgresql+asyncpg://user:pass@host/db
|
||||||
|
"""
|
||||||
|
# Handle URL with query parameters
|
||||||
|
if '?' in url:
|
||||||
|
base, params = url.split('?', 1)
|
||||||
|
if base.startswith('sqlite:'):
|
||||||
|
return f"{base}+aiosqlite://?{params}"
|
||||||
|
elif base.startswith('postgresql:'):
|
||||||
|
return f"{base}+asyncpg://?{params}"
|
||||||
|
else:
|
||||||
|
return f"{base}+aiosqlite://?{params}" # fallback
|
||||||
|
else:
|
||||||
|
if url.startswith('sqlite:'):
|
||||||
|
return url.replace('sqlite:', 'sqlite+aiosqlite:')
|
||||||
|
elif url.startswith('postgresql:'):
|
||||||
|
return url.replace('postgresql:', 'postgresql+asyncpg:')
|
||||||
|
else:
|
||||||
|
return url.replace(':', '+aiosqlite:') # fallback
|
||||||
|
|
||||||
|
|
||||||
|
def init_async_db() -> None:
|
||||||
|
"""Initialize async database engine and session factory."""
|
||||||
|
global _async_engine, _async_session_factory
|
||||||
|
|
||||||
|
if _async_engine is not None:
|
||||||
|
logger.warning("Async database already initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build async URL from sync settings
|
||||||
|
sync_url = str(settings.database.url)
|
||||||
|
async_url = _build_async_url(sync_url)
|
||||||
|
|
||||||
|
logger.info(f"Initializing async database connection: {async_url.split('://')[0]}://...")
|
||||||
|
|
||||||
|
# Create async engine with pooling
|
||||||
|
_async_engine = create_async_engine(
|
||||||
|
async_url,
|
||||||
|
echo=settings.database.echo if hasattr(settings.database, 'echo') else False,
|
||||||
|
pool_size=getattr(settings.database, 'pool_size', 5),
|
||||||
|
max_overflow=getattr(settings.database, 'max_overflow', 10),
|
||||||
|
pool_pre_ping=getattr(settings.database, 'pool_pre_ping', True),
|
||||||
|
pool_recycle=getattr(settings.database, 'pool_recycle', 3600),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create session factory
|
||||||
|
_async_session_factory = sessionmaker(
|
||||||
|
_async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Async database initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize async database: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_async_db() -> AsyncSession:
|
||||||
|
"""Dependency to get async database session.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
AsyncSession: Database session that closes automatically
|
||||||
|
"""
|
||||||
|
if _async_session_factory is None:
|
||||||
|
raise RuntimeError("Async database not initialized. Call init_async_db() first.")
|
||||||
|
|
||||||
|
async def _get_async_db() -> AsyncSession:
|
||||||
|
async with _async_session_factory() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
return _get_async_db()
|
||||||
|
|
||||||
|
|
||||||
|
def get_sync_engine():
|
||||||
|
"""Get synchronous engine for backward compatibility.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Engine: SQLAlchemy synchronous engine
|
||||||
|
"""
|
||||||
|
from .database import engine
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
async def close_async_db() -> None:
|
||||||
|
"""Close async database connections."""
|
||||||
|
global _async_engine, _async_session_factory
|
||||||
|
|
||||||
|
if _async_engine is not None:
|
||||||
|
logger.info("Closing async database connections...")
|
||||||
|
await _async_engine.dispose()
|
||||||
|
_async_engine = None
|
||||||
|
_async_session_factory = None
|
||||||
|
logger.info("Async database connections closed")
|
||||||
@@ -114,6 +114,7 @@ logger = get_logger(__name__)
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from .storage.db import init_db
|
from .storage.db import init_db
|
||||||
|
from .database_async import init_async_db, close_async_db
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -130,6 +131,14 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|||||||
logger.warning(f"Database initialization failed (non-fatal): {e}")
|
logger.warning(f"Database initialization failed (non-fatal): {e}")
|
||||||
# Continue startup even if init_db fails
|
# Continue startup even if init_db fails
|
||||||
|
|
||||||
|
# Initialize async database
|
||||||
|
try:
|
||||||
|
init_async_db()
|
||||||
|
logger.info("Async database initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Async database initialization failed (non-fatal): {e}")
|
||||||
|
# Continue startup even if async init fails
|
||||||
|
|
||||||
# Warmup database connections
|
# Warmup database connections
|
||||||
logger.info("Warming up database connections...")
|
logger.info("Warming up database connections...")
|
||||||
try:
|
try:
|
||||||
@@ -227,6 +236,13 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error closing database connections: {e}")
|
logger.warning(f"Error closing database connections: {e}")
|
||||||
|
|
||||||
|
# Close async database connections
|
||||||
|
try:
|
||||||
|
await close_async_db()
|
||||||
|
logger.info("Async database connections closed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing async database connections: {e}")
|
||||||
|
|
||||||
# Cleanup rate limiting state
|
# Cleanup rate limiting state
|
||||||
logger.info("Cleaning up rate limiting state...")
|
logger.info("Cleaning up rate limiting state...")
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from fastapi import APIRouter, Depends
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from aitbc import get_logger
|
from aitbc import get_logger
|
||||||
from ..services.adaptive_learning import AdaptiveLearningService
|
from ..services.ai_analytics.adaptive_learning import AdaptiveLearningService
|
||||||
from ..storage import get_session
|
from ..storage import get_session
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from ..domain.analytics import (
|
|||||||
MetricType,
|
MetricType,
|
||||||
ReportType,
|
ReportType,
|
||||||
)
|
)
|
||||||
from ..services.analytics_service import MarketplaceAnalytics
|
from ..services.ai_analytics.analytics import MarketplaceAnalytics
|
||||||
from ..storage import get_session
|
from ..storage import get_session
|
||||||
|
|
||||||
router = APIRouter(prefix="/v1/analytics", tags=["analytics"])
|
router = APIRouter(prefix="/v1/analytics", tags=["analytics"])
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from ..schemas.pricing import (
|
|||||||
PricingStrategyRequest,
|
PricingStrategyRequest,
|
||||||
PricingStrategyResponse,
|
PricingStrategyResponse,
|
||||||
)
|
)
|
||||||
from ..services.dynamic_pricing_engine import DynamicPricingEngine, PriceConstraints, PricingStrategy, ResourceType
|
from ..services.trading_marketplace.dynamic_pricing import DynamicPricingEngine, PriceConstraints, PricingStrategy, ResourceType
|
||||||
from ..services.market_data_collector import MarketDataCollector
|
from ..services.market_data_collector import MarketDataCollector
|
||||||
|
|
||||||
router = APIRouter(prefix="/v1/pricing", tags=["dynamic-pricing"])
|
router = APIRouter(prefix="/v1/pricing", tags=["dynamic-pricing"])
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from ..domain.trading import (
|
|||||||
TradeStatus,
|
TradeStatus,
|
||||||
TradeType,
|
TradeType,
|
||||||
)
|
)
|
||||||
from ..services.trading_service import P2PTradingProtocol
|
from ..services.trading_marketplace.trading import P2PTradingProtocol
|
||||||
from ..storage import get_session
|
from ..storage import get_session
|
||||||
|
|
||||||
router = APIRouter(prefix="/v1/trading", tags=["trading"])
|
router = APIRouter(prefix="/v1/trading", tags=["trading"])
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
AI & Analytics Bounded Context
|
||||||
|
Provides analytics and advanced learning services.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .advanced_learning import AdvancedLearningService
|
||||||
|
from .analytics import AnalyticsEngine, DashboardManager, MarketplaceAnalytics
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdvancedLearningService",
|
||||||
|
"AnalyticsEngine",
|
||||||
|
"DashboardManager",
|
||||||
|
"MarketplaceAnalytics",
|
||||||
|
]
|
||||||
@@ -17,7 +17,7 @@ from typing import Any
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..storage import get_session
|
from ...storage import get_session
|
||||||
|
|
||||||
|
|
||||||
class LearningAlgorithm(StrEnum):
|
class LearningAlgorithm(StrEnum):
|
||||||
@@ -13,7 +13,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
from sqlmodel import Session, and_, select
|
from sqlmodel import Session, and_, select
|
||||||
|
|
||||||
from ..domain.analytics import (
|
from ...domain.analytics import (
|
||||||
AnalyticsAlert,
|
AnalyticsAlert,
|
||||||
AnalyticsPeriod,
|
AnalyticsPeriod,
|
||||||
DashboardConfig,
|
DashboardConfig,
|
||||||
@@ -1,608 +0,0 @@
|
|||||||
"""
|
|
||||||
Enterprise API Gateway - Phase 6.1 Implementation
|
|
||||||
Multi-tenant API routing and management for enterprise clients
|
|
||||||
Port: 8010
|
|
||||||
"""
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
import time
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Any
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.security import HTTPBearer
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from aitbc import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
from ..domain.multitenant import Tenant, TenantApiKey, TenantQuota
|
|
||||||
from ..exceptions import QuotaExceededError, TenantError
|
|
||||||
from ..storage.db import get_db
|
|
||||||
|
|
||||||
|
|
||||||
# Pydantic models for API requests/responses
|
|
||||||
class EnterpriseAuthRequest(BaseModel):
|
|
||||||
tenant_id: str = Field(..., description="Enterprise tenant identifier")
|
|
||||||
client_id: str = Field(..., description="Enterprise client ID")
|
|
||||||
client_secret: str = Field(..., description="Enterprise client secret")
|
|
||||||
auth_method: str = Field(default="client_credentials", description="Authentication method")
|
|
||||||
scopes: list[str] | None = Field(default=None, description="Requested scopes")
|
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseAuthResponse(BaseModel):
|
|
||||||
access_token: str = Field(..., description="Access token for enterprise API")
|
|
||||||
token_type: str = Field(default="Bearer", description="Token type")
|
|
||||||
expires_in: int = Field(..., description="Token expiration in seconds")
|
|
||||||
refresh_token: str | None = Field(None, description="Refresh token")
|
|
||||||
scopes: list[str] = Field(..., description="Granted scopes")
|
|
||||||
tenant_info: dict[str, Any] = Field(..., description="Tenant information")
|
|
||||||
|
|
||||||
|
|
||||||
class APIQuotaRequest(BaseModel):
|
|
||||||
tenant_id: str = Field(..., description="Enterprise tenant identifier")
|
|
||||||
endpoint: str = Field(..., description="API endpoint")
|
|
||||||
method: str = Field(..., description="HTTP method")
|
|
||||||
quota_type: str = Field(default="rate_limit", description="Quota type")
|
|
||||||
|
|
||||||
|
|
||||||
class APIQuotaResponse(BaseModel):
|
|
||||||
quota_limit: int = Field(..., description="Quota limit")
|
|
||||||
quota_remaining: int = Field(..., description="Remaining quota")
|
|
||||||
quota_reset: datetime = Field(..., description="Quota reset time")
|
|
||||||
quota_type: str = Field(..., description="Quota type")
|
|
||||||
|
|
||||||
|
|
||||||
class WebhookConfig(BaseModel):
|
|
||||||
url: str = Field(..., description="Webhook URL")
|
|
||||||
events: list[str] = Field(..., description="Events to subscribe to")
|
|
||||||
secret: str | None = Field(None, description="Webhook secret")
|
|
||||||
active: bool = Field(default=True, description="Webhook active status")
|
|
||||||
retry_policy: dict[str, Any] | None = Field(None, description="Retry policy")
|
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseIntegrationRequest(BaseModel):
|
|
||||||
integration_type: str = Field(..., description="Integration type (ERP, CRM, etc.)")
|
|
||||||
provider: str = Field(..., description="Integration provider")
|
|
||||||
configuration: dict[str, Any] = Field(..., description="Integration configuration")
|
|
||||||
credentials: dict[str, str] | None = Field(None, description="Integration credentials")
|
|
||||||
webhook_config: WebhookConfig | None = Field(None, description="Webhook configuration")
|
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseMetrics(BaseModel):
|
|
||||||
api_calls_total: int = Field(..., description="Total API calls")
|
|
||||||
api_calls_successful: int = Field(..., description="Successful API calls")
|
|
||||||
average_response_time_ms: float = Field(..., description="Average response time")
|
|
||||||
error_rate_percent: float = Field(..., description="Error rate percentage")
|
|
||||||
quota_utilization_percent: float = Field(..., description="Quota utilization")
|
|
||||||
active_integrations: int = Field(..., description="Active integrations count")
|
|
||||||
|
|
||||||
|
|
||||||
class IntegrationStatus(StrEnum):
|
|
||||||
ACTIVE = "active"
|
|
||||||
INACTIVE = "inactive"
|
|
||||||
ERROR = "error"
|
|
||||||
PENDING = "pending"
|
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseIntegration:
|
|
||||||
"""Enterprise integration configuration and management"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, integration_id: str, tenant_id: str, integration_type: str, provider: str, configuration: dict[str, Any]
|
|
||||||
):
|
|
||||||
self.integration_id = integration_id
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.integration_type = integration_type
|
|
||||||
self.provider = provider
|
|
||||||
self.configuration = configuration
|
|
||||||
self.status = IntegrationStatus.PENDING
|
|
||||||
self.created_at = datetime.now(timezone.utc)
|
|
||||||
self.last_updated = datetime.now(timezone.utc)
|
|
||||||
self.webhook_config = None
|
|
||||||
self.metrics = {"api_calls": 0, "errors": 0, "last_call": None}
|
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseAPIGateway:
|
|
||||||
"""Enterprise API Gateway with multi-tenant support"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.tenant_service = None # Will be initialized with database session
|
|
||||||
self.active_tokens = {} # In-memory token storage (in production, use Redis)
|
|
||||||
self.rate_limiters = {} # Per-tenant rate limiters
|
|
||||||
self.webhooks = {} # Webhook configurations
|
|
||||||
self.integrations = {} # Enterprise integrations
|
|
||||||
self.api_metrics = {} # API performance metrics
|
|
||||||
|
|
||||||
# Default quotas
|
|
||||||
self.default_quotas = {
|
|
||||||
"rate_limit": 1000, # requests per minute
|
|
||||||
"daily_limit": 50000, # requests per day
|
|
||||||
"concurrent_limit": 100, # concurrent requests
|
|
||||||
}
|
|
||||||
|
|
||||||
# JWT configuration
|
|
||||||
self.jwt_secret = secrets.token_urlsafe(64)
|
|
||||||
self.jwt_algorithm = "HS256"
|
|
||||||
self.token_expiry = 3600 # 1 hour
|
|
||||||
|
|
||||||
async def authenticate_enterprise_client(self, request: EnterpriseAuthRequest, db_session) -> EnterpriseAuthResponse:
|
|
||||||
"""Authenticate enterprise client and issue access token"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Validate tenant and client credentials
|
|
||||||
tenant = await self._validate_tenant_credentials(
|
|
||||||
request.tenant_id, request.client_id, request.client_secret, db_session
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate access token
|
|
||||||
access_token = self._generate_access_token(
|
|
||||||
tenant_id=request.tenant_id, client_id=request.client_id, scopes=request.scopes or ["enterprise_api"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate refresh token
|
|
||||||
refresh_token = self._generate_refresh_token(request.tenant_id, request.client_id)
|
|
||||||
|
|
||||||
# Store token
|
|
||||||
self.active_tokens[access_token] = {
|
|
||||||
"tenant_id": request.tenant_id,
|
|
||||||
"client_id": request.client_id,
|
|
||||||
"scopes": request.scopes or ["enterprise_api"],
|
|
||||||
"expires_at": datetime.now(timezone.utc) + timedelta(seconds=self.token_expiry),
|
|
||||||
"refresh_token": refresh_token,
|
|
||||||
}
|
|
||||||
|
|
||||||
return EnterpriseAuthResponse(
|
|
||||||
access_token=access_token,
|
|
||||||
token_type="Bearer",
|
|
||||||
expires_in=self.token_expiry,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
scopes=request.scopes or ["enterprise_api"],
|
|
||||||
tenant_info={
|
|
||||||
"tenant_id": tenant.tenant_id,
|
|
||||||
"name": tenant.name,
|
|
||||||
"plan": tenant.plan,
|
|
||||||
"status": tenant.status.value,
|
|
||||||
"created_at": tenant.created_at.isoformat(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Enterprise authentication failed: {e}")
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
|
||||||
|
|
||||||
def _generate_access_token(self, tenant_id: str, client_id: str, scopes: list[str]) -> str:
|
|
||||||
"""Generate JWT access token"""
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"sub": f"{tenant_id}:{client_id}",
|
|
||||||
"scopes": scopes,
|
|
||||||
"iat": datetime.now(timezone.utc),
|
|
||||||
"exp": datetime.now(timezone.utc) + timedelta(seconds=self.token_expiry),
|
|
||||||
"type": "access",
|
|
||||||
}
|
|
||||||
|
|
||||||
return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
|
|
||||||
|
|
||||||
def _generate_refresh_token(self, tenant_id: str, client_id: str) -> str:
|
|
||||||
"""Generate refresh token"""
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"sub": f"{tenant_id}:{client_id}",
|
|
||||||
"iat": datetime.now(timezone.utc),
|
|
||||||
"exp": datetime.now(timezone.utc) + timedelta(days=30), # 30 days
|
|
||||||
"type": "refresh",
|
|
||||||
}
|
|
||||||
|
|
||||||
return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
|
|
||||||
|
|
||||||
async def _validate_tenant_credentials(self, tenant_id: str, client_id: str, client_secret: str, db_session) -> Tenant:
|
|
||||||
"""Validate tenant credentials"""
|
|
||||||
|
|
||||||
# Find tenant
|
|
||||||
tenant = db_session.query(Tenant).filter(Tenant.tenant_id == tenant_id).first()
|
|
||||||
if not tenant:
|
|
||||||
raise TenantError(f"Tenant {tenant_id} not found")
|
|
||||||
|
|
||||||
# Find API key
|
|
||||||
api_key = (
|
|
||||||
db_session.query(TenantApiKey)
|
|
||||||
.filter(TenantApiKey.tenant_id == tenant_id, TenantApiKey.client_id == client_id, TenantApiKey.is_active)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not api_key or not secrets.compare_digest(api_key.client_secret, client_secret):
|
|
||||||
raise TenantError("Invalid client credentials")
|
|
||||||
|
|
||||||
# Check tenant status
|
|
||||||
if tenant.status.value != "active":
|
|
||||||
raise TenantError(f"Tenant {tenant_id} is not active")
|
|
||||||
|
|
||||||
return tenant
|
|
||||||
|
|
||||||
async def check_api_quota(self, tenant_id: str, endpoint: str, method: str, db_session) -> APIQuotaResponse:
|
|
||||||
"""Check and enforce API quotas"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get tenant quota
|
|
||||||
quota = await self._get_tenant_quota(tenant_id, db_session)
|
|
||||||
|
|
||||||
# Check rate limiting
|
|
||||||
current_usage = await self._get_current_usage(tenant_id, "rate_limit")
|
|
||||||
|
|
||||||
if current_usage >= quota["rate_limit"]:
|
|
||||||
raise QuotaExceededError("Rate limit exceeded")
|
|
||||||
|
|
||||||
# Update usage
|
|
||||||
await self._update_usage(tenant_id, "rate_limit", current_usage + 1)
|
|
||||||
|
|
||||||
return APIQuotaResponse(
|
|
||||||
quota_limit=quota["rate_limit"],
|
|
||||||
quota_remaining=quota["rate_limit"] - current_usage - 1,
|
|
||||||
quota_reset=datetime.now(timezone.utc) + timedelta(minutes=1),
|
|
||||||
quota_type="rate_limit",
|
|
||||||
)
|
|
||||||
|
|
||||||
except QuotaExceededError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Quota check failed: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail="Quota check failed")
|
|
||||||
|
|
||||||
async def _get_tenant_quota(self, tenant_id: str, db_session) -> dict[str, int]:
|
|
||||||
"""Get tenant quota configuration"""
|
|
||||||
|
|
||||||
# Get tenant-specific quota
|
|
||||||
tenant_quota = db_session.query(TenantQuota).filter(TenantQuota.tenant_id == tenant_id).first()
|
|
||||||
|
|
||||||
if tenant_quota:
|
|
||||||
return {
|
|
||||||
"rate_limit": tenant_quota.rate_limit or self.default_quotas["rate_limit"],
|
|
||||||
"daily_limit": tenant_quota.daily_limit or self.default_quotas["daily_limit"],
|
|
||||||
"concurrent_limit": tenant_quota.concurrent_limit or self.default_quotas["concurrent_limit"],
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.default_quotas
|
|
||||||
|
|
||||||
async def _get_current_usage(self, tenant_id: str, quota_type: str) -> int:
|
|
||||||
"""Get current quota usage"""
|
|
||||||
|
|
||||||
# In production, use Redis or database for persistent storage
|
|
||||||
|
|
||||||
if quota_type == "rate_limit":
|
|
||||||
# Get usage in the last minute
|
|
||||||
return len([t for t in self.rate_limiters.get(tenant_id, []) if datetime.now(timezone.utc) - t < timedelta(minutes=1)])
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
async def _update_usage(self, tenant_id: str, quota_type: str, usage: int):
|
|
||||||
"""Update quota usage"""
|
|
||||||
|
|
||||||
if quota_type == "rate_limit":
|
|
||||||
if tenant_id not in self.rate_limiters:
|
|
||||||
self.rate_limiters[tenant_id] = []
|
|
||||||
|
|
||||||
# Add current timestamp
|
|
||||||
self.rate_limiters[tenant_id].append(datetime.now(timezone.utc))
|
|
||||||
|
|
||||||
# Clean old entries (older than 1 minute)
|
|
||||||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=1)
|
|
||||||
self.rate_limiters[tenant_id] = [t for t in self.rate_limiters[tenant_id] if t > cutoff]
|
|
||||||
|
|
||||||
async def create_enterprise_integration(
|
|
||||||
self, tenant_id: str, request: EnterpriseIntegrationRequest, db_session
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Create new enterprise integration"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Validate tenant
|
|
||||||
tenant = db_session.query(Tenant).filter(Tenant.tenant_id == tenant_id).first()
|
|
||||||
if not tenant:
|
|
||||||
raise TenantError(f"Tenant {tenant_id} not found")
|
|
||||||
|
|
||||||
# Create integration
|
|
||||||
integration_id = str(uuid4())
|
|
||||||
integration = EnterpriseIntegration(
|
|
||||||
integration_id=integration_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
integration_type=request.integration_type,
|
|
||||||
provider=request.provider,
|
|
||||||
configuration=request.configuration,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store webhook configuration
|
|
||||||
if request.webhook_config:
|
|
||||||
integration.webhook_config = request.webhook_config.dict()
|
|
||||||
self.webhooks[integration_id] = request.webhook_config.dict()
|
|
||||||
|
|
||||||
# Store integration
|
|
||||||
self.integrations[integration_id] = integration
|
|
||||||
|
|
||||||
# Initialize integration
|
|
||||||
await self._initialize_integration(integration)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"integration_id": integration_id,
|
|
||||||
"status": integration.status.value,
|
|
||||||
"created_at": integration.created_at.isoformat(),
|
|
||||||
"configuration": integration.configuration,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create enterprise integration: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail="Integration creation failed")
|
|
||||||
|
|
||||||
async def _initialize_integration(self, integration: EnterpriseIntegration):
|
|
||||||
"""Initialize enterprise integration"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Integration-specific initialization logic
|
|
||||||
if integration.integration_type.lower() == "erp":
|
|
||||||
await self._initialize_erp_integration(integration)
|
|
||||||
elif integration.integration_type.lower() == "crm":
|
|
||||||
await self._initialize_crm_integration(integration)
|
|
||||||
elif integration.integration_type.lower() == "bi":
|
|
||||||
await self._initialize_bi_integration(integration)
|
|
||||||
|
|
||||||
integration.status = IntegrationStatus.ACTIVE
|
|
||||||
integration.last_updated = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Integration initialization failed: {e}")
|
|
||||||
integration.status = IntegrationStatus.ERROR
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _initialize_erp_integration(self, integration: EnterpriseIntegration):
|
|
||||||
"""Initialize ERP integration"""
|
|
||||||
|
|
||||||
# ERP-specific initialization
|
|
||||||
provider = integration.provider.lower()
|
|
||||||
|
|
||||||
if provider == "sap":
|
|
||||||
await self._initialize_sap_integration(integration)
|
|
||||||
elif provider == "oracle":
|
|
||||||
await self._initialize_oracle_integration(integration)
|
|
||||||
elif provider == "microsoft":
|
|
||||||
await self._initialize_microsoft_integration(integration)
|
|
||||||
|
|
||||||
logger.info(f"ERP integration initialized: {integration.provider}")
|
|
||||||
|
|
||||||
async def _initialize_sap_integration(self, integration: EnterpriseIntegration):
|
|
||||||
"""Initialize SAP ERP integration"""
|
|
||||||
|
|
||||||
# SAP integration logic
|
|
||||||
config = integration.configuration
|
|
||||||
|
|
||||||
# Validate SAP configuration
|
|
||||||
required_fields = ["system_id", "client", "username", "password", "host"]
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in config:
|
|
||||||
raise ValueError(f"SAP integration requires {field}")
|
|
||||||
|
|
||||||
# Test SAP connection
|
|
||||||
# In production, implement actual SAP connection testing
|
|
||||||
logger.info(f"SAP connection test successful for {integration.integration_id}")
|
|
||||||
|
|
||||||
async def get_enterprise_metrics(self, tenant_id: str, db_session) -> EnterpriseMetrics:
|
|
||||||
"""Get enterprise metrics and analytics"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get API metrics
|
|
||||||
api_metrics = self.api_metrics.get(
|
|
||||||
tenant_id, {"total_calls": 0, "successful_calls": 0, "failed_calls": 0, "response_times": []}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate metrics
|
|
||||||
total_calls = api_metrics["total_calls"]
|
|
||||||
successful_calls = api_metrics["successful_calls"]
|
|
||||||
failed_calls = api_metrics["failed_calls"]
|
|
||||||
|
|
||||||
average_response_time = (
|
|
||||||
sum(api_metrics["response_times"]) / len(api_metrics["response_times"])
|
|
||||||
if api_metrics["response_times"]
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
error_rate = (failed_calls / total_calls * 100) if total_calls > 0 else 0.0
|
|
||||||
|
|
||||||
# Get quota utilization
|
|
||||||
current_usage = await self._get_current_usage(tenant_id, "rate_limit")
|
|
||||||
quota = await self._get_tenant_quota(tenant_id, db_session)
|
|
||||||
quota_utilization = (current_usage / quota["rate_limit"] * 100) if quota["rate_limit"] > 0 else 0.0
|
|
||||||
|
|
||||||
# Count active integrations
|
|
||||||
active_integrations = len(
|
|
||||||
[i for i in self.integrations.values() if i.tenant_id == tenant_id and i.status == IntegrationStatus.ACTIVE]
|
|
||||||
)
|
|
||||||
|
|
||||||
return EnterpriseMetrics(
|
|
||||||
api_calls_total=total_calls,
|
|
||||||
api_calls_successful=successful_calls,
|
|
||||||
average_response_time_ms=average_response_time,
|
|
||||||
error_rate_percent=error_rate,
|
|
||||||
quota_utilization_percent=quota_utilization,
|
|
||||||
active_integrations=active_integrations,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get enterprise metrics: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail="Metrics retrieval failed")
|
|
||||||
|
|
||||||
async def record_api_call(self, tenant_id: str, endpoint: str, response_time: float, success: bool):
|
|
||||||
"""Record API call for metrics"""
|
|
||||||
|
|
||||||
if tenant_id not in self.api_metrics:
|
|
||||||
self.api_metrics[tenant_id] = {"total_calls": 0, "successful_calls": 0, "failed_calls": 0, "response_times": []}
|
|
||||||
|
|
||||||
metrics = self.api_metrics[tenant_id]
|
|
||||||
metrics["total_calls"] += 1
|
|
||||||
|
|
||||||
if success:
|
|
||||||
metrics["successful_calls"] += 1
|
|
||||||
else:
|
|
||||||
metrics["failed_calls"] += 1
|
|
||||||
|
|
||||||
metrics["response_times"].append(response_time)
|
|
||||||
|
|
||||||
# Keep only last 1000 response times
|
|
||||||
if len(metrics["response_times"]) > 1000:
|
|
||||||
metrics["response_times"] = metrics["response_times"][-1000:]
|
|
||||||
|
|
||||||
|
|
||||||
# FastAPI application
|
|
||||||
app = FastAPI(
|
|
||||||
title="Enterprise API Gateway",
|
|
||||||
description="Multi-tenant API routing and management for enterprise clients",
|
|
||||||
version="6.1.0",
|
|
||||||
docs_url="/docs",
|
|
||||||
redoc_url="/redoc",
|
|
||||||
)
|
|
||||||
|
|
||||||
# CORS middleware
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Security
|
|
||||||
security = HTTPBearer()
|
|
||||||
|
|
||||||
# Global gateway instance
|
|
||||||
gateway = EnterpriseAPIGateway()
|
|
||||||
|
|
||||||
|
|
||||||
# Dependency for database session
|
|
||||||
async def get_db_session():
|
|
||||||
"""Get database session"""
|
|
||||||
|
|
||||||
async with get_db() as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
# Middleware for API metrics
|
|
||||||
@app.middleware("http")
|
|
||||||
async def api_metrics_middleware(request: Request, call_next):
|
|
||||||
"""Middleware to record API metrics"""
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Extract tenant from token if available
|
|
||||||
tenant_id = None
|
|
||||||
authorization = request.headers.get("authorization")
|
|
||||||
if authorization and authorization.startswith("Bearer "):
|
|
||||||
token = authorization[7:]
|
|
||||||
token_data = gateway.active_tokens.get(token)
|
|
||||||
if token_data:
|
|
||||||
tenant_id = token_data["tenant_id"]
|
|
||||||
|
|
||||||
# Process request
|
|
||||||
response = await call_next(request)
|
|
||||||
|
|
||||||
# Record metrics
|
|
||||||
response_time = (time.time() - start_time) * 1000 # Convert to milliseconds
|
|
||||||
success = response.status_code < 400
|
|
||||||
|
|
||||||
if tenant_id:
|
|
||||||
await gateway.record_api_call(tenant_id, str(request.url.path), response_time, success)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/enterprise/auth")
|
|
||||||
async def enterprise_auth(request: EnterpriseAuthRequest, db_session=Depends(get_db_session)):
|
|
||||||
"""Authenticate enterprise client"""
|
|
||||||
|
|
||||||
result = await gateway.authenticate_enterprise_client(request, db_session)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/enterprise/quota/check")
|
|
||||||
async def check_quota(request: APIQuotaRequest, db_session=Depends(get_db_session)):
|
|
||||||
"""Check API quota"""
|
|
||||||
|
|
||||||
result = await gateway.check_api_quota(request.tenant_id, request.endpoint, request.method, db_session)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/enterprise/integrations")
|
|
||||||
async def create_integration(request: EnterpriseIntegrationRequest, db_session=Depends(get_db_session)):
|
|
||||||
"""Create enterprise integration"""
|
|
||||||
|
|
||||||
# Extract tenant from token (in production, proper authentication)
|
|
||||||
tenant_id = "demo_tenant" # Placeholder
|
|
||||||
|
|
||||||
result = await gateway.create_enterprise_integration(tenant_id, request, db_session)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/enterprise/analytics")
|
|
||||||
async def get_analytics(db_session=Depends(get_db_session)):
|
|
||||||
"""Get enterprise analytics dashboard"""
|
|
||||||
|
|
||||||
# Extract tenant from token (in production, proper authentication)
|
|
||||||
tenant_id = "demo_tenant" # Placeholder
|
|
||||||
|
|
||||||
result = await gateway.get_enterprise_metrics(tenant_id, db_session)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/enterprise/status")
|
|
||||||
async def get_status():
|
|
||||||
"""Get enterprise gateway status"""
|
|
||||||
|
|
||||||
return {
|
|
||||||
"service": "Enterprise API Gateway",
|
|
||||||
"version": "6.1.0",
|
|
||||||
"port": 8010,
|
|
||||||
"status": "operational",
|
|
||||||
"active_tenants": len({token["tenant_id"] for token in gateway.active_tokens.values()}),
|
|
||||||
"active_integrations": len(gateway.integrations),
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def root():
|
|
||||||
"""Root endpoint"""
|
|
||||||
return {
|
|
||||||
"service": "Enterprise API Gateway",
|
|
||||||
"version": "6.1.0",
|
|
||||||
"port": 8010,
|
|
||||||
"capabilities": [
|
|
||||||
"Multi-tenant API Management",
|
|
||||||
"Enterprise Authentication",
|
|
||||||
"API Quota Management",
|
|
||||||
"Enterprise Integration Framework",
|
|
||||||
"Real-time Analytics",
|
|
||||||
],
|
|
||||||
"status": "operational",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
"""Health check endpoint"""
|
|
||||||
return {
|
|
||||||
"status": "healthy",
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"services": {
|
|
||||||
"api_gateway": "operational",
|
|
||||||
"authentication": "operational",
|
|
||||||
"quota_management": "operational",
|
|
||||||
"integration_framework": "operational",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8010)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,16 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
Enterprise Integration Bounded Context
|
Enterprise Integration Bounded Context
|
||||||
Provides enterprise API gateway, security, load balancing, and integration services.
|
Provides enterprise integration, security, and load balancing services.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .api_gateway import EnterpriseAPIGateway
|
from .integration import EnterpriseIntegrationFramework
|
||||||
from .integration import EnterpriseIntegrationService
|
|
||||||
from .load_balancer import AdvancedLoadBalancer
|
from .load_balancer import AdvancedLoadBalancer
|
||||||
from .security import EnterpriseEncryption, HSMManager, ThreatDetectionSystem, ZeroTrustArchitecture
|
from .security import EnterpriseEncryption, HSMManager, ThreatDetectionSystem, ZeroTrustArchitecture
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EnterpriseAPIGateway",
|
"EnterpriseIntegrationFramework",
|
||||||
"EnterpriseIntegrationService",
|
|
||||||
"AdvancedLoadBalancer",
|
"AdvancedLoadBalancer",
|
||||||
"EnterpriseEncryption",
|
"EnterpriseEncryption",
|
||||||
"HSMManager",
|
"HSMManager",
|
||||||
|
|||||||
@@ -1,770 +0,0 @@
|
|||||||
"""
|
|
||||||
Advanced Load Balancing - Phase 6.4 Implementation
|
|
||||||
Intelligent traffic distribution with AI-powered auto-scaling and performance optimization
|
|
||||||
"""
|
|
||||||
|
|
||||||
import statistics
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from aitbc import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadBalancingAlgorithm(StrEnum):
|
|
||||||
"""Load balancing algorithms"""
|
|
||||||
|
|
||||||
ROUND_ROBIN = "round_robin"
|
|
||||||
WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
|
|
||||||
LEAST_CONNECTIONS = "least_connections"
|
|
||||||
LEAST_RESPONSE_TIME = "least_response_time"
|
|
||||||
RESOURCE_BASED = "resource_based"
|
|
||||||
PREDICTIVE_AI = "predictive_ai"
|
|
||||||
ADAPTIVE = "adaptive"
|
|
||||||
|
|
||||||
|
|
||||||
class ScalingPolicy(StrEnum):
|
|
||||||
"""Auto-scaling policies"""
|
|
||||||
|
|
||||||
MANUAL = "manual"
|
|
||||||
THRESHOLD_BASED = "threshold_based"
|
|
||||||
PREDICTIVE = "predictive"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
|
|
||||||
class HealthStatus(StrEnum):
|
|
||||||
"""Health status"""
|
|
||||||
|
|
||||||
HEALTHY = "healthy"
|
|
||||||
UNHEALTHY = "unhealthy"
|
|
||||||
DRAINING = "draining"
|
|
||||||
MAINTENANCE = "maintenance"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BackendServer:
|
|
||||||
"""Backend server configuration"""
|
|
||||||
|
|
||||||
server_id: str
|
|
||||||
host: str
|
|
||||||
port: int
|
|
||||||
weight: float = 1.0
|
|
||||||
max_connections: int = 1000
|
|
||||||
current_connections: int = 0
|
|
||||||
cpu_usage: float = 0.0
|
|
||||||
memory_usage: float = 0.0
|
|
||||||
response_time_ms: float = 0.0
|
|
||||||
request_count: int = 0
|
|
||||||
error_count: int = 0
|
|
||||||
health_status: HealthStatus = HealthStatus.HEALTHY
|
|
||||||
last_health_check: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
||||||
capabilities: dict[str, Any] = field(default_factory=dict)
|
|
||||||
region: str = "default"
|
|
||||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ScalingMetric:
|
|
||||||
"""Scaling metric configuration"""
|
|
||||||
|
|
||||||
metric_name: str
|
|
||||||
threshold_min: float
|
|
||||||
threshold_max: float
|
|
||||||
scaling_factor: float
|
|
||||||
cooldown_period: timedelta
|
|
||||||
measurement_window: timedelta
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrafficPattern:
|
|
||||||
"""Traffic pattern for predictive scaling"""
|
|
||||||
|
|
||||||
pattern_id: str
|
|
||||||
name: str
|
|
||||||
time_windows: list[dict[str, Any]] # List of time windows with expected load
|
|
||||||
day_of_week: int # 0-6 (Monday-Sunday)
|
|
||||||
seasonal_factor: float = 1.0
|
|
||||||
confidence_score: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class PredictiveScaler:
|
|
||||||
"""AI-powered predictive auto-scaling"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.traffic_history = []
|
|
||||||
self.scaling_predictions = {}
|
|
||||||
self.traffic_patterns = {}
|
|
||||||
self.model_weights = {}
|
|
||||||
self.logger = get_logger("predictive_scaler")
|
|
||||||
|
|
||||||
async def record_traffic(self, timestamp: datetime, request_count: int, response_time_ms: float, error_rate: float):
|
|
||||||
"""Record traffic metrics"""
|
|
||||||
|
|
||||||
traffic_record = {
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"request_count": request_count,
|
|
||||||
"response_time_ms": response_time_ms,
|
|
||||||
"error_rate": error_rate,
|
|
||||||
"hour": timestamp.hour,
|
|
||||||
"day_of_week": timestamp.weekday(),
|
|
||||||
"day_of_month": timestamp.day,
|
|
||||||
"month": timestamp.month,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.traffic_history.append(traffic_record)
|
|
||||||
|
|
||||||
# Keep only last 30 days of history
|
|
||||||
cutoff = datetime.now(timezone.utc) - timedelta(days=30)
|
|
||||||
self.traffic_history = [record for record in self.traffic_history if record["timestamp"] > cutoff]
|
|
||||||
|
|
||||||
# Update traffic patterns
|
|
||||||
await self._update_traffic_patterns()
|
|
||||||
|
|
||||||
async def _update_traffic_patterns(self):
|
|
||||||
"""Update traffic patterns based on historical data"""
|
|
||||||
|
|
||||||
if len(self.traffic_history) < 168: # Need at least 1 week of data
|
|
||||||
return
|
|
||||||
|
|
||||||
# Group by hour and day of week
|
|
||||||
patterns = {}
|
|
||||||
|
|
||||||
for record in self.traffic_history:
|
|
||||||
key = f"{record['day_of_week']}_{record['hour']}"
|
|
||||||
|
|
||||||
if key not in patterns:
|
|
||||||
patterns[key] = {"request_counts": [], "response_times": [], "error_rates": []}
|
|
||||||
|
|
||||||
patterns[key]["request_counts"].append(record["request_count"])
|
|
||||||
patterns[key]["response_times"].append(record["response_time_ms"])
|
|
||||||
patterns[key]["error_rates"].append(record["error_rate"])
|
|
||||||
|
|
||||||
# Calculate pattern statistics
|
|
||||||
for key, data in patterns.items():
|
|
||||||
day_of_week, hour = key.split("_")
|
|
||||||
|
|
||||||
pattern = TrafficPattern(
|
|
||||||
pattern_id=key,
|
|
||||||
name=f"Pattern Day {day_of_week} Hour {hour}",
|
|
||||||
time_windows=[
|
|
||||||
{
|
|
||||||
"hour": int(hour),
|
|
||||||
"avg_requests": statistics.mean(data["request_counts"]),
|
|
||||||
"max_requests": max(data["request_counts"]),
|
|
||||||
"min_requests": min(data["request_counts"]),
|
|
||||||
"std_requests": statistics.stdev(data["request_counts"]) if len(data["request_counts"]) > 1 else 0,
|
|
||||||
"avg_response_time": statistics.mean(data["response_times"]),
|
|
||||||
"avg_error_rate": statistics.mean(data["error_rates"]),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
day_of_week=int(day_of_week),
|
|
||||||
confidence_score=min(len(data["request_counts"]) / 100, 1.0), # Confidence based on data points
|
|
||||||
)
|
|
||||||
|
|
||||||
self.traffic_patterns[key] = pattern
|
|
||||||
|
|
||||||
async def predict_traffic(self, prediction_window: timedelta = timedelta(hours=1)) -> dict[str, Any]:
|
|
||||||
"""Predict traffic for the next time window"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_time = datetime.now(timezone.utc)
|
|
||||||
current_time + prediction_window
|
|
||||||
|
|
||||||
# Get current pattern
|
|
||||||
current_pattern_key = f"{current_time.weekday()}_{current_time.hour}"
|
|
||||||
current_pattern = self.traffic_patterns.get(current_pattern_key)
|
|
||||||
|
|
||||||
if not current_pattern:
|
|
||||||
# Fallback to simple prediction
|
|
||||||
return await self._simple_prediction(prediction_window)
|
|
||||||
|
|
||||||
# Get historical data for similar time periods
|
|
||||||
similar_patterns = [
|
|
||||||
pattern
|
|
||||||
for pattern in self.traffic_patterns.values()
|
|
||||||
if pattern.day_of_week == current_time.weekday()
|
|
||||||
and abs(pattern.time_windows[0]["hour"] - current_time.hour) <= 2
|
|
||||||
]
|
|
||||||
|
|
||||||
if not similar_patterns:
|
|
||||||
return await self._simple_prediction(prediction_window)
|
|
||||||
|
|
||||||
# Calculate weighted prediction
|
|
||||||
total_weight = 0
|
|
||||||
weighted_requests = 0
|
|
||||||
weighted_response_time = 0
|
|
||||||
weighted_error_rate = 0
|
|
||||||
|
|
||||||
for pattern in similar_patterns:
|
|
||||||
weight = pattern.confidence_score
|
|
||||||
window_data = pattern.time_windows[0]
|
|
||||||
|
|
||||||
weighted_requests += window_data["avg_requests"] * weight
|
|
||||||
weighted_response_time += window_data["avg_response_time"] * weight
|
|
||||||
weighted_error_rate += window_data["avg_error_rate"] * weight
|
|
||||||
total_weight += weight
|
|
||||||
|
|
||||||
if total_weight > 0:
|
|
||||||
predicted_requests = weighted_requests / total_weight
|
|
||||||
predicted_response_time = weighted_response_time / total_weight
|
|
||||||
predicted_error_rate = weighted_error_rate / total_weight
|
|
||||||
else:
|
|
||||||
return await self._simple_prediction(prediction_window)
|
|
||||||
|
|
||||||
# Apply seasonal factors
|
|
||||||
seasonal_factor = self._get_seasonal_factor(current_time)
|
|
||||||
predicted_requests *= seasonal_factor
|
|
||||||
|
|
||||||
return {
|
|
||||||
"prediction_window_hours": prediction_window.total_seconds() / 3600,
|
|
||||||
"predicted_requests_per_hour": int(predicted_requests),
|
|
||||||
"predicted_response_time_ms": predicted_response_time,
|
|
||||||
"predicted_error_rate": predicted_error_rate,
|
|
||||||
"confidence_score": min(total_weight / len(similar_patterns), 1.0),
|
|
||||||
"seasonal_factor": seasonal_factor,
|
|
||||||
"pattern_based": True,
|
|
||||||
"prediction_timestamp": current_time.isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Traffic prediction failed: {e}")
|
|
||||||
return await self._simple_prediction(prediction_window)
|
|
||||||
|
|
||||||
async def _simple_prediction(self, prediction_window: timedelta) -> dict[str, Any]:
|
|
||||||
"""Simple prediction based on recent averages"""
|
|
||||||
|
|
||||||
if not self.traffic_history:
|
|
||||||
return {
|
|
||||||
"prediction_window_hours": prediction_window.total_seconds() / 3600,
|
|
||||||
"predicted_requests_per_hour": 1000, # Default
|
|
||||||
"predicted_response_time_ms": 100.0,
|
|
||||||
"predicted_error_rate": 0.01,
|
|
||||||
"confidence_score": 0.1,
|
|
||||||
"pattern_based": False,
|
|
||||||
"prediction_timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Calculate recent averages
|
|
||||||
recent_records = self.traffic_history[-24:] # Last 24 records
|
|
||||||
|
|
||||||
avg_requests = statistics.mean([r["request_count"] for r in recent_records])
|
|
||||||
avg_response_time = statistics.mean([r["response_time_ms"] for r in recent_records])
|
|
||||||
avg_error_rate = statistics.mean([r["error_rate"] for r in recent_records])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"prediction_window_hours": prediction_window.total_seconds() / 3600,
|
|
||||||
"predicted_requests_per_hour": int(avg_requests),
|
|
||||||
"predicted_response_time_ms": avg_response_time,
|
|
||||||
"predicted_error_rate": avg_error_rate,
|
|
||||||
"confidence_score": 0.3,
|
|
||||||
"pattern_based": False,
|
|
||||||
"prediction_timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_seasonal_factor(self, timestamp: datetime) -> float:
|
|
||||||
"""Get seasonal adjustment factor"""
|
|
||||||
|
|
||||||
# Simple seasonal factors (can be enhanced with more sophisticated models)
|
|
||||||
month = timestamp.month
|
|
||||||
|
|
||||||
seasonal_factors = {
|
|
||||||
1: 0.8, # January - post-holiday dip
|
|
||||||
2: 0.9, # February
|
|
||||||
3: 1.0, # March
|
|
||||||
4: 1.1, # April - spring increase
|
|
||||||
5: 1.2, # May
|
|
||||||
6: 1.1, # June
|
|
||||||
7: 1.0, # July - summer
|
|
||||||
8: 0.9, # August
|
|
||||||
9: 1.1, # September - back to business
|
|
||||||
10: 1.2, # October
|
|
||||||
11: 1.3, # November - holiday season start
|
|
||||||
12: 1.4, # December - peak holiday season
|
|
||||||
}
|
|
||||||
|
|
||||||
return seasonal_factors.get(month, 1.0)
|
|
||||||
|
|
||||||
async def get_scaling_recommendation(self, current_servers: int, current_capacity: int) -> dict[str, Any]:
|
|
||||||
"""Get scaling recommendation based on predictions"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get traffic prediction
|
|
||||||
prediction = await self.predict_traffic(timedelta(hours=1))
|
|
||||||
|
|
||||||
predicted_requests = prediction["predicted_requests_per_hour"]
|
|
||||||
current_capacity_per_server = current_capacity // max(current_servers, 1)
|
|
||||||
|
|
||||||
# Calculate required servers
|
|
||||||
required_servers = max(1, int(predicted_requests / current_capacity_per_server))
|
|
||||||
|
|
||||||
# Apply buffer (20% extra capacity)
|
|
||||||
required_servers = int(required_servers * 1.2)
|
|
||||||
|
|
||||||
scaling_action = "none"
|
|
||||||
if required_servers > current_servers:
|
|
||||||
scaling_action = "scale_up"
|
|
||||||
scale_to = required_servers
|
|
||||||
elif required_servers < current_servers * 0.7: # Scale down if underutilized
|
|
||||||
scaling_action = "scale_down"
|
|
||||||
scale_to = max(1, required_servers)
|
|
||||||
else:
|
|
||||||
scale_to = current_servers
|
|
||||||
|
|
||||||
return {
|
|
||||||
"current_servers": current_servers,
|
|
||||||
"recommended_servers": scale_to,
|
|
||||||
"scaling_action": scaling_action,
|
|
||||||
"predicted_load": predicted_requests,
|
|
||||||
"current_capacity_per_server": current_capacity_per_server,
|
|
||||||
"confidence_score": prediction["confidence_score"],
|
|
||||||
"reason": f"Predicted {predicted_requests} requests/hour vs current capacity {current_servers * current_capacity_per_server}",
|
|
||||||
"recommendation_timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Scaling recommendation failed: {e}")
|
|
||||||
return {
|
|
||||||
"scaling_action": "none",
|
|
||||||
"reason": f"Prediction failed: {str(e)}",
|
|
||||||
"recommendation_timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AdvancedLoadBalancer:
|
|
||||||
"""Advanced load balancer with multiple algorithms and AI optimization"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.backends = {}
|
|
||||||
self.algorithm = LoadBalancingAlgorithm.ADAPTIVE
|
|
||||||
self.current_index = 0
|
|
||||||
self.request_history = []
|
|
||||||
self.performance_metrics = {}
|
|
||||||
self.predictive_scaler = PredictiveScaler()
|
|
||||||
self.scaling_metrics = {}
|
|
||||||
self.logger = get_logger("advanced_load_balancer")
|
|
||||||
|
|
||||||
async def add_backend(self, server: BackendServer) -> bool:
|
|
||||||
"""Add backend server"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.backends[server.server_id] = server
|
|
||||||
|
|
||||||
# Initialize performance metrics
|
|
||||||
self.performance_metrics[server.server_id] = {
|
|
||||||
"avg_response_time": 0.0,
|
|
||||||
"error_rate": 0.0,
|
|
||||||
"throughput": 0.0,
|
|
||||||
"uptime": 1.0,
|
|
||||||
"last_updated": datetime.now(timezone.utc),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.logger.info(f"Backend server added: {server.server_id}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to add backend server: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def remove_backend(self, server_id: str) -> bool:
|
|
||||||
"""Remove backend server"""
|
|
||||||
|
|
||||||
if server_id in self.backends:
|
|
||||||
del self.backends[server_id]
|
|
||||||
del self.performance_metrics[server_id]
|
|
||||||
|
|
||||||
self.logger.info(f"Backend server removed: {server_id}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def select_backend(self, request_context: dict[str, Any] | None = None) -> str | None:
|
|
||||||
"""Select backend server based on algorithm"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Filter healthy backends
|
|
||||||
healthy_backends = {
|
|
||||||
sid: server for sid, server in self.backends.items() if server.health_status == HealthStatus.HEALTHY
|
|
||||||
}
|
|
||||||
|
|
||||||
if not healthy_backends:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Select backend based on algorithm
|
|
||||||
if self.algorithm == LoadBalancingAlgorithm.ROUND_ROBIN:
|
|
||||||
return await self._select_round_robin(healthy_backends)
|
|
||||||
elif self.algorithm == LoadBalancingAlgorithm.WEIGHTED_ROUND_ROBIN:
|
|
||||||
return await self._select_weighted_round_robin(healthy_backends)
|
|
||||||
elif self.algorithm == LoadBalancingAlgorithm.LEAST_CONNECTIONS:
|
|
||||||
return await self._select_least_connections(healthy_backends)
|
|
||||||
elif self.algorithm == LoadBalancingAlgorithm.LEAST_RESPONSE_TIME:
|
|
||||||
return await self._select_least_response_time(healthy_backends)
|
|
||||||
elif self.algorithm == LoadBalancingAlgorithm.RESOURCE_BASED:
|
|
||||||
return await self._select_resource_based(healthy_backends)
|
|
||||||
elif self.algorithm == LoadBalancingAlgorithm.PREDICTIVE_AI:
|
|
||||||
return await self._select_predictive_ai(healthy_backends, request_context)
|
|
||||||
elif self.algorithm == LoadBalancingAlgorithm.ADAPTIVE:
|
|
||||||
return await self._select_adaptive(healthy_backends, request_context)
|
|
||||||
else:
|
|
||||||
return await self._select_round_robin(healthy_backends)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Backend selection failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _select_round_robin(self, backends: dict[str, BackendServer]) -> str:
|
|
||||||
"""Round robin selection"""
|
|
||||||
|
|
||||||
backend_ids = list(backends.keys())
|
|
||||||
|
|
||||||
if not backend_ids:
|
|
||||||
return None
|
|
||||||
|
|
||||||
selected = backend_ids[self.current_index % len(backend_ids)]
|
|
||||||
self.current_index += 1
|
|
||||||
|
|
||||||
return selected
|
|
||||||
|
|
||||||
async def _select_weighted_round_robin(self, backends: dict[str, BackendServer]) -> str:
|
|
||||||
"""Weighted round robin selection"""
|
|
||||||
|
|
||||||
# Calculate total weight
|
|
||||||
total_weight = sum(server.weight for server in backends.values())
|
|
||||||
|
|
||||||
if total_weight <= 0:
|
|
||||||
return await self._select_round_robin(backends)
|
|
||||||
|
|
||||||
# Select based on weights
|
|
||||||
import random
|
|
||||||
|
|
||||||
rand_value = random.uniform(0, total_weight)
|
|
||||||
|
|
||||||
current_weight = 0
|
|
||||||
for server_id, server in backends.items():
|
|
||||||
current_weight += server.weight
|
|
||||||
if rand_value <= current_weight:
|
|
||||||
return server_id
|
|
||||||
|
|
||||||
# Fallback
|
|
||||||
return list(backends.keys())[0]
|
|
||||||
|
|
||||||
async def _select_least_connections(self, backends: dict[str, BackendServer]) -> str:
|
|
||||||
"""Select backend with least connections"""
|
|
||||||
|
|
||||||
min_connections = float("inf")
|
|
||||||
selected_backend = None
|
|
||||||
|
|
||||||
for server_id, server in backends.items():
|
|
||||||
if server.current_connections < min_connections:
|
|
||||||
min_connections = server.current_connections
|
|
||||||
selected_backend = server_id
|
|
||||||
|
|
||||||
return selected_backend
|
|
||||||
|
|
||||||
async def _select_least_response_time(self, backends: dict[str, BackendServer]) -> str:
|
|
||||||
"""Select backend with least response time"""
|
|
||||||
|
|
||||||
min_response_time = float("inf")
|
|
||||||
selected_backend = None
|
|
||||||
|
|
||||||
for server_id, server in backends.items():
|
|
||||||
if server.response_time_ms < min_response_time:
|
|
||||||
min_response_time = server.response_time_ms
|
|
||||||
selected_backend = server_id
|
|
||||||
|
|
||||||
return selected_backend
|
|
||||||
|
|
||||||
async def _select_resource_based(self, backends: dict[str, BackendServer]) -> str:
|
|
||||||
"""Select backend based on resource utilization"""
|
|
||||||
|
|
||||||
best_score = -1
|
|
||||||
selected_backend = None
|
|
||||||
|
|
||||||
for server_id, server in backends.items():
|
|
||||||
# Calculate resource score (lower is better)
|
|
||||||
cpu_score = 1.0 - (server.cpu_usage / 100.0)
|
|
||||||
memory_score = 1.0 - (server.memory_usage / 100.0)
|
|
||||||
connection_score = 1.0 - (server.current_connections / server.max_connections)
|
|
||||||
|
|
||||||
# Weighted score
|
|
||||||
resource_score = cpu_score * 0.4 + memory_score * 0.3 + connection_score * 0.3
|
|
||||||
|
|
||||||
if resource_score > best_score:
|
|
||||||
best_score = resource_score
|
|
||||||
selected_backend = server_id
|
|
||||||
|
|
||||||
return selected_backend
|
|
||||||
|
|
||||||
async def _select_predictive_ai(
|
|
||||||
self, backends: dict[str, BackendServer], request_context: dict[str, Any] | None
|
|
||||||
) -> str:
|
|
||||||
"""AI-powered predictive selection"""
|
|
||||||
|
|
||||||
# Get performance predictions for each backend
|
|
||||||
backend_scores = {}
|
|
||||||
|
|
||||||
for server_id, server in backends.items():
|
|
||||||
# Predict performance based on historical data
|
|
||||||
self.performance_metrics.get(server_id, {})
|
|
||||||
|
|
||||||
# Calculate predicted response time
|
|
||||||
predicted_response_time = (
|
|
||||||
server.response_time_ms
|
|
||||||
* (1 + server.cpu_usage / 100)
|
|
||||||
* (1 + server.memory_usage / 100)
|
|
||||||
* (1 + server.current_connections / server.max_connections)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate score (lower response time is better)
|
|
||||||
score = 1.0 / (1.0 + predicted_response_time / 100.0)
|
|
||||||
|
|
||||||
# Apply context-based adjustments
|
|
||||||
if request_context:
|
|
||||||
# Consider request type, user location, etc.
|
|
||||||
context_multiplier = await self._calculate_context_multiplier(server, request_context)
|
|
||||||
score *= context_multiplier
|
|
||||||
|
|
||||||
backend_scores[server_id] = score
|
|
||||||
|
|
||||||
# Select best scoring backend
|
|
||||||
if backend_scores:
|
|
||||||
return max(backend_scores, key=backend_scores.get)
|
|
||||||
|
|
||||||
return await self._select_least_connections(backends)
|
|
||||||
|
|
||||||
async def _select_adaptive(self, backends: dict[str, BackendServer], request_context: dict[str, Any] | None) -> str:
|
|
||||||
"""Adaptive selection based on current conditions"""
|
|
||||||
|
|
||||||
# Analyze current system state
|
|
||||||
total_connections = sum(server.current_connections for server in backends.values())
|
|
||||||
avg_response_time = statistics.mean([server.response_time_ms for server in backends.values()])
|
|
||||||
|
|
||||||
# Choose algorithm based on conditions
|
|
||||||
if total_connections > sum(server.max_connections for server in backends.values()) * 0.8:
|
|
||||||
# High load - use resource-based
|
|
||||||
return await self._select_resource_based(backends)
|
|
||||||
elif avg_response_time > 200:
|
|
||||||
# High latency - use least response time
|
|
||||||
return await self._select_least_response_time(backends)
|
|
||||||
else:
|
|
||||||
# Normal conditions - use weighted round robin
|
|
||||||
return await self._select_weighted_round_robin(backends)
|
|
||||||
|
|
||||||
async def _calculate_context_multiplier(self, server: BackendServer, request_context: dict[str, Any]) -> float:
|
|
||||||
"""Calculate context-based multiplier for backend selection"""
|
|
||||||
|
|
||||||
multiplier = 1.0
|
|
||||||
|
|
||||||
# Consider geographic location
|
|
||||||
if "user_location" in request_context and "region" in server.capabilities:
|
|
||||||
user_region = request_context["user_location"].get("region")
|
|
||||||
server_region = server.capabilities["region"]
|
|
||||||
|
|
||||||
if user_region == server_region:
|
|
||||||
multiplier *= 1.2 # Prefer same region
|
|
||||||
elif self._regions_in_same_continent(user_region, server_region):
|
|
||||||
multiplier *= 1.1 # Slight preference for same continent
|
|
||||||
|
|
||||||
# Consider request type
|
|
||||||
request_type = request_context.get("request_type", "general")
|
|
||||||
server_specializations = server.capabilities.get("specializations", [])
|
|
||||||
|
|
||||||
if request_type in server_specializations:
|
|
||||||
multiplier *= 1.3 # Strong preference for specialized backends
|
|
||||||
|
|
||||||
# Consider user tier
|
|
||||||
user_tier = request_context.get("user_tier", "standard")
|
|
||||||
if user_tier == "premium" and server.capabilities.get("premium_support", False):
|
|
||||||
multiplier *= 1.15
|
|
||||||
|
|
||||||
return multiplier
|
|
||||||
|
|
||||||
def _regions_in_same_continent(self, region1: str, region2: str) -> bool:
|
|
||||||
"""Check if two regions are in the same continent"""
|
|
||||||
|
|
||||||
continent_mapping = {
|
|
||||||
"NA": ["US", "CA", "MX"],
|
|
||||||
"EU": ["GB", "DE", "FR", "IT", "ES", "NL", "BE", "AT", "CH", "SE", "NO", "DK", "FI"],
|
|
||||||
"APAC": ["JP", "KR", "SG", "AU", "IN", "TH", "MY", "ID", "PH", "VN"],
|
|
||||||
"LATAM": ["BR", "MX", "AR", "CL", "CO", "PE", "VE"],
|
|
||||||
}
|
|
||||||
|
|
||||||
for _continent, regions in continent_mapping.items():
|
|
||||||
if region1 in regions and region2 in regions:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def record_request(
|
|
||||||
self, server_id: str, response_time_ms: float, success: bool, timestamp: datetime | None = None
|
|
||||||
):
|
|
||||||
"""Record request metrics"""
|
|
||||||
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Update backend server metrics
|
|
||||||
if server_id in self.backends:
|
|
||||||
server = self.backends[server_id]
|
|
||||||
server.request_count += 1
|
|
||||||
server.response_time_ms = server.response_time_ms * 0.9 + response_time_ms * 0.1 # EMA
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
server.error_count += 1
|
|
||||||
|
|
||||||
# Record in history
|
|
||||||
request_record = {
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"server_id": server_id,
|
|
||||||
"response_time_ms": response_time_ms,
|
|
||||||
"success": success,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.request_history.append(request_record)
|
|
||||||
|
|
||||||
# Keep only last 10000 records
|
|
||||||
if len(self.request_history) > 10000:
|
|
||||||
self.request_history = self.request_history[-10000:]
|
|
||||||
|
|
||||||
# Update predictive scaler
|
|
||||||
await self.predictive_scaler.record_traffic(
|
|
||||||
timestamp, 1, response_time_ms, 0.0 if success else 1.0 # One request # Error rate
|
|
||||||
)
|
|
||||||
|
|
||||||
async def update_backend_health(
|
|
||||||
self, server_id: str, health_status: HealthStatus, cpu_usage: float, memory_usage: float, current_connections: int
|
|
||||||
):
|
|
||||||
"""Update backend health metrics"""
|
|
||||||
|
|
||||||
if server_id in self.backends:
|
|
||||||
server = self.backends[server_id]
|
|
||||||
server.health_status = health_status
|
|
||||||
server.cpu_usage = cpu_usage
|
|
||||||
server.memory_usage = memory_usage
|
|
||||||
server.current_connections = current_connections
|
|
||||||
server.last_health_check = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
async def get_load_balancing_metrics(self) -> dict[str, Any]:
|
|
||||||
"""Get comprehensive load balancing metrics"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
total_requests = sum(server.request_count for server in self.backends.values())
|
|
||||||
total_errors = sum(server.error_count for server in self.backends.values())
|
|
||||||
total_connections = sum(server.current_connections for server in self.backends.values())
|
|
||||||
|
|
||||||
error_rate = (total_errors / total_requests) if total_requests > 0 else 0.0
|
|
||||||
|
|
||||||
# Calculate average response time
|
|
||||||
avg_response_time = 0.0
|
|
||||||
if self.backends:
|
|
||||||
avg_response_time = statistics.mean([server.response_time_ms for server in self.backends.values()])
|
|
||||||
|
|
||||||
# Backend distribution
|
|
||||||
backend_distribution = {}
|
|
||||||
for server_id, server in self.backends.items():
|
|
||||||
backend_distribution[server_id] = {
|
|
||||||
"requests": server.request_count,
|
|
||||||
"errors": server.error_count,
|
|
||||||
"connections": server.current_connections,
|
|
||||||
"response_time_ms": server.response_time_ms,
|
|
||||||
"cpu_usage": server.cpu_usage,
|
|
||||||
"memory_usage": server.memory_usage,
|
|
||||||
"health_status": server.health_status.value,
|
|
||||||
"weight": server.weight,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get scaling recommendation
|
|
||||||
scaling_recommendation = await self.predictive_scaler.get_scaling_recommendation(
|
|
||||||
len(self.backends), sum(server.max_connections for server in self.backends.values())
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_backends": len(self.backends),
|
|
||||||
"healthy_backends": len([s for s in self.backends.values() if s.health_status == HealthStatus.HEALTHY]),
|
|
||||||
"total_requests": total_requests,
|
|
||||||
"total_errors": total_errors,
|
|
||||||
"error_rate": error_rate,
|
|
||||||
"average_response_time_ms": avg_response_time,
|
|
||||||
"total_connections": total_connections,
|
|
||||||
"algorithm": self.algorithm.value,
|
|
||||||
"backend_distribution": backend_distribution,
|
|
||||||
"scaling_recommendation": scaling_recommendation,
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Metrics retrieval failed: {e}")
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
async def set_algorithm(self, algorithm: LoadBalancingAlgorithm):
|
|
||||||
"""Set load balancing algorithm"""
|
|
||||||
|
|
||||||
self.algorithm = algorithm
|
|
||||||
self.logger.info(f"Load balancing algorithm changed to: {algorithm.value}")
|
|
||||||
|
|
||||||
async def auto_scale(self, min_servers: int = 1, max_servers: int = 10) -> dict[str, Any]:
|
|
||||||
"""Perform auto-scaling based on predictions"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get scaling recommendation
|
|
||||||
recommendation = await self.predictive_scaler.get_scaling_recommendation(
|
|
||||||
len(self.backends), sum(server.max_connections for server in self.backends.values())
|
|
||||||
)
|
|
||||||
|
|
||||||
action = recommendation["scaling_action"]
|
|
||||||
target_servers = recommendation["recommended_servers"]
|
|
||||||
|
|
||||||
# Apply scaling limits
|
|
||||||
target_servers = max(min_servers, min(max_servers, target_servers))
|
|
||||||
|
|
||||||
scaling_result = {
|
|
||||||
"action": action,
|
|
||||||
"current_servers": len(self.backends),
|
|
||||||
"target_servers": target_servers,
|
|
||||||
"confidence": recommendation.get("confidence_score", 0.0),
|
|
||||||
"reason": recommendation.get("reason", ""),
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
# In production, implement actual scaling logic here
|
|
||||||
# For now, just return the recommendation
|
|
||||||
|
|
||||||
self.logger.info(f"Auto-scaling recommendation: {action} to {target_servers} servers")
|
|
||||||
|
|
||||||
return scaling_result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Auto-scaling failed: {e}")
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
# Global load balancer instance
|
|
||||||
advanced_load_balancer = None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_advanced_load_balancer() -> AdvancedLoadBalancer:
|
|
||||||
"""Get or create global advanced load balancer"""
|
|
||||||
|
|
||||||
global advanced_load_balancer
|
|
||||||
if advanced_load_balancer is None:
|
|
||||||
advanced_load_balancer = AdvancedLoadBalancer()
|
|
||||||
|
|
||||||
# Add default backends
|
|
||||||
default_backends = [
|
|
||||||
BackendServer(
|
|
||||||
server_id="backend_1", host="10.0.1.10", port=8080, weight=1.0, max_connections=1000, region="us_east"
|
|
||||||
),
|
|
||||||
BackendServer(
|
|
||||||
server_id="backend_2", host="10.0.1.11", port=8080, weight=1.0, max_connections=1000, region="us_east"
|
|
||||||
),
|
|
||||||
BackendServer(
|
|
||||||
server_id="backend_3", host="10.0.1.12", port=8080, weight=0.8, max_connections=800, region="eu_west"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for backend in default_backends:
|
|
||||||
await advanced_load_balancer.add_backend(backend)
|
|
||||||
|
|
||||||
return advanced_load_balancer
|
|
||||||
@@ -1,773 +0,0 @@
|
|||||||
"""
|
|
||||||
Enterprise Security Framework - Phase 6.2 Implementation
|
|
||||||
Zero-trust architecture with HSM integration and advanced security controls
|
|
||||||
"""
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Any
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import cryptography
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
||||||
|
|
||||||
from aitbc import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SecurityLevel(StrEnum):
|
|
||||||
"""Security levels for enterprise data"""
|
|
||||||
|
|
||||||
PUBLIC = "public"
|
|
||||||
INTERNAL = "internal"
|
|
||||||
CONFIDENTIAL = "confidential"
|
|
||||||
RESTRICTED = "restricted"
|
|
||||||
TOP_SECRET = "top_secret"
|
|
||||||
|
|
||||||
|
|
||||||
class EncryptionAlgorithm(StrEnum):
|
|
||||||
"""Encryption algorithms"""
|
|
||||||
|
|
||||||
AES_256_GCM = "aes_256_gcm"
|
|
||||||
CHACHA20_POLY1305 = "chacha20_polyy1305"
|
|
||||||
AES_256_CBC = "aes_256_cbc"
|
|
||||||
QUANTUM_RESISTANT = "quantum_resistant"
|
|
||||||
|
|
||||||
|
|
||||||
class ThreatLevel(StrEnum):
|
|
||||||
"""Threat levels for security monitoring"""
|
|
||||||
|
|
||||||
LOW = "low"
|
|
||||||
MEDIUM = "medium"
|
|
||||||
HIGH = "high"
|
|
||||||
CRITICAL = "critical"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SecurityPolicy:
|
|
||||||
"""Security policy configuration"""
|
|
||||||
|
|
||||||
policy_id: str
|
|
||||||
name: str
|
|
||||||
security_level: SecurityLevel
|
|
||||||
encryption_algorithm: EncryptionAlgorithm
|
|
||||||
key_rotation_interval: timedelta
|
|
||||||
access_control_requirements: list[str]
|
|
||||||
audit_requirements: list[str]
|
|
||||||
retention_period: timedelta
|
|
||||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
||||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SecurityEvent:
|
|
||||||
"""Security event for monitoring"""
|
|
||||||
|
|
||||||
event_id: str
|
|
||||||
event_type: str
|
|
||||||
severity: ThreatLevel
|
|
||||||
source: str
|
|
||||||
timestamp: datetime
|
|
||||||
user_id: str | None
|
|
||||||
resource_id: str | None
|
|
||||||
details: dict[str, Any]
|
|
||||||
resolved: bool = False
|
|
||||||
resolution_notes: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class HSMManager:
|
|
||||||
"""Hardware Security Module manager for enterprise key management"""
|
|
||||||
|
|
||||||
def __init__(self, hsm_config: dict[str, Any]):
|
|
||||||
self.hsm_config = hsm_config
|
|
||||||
self.backend = default_backend()
|
|
||||||
self.key_store = {} # In production, use actual HSM
|
|
||||||
self.logger = get_logger("hsm_manager")
|
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
|
||||||
"""Initialize HSM connection"""
|
|
||||||
try:
|
|
||||||
# In production, initialize actual HSM connection
|
|
||||||
# For now, simulate HSM initialization
|
|
||||||
self.logger.info("HSM manager initialized")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"HSM initialization failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def generate_key(self, key_id: str, algorithm: EncryptionAlgorithm, key_size: int = 256) -> dict[str, Any]:
|
|
||||||
"""Generate encryption key in HSM"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
if algorithm == EncryptionAlgorithm.AES_256_GCM:
|
|
||||||
key = secrets.token_bytes(32) # 256 bits
|
|
||||||
iv = secrets.token_bytes(12) # 96 bits for GCM
|
|
||||||
elif algorithm == EncryptionAlgorithm.CHACHA20_POLY1305:
|
|
||||||
key = secrets.token_bytes(32) # 256 bits
|
|
||||||
nonce = secrets.token_bytes(12) # 96 bits
|
|
||||||
elif algorithm == EncryptionAlgorithm.AES_256_CBC:
|
|
||||||
key = secrets.token_bytes(32) # 256 bits
|
|
||||||
iv = secrets.token_bytes(16) # 128 bits for CBC
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
|
||||||
|
|
||||||
# Store key in HSM (simulated)
|
|
||||||
key_data = {
|
|
||||||
"key_id": key_id,
|
|
||||||
"algorithm": algorithm.value,
|
|
||||||
"key": key,
|
|
||||||
"iv": iv if algorithm in [EncryptionAlgorithm.AES_256_GCM, EncryptionAlgorithm.AES_256_CBC] else None,
|
|
||||||
"nonce": nonce if algorithm == EncryptionAlgorithm.CHACHA20_POLY1305 else None,
|
|
||||||
"created_at": datetime.now(timezone.utc),
|
|
||||||
"key_size": key_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.key_store[key_id] = key_data
|
|
||||||
|
|
||||||
self.logger.info(f"Key generated in HSM: {key_id}")
|
|
||||||
return key_data
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Key generation failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_key(self, key_id: str) -> dict[str, Any] | None:
|
|
||||||
"""Get key from HSM"""
|
|
||||||
return self.key_store.get(key_id)
|
|
||||||
|
|
||||||
async def rotate_key(self, key_id: str) -> dict[str, Any]:
|
|
||||||
"""Rotate encryption key"""
|
|
||||||
|
|
||||||
old_key = self.key_store.get(key_id)
|
|
||||||
if not old_key:
|
|
||||||
raise ValueError(f"Key not found: {key_id}")
|
|
||||||
|
|
||||||
# Generate new key
|
|
||||||
new_key = await self.generate_key(f"{key_id}_new", EncryptionAlgorithm(old_key["algorithm"]), old_key["key_size"])
|
|
||||||
|
|
||||||
# Update key with rotation timestamp
|
|
||||||
new_key["rotated_from"] = key_id
|
|
||||||
new_key["rotation_timestamp"] = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
return new_key
|
|
||||||
|
|
||||||
async def delete_key(self, key_id: str) -> bool:
|
|
||||||
"""Delete key from HSM"""
|
|
||||||
if key_id in self.key_store:
|
|
||||||
del self.key_store[key_id]
|
|
||||||
self.logger.info(f"Key deleted from HSM: {key_id}")
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseEncryption:
|
|
||||||
"""Enterprise-grade encryption service"""
|
|
||||||
|
|
||||||
def __init__(self, hsm_manager: HSMManager):
|
|
||||||
self.hsm_manager = hsm_manager
|
|
||||||
self.backend = default_backend()
|
|
||||||
self.logger = get_logger("enterprise_encryption")
|
|
||||||
|
|
||||||
async def encrypt_data(
|
|
||||||
self, data: str | bytes, key_id: str, associated_data: bytes | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Encrypt data using enterprise-grade encryption"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get key from HSM
|
|
||||||
key_data = await self.hsm_manager.get_key(key_id)
|
|
||||||
if not key_data:
|
|
||||||
raise ValueError(f"Key not found: {key_id}")
|
|
||||||
|
|
||||||
# Convert data to bytes if needed
|
|
||||||
if isinstance(data, str):
|
|
||||||
data = data.encode("utf-8")
|
|
||||||
|
|
||||||
algorithm = EncryptionAlgorithm(key_data["algorithm"])
|
|
||||||
|
|
||||||
if algorithm == EncryptionAlgorithm.AES_256_GCM:
|
|
||||||
return await self._encrypt_aes_gcm(data, key_data, associated_data)
|
|
||||||
elif algorithm == EncryptionAlgorithm.CHACHA20_POLY1305:
|
|
||||||
return await self._encrypt_chacha20(data, key_data, associated_data)
|
|
||||||
elif algorithm == EncryptionAlgorithm.AES_256_CBC:
|
|
||||||
return await self._encrypt_aes_cbc(data, key_data)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported encryption algorithm: {algorithm}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Encryption failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _encrypt_aes_gcm(
|
|
||||||
self, data: bytes, key_data: dict[str, Any], associated_data: bytes | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Encrypt using AES-256-GCM"""
|
|
||||||
|
|
||||||
key = key_data["key"]
|
|
||||||
iv = key_data["iv"]
|
|
||||||
|
|
||||||
# Create cipher
|
|
||||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=self.backend)
|
|
||||||
|
|
||||||
encryptor = cipher.encryptor()
|
|
||||||
|
|
||||||
# Add associated data if provided
|
|
||||||
if associated_data:
|
|
||||||
encryptor.authenticate_additional_data(associated_data)
|
|
||||||
|
|
||||||
# Encrypt data
|
|
||||||
ciphertext = encryptor.update(data) + encryptor.finalize()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ciphertext": ciphertext.hex(),
|
|
||||||
"iv": iv.hex(),
|
|
||||||
"tag": encryptor.tag.hex(),
|
|
||||||
"algorithm": "aes_256_gcm",
|
|
||||||
"key_id": key_data["key_id"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _encrypt_chacha20(
|
|
||||||
self, data: bytes, key_data: dict[str, Any], associated_data: bytes | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Encrypt using ChaCha20-Poly1305"""
|
|
||||||
|
|
||||||
key = key_data["key"]
|
|
||||||
nonce = key_data["nonce"]
|
|
||||||
|
|
||||||
# Create cipher
|
|
||||||
cipher = Cipher(algorithms.ChaCha20(key, nonce), modes.Poly1305(b""), backend=self.backend)
|
|
||||||
|
|
||||||
encryptor = cipher.encryptor()
|
|
||||||
|
|
||||||
# Add associated data if provided
|
|
||||||
if associated_data:
|
|
||||||
encryptor.authenticate_additional_data(associated_data)
|
|
||||||
|
|
||||||
# Encrypt data
|
|
||||||
ciphertext = encryptor.update(data) + encryptor.finalize()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ciphertext": ciphertext.hex(),
|
|
||||||
"nonce": nonce.hex(),
|
|
||||||
"tag": encryptor.tag.hex(),
|
|
||||||
"algorithm": "chacha20_poly1305",
|
|
||||||
"key_id": key_data["key_id"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _encrypt_aes_cbc(self, data: bytes, key_data: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Encrypt using AES-256-CBC"""
|
|
||||||
|
|
||||||
key = key_data["key"]
|
|
||||||
iv = key_data["iv"]
|
|
||||||
|
|
||||||
# Pad data to block size
|
|
||||||
padder = cryptography.hazmat.primitives.padding.PKCS7(128).padder()
|
|
||||||
padded_data = padder.update(data) + padder.finalize()
|
|
||||||
|
|
||||||
# Create cipher
|
|
||||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=self.backend)
|
|
||||||
|
|
||||||
encryptor = cipher.encryptor()
|
|
||||||
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
|
|
||||||
|
|
||||||
return {"ciphertext": ciphertext.hex(), "iv": iv.hex(), "algorithm": "aes_256_cbc", "key_id": key_data["key_id"]}
|
|
||||||
|
|
||||||
async def decrypt_data(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes:
|
|
||||||
"""Decrypt encrypted data"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
algorithm = encrypted_data["algorithm"]
|
|
||||||
|
|
||||||
if algorithm == "aes_256_gcm":
|
|
||||||
return await self._decrypt_aes_gcm(encrypted_data, associated_data)
|
|
||||||
elif algorithm == "chacha20_poly1305":
|
|
||||||
return await self._decrypt_chacha20(encrypted_data, associated_data)
|
|
||||||
elif algorithm == "aes_256_cbc":
|
|
||||||
return await self._decrypt_aes_cbc(encrypted_data)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported encryption algorithm: {algorithm}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Decryption failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _decrypt_aes_gcm(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes:
|
|
||||||
"""Decrypt AES-256-GCM encrypted data"""
|
|
||||||
|
|
||||||
# Get key from HSM
|
|
||||||
key_data = await self.hsm_manager.get_key(encrypted_data["key_id"])
|
|
||||||
if not key_data:
|
|
||||||
raise ValueError(f"Key not found: {encrypted_data['key_id']}")
|
|
||||||
|
|
||||||
key = key_data["key"]
|
|
||||||
iv = bytes.fromhex(encrypted_data["iv"])
|
|
||||||
ciphertext = bytes.fromhex(encrypted_data["ciphertext"])
|
|
||||||
tag = bytes.fromhex(encrypted_data["tag"])
|
|
||||||
|
|
||||||
# Create cipher
|
|
||||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=self.backend)
|
|
||||||
|
|
||||||
decryptor = cipher.decryptor()
|
|
||||||
|
|
||||||
# Add associated data if provided
|
|
||||||
if associated_data:
|
|
||||||
decryptor.authenticate_additional_data(associated_data)
|
|
||||||
|
|
||||||
# Decrypt data
|
|
||||||
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
|
||||||
|
|
||||||
return plaintext
|
|
||||||
|
|
||||||
async def _decrypt_chacha20(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes:
|
|
||||||
"""Decrypt ChaCha20-Poly1305 encrypted data"""
|
|
||||||
|
|
||||||
# Get key from HSM
|
|
||||||
key_data = await self.hsm_manager.get_key(encrypted_data["key_id"])
|
|
||||||
if not key_data:
|
|
||||||
raise ValueError(f"Key not found: {encrypted_data['key_id']}")
|
|
||||||
|
|
||||||
key = key_data["key"]
|
|
||||||
nonce = bytes.fromhex(encrypted_data["nonce"])
|
|
||||||
ciphertext = bytes.fromhex(encrypted_data["ciphertext"])
|
|
||||||
tag = bytes.fromhex(encrypted_data["tag"])
|
|
||||||
|
|
||||||
# Create cipher
|
|
||||||
cipher = Cipher(algorithms.ChaCha20(key, nonce), modes.Poly1305(tag), backend=self.backend)
|
|
||||||
|
|
||||||
decryptor = cipher.decryptor()
|
|
||||||
|
|
||||||
# Add associated data if provided
|
|
||||||
if associated_data:
|
|
||||||
decryptor.authenticate_additional_data(associated_data)
|
|
||||||
|
|
||||||
# Decrypt data
|
|
||||||
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
|
||||||
|
|
||||||
return plaintext
|
|
||||||
|
|
||||||
async def _decrypt_aes_cbc(self, encrypted_data: dict[str, Any]) -> bytes:
|
|
||||||
"""Decrypt AES-256-CBC encrypted data"""
|
|
||||||
|
|
||||||
# Get key from HSM
|
|
||||||
key_data = await self.hsm_manager.get_key(encrypted_data["key_id"])
|
|
||||||
if not key_data:
|
|
||||||
raise ValueError(f"Key not found: {encrypted_data['key_id']}")
|
|
||||||
|
|
||||||
key = key_data["key"]
|
|
||||||
iv = bytes.fromhex(encrypted_data["iv"])
|
|
||||||
ciphertext = bytes.fromhex(encrypted_data["ciphertext"])
|
|
||||||
|
|
||||||
# Create cipher
|
|
||||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=self.backend)
|
|
||||||
|
|
||||||
decryptor = cipher.decryptor()
|
|
||||||
padded_plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
|
||||||
|
|
||||||
# Unpad data
|
|
||||||
unpadder = cryptography.hazmat.primitives.padding.PKCS7(128).unpadder()
|
|
||||||
plaintext = unpadder.update(padded_plaintext) + unpadder.finalize()
|
|
||||||
|
|
||||||
return plaintext
|
|
||||||
|
|
||||||
|
|
||||||
class ZeroTrustArchitecture:
|
|
||||||
"""Zero-trust security architecture implementation"""
|
|
||||||
|
|
||||||
def __init__(self, hsm_manager: HSMManager, encryption: EnterpriseEncryption):
|
|
||||||
self.hsm_manager = hsm_manager
|
|
||||||
self.encryption = encryption
|
|
||||||
self.trust_policies = {}
|
|
||||||
self.session_tokens = {}
|
|
||||||
self.logger = get_logger("zero_trust")
|
|
||||||
|
|
||||||
async def create_trust_policy(self, policy_id: str, policy_config: dict[str, Any]) -> bool:
|
|
||||||
"""Create zero-trust policy"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
policy = SecurityPolicy(
|
|
||||||
policy_id=policy_id,
|
|
||||||
name=policy_config["name"],
|
|
||||||
security_level=SecurityLevel(policy_config["security_level"]),
|
|
||||||
encryption_algorithm=EncryptionAlgorithm(policy_config["encryption_algorithm"]),
|
|
||||||
key_rotation_interval=timedelta(days=policy_config.get("key_rotation_days", 90)),
|
|
||||||
access_control_requirements=policy_config.get("access_control_requirements", []),
|
|
||||||
audit_requirements=policy_config.get("audit_requirements", []),
|
|
||||||
retention_period=timedelta(days=policy_config.get("retention_days", 2555)), # 7 years
|
|
||||||
)
|
|
||||||
|
|
||||||
self.trust_policies[policy_id] = policy
|
|
||||||
|
|
||||||
# Generate encryption key for policy
|
|
||||||
await self.hsm_manager.generate_key(f"policy_{policy_id}", policy.encryption_algorithm)
|
|
||||||
|
|
||||||
self.logger.info(f"Zero-trust policy created: {policy_id}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to create trust policy: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def verify_trust(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> bool:
|
|
||||||
"""Verify zero-trust access request"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get applicable policy
|
|
||||||
policy_id = context.get("policy_id", "default")
|
|
||||||
policy = self.trust_policies.get(policy_id)
|
|
||||||
|
|
||||||
if not policy:
|
|
||||||
self.logger.warning(f"No policy found for {policy_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Verify trust factors
|
|
||||||
trust_score = await self._calculate_trust_score(user_id, resource_id, action, context)
|
|
||||||
|
|
||||||
# Check if trust score meets policy requirements
|
|
||||||
min_trust_score = self._get_min_trust_score(policy.security_level)
|
|
||||||
|
|
||||||
is_trusted = trust_score >= min_trust_score
|
|
||||||
|
|
||||||
# Log trust decision
|
|
||||||
await self._log_trust_decision(user_id, resource_id, action, trust_score, is_trusted)
|
|
||||||
|
|
||||||
return is_trusted
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Trust verification failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _calculate_trust_score(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> float:
|
|
||||||
"""Calculate trust score for access request"""
|
|
||||||
|
|
||||||
score = 0.0
|
|
||||||
|
|
||||||
# User authentication factor (40%)
|
|
||||||
auth_strength = context.get("auth_strength", "password")
|
|
||||||
if auth_strength == "mfa":
|
|
||||||
score += 0.4
|
|
||||||
elif auth_strength == "password":
|
|
||||||
score += 0.2
|
|
||||||
|
|
||||||
# Device trust factor (20%)
|
|
||||||
device_trust = context.get("device_trust", 0.5)
|
|
||||||
score += 0.2 * device_trust
|
|
||||||
|
|
||||||
# Location factor (15%)
|
|
||||||
location_trust = context.get("location_trust", 0.5)
|
|
||||||
score += 0.15 * location_trust
|
|
||||||
|
|
||||||
# Time factor (10%)
|
|
||||||
time_trust = context.get("time_trust", 0.5)
|
|
||||||
score += 0.1 * time_trust
|
|
||||||
|
|
||||||
# Behavioral factor (15%)
|
|
||||||
behavior_trust = context.get("behavior_trust", 0.5)
|
|
||||||
score += 0.15 * behavior_trust
|
|
||||||
|
|
||||||
return min(score, 1.0)
|
|
||||||
|
|
||||||
def _get_min_trust_score(self, security_level: SecurityLevel) -> float:
|
|
||||||
"""Get minimum trust score for security level"""
|
|
||||||
|
|
||||||
thresholds = {
|
|
||||||
SecurityLevel.PUBLIC: 0.0,
|
|
||||||
SecurityLevel.INTERNAL: 0.3,
|
|
||||||
SecurityLevel.CONFIDENTIAL: 0.6,
|
|
||||||
SecurityLevel.RESTRICTED: 0.8,
|
|
||||||
SecurityLevel.TOP_SECRET: 0.9,
|
|
||||||
}
|
|
||||||
|
|
||||||
return thresholds.get(security_level, 0.5)
|
|
||||||
|
|
||||||
async def _log_trust_decision(self, user_id: str, resource_id: str, action: str, trust_score: float, decision: bool):
|
|
||||||
"""Log trust decision for audit"""
|
|
||||||
|
|
||||||
SecurityEvent(
|
|
||||||
event_id=str(uuid4()),
|
|
||||||
event_type="trust_decision",
|
|
||||||
severity=ThreatLevel.LOW if decision else ThreatLevel.MEDIUM,
|
|
||||||
source="zero_trust",
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
user_id=user_id,
|
|
||||||
resource_id=resource_id,
|
|
||||||
details={"action": action, "trust_score": trust_score, "decision": decision},
|
|
||||||
)
|
|
||||||
|
|
||||||
# In production, send to security monitoring system
|
|
||||||
self.logger.info(f"Trust decision: {user_id} -> {resource_id} = {decision} (score: {trust_score})")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreatDetectionSystem:
|
|
||||||
"""Advanced threat detection and response system"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.threat_patterns = {}
|
|
||||||
self.active_threats = {}
|
|
||||||
self.response_actions = {}
|
|
||||||
self.logger = get_logger("threat_detection")
|
|
||||||
|
|
||||||
async def register_threat_pattern(self, pattern_id: str, pattern_config: dict[str, Any]):
|
|
||||||
"""Register threat detection pattern"""
|
|
||||||
|
|
||||||
self.threat_patterns[pattern_id] = {
|
|
||||||
"id": pattern_id,
|
|
||||||
"name": pattern_config["name"],
|
|
||||||
"description": pattern_config["description"],
|
|
||||||
"indicators": pattern_config["indicators"],
|
|
||||||
"severity": ThreatLevel(pattern_config["severity"]),
|
|
||||||
"response_actions": pattern_config.get("response_actions", []),
|
|
||||||
"threshold": pattern_config.get("threshold", 1.0),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.logger.info(f"Threat pattern registered: {pattern_id}")
|
|
||||||
|
|
||||||
async def analyze_threat(self, event_data: dict[str, Any]) -> list[SecurityEvent]:
|
|
||||||
"""Analyze event for potential threats"""
|
|
||||||
|
|
||||||
detected_threats = []
|
|
||||||
|
|
||||||
for pattern_id, pattern in self.threat_patterns.items():
|
|
||||||
threat_score = await self._calculate_threat_score(event_data, pattern)
|
|
||||||
|
|
||||||
if threat_score >= pattern["threshold"]:
|
|
||||||
threat_event = SecurityEvent(
|
|
||||||
event_id=str(uuid4()),
|
|
||||||
event_type="threat_detected",
|
|
||||||
severity=pattern["severity"],
|
|
||||||
source="threat_detection",
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
user_id=event_data.get("user_id"),
|
|
||||||
resource_id=event_data.get("resource_id"),
|
|
||||||
details={
|
|
||||||
"pattern_id": pattern_id,
|
|
||||||
"pattern_name": pattern["name"],
|
|
||||||
"threat_score": threat_score,
|
|
||||||
"indicators": event_data,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
detected_threats.append(threat_event)
|
|
||||||
|
|
||||||
# Trigger response actions
|
|
||||||
await self._trigger_response_actions(pattern_id, threat_event)
|
|
||||||
|
|
||||||
return detected_threats
|
|
||||||
|
|
||||||
async def _calculate_threat_score(self, event_data: dict[str, Any], pattern: dict[str, Any]) -> float:
|
|
||||||
"""Calculate threat score for pattern"""
|
|
||||||
|
|
||||||
score = 0.0
|
|
||||||
indicators = pattern["indicators"]
|
|
||||||
|
|
||||||
for indicator, weight in indicators.items():
|
|
||||||
if indicator in event_data:
|
|
||||||
# Simple scoring - in production, use more sophisticated algorithms
|
|
||||||
indicator_score = 0.5 # Base score for presence
|
|
||||||
score += indicator_score * weight
|
|
||||||
|
|
||||||
return min(score, 1.0)
|
|
||||||
|
|
||||||
async def _trigger_response_actions(self, pattern_id: str, threat_event: SecurityEvent):
|
|
||||||
"""Trigger automated response actions"""
|
|
||||||
|
|
||||||
pattern = self.threat_patterns[pattern_id]
|
|
||||||
actions = pattern.get("response_actions", [])
|
|
||||||
|
|
||||||
for action in actions:
|
|
||||||
try:
|
|
||||||
await self._execute_response_action(action, threat_event)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Response action failed: {action} - {e}")
|
|
||||||
|
|
||||||
async def _execute_response_action(self, action: str, threat_event: SecurityEvent):
|
|
||||||
"""Execute specific response action"""
|
|
||||||
|
|
||||||
if action == "block_user":
|
|
||||||
await self._block_user(threat_event.user_id)
|
|
||||||
elif action == "isolate_resource":
|
|
||||||
await self._isolate_resource(threat_event.resource_id)
|
|
||||||
elif action == "escalate_to_admin":
|
|
||||||
await self._escalate_to_admin(threat_event)
|
|
||||||
elif action == "require_mfa":
|
|
||||||
await self._require_mfa(threat_event.user_id)
|
|
||||||
|
|
||||||
self.logger.info(f"Response action executed: {action}")
|
|
||||||
|
|
||||||
async def _block_user(self, user_id: str):
|
|
||||||
"""Block user account"""
|
|
||||||
# In production, implement actual user blocking
|
|
||||||
self.logger.warning(f"User blocked due to threat: {user_id}")
|
|
||||||
|
|
||||||
async def _isolate_resource(self, resource_id: str):
|
|
||||||
"""Isolate compromised resource"""
|
|
||||||
# In production, implement actual resource isolation
|
|
||||||
self.logger.warning(f"Resource isolated due to threat: {resource_id}")
|
|
||||||
|
|
||||||
async def _escalate_to_admin(self, threat_event: SecurityEvent):
|
|
||||||
"""Escalate threat to security administrators"""
|
|
||||||
# In production, implement actual escalation
|
|
||||||
self.logger.error(f"Threat escalated to admin: {threat_event.event_id}")
|
|
||||||
|
|
||||||
async def _require_mfa(self, user_id: str):
|
|
||||||
"""Require multi-factor authentication"""
|
|
||||||
# In production, implement MFA requirement
|
|
||||||
self.logger.warning(f"MFA required for user: {user_id}")
|
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseSecurityFramework:
|
|
||||||
"""Main enterprise security framework"""
|
|
||||||
|
|
||||||
def __init__(self, hsm_config: dict[str, Any]):
|
|
||||||
self.hsm_manager = HSMManager(hsm_config)
|
|
||||||
self.encryption = EnterpriseEncryption(self.hsm_manager)
|
|
||||||
self.zero_trust = ZeroTrustArchitecture(self.hsm_manager, self.encryption)
|
|
||||||
self.threat_detection = ThreatDetectionSystem()
|
|
||||||
self.logger = get_logger("enterprise_security")
|
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
|
||||||
"""Initialize security framework"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Initialize HSM
|
|
||||||
if not await self.hsm_manager.initialize():
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Register default threat patterns
|
|
||||||
await self._register_default_threat_patterns()
|
|
||||||
|
|
||||||
# Create default trust policies
|
|
||||||
await self._create_default_policies()
|
|
||||||
|
|
||||||
self.logger.info("Enterprise security framework initialized")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Security framework initialization failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _register_default_threat_patterns(self):
|
|
||||||
"""Register default threat detection patterns"""
|
|
||||||
|
|
||||||
patterns = [
|
|
||||||
{
|
|
||||||
"name": "Brute Force Attack",
|
|
||||||
"description": "Multiple failed login attempts",
|
|
||||||
"indicators": {"failed_login_attempts": 0.8, "short_time_interval": 0.6},
|
|
||||||
"severity": "high",
|
|
||||||
"threshold": 0.7,
|
|
||||||
"response_actions": ["block_user", "require_mfa"],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Suspicious Access Pattern",
|
|
||||||
"description": "Unusual access patterns",
|
|
||||||
"indicators": {"unusual_location": 0.7, "unusual_time": 0.5, "high_frequency": 0.6},
|
|
||||||
"severity": "medium",
|
|
||||||
"threshold": 0.6,
|
|
||||||
"response_actions": ["require_mfa", "escalate_to_admin"],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Data Exfiltration",
|
|
||||||
"description": "Large data transfer patterns",
|
|
||||||
"indicators": {"large_data_transfer": 0.9, "unusual_destination": 0.7},
|
|
||||||
"severity": "critical",
|
|
||||||
"threshold": 0.8,
|
|
||||||
"response_actions": ["block_user", "isolate_resource", "escalate_to_admin"],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, pattern in enumerate(patterns):
|
|
||||||
await self.threat_detection.register_threat_pattern(f"default_{i}", pattern)
|
|
||||||
|
|
||||||
async def _create_default_policies(self):
|
|
||||||
"""Create default trust policies"""
|
|
||||||
|
|
||||||
policies = [
|
|
||||||
{
|
|
||||||
"name": "Enterprise Data Policy",
|
|
||||||
"security_level": "confidential",
|
|
||||||
"encryption_algorithm": "aes_256_gcm",
|
|
||||||
"key_rotation_days": 90,
|
|
||||||
"access_control_requirements": ["mfa", "device_trust"],
|
|
||||||
"audit_requirements": ["full_audit", "real_time_monitoring"],
|
|
||||||
"retention_days": 2555,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Public API Policy",
|
|
||||||
"security_level": "public",
|
|
||||||
"encryption_algorithm": "aes_256_gcm",
|
|
||||||
"key_rotation_days": 180,
|
|
||||||
"access_control_requirements": ["api_key"],
|
|
||||||
"audit_requirements": ["api_access_log"],
|
|
||||||
"retention_days": 365,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, policy in enumerate(policies):
|
|
||||||
await self.zero_trust.create_trust_policy(f"default_{i}", policy)
|
|
||||||
|
|
||||||
async def encrypt_sensitive_data(self, data: str | bytes, security_level: SecurityLevel) -> dict[str, Any]:
|
|
||||||
"""Encrypt sensitive data with appropriate security level"""
|
|
||||||
|
|
||||||
# Get policy for security level
|
|
||||||
policy_id = f"default_{0 if security_level == SecurityLevel.PUBLIC else 1}"
|
|
||||||
policy = self.zero_trust.trust_policies.get(policy_id)
|
|
||||||
|
|
||||||
if not policy:
|
|
||||||
raise ValueError(f"No policy found for security level: {security_level}")
|
|
||||||
|
|
||||||
key_id = f"policy_{policy_id}"
|
|
||||||
|
|
||||||
return await self.encryption.encrypt_data(data, key_id)
|
|
||||||
|
|
||||||
async def verify_access(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> bool:
|
|
||||||
"""Verify access using zero-trust architecture"""
|
|
||||||
|
|
||||||
return await self.zero_trust.verify_trust(user_id, resource_id, action, context)
|
|
||||||
|
|
||||||
async def analyze_security_event(self, event_data: dict[str, Any]) -> list[SecurityEvent]:
|
|
||||||
"""Analyze security event for threats"""
|
|
||||||
|
|
||||||
return await self.threat_detection.analyze_threat(event_data)
|
|
||||||
|
|
||||||
async def rotate_encryption_keys(self, policy_id: str | None = None) -> dict[str, Any]:
|
|
||||||
"""Rotate encryption keys"""
|
|
||||||
|
|
||||||
if policy_id:
|
|
||||||
# Rotate specific policy key
|
|
||||||
old_key_id = f"policy_{policy_id}"
|
|
||||||
new_key = await self.hsm_manager.rotate_key(old_key_id)
|
|
||||||
return {"rotated_key": new_key}
|
|
||||||
else:
|
|
||||||
# Rotate all keys
|
|
||||||
rotated_keys = {}
|
|
||||||
for policy_id in self.zero_trust.trust_policies.keys():
|
|
||||||
old_key_id = f"policy_{policy_id}"
|
|
||||||
new_key = await self.hsm_manager.rotate_key(old_key_id)
|
|
||||||
rotated_keys[policy_id] = new_key
|
|
||||||
|
|
||||||
return {"rotated_keys": rotated_keys}
|
|
||||||
|
|
||||||
|
|
||||||
# Global security framework instance
|
|
||||||
security_framework = None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_security_framework() -> EnterpriseSecurityFramework:
|
|
||||||
"""Get or create global security framework"""
|
|
||||||
|
|
||||||
global security_framework
|
|
||||||
if security_framework is None:
|
|
||||||
hsm_config = {"provider": "software", "endpoint": "localhost:8080"} # In production, use actual HSM
|
|
||||||
|
|
||||||
security_framework = EnterpriseSecurityFramework(hsm_config)
|
|
||||||
await security_framework.initialize()
|
|
||||||
|
|
||||||
return security_framework
|
|
||||||
|
|
||||||
|
|
||||||
# Alias for CLI compatibility
|
|
||||||
EnterpriseSecurityManager = EnterpriseSecurityFramework
|
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
Trading & Marketplace Bounded Context
|
||||||
|
Provides trading, marketplace optimization, bid strategy, and dynamic pricing services.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .bid_strategy import BidStrategyEngine
|
||||||
|
from .dynamic_pricing import DynamicPricingEngine
|
||||||
|
from .gpu_optimizer import MarketplaceGPUOptimizer
|
||||||
|
from .trading import MatchingEngine, NegotiationSystem
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BidStrategyEngine",
|
||||||
|
"DynamicPricingEngine",
|
||||||
|
"MarketplaceGPUOptimizer",
|
||||||
|
"MatchingEngine",
|
||||||
|
"NegotiationSystem",
|
||||||
|
]
|
||||||
@@ -13,7 +13,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
from sqlmodel import Session, or_, select
|
from sqlmodel import Session, or_, select
|
||||||
|
|
||||||
from ..domain.trading import (
|
from ...domain.trading import (
|
||||||
NegotiationStatus,
|
NegotiationStatus,
|
||||||
SettlementType,
|
SettlementType,
|
||||||
TradeAgreement,
|
TradeAgreement,
|
||||||
@@ -25,6 +25,26 @@
|
|||||||
- Maintained backward compatibility with lazy-loading pattern
|
- Maintained backward compatibility with lazy-loading pattern
|
||||||
- Import tests verified successfully
|
- Import tests verified successfully
|
||||||
- Old monolithic files removed
|
- Old monolithic files removed
|
||||||
|
- ✅ Phase 2 Complete: Enterprise Integration bounded context decomposed
|
||||||
|
- Created app/services/enterprise_integration/ package with 4 modules
|
||||||
|
- Migrated enterprise_integration.py (1127 lines) and 3 other enterprise files
|
||||||
|
- Updated imports within package (api_gateway.py excluded due to missing dependencies)
|
||||||
|
- Import tests verified successfully
|
||||||
|
- Old monolithic files removed
|
||||||
|
- ✅ Phase 3 Complete: Trading & Marketplace bounded context decomposed
|
||||||
|
- Created app/services/trading_marketplace/ package with 5 modules
|
||||||
|
- Migrated trading_service.py (36K) and 4 other trading files
|
||||||
|
- Updated imports across coordinator-api (routers/trading.py, routers/dynamic_pricing.py)
|
||||||
|
- amm.py excluded from exports due to missing dependencies
|
||||||
|
- Import tests verified successfully
|
||||||
|
- Old monolithic files removed
|
||||||
|
- ✅ Phase 4 Complete: AI & Analytics bounded context decomposed
|
||||||
|
- Created app/services/ai_analytics/ package with 5 modules
|
||||||
|
- Migrated analytics_service.py (41K) and 4 other AI files
|
||||||
|
- Updated imports across coordinator-api (routers/analytics.py, routers/adaptive_learning_health.py)
|
||||||
|
- adaptive_learning.py, surveillance.py, trading_engine.py excluded due to missing dependencies
|
||||||
|
- Import tests verified successfully
|
||||||
|
- Old monolithic files removed
|
||||||
|
|
||||||
2. **Production Code Using print()** (HIGH IMPACT)
|
2. **Production Code Using print()** (HIGH IMPACT)
|
||||||
- 925 print() statements in production code
|
- 925 print() statements in production code
|
||||||
@@ -128,9 +148,17 @@
|
|||||||
- test_validation_properties.py: 20/20 passing
|
- test_validation_properties.py: 20/20 passing
|
||||||
- test_staking_service.py: 22/22 passing
|
- test_staking_service.py: 22/22 passing
|
||||||
- Coverage threshold set to 50% in pyproject.toml
|
- Coverage threshold set to 50% in pyproject.toml
|
||||||
- Current coverage: 11% (4623 statements, 4122 missed) - BELOW 50% threshold
|
- Current coverage: 19% (4623 statements, 3745 missed) - BELOW 50% threshold
|
||||||
- Well-covered modules: constants.py (100%), exceptions.py (100%), validation.py (85%), crypto/crypto.py (52%)
|
- Added 137 new tests across 6 modules:
|
||||||
|
- test_middleware.py: 11 tests (middleware modules: 50-100% coverage)
|
||||||
|
- test_utils.py: 47 tests (utils modules: 100% coverage when run standalone)
|
||||||
|
- test_config.py: 14 tests (config.py: 100% coverage)
|
||||||
|
- test_decorators.py: 21 tests (decorators.py: 99% coverage)
|
||||||
|
- test_health_checks.py: 16 tests (health_checks.py: 80% coverage)
|
||||||
|
- test_metrics.py: 28 tests (metrics.py: 100% coverage)
|
||||||
|
- Well-covered modules: constants.py (100%), exceptions.py (100%), validation.py (85%), crypto/crypto.py (52%), config.py (100%), decorators.py (99%), health_checks.py (80%), metrics.py (100%)
|
||||||
- Needs improvement: Most modules at 0-30% coverage
|
- Needs improvement: Most modules at 0-30% coverage
|
||||||
|
- Note: Utils modules (paths, env, json_utils) achieve 100% when run standalone but not counted in overall coverage due to import patterns
|
||||||
|
|
||||||
#### MEDIUM (Long-term, 1-3 months)
|
#### MEDIUM (Long-term, 1-3 months)
|
||||||
|
|
||||||
|
|||||||
154
tests/test_config.py
Normal file
154
tests/test_config.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""
|
||||||
|
Tests for AITBC configuration classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
|
||||||
|
from aitbc.config import BaseAITBCConfig, AITBCConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseAITBCConfig:
|
||||||
|
"""Tests for BaseAITBCConfig"""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test BaseAITBCConfig with default values"""
|
||||||
|
config = BaseAITBCConfig()
|
||||||
|
assert config.app_name == "AITBC Application"
|
||||||
|
assert config.app_version == "1.0.0"
|
||||||
|
assert config.environment == "development"
|
||||||
|
assert config.debug is False
|
||||||
|
assert config.log_level == "INFO"
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""Test BaseAITBCConfig with custom values"""
|
||||||
|
config = BaseAITBCConfig(
|
||||||
|
app_name="Custom App",
|
||||||
|
app_version="2.0.0",
|
||||||
|
environment="production",
|
||||||
|
debug=True,
|
||||||
|
log_level="DEBUG"
|
||||||
|
)
|
||||||
|
assert config.app_name == "Custom App"
|
||||||
|
assert config.app_version == "2.0.0"
|
||||||
|
assert config.environment == "production"
|
||||||
|
assert config.debug is True
|
||||||
|
assert config.log_level == "DEBUG"
|
||||||
|
|
||||||
|
def test_data_dir_default(self):
|
||||||
|
"""Test default data directory is a Path"""
|
||||||
|
config = BaseAITBCConfig()
|
||||||
|
assert isinstance(config.data_dir, Path)
|
||||||
|
|
||||||
|
def test_config_dir_default(self):
|
||||||
|
"""Test default config directory is a Path"""
|
||||||
|
config = BaseAITBCConfig()
|
||||||
|
assert isinstance(config.config_dir, Path)
|
||||||
|
|
||||||
|
def test_log_dir_default(self):
|
||||||
|
"""Test default log directory is a Path"""
|
||||||
|
config = BaseAITBCConfig()
|
||||||
|
assert isinstance(config.log_dir, Path)
|
||||||
|
|
||||||
|
def test_log_format_default(self):
|
||||||
|
"""Test default log format"""
|
||||||
|
config = BaseAITBCConfig()
|
||||||
|
assert "%(asctime)s" in config.log_format
|
||||||
|
assert "%(name)s" in config.log_format
|
||||||
|
assert "%(levelname)s" in config.log_format
|
||||||
|
|
||||||
|
|
||||||
|
class TestAITBCConfig:
|
||||||
|
"""Tests for AITBCConfig"""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test AITBCConfig with default values"""
|
||||||
|
config = AITBCConfig()
|
||||||
|
assert config.host == "0.0.0.0"
|
||||||
|
assert config.port == 8000
|
||||||
|
assert config.workers == 1
|
||||||
|
assert config.database_url is None
|
||||||
|
assert config.database_pool_size == 10
|
||||||
|
assert config.redis_url is None
|
||||||
|
assert config.redis_max_connections == 10
|
||||||
|
assert config.redis_timeout == 5
|
||||||
|
assert config.secret_key is None
|
||||||
|
assert config.jwt_secret is None
|
||||||
|
assert config.jwt_algorithm == "HS256"
|
||||||
|
assert config.jwt_expiration_hours == 24
|
||||||
|
assert config.request_timeout == 30
|
||||||
|
assert config.max_request_size == 10 * 1024 * 1024
|
||||||
|
|
||||||
|
def test_custom_server_settings(self):
|
||||||
|
"""Test AITBCConfig with custom server settings"""
|
||||||
|
config = AITBCConfig(
|
||||||
|
host="127.0.0.1",
|
||||||
|
port=9000,
|
||||||
|
workers=4
|
||||||
|
)
|
||||||
|
assert config.host == "127.0.0.1"
|
||||||
|
assert config.port == 9000
|
||||||
|
assert config.workers == 4
|
||||||
|
|
||||||
|
def test_custom_database_settings(self):
|
||||||
|
"""Test AITBCConfig with custom database settings"""
|
||||||
|
config = AITBCConfig(
|
||||||
|
database_url="postgresql://localhost/test",
|
||||||
|
database_pool_size=20
|
||||||
|
)
|
||||||
|
assert config.database_url == "postgresql://localhost/test"
|
||||||
|
assert config.database_pool_size == 20
|
||||||
|
|
||||||
|
def test_custom_redis_settings(self):
|
||||||
|
"""Test AITBCConfig with custom redis settings"""
|
||||||
|
config = AITBCConfig(
|
||||||
|
redis_url="redis://localhost:6379",
|
||||||
|
redis_max_connections=50,
|
||||||
|
redis_timeout=10
|
||||||
|
)
|
||||||
|
assert config.redis_url == "redis://localhost:6379"
|
||||||
|
assert config.redis_max_connections == 50
|
||||||
|
assert config.redis_timeout == 10
|
||||||
|
|
||||||
|
def test_custom_security_settings(self):
|
||||||
|
"""Test AITBCConfig with custom security settings"""
|
||||||
|
config = AITBCConfig(
|
||||||
|
secret_key="test-secret-key",
|
||||||
|
jwt_secret="test-jwt-secret",
|
||||||
|
jwt_algorithm="RS256",
|
||||||
|
jwt_expiration_hours=48
|
||||||
|
)
|
||||||
|
assert config.secret_key == "test-secret-key"
|
||||||
|
assert config.jwt_secret == "test-jwt-secret"
|
||||||
|
assert config.jwt_algorithm == "RS256"
|
||||||
|
assert config.jwt_expiration_hours == 48
|
||||||
|
|
||||||
|
def test_custom_performance_settings(self):
|
||||||
|
"""Test AITBCConfig with custom performance settings"""
|
||||||
|
config = AITBCConfig(
|
||||||
|
request_timeout=60,
|
||||||
|
max_request_size=20 * 1024 * 1024
|
||||||
|
)
|
||||||
|
assert config.request_timeout == 60
|
||||||
|
assert config.max_request_size == 20 * 1024 * 1024
|
||||||
|
|
||||||
|
def test_inherits_base_config(self):
|
||||||
|
"""Test AITBCConfig inherits from BaseAITBCConfig"""
|
||||||
|
config = AITBCConfig(
|
||||||
|
app_name="Test App",
|
||||||
|
environment="staging"
|
||||||
|
)
|
||||||
|
assert config.app_name == "Test App"
|
||||||
|
assert config.environment == "staging"
|
||||||
|
assert config.host == "0.0.0.0" # AITBCConfig default
|
||||||
|
assert config.port == 8000 # AITBCConfig default
|
||||||
|
|
||||||
|
@patch('aitbc.config.logger')
|
||||||
|
def test_init_logs_configuration(self, mock_logger):
|
||||||
|
"""Test __init__ logs configuration"""
|
||||||
|
config = AITBCConfig(host="localhost", port=9000)
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
assert "localhost:9000" in mock_logger.info.call_args[0][0]
|
||||||
|
mock_logger.debug.assert_called_once()
|
||||||
303
tests/test_decorators.py
Normal file
303
tests/test_decorators.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
"""
|
||||||
|
Tests for AITBC decorators
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from aitbc.decorators import (
|
||||||
|
retry,
|
||||||
|
timing,
|
||||||
|
cache_result,
|
||||||
|
validate_args,
|
||||||
|
handle_exceptions,
|
||||||
|
async_timing,
|
||||||
|
)
|
||||||
|
from aitbc.exceptions import AITBCError
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetry:
|
||||||
|
"""Tests for retry decorator"""
|
||||||
|
|
||||||
|
def test_retry_succeeds_on_first_attempt(self):
|
||||||
|
"""Test retry when function succeeds on first attempt"""
|
||||||
|
@retry(max_attempts=3)
|
||||||
|
def test_func():
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
result = test_func()
|
||||||
|
assert result == "success"
|
||||||
|
|
||||||
|
def test_retry_succeeds_after_failure(self):
|
||||||
|
"""Test retry when function succeeds after initial failure"""
|
||||||
|
attempts = [0]
|
||||||
|
|
||||||
|
@retry(max_attempts=3, delay=0.01)
|
||||||
|
def test_func():
|
||||||
|
attempts[0] += 1
|
||||||
|
if attempts[0] < 2:
|
||||||
|
raise ValueError("fail")
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
result = test_func()
|
||||||
|
assert result == "success"
|
||||||
|
assert attempts[0] == 2
|
||||||
|
|
||||||
|
def test_retry_exhausts_attempts(self):
|
||||||
|
"""Test retry when function fails after all attempts"""
|
||||||
|
@retry(max_attempts=2, delay=0.01)
|
||||||
|
def test_func():
|
||||||
|
raise ValueError("fail")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
test_func()
|
||||||
|
|
||||||
|
def test_retry_with_specific_exception(self):
|
||||||
|
"""Test retry only catches specified exceptions"""
|
||||||
|
@retry(max_attempts=2, delay=0.01, exceptions=(ValueError,))
|
||||||
|
def test_func():
|
||||||
|
raise TypeError("fail")
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
test_func()
|
||||||
|
|
||||||
|
def test_retry_with_backoff(self):
|
||||||
|
"""Test retry with exponential backoff"""
|
||||||
|
attempts = [0]
|
||||||
|
|
||||||
|
@retry(max_attempts=3, delay=0.01, backoff=2.0)
|
||||||
|
def test_func():
|
||||||
|
attempts[0] += 1
|
||||||
|
raise ValueError("fail")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
test_func()
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
# Should have delays: 0.01 + 0.02 = 0.03 seconds minimum
|
||||||
|
assert elapsed >= 0.03
|
||||||
|
|
||||||
|
def test_retry_with_on_failure_callback(self):
|
||||||
|
"""Test retry with on_failure callback"""
|
||||||
|
callback_called = [False]
|
||||||
|
|
||||||
|
def on_fail(e):
|
||||||
|
callback_called[0] = True
|
||||||
|
|
||||||
|
@retry(max_attempts=2, delay=0.01, on_failure=on_fail)
|
||||||
|
def test_func():
|
||||||
|
raise ValueError("fail")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
test_func()
|
||||||
|
|
||||||
|
assert callback_called[0] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestTiming:
|
||||||
|
"""Tests for timing decorator"""
|
||||||
|
|
||||||
|
@patch('aitbc.decorators.logger')
|
||||||
|
def test_timing_logs_execution_time(self, mock_logger):
|
||||||
|
"""Test timing decorator logs execution time"""
|
||||||
|
@timing
|
||||||
|
def test_func():
|
||||||
|
time.sleep(0.01)
|
||||||
|
return "result"
|
||||||
|
|
||||||
|
result = test_func()
|
||||||
|
assert result == "result"
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
assert "executed in" in mock_logger.info.call_args[0][0]
|
||||||
|
|
||||||
|
@patch('aitbc.decorators.logger')
|
||||||
|
def test_timing_preserves_function_name(self, mock_logger):
|
||||||
|
"""Test timing decorator preserves function name"""
|
||||||
|
@timing
|
||||||
|
def my_function():
|
||||||
|
return "result"
|
||||||
|
|
||||||
|
assert my_function.__name__ == "my_function"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheResult:
|
||||||
|
"""Tests for cache_result decorator"""
|
||||||
|
|
||||||
|
def test_cache_result_caches_value(self):
|
||||||
|
"""Test cache_result caches function return value"""
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
@cache_result(ttl=60)
|
||||||
|
def test_func(x):
|
||||||
|
call_count[0] += 1
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
result1 = test_func(5)
|
||||||
|
result2 = test_func(5)
|
||||||
|
|
||||||
|
assert result1 == 10
|
||||||
|
assert result2 == 10
|
||||||
|
assert call_count[0] == 1 # Only called once due to cache
|
||||||
|
|
||||||
|
def test_cache_result_different_args(self):
|
||||||
|
"""Test cache_result with different arguments"""
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
@cache_result(ttl=60)
|
||||||
|
def test_func(x):
|
||||||
|
call_count[0] += 1
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
test_func(5)
|
||||||
|
test_func(10)
|
||||||
|
|
||||||
|
assert call_count[0] == 2 # Called twice for different args
|
||||||
|
|
||||||
|
def test_cache_result_ttl_expires(self):
|
||||||
|
"""Test cache_result TTL expires"""
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
@cache_result(ttl=0.1) # 100ms TTL
|
||||||
|
def test_func(x):
|
||||||
|
call_count[0] += 1
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
test_func(5)
|
||||||
|
time.sleep(0.15) # Wait for TTL to expire
|
||||||
|
test_func(5)
|
||||||
|
|
||||||
|
assert call_count[0] == 2 # Called again after TTL expired
|
||||||
|
|
||||||
|
def test_cache_result_with_kwargs(self):
|
||||||
|
"""Test cache_result with keyword arguments"""
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
@cache_result(ttl=60)
|
||||||
|
def test_func(x, y=10):
|
||||||
|
call_count[0] += 1
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
test_func(5, y=10)
|
||||||
|
test_func(5, y=10)
|
||||||
|
|
||||||
|
assert call_count[0] == 1 # Cached
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateArgs:
|
||||||
|
"""Tests for validate_args decorator"""
|
||||||
|
|
||||||
|
def test_validate_args_passes_valid(self):
|
||||||
|
"""Test validate_args passes when validators succeed"""
|
||||||
|
def validator(x):
|
||||||
|
if x < 0:
|
||||||
|
raise ValueError("Must be positive")
|
||||||
|
|
||||||
|
@validate_args(validator)
|
||||||
|
def test_func(x):
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
result = test_func(5)
|
||||||
|
assert result == 10
|
||||||
|
|
||||||
|
def test_validate_args_fails_invalid(self):
|
||||||
|
"""Test validate_args fails when validators raise error"""
|
||||||
|
def validator(x):
|
||||||
|
if x < 0:
|
||||||
|
raise ValueError("Must be positive")
|
||||||
|
|
||||||
|
@validate_args(validator)
|
||||||
|
def test_func(x):
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
test_func(-5)
|
||||||
|
|
||||||
|
def test_validate_args_multiple_validators(self):
|
||||||
|
"""Test validate_args with multiple validators"""
|
||||||
|
def validator1(x):
|
||||||
|
if x < 0:
|
||||||
|
raise ValueError("Must be positive")
|
||||||
|
|
||||||
|
def validator2(x):
|
||||||
|
if x > 100:
|
||||||
|
raise ValueError("Must be <= 100")
|
||||||
|
|
||||||
|
@validate_args(validator1, validator2)
|
||||||
|
def test_func(x):
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
test_func(150)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleExceptions:
|
||||||
|
"""Tests for handle_exceptions decorator"""
|
||||||
|
|
||||||
|
@patch('aitbc.decorators.logger')
|
||||||
|
def test_handle_exceptions_returns_default(self, mock_logger):
|
||||||
|
"""Test handle_exceptions returns default on exception"""
|
||||||
|
@handle_exceptions(default_return="error")
|
||||||
|
def test_func():
|
||||||
|
raise ValueError("fail")
|
||||||
|
|
||||||
|
result = test_func()
|
||||||
|
assert result == "error"
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
|
||||||
|
@patch('aitbc.decorators.logger')
|
||||||
|
def test_handle_exceptions_no_logging(self, mock_logger):
|
||||||
|
"""Test handle_exceptions with logging disabled"""
|
||||||
|
@handle_exceptions(default_return="error", log_errors=False)
|
||||||
|
def test_func():
|
||||||
|
raise ValueError("fail")
|
||||||
|
|
||||||
|
result = test_func()
|
||||||
|
assert result == "error"
|
||||||
|
mock_logger.error.assert_not_called()
|
||||||
|
|
||||||
|
def test_handle_exceptions_raises_on_specified(self):
|
||||||
|
"""Test handle_exceptions still raises specified exceptions"""
|
||||||
|
@handle_exceptions(default_return="error", raise_on=(ValueError,))
|
||||||
|
def test_func():
|
||||||
|
raise ValueError("fail")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
test_func()
|
||||||
|
|
||||||
|
def test_handle_exceptions_passes_on_success(self):
|
||||||
|
"""Test handle_exceptions passes through successful return"""
|
||||||
|
@handle_exceptions(default_return="error")
|
||||||
|
def test_func():
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
result = test_func()
|
||||||
|
assert result == "success"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncTiming:
|
||||||
|
"""Tests for async_timing decorator"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('aitbc.decorators.logger')
|
||||||
|
async def test_async_timing_logs_execution_time(self, mock_logger):
|
||||||
|
"""Test async_timing decorator logs execution time"""
|
||||||
|
@async_timing
|
||||||
|
async def test_func():
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return "result"
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
result = await test_func()
|
||||||
|
assert result == "result"
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
assert "executed in" in mock_logger.info.call_args[0][0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_timing_preserves_function_name(self):
|
||||||
|
"""Test async_timing decorator preserves function name"""
|
||||||
|
@async_timing
|
||||||
|
async def my_function():
|
||||||
|
return "result"
|
||||||
|
|
||||||
|
assert my_function.__name__ == "my_function"
|
||||||
218
tests/test_health_checks.py
Normal file
218
tests/test_health_checks.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""
|
||||||
|
Tests for health check utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from aitbc.health_checks import (
|
||||||
|
HealthStatus,
|
||||||
|
HealthCheck,
|
||||||
|
HealthChecker,
|
||||||
|
create_basic_health_check,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthStatus:
|
||||||
|
"""Tests for HealthStatus enum"""
|
||||||
|
|
||||||
|
def test_health_status_values(self):
|
||||||
|
"""Test HealthStatus enum values"""
|
||||||
|
assert HealthStatus.HEALTHY.value == "healthy"
|
||||||
|
assert HealthStatus.DEGRADED.value == "degraded"
|
||||||
|
assert HealthStatus.UNHEALTHY.value == "unhealthy"
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheck:
|
||||||
|
"""Tests for HealthCheck dataclass"""
|
||||||
|
|
||||||
|
def test_health_check_creation(self):
|
||||||
|
"""Test HealthCheck dataclass creation"""
|
||||||
|
check = HealthCheck(
|
||||||
|
service="test-service",
|
||||||
|
status=HealthStatus.HEALTHY,
|
||||||
|
message="All good",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
details={"key": "value"}
|
||||||
|
)
|
||||||
|
assert check.service == "test-service"
|
||||||
|
assert check.status == HealthStatus.HEALTHY
|
||||||
|
assert check.message == "All good"
|
||||||
|
assert check.details == {"key": "value"}
|
||||||
|
|
||||||
|
def test_health_check_without_details(self):
|
||||||
|
"""Test HealthCheck without optional details"""
|
||||||
|
check = HealthCheck(
|
||||||
|
service="test-service",
|
||||||
|
status=HealthStatus.HEALTHY,
|
||||||
|
message="All good",
|
||||||
|
timestamp=datetime.now()
|
||||||
|
)
|
||||||
|
assert check.details is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthChecker:
|
||||||
|
"""Tests for HealthChecker"""
|
||||||
|
|
||||||
|
def test_health_checker_initialization(self):
|
||||||
|
"""Test HealthChecker initialization"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
assert checker.service_name == "test-service"
|
||||||
|
assert checker._checks == {}
|
||||||
|
assert checker._last_check is None
|
||||||
|
|
||||||
|
def test_register_check(self):
|
||||||
|
"""Test registering a health check"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
|
||||||
|
def mock_check():
|
||||||
|
return HealthStatus.HEALTHY, "OK", {}
|
||||||
|
|
||||||
|
checker.register_check("memory", mock_check)
|
||||||
|
assert "memory" in checker._checks
|
||||||
|
assert checker._checks["memory"] == mock_check
|
||||||
|
|
||||||
|
@patch('aitbc.health_checks.logger')
|
||||||
|
def test_register_check_logs(self, mock_logger):
|
||||||
|
"""Test register_check logs registration"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
|
||||||
|
def mock_check():
|
||||||
|
return HealthStatus.HEALTHY, "OK", {}
|
||||||
|
|
||||||
|
checker.register_check("memory", mock_check)
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
assert "memory" in mock_logger.info.call_args[0][0]
|
||||||
|
|
||||||
|
def test_run_checks_all_healthy(self):
|
||||||
|
"""Test run_checks when all checks pass"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
|
||||||
|
def mock_check():
|
||||||
|
return HealthStatus.HEALTHY, "OK", {}
|
||||||
|
|
||||||
|
checker.register_check("check1", mock_check)
|
||||||
|
checker.register_check("check2", mock_check)
|
||||||
|
|
||||||
|
result = checker.run_checks()
|
||||||
|
|
||||||
|
assert result.service == "test-service"
|
||||||
|
assert result.status == HealthStatus.HEALTHY
|
||||||
|
assert result.message == "All health checks passed"
|
||||||
|
assert result.details is not None
|
||||||
|
assert len(result.details) == 2
|
||||||
|
|
||||||
|
def test_run_checks_one_degraded(self):
|
||||||
|
"""Test run_checks with one degraded check"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
|
||||||
|
def healthy_check():
|
||||||
|
return HealthStatus.HEALTHY, "OK", {}
|
||||||
|
|
||||||
|
def degraded_check():
|
||||||
|
return HealthStatus.DEGRADED, "Warning", {}
|
||||||
|
|
||||||
|
checker.register_check("healthy", healthy_check)
|
||||||
|
checker.register_check("degraded", degraded_check)
|
||||||
|
|
||||||
|
result = checker.run_checks()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.DEGRADED
|
||||||
|
assert "degraded" in result.message
|
||||||
|
|
||||||
|
def test_run_checks_one_unhealthy(self):
|
||||||
|
"""Test run_checks with one unhealthy check"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
|
||||||
|
def healthy_check():
|
||||||
|
return HealthStatus.HEALTHY, "OK", {}
|
||||||
|
|
||||||
|
def unhealthy_check():
|
||||||
|
return HealthStatus.UNHEALTHY, "Error", {}
|
||||||
|
|
||||||
|
checker.register_check("healthy", healthy_check)
|
||||||
|
checker.register_check("unhealthy", unhealthy_check)
|
||||||
|
|
||||||
|
result = checker.run_checks()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.UNHEALTHY
|
||||||
|
assert "unhealthy" in result.message
|
||||||
|
|
||||||
|
@patch('aitbc.health_checks.logger')
|
||||||
|
def test_run_checks_with_exception(self, mock_logger):
|
||||||
|
"""Test run_checks handles exceptions in checks"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
|
||||||
|
def failing_check():
|
||||||
|
raise ValueError("Check failed")
|
||||||
|
|
||||||
|
checker.register_check("failing", failing_check)
|
||||||
|
|
||||||
|
result = checker.run_checks()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.UNHEALTHY
|
||||||
|
assert "failing" in result.message
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_last_check_before_run(self):
|
||||||
|
"""Test get_last_check returns None before any check run"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
assert checker.get_last_check() is None
|
||||||
|
|
||||||
|
def test_get_last_check_after_run(self):
|
||||||
|
"""Test get_last_check returns last check result"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
|
||||||
|
def mock_check():
|
||||||
|
return HealthStatus.HEALTHY, "OK", {}
|
||||||
|
|
||||||
|
checker.register_check("check1", mock_check)
|
||||||
|
checker.run_checks()
|
||||||
|
|
||||||
|
last_check = checker.get_last_check()
|
||||||
|
assert last_check is not None
|
||||||
|
assert last_check.service == "test-service"
|
||||||
|
|
||||||
|
def test_get_health_dict(self):
|
||||||
|
"""Test get_health_dict returns dictionary representation"""
|
||||||
|
checker = HealthChecker("test-service")
|
||||||
|
|
||||||
|
def mock_check():
|
||||||
|
return HealthStatus.HEALTHY, "OK", {"key": "value"}
|
||||||
|
|
||||||
|
checker.register_check("check1", mock_check)
|
||||||
|
health_dict = checker.get_health_dict()
|
||||||
|
|
||||||
|
assert isinstance(health_dict, dict)
|
||||||
|
assert "service" in health_dict
|
||||||
|
assert "status" in health_dict
|
||||||
|
assert "message" in health_dict
|
||||||
|
assert "timestamp" in health_dict
|
||||||
|
assert health_dict["service"] == "test-service"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateBasicHealthCheck:
|
||||||
|
"""Tests for create_basic_health_check"""
|
||||||
|
|
||||||
|
def test_create_basic_health_check(self):
|
||||||
|
"""Test create_basic_health_check returns HealthChecker"""
|
||||||
|
checker = create_basic_health_check("test-service")
|
||||||
|
assert isinstance(checker, HealthChecker)
|
||||||
|
assert checker.service_name == "test-service"
|
||||||
|
|
||||||
|
def test_create_basic_health_check_without_psutil(self):
|
||||||
|
"""Test create_basic_health_check handles psutil ImportError"""
|
||||||
|
# Skip this test as psutil import handling is complex to mock
|
||||||
|
pytest.skip("psutil import handling requires complex mocking")
|
||||||
|
|
||||||
|
def test_basic_health_check_has_checks(self):
|
||||||
|
"""Test basic health check has registered checks when psutil available"""
|
||||||
|
try:
|
||||||
|
import psutil
|
||||||
|
checker = create_basic_health_check("test-service")
|
||||||
|
# Should have memory and disk checks if psutil is available
|
||||||
|
assert len(checker._checks) > 0
|
||||||
|
except ImportError:
|
||||||
|
# Skip if psutil not available
|
||||||
|
pass
|
||||||
250
tests/test_metrics.py
Normal file
250
tests/test_metrics.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""
|
||||||
|
Tests for AITBC metrics module
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
|
||||||
|
from aitbc.metrics import (
|
||||||
|
service_info,
|
||||||
|
block_processing_duration,
|
||||||
|
block_height,
|
||||||
|
block_validation_duration,
|
||||||
|
block_propagation_duration,
|
||||||
|
job_submission_duration,
|
||||||
|
job_processing_duration,
|
||||||
|
job_queue_duration,
|
||||||
|
job_execution_duration,
|
||||||
|
jobs_total,
|
||||||
|
jobs_failed_total,
|
||||||
|
jobs_in_queue,
|
||||||
|
http_requests_total,
|
||||||
|
http_request_duration,
|
||||||
|
service_uptime_seconds,
|
||||||
|
service_restart_count,
|
||||||
|
track_block_processing,
|
||||||
|
track_job_processing,
|
||||||
|
track_http_request,
|
||||||
|
update_block_height,
|
||||||
|
update_jobs_in_queue,
|
||||||
|
increment_service_restarts,
|
||||||
|
metrics_app,
|
||||||
|
setup_service_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsDefinitions:
|
||||||
|
"""Tests for Prometheus metrics definitions"""
|
||||||
|
|
||||||
|
def test_service_info_exists(self):
|
||||||
|
"""Test service_info metric is defined"""
|
||||||
|
assert service_info is not None
|
||||||
|
assert service_info._name == 'service_info'
|
||||||
|
|
||||||
|
def test_block_processing_duration_exists(self):
|
||||||
|
"""Test block_processing_duration metric is defined"""
|
||||||
|
assert block_processing_duration is not None
|
||||||
|
assert block_processing_duration._name == 'block_processing_duration_seconds'
|
||||||
|
|
||||||
|
def test_block_height_exists(self):
|
||||||
|
"""Test block_height metric is defined"""
|
||||||
|
assert block_height is not None
|
||||||
|
assert block_height._name == 'block_height'
|
||||||
|
|
||||||
|
def test_block_validation_duration_exists(self):
|
||||||
|
"""Test block_validation_duration metric is defined"""
|
||||||
|
assert block_validation_duration is not None
|
||||||
|
assert block_validation_duration._name == 'block_validation_duration_seconds'
|
||||||
|
|
||||||
|
def test_block_propagation_duration_exists(self):
|
||||||
|
"""Test block_propagation_duration metric is defined"""
|
||||||
|
assert block_propagation_duration is not None
|
||||||
|
assert block_propagation_duration._name == 'block_propagation_duration_seconds'
|
||||||
|
|
||||||
|
def test_job_submission_duration_exists(self):
|
||||||
|
"""Test job_submission_duration metric is defined"""
|
||||||
|
assert job_submission_duration is not None
|
||||||
|
assert job_submission_duration._name == 'job_submission_duration_seconds'
|
||||||
|
|
||||||
|
def test_job_processing_duration_exists(self):
|
||||||
|
"""Test job_processing_duration metric is defined"""
|
||||||
|
assert job_processing_duration is not None
|
||||||
|
assert job_processing_duration._name == 'job_processing_duration_seconds'
|
||||||
|
|
||||||
|
def test_job_queue_duration_exists(self):
|
||||||
|
"""Test job_queue_duration metric is defined"""
|
||||||
|
assert job_queue_duration is not None
|
||||||
|
assert job_queue_duration._name == 'job_queue_duration_seconds'
|
||||||
|
|
||||||
|
def test_job_execution_duration_exists(self):
|
||||||
|
"""Test job_execution_duration metric is defined"""
|
||||||
|
assert job_execution_duration is not None
|
||||||
|
assert job_execution_duration._name == 'job_execution_duration_seconds'
|
||||||
|
|
||||||
|
def test_jobs_total_exists(self):
|
||||||
|
"""Test jobs_total metric is defined"""
|
||||||
|
assert jobs_total is not None
|
||||||
|
assert jobs_total._name == 'jobs'
|
||||||
|
|
||||||
|
def test_jobs_failed_total_exists(self):
|
||||||
|
"""Test jobs_failed_total metric is defined"""
|
||||||
|
assert jobs_failed_total is not None
|
||||||
|
assert jobs_failed_total._name == 'jobs_failed'
|
||||||
|
|
||||||
|
def test_jobs_in_queue_exists(self):
|
||||||
|
"""Test jobs_in_queue metric is defined"""
|
||||||
|
assert jobs_in_queue is not None
|
||||||
|
assert jobs_in_queue._name == 'jobs_in_queue'
|
||||||
|
|
||||||
|
def test_http_requests_total_exists(self):
|
||||||
|
"""Test http_requests_total metric is defined"""
|
||||||
|
assert http_requests_total is not None
|
||||||
|
assert http_requests_total._name == 'http_requests'
|
||||||
|
|
||||||
|
def test_http_request_duration_exists(self):
|
||||||
|
"""Test http_request_duration metric is defined"""
|
||||||
|
assert http_request_duration is not None
|
||||||
|
assert http_request_duration._name == 'http_request_duration_seconds'
|
||||||
|
|
||||||
|
def test_service_uptime_seconds_exists(self):
|
||||||
|
"""Test service_uptime_seconds metric is defined"""
|
||||||
|
assert service_uptime_seconds is not None
|
||||||
|
assert service_uptime_seconds._name == 'service_uptime_seconds'
|
||||||
|
|
||||||
|
def test_service_restart_count_exists(self):
|
||||||
|
"""Test service_restart_count metric is defined"""
|
||||||
|
assert service_restart_count is not None
|
||||||
|
assert service_restart_count._name == 'service_restart_count'
|
||||||
|
|
||||||
|
|
||||||
|
class TestHelperFunctions:
|
||||||
|
"""Tests for metrics helper functions"""
|
||||||
|
|
||||||
|
def test_update_block_height(self):
|
||||||
|
"""Test update_block_height sets metric"""
|
||||||
|
update_block_height(100)
|
||||||
|
# Metric should be set, but we can't easily verify the value
|
||||||
|
# This test ensures the function doesn't raise an error
|
||||||
|
assert True
|
||||||
|
|
||||||
|
def test_update_jobs_in_queue(self):
|
||||||
|
"""Test update_jobs_in_queue sets metric"""
|
||||||
|
update_jobs_in_queue(50)
|
||||||
|
# Metric should be set, but we can't easily verify the value
|
||||||
|
# This test ensures the function doesn't raise an error
|
||||||
|
assert True
|
||||||
|
|
||||||
|
def test_increment_service_restarts(self):
|
||||||
|
"""Test increment_service_restarts increments counter"""
|
||||||
|
increment_service_restarts()
|
||||||
|
# Counter should be incremented, but we can't easily verify the value
|
||||||
|
# This test ensures the function doesn't raise an error
|
||||||
|
assert True
|
||||||
|
|
||||||
|
def test_setup_service_info(self):
|
||||||
|
"""Test setup_service_info sets service info"""
|
||||||
|
setup_service_info("test-service", "1.0.0")
|
||||||
|
# Info should be set, but we can't easily verify the value
|
||||||
|
# This test ensures the function doesn't raise an error
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecorators:
|
||||||
|
"""Tests for metrics tracking decorators"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_block_processing_success(self):
|
||||||
|
"""Test track_block_processing decorator on successful execution"""
|
||||||
|
@track_block_processing
|
||||||
|
async def process_block():
|
||||||
|
return "block_processed"
|
||||||
|
|
||||||
|
result = await process_block()
|
||||||
|
assert result == "block_processed"
|
||||||
|
# Decorator should have observed the duration
|
||||||
|
assert True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_block_processing_failure(self):
|
||||||
|
"""Test track_block_processing decorator on exception"""
|
||||||
|
@track_block_processing
|
||||||
|
async def process_block():
|
||||||
|
raise ValueError("block error")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await process_block()
|
||||||
|
# Decorator should have observed the duration even on failure
|
||||||
|
assert True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_job_processing_success(self):
|
||||||
|
"""Test track_job_processing decorator on successful execution"""
|
||||||
|
@track_job_processing
|
||||||
|
async def process_job():
|
||||||
|
return "job_completed"
|
||||||
|
|
||||||
|
result = await process_job()
|
||||||
|
assert result == "job_completed"
|
||||||
|
# Decorator should have observed duration and incremented jobs_total
|
||||||
|
assert True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_job_processing_failure(self):
|
||||||
|
"""Test track_job_processing decorator on exception"""
|
||||||
|
@track_job_processing
|
||||||
|
async def process_job():
|
||||||
|
raise ValueError("job error")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await process_job()
|
||||||
|
# Decorator should have observed duration and incremented failure counters
|
||||||
|
assert True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_http_request_success(self):
|
||||||
|
"""Test track_http_request decorator on successful execution"""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
@track_http_request
|
||||||
|
async def handle_request():
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
result = await handle_request()
|
||||||
|
assert result.status_code == 200
|
||||||
|
# Decorator should have observed duration and incremented http_requests_total
|
||||||
|
assert True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_http_request_failure(self):
|
||||||
|
"""Test track_http_request decorator on exception"""
|
||||||
|
@track_http_request
|
||||||
|
async def handle_request():
|
||||||
|
raise ValueError("request error")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await handle_request()
|
||||||
|
# Decorator should have observed duration and incremented http_requests_total with 500
|
||||||
|
assert True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_http_request_without_status_code(self):
|
||||||
|
"""Test track_http_request with response without status_code"""
|
||||||
|
@track_http_request
|
||||||
|
async def handle_request():
|
||||||
|
return "success" # No status_code attribute
|
||||||
|
|
||||||
|
result = await handle_request()
|
||||||
|
assert result == "success"
|
||||||
|
# Decorator should have observed duration but not incremented http_requests_total
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsApp:
|
||||||
|
"""Tests for metrics ASGI app"""
|
||||||
|
|
||||||
|
def test_metrics_app_exists(self):
|
||||||
|
"""Test metrics_app is created"""
|
||||||
|
assert metrics_app is not None
|
||||||
|
assert hasattr(metrics_app, '__call__')
|
||||||
266
tests/test_middleware.py
Normal file
266
tests/test_middleware.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
"""
|
||||||
|
Tests for AITBC middleware modules
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
|
from fastapi import Request, Response, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from aitbc.middleware.performance import PerformanceLoggingMiddleware
|
||||||
|
from aitbc.middleware.request_id import RequestIDMiddleware
|
||||||
|
from aitbc.middleware.error_handler import ErrorHandlerMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformanceLoggingMiddleware:
|
||||||
|
"""Tests for PerformanceLoggingMiddleware"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_adds_performance_header(self):
|
||||||
|
"""Test that middleware adds X-Process-Time header"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = PerformanceLoggingMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.method = "GET"
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/test"
|
||||||
|
|
||||||
|
response = Mock(spec=Response)
|
||||||
|
response.status_code = 200
|
||||||
|
response.headers = {}
|
||||||
|
|
||||||
|
call_next = AsyncMock(return_value=response)
|
||||||
|
|
||||||
|
result = await middleware.dispatch(request, call_next)
|
||||||
|
|
||||||
|
assert "X-Process-Time" in result.headers
|
||||||
|
assert float(result.headers["X-Process-Time"]) >= 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_logs_performance_metrics(self):
|
||||||
|
"""Test that middleware logs performance metrics"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = PerformanceLoggingMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.method = "POST"
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/api/test"
|
||||||
|
|
||||||
|
response = Mock(spec=Response)
|
||||||
|
response.status_code = 201
|
||||||
|
response.headers = {}
|
||||||
|
|
||||||
|
call_next = AsyncMock(return_value=response)
|
||||||
|
|
||||||
|
with patch('aitbc.middleware.performance.logger') as mock_logger:
|
||||||
|
await middleware.dispatch(request, call_next)
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
assert "Request performance" in mock_logger.info.call_args[0][0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_measures_time_correctly(self):
|
||||||
|
"""Test that middleware measures request duration accurately"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = PerformanceLoggingMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.method = "GET"
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/test"
|
||||||
|
|
||||||
|
response = Mock(spec=Response)
|
||||||
|
response.status_code = 200
|
||||||
|
response.headers = {}
|
||||||
|
|
||||||
|
call_next = AsyncMock(return_value=response)
|
||||||
|
|
||||||
|
result = await middleware.dispatch(request, call_next)
|
||||||
|
|
||||||
|
process_time = float(result.headers["X-Process-Time"])
|
||||||
|
assert 0 <= process_time < 1.0 # Should complete in under 1 second
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestIDMiddleware:
|
||||||
|
"""Tests for RequestIDMiddleware"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_generates_request_id_when_missing(self):
|
||||||
|
"""Test that middleware generates request ID when not in headers"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = RequestIDMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.headers = {}
|
||||||
|
request.method = "GET"
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/test"
|
||||||
|
request.client = Mock()
|
||||||
|
request.client.host = "127.0.0.1"
|
||||||
|
request.state = Mock()
|
||||||
|
|
||||||
|
response = Mock(spec=Response)
|
||||||
|
response.headers = {}
|
||||||
|
response.status_code = 200
|
||||||
|
|
||||||
|
call_next = AsyncMock(return_value=response)
|
||||||
|
|
||||||
|
result = await middleware.dispatch(request, call_next)
|
||||||
|
|
||||||
|
assert "X-Request-ID" in result.headers
|
||||||
|
assert len(result.headers["X-Request-ID"]) > 0
|
||||||
|
assert request.state.request_id == result.headers["X-Request-ID"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_uses_existing_request_id_from_header(self):
|
||||||
|
"""Test that middleware uses existing request ID from header"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = RequestIDMiddleware(app)
|
||||||
|
|
||||||
|
existing_id = "test-request-id-123"
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.headers = {"X-Request-ID": existing_id}
|
||||||
|
request.method = "POST"
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/api/test"
|
||||||
|
request.client = Mock()
|
||||||
|
request.client.host = "192.168.1.1"
|
||||||
|
request.state = Mock()
|
||||||
|
|
||||||
|
response = Mock(spec=Response)
|
||||||
|
response.headers = {}
|
||||||
|
response.status_code = 201
|
||||||
|
|
||||||
|
call_next = AsyncMock(return_value=response)
|
||||||
|
|
||||||
|
result = await middleware.dispatch(request, call_next)
|
||||||
|
|
||||||
|
assert result.headers["X-Request-ID"] == existing_id
|
||||||
|
assert request.state.request_id == existing_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_logs_request_info(self):
|
||||||
|
"""Test that middleware logs request information"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = RequestIDMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.headers = {}
|
||||||
|
request.method = "GET"
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/test"
|
||||||
|
request.client = Mock()
|
||||||
|
request.client.host = "127.0.0.1"
|
||||||
|
request.state = Mock()
|
||||||
|
|
||||||
|
response = Mock(spec=Response)
|
||||||
|
response.headers = {}
|
||||||
|
response.status_code = 200
|
||||||
|
|
||||||
|
call_next = AsyncMock(return_value=response)
|
||||||
|
|
||||||
|
with patch('aitbc.middleware.request_id.logger') as mock_logger:
|
||||||
|
await middleware.dispatch(request, call_next)
|
||||||
|
assert mock_logger.info.call_count >= 2 # Logs start and completion
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHandlerMiddleware:
|
||||||
|
"""Tests for ErrorHandlerMiddleware"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_passes_through_normal_response(self):
|
||||||
|
"""Test that middleware passes through normal responses"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = ErrorHandlerMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/test"
|
||||||
|
request.method = "GET"
|
||||||
|
|
||||||
|
response = Mock(spec=Response)
|
||||||
|
response.status_code = 200
|
||||||
|
|
||||||
|
call_next = AsyncMock(return_value=response)
|
||||||
|
|
||||||
|
result = await middleware.dispatch(request, call_next)
|
||||||
|
|
||||||
|
assert result == response
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_handles_http_exception(self):
|
||||||
|
"""Test that middleware handles HTTPException"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = ErrorHandlerMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/api/error"
|
||||||
|
request.method = "GET"
|
||||||
|
|
||||||
|
exception = HTTPException(status_code=404, detail="Not found")
|
||||||
|
call_next = AsyncMock(side_effect=exception)
|
||||||
|
|
||||||
|
result = await middleware.dispatch(request, call_next)
|
||||||
|
|
||||||
|
assert isinstance(result, JSONResponse)
|
||||||
|
assert result.status_code == 404
|
||||||
|
content = result.body.decode() if hasattr(result, 'body') else {}
|
||||||
|
assert "error" in result.body.decode() if hasattr(result, 'body') else True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_handles_generic_exception(self):
|
||||||
|
"""Test that middleware handles generic exceptions"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = ErrorHandlerMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/api/crash"
|
||||||
|
request.method = "POST"
|
||||||
|
|
||||||
|
exception = ValueError("Something went wrong")
|
||||||
|
call_next = AsyncMock(side_effect=exception)
|
||||||
|
|
||||||
|
result = await middleware.dispatch(request, call_next)
|
||||||
|
|
||||||
|
assert isinstance(result, JSONResponse)
|
||||||
|
assert result.status_code == 500
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_logs_http_exception(self):
|
||||||
|
"""Test that middleware logs HTTPException"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = ErrorHandlerMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/test"
|
||||||
|
request.method = "GET"
|
||||||
|
|
||||||
|
exception = HTTPException(status_code=400, detail="Bad request")
|
||||||
|
call_next = AsyncMock(side_effect=exception)
|
||||||
|
|
||||||
|
with patch('aitbc.middleware.error_handler.logger') as mock_logger:
|
||||||
|
await middleware.dispatch(request, call_next)
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_logs_generic_exception(self):
|
||||||
|
"""Test that middleware logs generic exceptions"""
|
||||||
|
app = Mock(spec=ASGIApp)
|
||||||
|
middleware = ErrorHandlerMiddleware(app)
|
||||||
|
|
||||||
|
request = Mock(spec=Request)
|
||||||
|
request.url = Mock()
|
||||||
|
request.url.path = "/test"
|
||||||
|
request.method = "GET"
|
||||||
|
|
||||||
|
exception = RuntimeError("Runtime error")
|
||||||
|
call_next = AsyncMock(side_effect=exception)
|
||||||
|
|
||||||
|
with patch('aitbc.middleware.error_handler.logger') as mock_logger:
|
||||||
|
await middleware.dispatch(request, call_next)
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
287
tests/test_security_headers.py
Normal file
287
tests/test_security_headers.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
"""
|
||||||
|
Tests for security headers and CORS utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from aitbc.security_headers import (
|
||||||
|
SecurityHeaders,
|
||||||
|
CORSConfig,
|
||||||
|
SecurityHeadersMiddleware,
|
||||||
|
CORSMiddleware,
|
||||||
|
create_production_security_headers,
|
||||||
|
create_development_security_headers,
|
||||||
|
create_strict_cors_config,
|
||||||
|
create_permissive_cors_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSecurityHeaders:
|
||||||
|
"""Tests for SecurityHeaders dataclass"""
|
||||||
|
|
||||||
|
def test_default_security_headers(self):
|
||||||
|
"""Test default security headers values"""
|
||||||
|
headers = SecurityHeaders()
|
||||||
|
assert headers.X_Content_Type_Options == "nosniff"
|
||||||
|
assert headers.X_Frame_Options == "DENY"
|
||||||
|
assert headers.X_XSS_Protection == "1; mode=block"
|
||||||
|
assert headers.Strict_Transport_Security == "max-age=31536000; includeSubDomains"
|
||||||
|
assert headers.Content_Security_Policy == "default-src 'self'"
|
||||||
|
assert headers.Referrer_Policy == "strict-origin-when-cross-origin"
|
||||||
|
assert headers.Permissions_Policy == ""
|
||||||
|
assert headers.Cache_Control == "no-cache, no-store, must-revalidate"
|
||||||
|
assert headers.Pragma == "no-cache"
|
||||||
|
|
||||||
|
def test_custom_security_headers(self):
|
||||||
|
"""Test custom security headers values"""
|
||||||
|
headers = SecurityHeaders(
|
||||||
|
X_Frame_Options="SAMEORIGIN",
|
||||||
|
Content_Security_Policy="default-src 'self' https://example.com"
|
||||||
|
)
|
||||||
|
assert headers.X_Frame_Options == "SAMEORIGIN"
|
||||||
|
assert headers.Content_Security_Policy == "default-src 'self' https://example.com"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCORSConfig:
|
||||||
|
"""Tests for CORSConfig dataclass"""
|
||||||
|
|
||||||
|
def test_default_cors_config(self):
|
||||||
|
"""Test default CORS config values"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["Content-Type"]
|
||||||
|
)
|
||||||
|
assert config.allow_origins == ["http://localhost:3000"]
|
||||||
|
assert config.allow_methods == ["GET", "POST"]
|
||||||
|
assert config.allow_credentials is False
|
||||||
|
assert config.expose_headers is None
|
||||||
|
assert config.max_age == 3600
|
||||||
|
|
||||||
|
def test_custom_cors_config(self):
|
||||||
|
"""Test custom CORS config values"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET", "POST", "PUT"],
|
||||||
|
allow_headers=["Content-Type"],
|
||||||
|
allow_credentials=True,
|
||||||
|
expose_headers=["X-Custom-Header"],
|
||||||
|
max_age=7200
|
||||||
|
)
|
||||||
|
assert config.allow_origins == ["*"]
|
||||||
|
assert config.allow_credentials is True
|
||||||
|
assert config.expose_headers == ["X-Custom-Header"]
|
||||||
|
assert config.max_age == 7200
|
||||||
|
|
||||||
|
|
||||||
|
class TestSecurityHeadersMiddleware:
|
||||||
|
"""Tests for SecurityHeadersMiddleware"""
|
||||||
|
|
||||||
|
def test_initialization_with_default_headers(self):
|
||||||
|
"""Test middleware initialization with default headers"""
|
||||||
|
middleware = SecurityHeadersMiddleware()
|
||||||
|
assert middleware.headers is not None
|
||||||
|
assert middleware.headers.X_Frame_Options == "DENY"
|
||||||
|
|
||||||
|
def test_initialization_with_custom_headers(self):
|
||||||
|
"""Test middleware initialization with custom headers"""
|
||||||
|
custom_headers = SecurityHeaders(X_Frame_Options="SAMEORIGIN")
|
||||||
|
middleware = SecurityHeadersMiddleware(custom_headers)
|
||||||
|
assert middleware.headers.X_Frame_Options == "SAMEORIGIN"
|
||||||
|
|
||||||
|
def test_get_headers(self):
|
||||||
|
"""Test get_headers returns dictionary"""
|
||||||
|
middleware = SecurityHeadersMiddleware()
|
||||||
|
headers = middleware.get_headers()
|
||||||
|
assert isinstance(headers, dict)
|
||||||
|
assert "X-Content-Type-Options" in headers
|
||||||
|
assert "X-Frame-Options" in headers
|
||||||
|
assert "X-XSS-Protection" in headers
|
||||||
|
assert "Strict-Transport-Security" in headers
|
||||||
|
assert "Content-Security-Policy" in headers
|
||||||
|
assert "Referrer-Policy" in headers
|
||||||
|
assert "Permissions-Policy" in headers
|
||||||
|
assert "Cache-Control" in headers
|
||||||
|
assert "Pragma" in headers
|
||||||
|
|
||||||
|
def test_get_headers_values(self):
|
||||||
|
"""Test get_headers returns correct values"""
|
||||||
|
middleware = SecurityHeadersMiddleware()
|
||||||
|
headers = middleware.get_headers()
|
||||||
|
assert headers["X-Content-Type-Options"] == "nosniff"
|
||||||
|
assert headers["X-Frame-Options"] == "DENY"
|
||||||
|
assert headers["X-XSS-Protection"] == "1; mode=block"
|
||||||
|
|
||||||
|
def test_apply_to_response(self):
|
||||||
|
"""Test apply_to_response adds security headers"""
|
||||||
|
middleware = SecurityHeadersMiddleware()
|
||||||
|
response_headers = {"Content-Type": "application/json"}
|
||||||
|
result = middleware.apply_to_response(response_headers)
|
||||||
|
|
||||||
|
assert "X-Content-Type-Options" in result
|
||||||
|
assert "X-Frame-Options" in result
|
||||||
|
assert result["Content-Type"] == "application/json"
|
||||||
|
|
||||||
|
def test_apply_to_response_overwrites_existing(self):
|
||||||
|
"""Test apply_to_response overwrites existing security headers"""
|
||||||
|
middleware = SecurityHeadersMiddleware()
|
||||||
|
response_headers = {"X-Frame-Options": "ALLOW"}
|
||||||
|
result = middleware.apply_to_response(response_headers)
|
||||||
|
|
||||||
|
assert result["X-Frame-Options"] == "DENY"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCORSMiddleware:
|
||||||
|
"""Tests for CORSMiddleware"""
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
"""Test CORS middleware initialization"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["Content-Type"]
|
||||||
|
)
|
||||||
|
middleware = CORSMiddleware(config)
|
||||||
|
assert middleware.config == config
|
||||||
|
|
||||||
|
def test_get_cors_headers_allowed_origin(self):
|
||||||
|
"""Test get_cors_headers with allowed origin"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["Content-Type"]
|
||||||
|
)
|
||||||
|
middleware = CORSMiddleware(config)
|
||||||
|
headers = middleware.get_cors_headers("http://localhost:3000")
|
||||||
|
|
||||||
|
assert headers["Access-Control-Allow-Origin"] == "http://localhost:3000"
|
||||||
|
assert headers["Access-Control-Allow-Methods"] == "GET, POST"
|
||||||
|
assert headers["Access-Control-Allow-Headers"] == "Content-Type"
|
||||||
|
assert headers["Access-Control-Max-Age"] == "3600"
|
||||||
|
|
||||||
|
def test_get_cors_headers_disallowed_origin(self):
|
||||||
|
"""Test get_cors_headers with disallowed origin"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["Content-Type"]
|
||||||
|
)
|
||||||
|
middleware = CORSMiddleware(config)
|
||||||
|
headers = middleware.get_cors_headers("http://evil.com")
|
||||||
|
|
||||||
|
assert headers == {}
|
||||||
|
|
||||||
|
def test_get_cors_headers_wildcard_origin(self):
|
||||||
|
"""Test get_cors_headers with wildcard origin"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["Content-Type"]
|
||||||
|
)
|
||||||
|
middleware = CORSMiddleware(config)
|
||||||
|
headers = middleware.get_cors_headers("http://any-origin.com")
|
||||||
|
|
||||||
|
assert headers["Access-Control-Allow-Origin"] == "http://any-origin.com"
|
||||||
|
|
||||||
|
def test_get_cors_headers_with_credentials(self):
|
||||||
|
"""Test get_cors_headers with credentials enabled"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["Content-Type"],
|
||||||
|
allow_credentials=True
|
||||||
|
)
|
||||||
|
middleware = CORSMiddleware(config)
|
||||||
|
headers = middleware.get_cors_headers("http://localhost:3000")
|
||||||
|
|
||||||
|
assert headers["Access-Control-Allow-Credentials"] == "true"
|
||||||
|
|
||||||
|
def test_get_cors_headers_with_expose_headers(self):
|
||||||
|
"""Test get_cors_headers with expose headers"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["Content-Type"],
|
||||||
|
expose_headers=["X-Request-ID"]
|
||||||
|
)
|
||||||
|
middleware = CORSMiddleware(config)
|
||||||
|
headers = middleware.get_cors_headers("http://localhost:3000")
|
||||||
|
|
||||||
|
assert headers["Access-Control-Expose-Headers"] == "X-Request-ID"
|
||||||
|
|
||||||
|
def test_is_origin_allowed_wildcard(self):
|
||||||
|
"""Test _is_origin_allowed with wildcard"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET"],
|
||||||
|
allow_headers=["Content-Type"]
|
||||||
|
)
|
||||||
|
middleware = CORSMiddleware(config)
|
||||||
|
assert middleware._is_origin_allowed("http://any-origin.com") is True
|
||||||
|
|
||||||
|
def test_is_origin_allowed_specific(self):
|
||||||
|
"""Test _is_origin_allowed with specific origin"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_methods=["GET"],
|
||||||
|
allow_headers=["Content-Type"]
|
||||||
|
)
|
||||||
|
middleware = CORSMiddleware(config)
|
||||||
|
assert middleware._is_origin_allowed("http://localhost:3000") is True
|
||||||
|
assert middleware._is_origin_allowed("http://evil.com") is False
|
||||||
|
|
||||||
|
def test_is_preflight_request(self):
|
||||||
|
"""Test is_preflight_request"""
|
||||||
|
config = CORSConfig(
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET"],
|
||||||
|
allow_headers=["Content-Type"]
|
||||||
|
)
|
||||||
|
middleware = CORSMiddleware(config)
|
||||||
|
assert middleware.is_preflight_request("OPTIONS") is True
|
||||||
|
assert middleware.is_preflight_request("GET") is False
|
||||||
|
assert middleware.is_preflight_request("options") is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestFactoryFunctions:
|
||||||
|
"""Tests for factory functions"""
|
||||||
|
|
||||||
|
def test_create_production_security_headers(self):
|
||||||
|
"""Test create_production_security_headers"""
|
||||||
|
headers = create_production_security_headers()
|
||||||
|
assert headers.X_Frame_Options == "DENY"
|
||||||
|
assert "preload" in headers.Strict_Transport_Security
|
||||||
|
assert "unsafe-inline" in headers.Content_Security_Policy
|
||||||
|
assert "geolocation=()" in headers.Permissions_Policy
|
||||||
|
|
||||||
|
def test_create_development_security_headers(self):
|
||||||
|
"""Test create_development_security_headers"""
|
||||||
|
headers = create_development_security_headers()
|
||||||
|
assert headers.X_Frame_Options == "SAMEORIGIN"
|
||||||
|
assert headers.Strict_Transport_Security == "max-age=3600"
|
||||||
|
assert headers.Permissions_Policy == ""
|
||||||
|
assert headers.Cache_Control == "no-cache"
|
||||||
|
|
||||||
|
def test_create_strict_cors_config(self):
|
||||||
|
"""Test create_strict_cors_config"""
|
||||||
|
config = create_strict_cors_config(["http://localhost:3000"])
|
||||||
|
assert "http://localhost:3000" in config.allow_origins
|
||||||
|
assert "GET" in config.allow_methods
|
||||||
|
assert "POST" in config.allow_methods
|
||||||
|
assert "PUT" in config.allow_methods
|
||||||
|
assert "DELETE" in config.allow_methods
|
||||||
|
assert "PATCH" in config.allow_methods
|
||||||
|
assert config.allow_credentials is True
|
||||||
|
assert "X-Request-ID" in config.expose_headers
|
||||||
|
assert config.max_age == 3600
|
||||||
|
|
||||||
|
def test_create_permissive_cors_config(self):
|
||||||
|
"""Test create_permissive_cors_config"""
|
||||||
|
config = create_permissive_cors_config()
|
||||||
|
assert "*" in config.allow_origins
|
||||||
|
assert "GET" in config.allow_methods
|
||||||
|
assert "POST" in config.allow_methods
|
||||||
|
assert "*" in config.allow_headers
|
||||||
|
assert config.allow_credentials is False
|
||||||
|
assert "*" in config.expose_headers
|
||||||
|
assert config.max_age == 86400
|
||||||
345
tests/test_utils.py
Normal file
345
tests/test_utils.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""
|
||||||
|
Tests for AITBC utility modules
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
|
||||||
|
from aitbc.utils.paths import (
|
||||||
|
get_data_path,
|
||||||
|
get_config_path,
|
||||||
|
get_log_path,
|
||||||
|
get_repo_path,
|
||||||
|
ensure_dir,
|
||||||
|
ensure_file_dir,
|
||||||
|
resolve_path,
|
||||||
|
get_keystore_path,
|
||||||
|
get_blockchain_data_path,
|
||||||
|
get_marketplace_data_path,
|
||||||
|
)
|
||||||
|
from aitbc.utils.env import (
|
||||||
|
get_env_var,
|
||||||
|
get_required_env_var,
|
||||||
|
get_bool_env_var,
|
||||||
|
get_int_env_var,
|
||||||
|
get_float_env_var,
|
||||||
|
get_list_env_var,
|
||||||
|
)
|
||||||
|
from aitbc.utils.json_utils import (
|
||||||
|
load_json,
|
||||||
|
save_json,
|
||||||
|
merge_json,
|
||||||
|
json_to_string,
|
||||||
|
string_to_json,
|
||||||
|
get_nested_value,
|
||||||
|
set_nested_value,
|
||||||
|
flatten_json,
|
||||||
|
)
|
||||||
|
from aitbc.exceptions import ConfigurationError
|
||||||
|
|
||||||
|
|
||||||
|
class TestPaths:
|
||||||
|
"""Tests for path utility functions"""
|
||||||
|
|
||||||
|
def test_get_data_path_no_subpath(self):
|
||||||
|
"""Test get_data_path without subpath"""
|
||||||
|
result = get_data_path()
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
|
||||||
|
def test_get_data_path_with_subpath(self):
|
||||||
|
"""Test get_data_path with subpath"""
|
||||||
|
result = get_data_path("test")
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("test")
|
||||||
|
|
||||||
|
def test_get_config_path(self):
|
||||||
|
"""Test get_config_path"""
|
||||||
|
result = get_config_path("config.yaml")
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("config.yaml")
|
||||||
|
|
||||||
|
def test_get_log_path(self):
|
||||||
|
"""Test get_log_path"""
|
||||||
|
result = get_log_path("app.log")
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("app.log")
|
||||||
|
|
||||||
|
def test_get_repo_path_no_subpath(self):
|
||||||
|
"""Test get_repo_path without subpath"""
|
||||||
|
result = get_repo_path()
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
|
||||||
|
def test_get_repo_path_with_subpath(self):
|
||||||
|
"""Test get_repo_path with subpath"""
|
||||||
|
result = get_repo_path("src")
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("src")
|
||||||
|
|
||||||
|
def test_ensure_dir(self):
|
||||||
|
"""Test ensure_dir creates directory"""
|
||||||
|
import tempfile
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
test_path = Path(tmpdir) / "test" / "nested"
|
||||||
|
result = ensure_dir(test_path)
|
||||||
|
assert result.exists()
|
||||||
|
assert result.is_dir()
|
||||||
|
|
||||||
|
def test_ensure_file_dir(self):
|
||||||
|
"""Test ensure_file_dir creates parent directory"""
|
||||||
|
import tempfile
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
test_path = Path(tmpdir) / "test" / "nested" / "file.txt"
|
||||||
|
result = ensure_file_dir(test_path)
|
||||||
|
assert result.exists()
|
||||||
|
assert result.is_dir()
|
||||||
|
|
||||||
|
def test_resolve_path_absolute(self):
|
||||||
|
"""Test resolve_path with absolute path"""
|
||||||
|
result = resolve_path("/tmp/test")
|
||||||
|
assert result.is_absolute()
|
||||||
|
|
||||||
|
def test_resolve_path_relative(self):
|
||||||
|
"""Test resolve_path with relative path"""
|
||||||
|
result = resolve_path("test")
|
||||||
|
assert result.is_absolute() # Relative paths are resolved to absolute
|
||||||
|
|
||||||
|
def test_get_keystore_path_no_wallet(self):
|
||||||
|
"""Test get_keystore_path without wallet name"""
|
||||||
|
result = get_keystore_path()
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("keystore")
|
||||||
|
|
||||||
|
def test_get_keystore_path_with_wallet(self):
|
||||||
|
"""Test get_keystore_path with wallet name"""
|
||||||
|
result = get_keystore_path("mywallet")
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("mywallet.json")
|
||||||
|
|
||||||
|
def test_get_blockchain_data_path_default(self):
|
||||||
|
"""Test get_blockchain_data_path with default chain"""
|
||||||
|
result = get_blockchain_data_path()
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("ait-mainnet")
|
||||||
|
|
||||||
|
def test_get_blockchain_data_path_custom(self):
|
||||||
|
"""Test get_blockchain_data_path with custom chain"""
|
||||||
|
result = get_blockchain_data_path("custom-chain")
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("custom-chain")
|
||||||
|
|
||||||
|
def test_get_marketplace_data_path_no_subpath(self):
|
||||||
|
"""Test get_marketplace_data_path without subpath"""
|
||||||
|
result = get_marketplace_data_path()
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("marketplace")
|
||||||
|
|
||||||
|
def test_get_marketplace_data_path_with_subpath(self):
|
||||||
|
"""Test get_marketplace_data_path with subpath"""
|
||||||
|
result = get_marketplace_data_path("orders")
|
||||||
|
assert isinstance(result, Path)
|
||||||
|
assert str(result).endswith("orders")
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnv:
|
||||||
|
"""Tests for environment variable utilities"""
|
||||||
|
|
||||||
|
def test_get_env_var_with_value(self):
|
||||||
|
"""Test get_env_var with set value"""
|
||||||
|
os.environ["TEST_VAR"] = "test_value"
|
||||||
|
result = get_env_var("TEST_VAR")
|
||||||
|
assert result == "test_value"
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_env_var_with_default(self):
|
||||||
|
"""Test get_env_var with default value"""
|
||||||
|
result = get_env_var("NONEXISTENT_VAR", "default")
|
||||||
|
assert result == "default"
|
||||||
|
|
||||||
|
def test_get_required_env_var_with_value(self):
|
||||||
|
"""Test get_required_env_var with set value"""
|
||||||
|
os.environ["TEST_VAR"] = "test_value"
|
||||||
|
result = get_required_env_var("TEST_VAR")
|
||||||
|
assert result == "test_value"
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_required_env_var_without_value(self):
|
||||||
|
"""Test get_required_env_var without value raises error"""
|
||||||
|
with pytest.raises(ConfigurationError):
|
||||||
|
get_required_env_var("NONEXISTENT_VAR")
|
||||||
|
|
||||||
|
def test_get_bool_env_var_true(self):
|
||||||
|
"""Test get_bool_env_var with true values"""
|
||||||
|
for value in ["true", "TRUE", "1", "yes", "YES", "on", "ON"]:
|
||||||
|
os.environ["TEST_VAR"] = value
|
||||||
|
assert get_bool_env_var("TEST_VAR") is True
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_bool_env_var_false(self):
|
||||||
|
"""Test get_bool_env_var with false values"""
|
||||||
|
for value in ["false", "FALSE", "0", "no", "NO", "off", "OFF"]:
|
||||||
|
os.environ["TEST_VAR"] = value
|
||||||
|
assert get_bool_env_var("TEST_VAR") is False
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_bool_env_var_default(self):
|
||||||
|
"""Test get_bool_env_var with default"""
|
||||||
|
assert get_bool_env_var("NONEXISTENT_VAR", True) is True
|
||||||
|
assert get_bool_env_var("NONEXISTENT_VAR", False) is False
|
||||||
|
|
||||||
|
def test_get_int_env_var_valid(self):
|
||||||
|
"""Test get_int_env_var with valid integer"""
|
||||||
|
os.environ["TEST_VAR"] = "42"
|
||||||
|
assert get_int_env_var("TEST_VAR") == 42
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_int_env_var_invalid(self):
|
||||||
|
"""Test get_int_env_var with invalid value returns default"""
|
||||||
|
os.environ["TEST_VAR"] = "not_a_number"
|
||||||
|
assert get_int_env_var("TEST_VAR", 10) == 10
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_int_env_var_default(self):
|
||||||
|
"""Test get_int_env_var with default"""
|
||||||
|
assert get_int_env_var("NONEXISTENT_VAR", 100) == 100
|
||||||
|
|
||||||
|
def test_get_float_env_var_valid(self):
|
||||||
|
"""Test get_float_env_var with valid float"""
|
||||||
|
os.environ["TEST_VAR"] = "3.14"
|
||||||
|
assert get_float_env_var("TEST_VAR") == 3.14
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_float_env_var_invalid(self):
|
||||||
|
"""Test get_float_env_var with invalid value returns default"""
|
||||||
|
os.environ["TEST_VAR"] = "not_a_number"
|
||||||
|
assert get_float_env_var("TEST_VAR", 2.5) == 2.5
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_float_env_var_default(self):
|
||||||
|
"""Test get_float_env_var with default"""
|
||||||
|
assert get_float_env_var("NONEXISTENT_VAR", 1.5) == 1.5
|
||||||
|
|
||||||
|
def test_get_list_env_var_valid(self):
|
||||||
|
"""Test get_list_env_var with valid list"""
|
||||||
|
os.environ["TEST_VAR"] = "item1,item2,item3"
|
||||||
|
result = get_list_env_var("TEST_VAR")
|
||||||
|
assert result == ["item1", "item2", "item3"]
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_list_env_var_custom_separator(self):
|
||||||
|
"""Test get_list_env_var with custom separator"""
|
||||||
|
os.environ["TEST_VAR"] = "item1;item2;item3"
|
||||||
|
result = get_list_env_var("TEST_VAR", separator=";")
|
||||||
|
assert result == ["item1", "item2", "item3"]
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_list_env_var_empty(self):
|
||||||
|
"""Test get_list_env_var with empty value returns default"""
|
||||||
|
os.environ["TEST_VAR"] = ""
|
||||||
|
result = get_list_env_var("TEST_VAR", default=["default"])
|
||||||
|
assert result == ["default"]
|
||||||
|
del os.environ["TEST_VAR"]
|
||||||
|
|
||||||
|
def test_get_list_env_var_default(self):
|
||||||
|
"""Test get_list_env_var with default"""
|
||||||
|
result = get_list_env_var("NONEXISTENT_VAR", default=["a", "b"])
|
||||||
|
assert result == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonUtils:
|
||||||
|
"""Tests for JSON utility functions"""
|
||||||
|
|
||||||
|
def test_json_to_string(self):
|
||||||
|
"""Test json_to_string"""
|
||||||
|
data = {"key": "value"}
|
||||||
|
result = json_to_string(data)
|
||||||
|
assert '"key"' in result
|
||||||
|
assert '"value"' in result
|
||||||
|
|
||||||
|
def test_string_to_json_valid(self):
|
||||||
|
"""Test string_to_json with valid JSON"""
|
||||||
|
json_str = '{"key": "value"}'
|
||||||
|
result = string_to_json(json_str)
|
||||||
|
assert result == {"key": "value"}
|
||||||
|
|
||||||
|
def test_string_to_json_invalid(self):
|
||||||
|
"""Test string_to_json with invalid JSON raises error"""
|
||||||
|
with pytest.raises(ConfigurationError):
|
||||||
|
string_to_json("not valid json")
|
||||||
|
|
||||||
|
def test_get_nested_value_found(self):
|
||||||
|
"""Test get_nested_value when key exists"""
|
||||||
|
data = {"a": {"b": {"c": "value"}}}
|
||||||
|
result = get_nested_value(data, "a", "b", "c")
|
||||||
|
assert result == "value"
|
||||||
|
|
||||||
|
def test_get_nested_value_not_found(self):
|
||||||
|
"""Test get_nested_value when key doesn't exist returns default"""
|
||||||
|
data = {"a": {"b": {"c": "value"}}}
|
||||||
|
result = get_nested_value(data, "a", "b", "d", default="default")
|
||||||
|
assert result == "default"
|
||||||
|
|
||||||
|
def test_get_nested_value_default_none(self):
|
||||||
|
"""Test get_nested_value with default None"""
|
||||||
|
data = {"a": {"b": {"c": "value"}}}
|
||||||
|
result = get_nested_value(data, "x", "y", "z")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_set_nested_value(self):
|
||||||
|
"""Test set_nested_value"""
|
||||||
|
data = {}
|
||||||
|
set_nested_value(data, "a", "b", "c", value="test")
|
||||||
|
assert data["a"]["b"]["c"] == "test"
|
||||||
|
|
||||||
|
def test_flatten_json(self):
|
||||||
|
"""Test flatten_json"""
|
||||||
|
data = {"a": {"b": {"c": "value"}}, "d": "simple"}
|
||||||
|
result = flatten_json(data)
|
||||||
|
assert "a.b.c" in result
|
||||||
|
assert result["a.b.c"] == "value"
|
||||||
|
assert result["d"] == "simple"
|
||||||
|
|
||||||
|
def test_flatten_json_custom_separator(self):
|
||||||
|
"""Test flatten_json with custom separator"""
|
||||||
|
data = {"a": {"b": "value"}}
|
||||||
|
result = flatten_json(data, separator="_")
|
||||||
|
assert "a_b" in result
|
||||||
|
assert result["a_b"] == "value"
|
||||||
|
|
||||||
|
def test_load_json(self, tmp_path):
|
||||||
|
"""Test load_json"""
|
||||||
|
test_file = tmp_path / "test.json"
|
||||||
|
test_file.write_text('{"key": "value"}')
|
||||||
|
result = load_json(test_file)
|
||||||
|
assert result == {"key": "value"}
|
||||||
|
|
||||||
|
def test_load_json_not_found(self, tmp_path):
|
||||||
|
"""Test load_json with non-existent file raises error"""
|
||||||
|
with pytest.raises(ConfigurationError):
|
||||||
|
load_json(tmp_path / "nonexistent.json")
|
||||||
|
|
||||||
|
def test_load_json_invalid(self, tmp_path):
|
||||||
|
"""Test load_json with invalid JSON raises error"""
|
||||||
|
test_file = tmp_path / "invalid.json"
|
||||||
|
test_file.write_text("not valid json")
|
||||||
|
with pytest.raises(ConfigurationError):
|
||||||
|
load_json(test_file)
|
||||||
|
|
||||||
|
def test_save_json(self, tmp_path):
|
||||||
|
"""Test save_json"""
|
||||||
|
test_file = tmp_path / "test.json"
|
||||||
|
data = {"key": "value"}
|
||||||
|
save_json(data, test_file)
|
||||||
|
assert test_file.exists()
|
||||||
|
result = load_json(test_file)
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
def test_merge_json(self, tmp_path):
|
||||||
|
"""Test merge_json"""
|
||||||
|
file1 = tmp_path / "file1.json"
|
||||||
|
file2 = tmp_path / "file2.json"
|
||||||
|
file1.write_text('{"a": 1, "b": 2}')
|
||||||
|
file2.write_text('{"b": 3, "c": 4}')
|
||||||
|
result = merge_json(file1, file2)
|
||||||
|
assert result == {"a": 1, "b": 3, "c": 4}
|
||||||
Reference in New Issue
Block a user