refactor: reorganize services into bounded contexts and implement async database support
Some checks failed
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
API Endpoint Tests / test-api-endpoints (push) Has been cancelled
CLI Tests / test-cli (push) Has been cancelled
Package Tests / Python package - aitbc-agent-sdk (push) Has been cancelled
Package Tests / Python package - aitbc-core (push) Has been cancelled
Package Tests / Python package - aitbc-crypto (push) Has been cancelled
Package Tests / Python package - aitbc-sdk (push) Has been cancelled
Package Tests / JavaScript package - aitbc-sdk-js (push) Has been cancelled
Package Tests / JavaScript package - aitbc-token (push) Has been cancelled
Staking Tests / test-staking-service (push) Failing after 3s
Staking Tests / test-staking-integration (push) Has been skipped
Staking Tests / test-staking-contract (push) Has been skipped
Staking Tests / run-staking-test-runner (push) Has been skipped
Multi-Node Stress Testing / stress-test (push) Successful in 3s
Cross-Node Transaction Testing / transaction-test (push) Successful in 3s

- Moved services to bounded context packages:
  - adaptive_learning.py → ai_analytics/adaptive_learning.py
  - analytics_service.py → ai_analytics/analytics.py
  - dynamic_pricing_engine.py → trading_marketplace/dynamic_pricing.py
  - trading_service.py → trading_marketplace/trading.py
- Implemented async database module (database_async.py):
  - Added async SQLAlchemy engine with connection pooling
  - Added
This commit is contained in:
aitbc
2026-05-12 18:10:58 +02:00
parent 6895770510
commit c87806b68b
31 changed files with 2027 additions and 3293 deletions

View File

@@ -1 +1,117 @@
[{'adapter}': 'def _build_async_url(url: str', 'driver': 'str) -> str:', 'Convert sync URL to async URL."': 'sqlite:///path.db -> sqlite+aiosqlite:///path.db\n # postgresql:// -> postgresql+asyncpg://\n parts = url.split(', ', 1)\n if len(parts) == 2:\n return f': 'parts[0]'}, {'driver}': {'f': 'url'}, 'sqlite': 'For SQLite', 'connect_args={"check_same_thread': False}, {}] """Async database module with connection pooling for Coordinator API."""
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from .config import settings
import logging
logger = logging.getLogger(__name__)
# Global async engine and session factory
_async_engine = None
_async_session_factory = None
def _build_async_url(url: str) -> str:
"""Convert sync database URL to async URL.
Examples:
sqlite:///path.db -> sqlite+aiosqlite:///path.db
postgresql://user:pass@host/db -> postgresql+asyncpg://user:pass@host/db
"""
# Handle URL with query parameters
if '?' in url:
base, params = url.split('?', 1)
if base.startswith('sqlite:'):
return f"{base}+aiosqlite://?{params}"
elif base.startswith('postgresql:'):
return f"{base}+asyncpg://?{params}"
else:
return f"{base}+aiosqlite://?{params}" # fallback
else:
if url.startswith('sqlite:'):
return url.replace('sqlite:', 'sqlite+aiosqlite:')
elif url.startswith('postgresql:'):
return url.replace('postgresql:', 'postgresql+asyncpg:')
else:
return url.replace(':', '+aiosqlite:') # fallback
def init_async_db() -> None:
"""Initialize async database engine and session factory."""
global _async_engine, _async_session_factory
if _async_engine is not None:
logger.warning("Async database already initialized")
return
try:
# Build async URL from sync settings
sync_url = str(settings.database.url)
async_url = _build_async_url(sync_url)
logger.info(f"Initializing async database connection: {async_url.split('://')[0]}://...")
# Create async engine with pooling
_async_engine = create_async_engine(
async_url,
echo=settings.database.echo if hasattr(settings.database, 'echo') else False,
pool_size=getattr(settings.database, 'pool_size', 5),
max_overflow=getattr(settings.database, 'max_overflow', 10),
pool_pre_ping=getattr(settings.database, 'pool_pre_ping', True),
pool_recycle=getattr(settings.database, 'pool_recycle', 3600),
)
# Create session factory
_async_session_factory = sessionmaker(
_async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
logger.info("Async database initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize async database: {e}")
raise
def get_async_db() -> AsyncSession:
"""Dependency to get async database session.
Yields:
AsyncSession: Database session that closes automatically
"""
if _async_session_factory is None:
raise RuntimeError("Async database not initialized. Call init_async_db() first.")
async def _get_async_db() -> AsyncSession:
async with _async_session_factory() as session:
try:
yield session
finally:
await session.close()
return _get_async_db()
def get_sync_engine():
"""Get synchronous engine for backward compatibility.
Returns:
Engine: SQLAlchemy synchronous engine
"""
from .database import engine
return engine
async def close_async_db() -> None:
"""Close async database connections."""
global _async_engine, _async_session_factory
if _async_engine is not None:
logger.info("Closing async database connections...")
await _async_engine.dispose()
_async_engine = None
_async_session_factory = None
logger.info("Async database connections closed")

View File

@@ -114,6 +114,7 @@ logger = get_logger(__name__)
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from .storage.db import init_db from .storage.db import init_db
from .database_async import init_async_db, close_async_db
@asynccontextmanager @asynccontextmanager
@@ -130,6 +131,14 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
logger.warning(f"Database initialization failed (non-fatal): {e}") logger.warning(f"Database initialization failed (non-fatal): {e}")
# Continue startup even if init_db fails # Continue startup even if init_db fails
# Initialize async database
try:
init_async_db()
logger.info("Async database initialized successfully")
except Exception as e:
logger.warning(f"Async database initialization failed (non-fatal): {e}")
# Continue startup even if async init fails
# Warmup database connections # Warmup database connections
logger.info("Warming up database connections...") logger.info("Warming up database connections...")
try: try:
@@ -227,6 +236,13 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
except Exception as e: except Exception as e:
logger.warning(f"Error closing database connections: {e}") logger.warning(f"Error closing database connections: {e}")
# Close async database connections
try:
await close_async_db()
logger.info("Async database connections closed successfully")
except Exception as e:
logger.warning(f"Error closing async database connections: {e}")
# Cleanup rate limiting state # Cleanup rate limiting state
logger.info("Cleaning up rate limiting state...") logger.info("Cleaning up rate limiting state...")

View File

@@ -14,7 +14,7 @@ from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from aitbc import get_logger from aitbc import get_logger
from ..services.adaptive_learning import AdaptiveLearningService from ..services.ai_analytics.adaptive_learning import AdaptiveLearningService
from ..storage import get_session from ..storage import get_session
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -27,7 +27,7 @@ from ..domain.analytics import (
MetricType, MetricType,
ReportType, ReportType,
) )
from ..services.analytics_service import MarketplaceAnalytics from ..services.ai_analytics.analytics import MarketplaceAnalytics
from ..storage import get_session from ..storage import get_session
router = APIRouter(prefix="/v1/analytics", tags=["analytics"]) router = APIRouter(prefix="/v1/analytics", tags=["analytics"])

View File

@@ -23,7 +23,7 @@ from ..schemas.pricing import (
PricingStrategyRequest, PricingStrategyRequest,
PricingStrategyResponse, PricingStrategyResponse,
) )
from ..services.dynamic_pricing_engine import DynamicPricingEngine, PriceConstraints, PricingStrategy, ResourceType from ..services.trading_marketplace.dynamic_pricing import DynamicPricingEngine, PriceConstraints, PricingStrategy, ResourceType
from ..services.market_data_collector import MarketDataCollector from ..services.market_data_collector import MarketDataCollector
router = APIRouter(prefix="/v1/pricing", tags=["dynamic-pricing"]) router = APIRouter(prefix="/v1/pricing", tags=["dynamic-pricing"])

View File

@@ -28,7 +28,7 @@ from ..domain.trading import (
TradeStatus, TradeStatus,
TradeType, TradeType,
) )
from ..services.trading_service import P2PTradingProtocol from ..services.trading_marketplace.trading import P2PTradingProtocol
from ..storage import get_session from ..storage import get_session
router = APIRouter(prefix="/v1/trading", tags=["trading"]) router = APIRouter(prefix="/v1/trading", tags=["trading"])

View File

@@ -0,0 +1,14 @@
"""
AI & Analytics Bounded Context
Provides analytics and advanced learning services.
"""
from .advanced_learning import AdvancedLearningService
from .analytics import AnalyticsEngine, DashboardManager, MarketplaceAnalytics
__all__ = [
"AdvancedLearningService",
"AnalyticsEngine",
"DashboardManager",
"MarketplaceAnalytics",
]

View File

@@ -17,7 +17,7 @@ from typing import Any
import numpy as np import numpy as np
from ..storage import get_session from ...storage import get_session
class LearningAlgorithm(StrEnum): class LearningAlgorithm(StrEnum):

View File

@@ -13,7 +13,7 @@ logger = get_logger(__name__)
from sqlmodel import Session, and_, select from sqlmodel import Session, and_, select
from ..domain.analytics import ( from ...domain.analytics import (
AnalyticsAlert, AnalyticsAlert,
AnalyticsPeriod, AnalyticsPeriod,
DashboardConfig, DashboardConfig,

View File

@@ -1,608 +0,0 @@
"""
Enterprise API Gateway - Phase 6.1 Implementation
Multi-tenant API routing and management for enterprise clients
Port: 8010
"""
import secrets
import time
from datetime import datetime, timezone, timedelta
from enum import StrEnum
from typing import Any
from uuid import uuid4
import jwt
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer
from pydantic import BaseModel, Field
from aitbc import get_logger
logger = get_logger(__name__)
from ..domain.multitenant import Tenant, TenantApiKey, TenantQuota
from ..exceptions import QuotaExceededError, TenantError
from ..storage.db import get_db
# Pydantic models for API requests/responses
class EnterpriseAuthRequest(BaseModel):
tenant_id: str = Field(..., description="Enterprise tenant identifier")
client_id: str = Field(..., description="Enterprise client ID")
client_secret: str = Field(..., description="Enterprise client secret")
auth_method: str = Field(default="client_credentials", description="Authentication method")
scopes: list[str] | None = Field(default=None, description="Requested scopes")
class EnterpriseAuthResponse(BaseModel):
access_token: str = Field(..., description="Access token for enterprise API")
token_type: str = Field(default="Bearer", description="Token type")
expires_in: int = Field(..., description="Token expiration in seconds")
refresh_token: str | None = Field(None, description="Refresh token")
scopes: list[str] = Field(..., description="Granted scopes")
tenant_info: dict[str, Any] = Field(..., description="Tenant information")
class APIQuotaRequest(BaseModel):
tenant_id: str = Field(..., description="Enterprise tenant identifier")
endpoint: str = Field(..., description="API endpoint")
method: str = Field(..., description="HTTP method")
quota_type: str = Field(default="rate_limit", description="Quota type")
class APIQuotaResponse(BaseModel):
quota_limit: int = Field(..., description="Quota limit")
quota_remaining: int = Field(..., description="Remaining quota")
quota_reset: datetime = Field(..., description="Quota reset time")
quota_type: str = Field(..., description="Quota type")
class WebhookConfig(BaseModel):
url: str = Field(..., description="Webhook URL")
events: list[str] = Field(..., description="Events to subscribe to")
secret: str | None = Field(None, description="Webhook secret")
active: bool = Field(default=True, description="Webhook active status")
retry_policy: dict[str, Any] | None = Field(None, description="Retry policy")
class EnterpriseIntegrationRequest(BaseModel):
integration_type: str = Field(..., description="Integration type (ERP, CRM, etc.)")
provider: str = Field(..., description="Integration provider")
configuration: dict[str, Any] = Field(..., description="Integration configuration")
credentials: dict[str, str] | None = Field(None, description="Integration credentials")
webhook_config: WebhookConfig | None = Field(None, description="Webhook configuration")
class EnterpriseMetrics(BaseModel):
api_calls_total: int = Field(..., description="Total API calls")
api_calls_successful: int = Field(..., description="Successful API calls")
average_response_time_ms: float = Field(..., description="Average response time")
error_rate_percent: float = Field(..., description="Error rate percentage")
quota_utilization_percent: float = Field(..., description="Quota utilization")
active_integrations: int = Field(..., description="Active integrations count")
class IntegrationStatus(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
PENDING = "pending"
class EnterpriseIntegration:
"""Enterprise integration configuration and management"""
def __init__(
self, integration_id: str, tenant_id: str, integration_type: str, provider: str, configuration: dict[str, Any]
):
self.integration_id = integration_id
self.tenant_id = tenant_id
self.integration_type = integration_type
self.provider = provider
self.configuration = configuration
self.status = IntegrationStatus.PENDING
self.created_at = datetime.now(timezone.utc)
self.last_updated = datetime.now(timezone.utc)
self.webhook_config = None
self.metrics = {"api_calls": 0, "errors": 0, "last_call": None}
class EnterpriseAPIGateway:
"""Enterprise API Gateway with multi-tenant support"""
def __init__(self):
self.tenant_service = None # Will be initialized with database session
self.active_tokens = {} # In-memory token storage (in production, use Redis)
self.rate_limiters = {} # Per-tenant rate limiters
self.webhooks = {} # Webhook configurations
self.integrations = {} # Enterprise integrations
self.api_metrics = {} # API performance metrics
# Default quotas
self.default_quotas = {
"rate_limit": 1000, # requests per minute
"daily_limit": 50000, # requests per day
"concurrent_limit": 100, # concurrent requests
}
# JWT configuration
self.jwt_secret = secrets.token_urlsafe(64)
self.jwt_algorithm = "HS256"
self.token_expiry = 3600 # 1 hour
async def authenticate_enterprise_client(self, request: EnterpriseAuthRequest, db_session) -> EnterpriseAuthResponse:
"""Authenticate enterprise client and issue access token"""
try:
# Validate tenant and client credentials
tenant = await self._validate_tenant_credentials(
request.tenant_id, request.client_id, request.client_secret, db_session
)
# Generate access token
access_token = self._generate_access_token(
tenant_id=request.tenant_id, client_id=request.client_id, scopes=request.scopes or ["enterprise_api"]
)
# Generate refresh token
refresh_token = self._generate_refresh_token(request.tenant_id, request.client_id)
# Store token
self.active_tokens[access_token] = {
"tenant_id": request.tenant_id,
"client_id": request.client_id,
"scopes": request.scopes or ["enterprise_api"],
"expires_at": datetime.now(timezone.utc) + timedelta(seconds=self.token_expiry),
"refresh_token": refresh_token,
}
return EnterpriseAuthResponse(
access_token=access_token,
token_type="Bearer",
expires_in=self.token_expiry,
refresh_token=refresh_token,
scopes=request.scopes or ["enterprise_api"],
tenant_info={
"tenant_id": tenant.tenant_id,
"name": tenant.name,
"plan": tenant.plan,
"status": tenant.status.value,
"created_at": tenant.created_at.isoformat(),
},
)
except Exception as e:
logger.error(f"Enterprise authentication failed: {e}")
raise HTTPException(status_code=401, detail="Authentication failed")
def _generate_access_token(self, tenant_id: str, client_id: str, scopes: list[str]) -> str:
"""Generate JWT access token"""
payload = {
"sub": f"{tenant_id}:{client_id}",
"scopes": scopes,
"iat": datetime.now(timezone.utc),
"exp": datetime.now(timezone.utc) + timedelta(seconds=self.token_expiry),
"type": "access",
}
return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
def _generate_refresh_token(self, tenant_id: str, client_id: str) -> str:
"""Generate refresh token"""
payload = {
"sub": f"{tenant_id}:{client_id}",
"iat": datetime.now(timezone.utc),
"exp": datetime.now(timezone.utc) + timedelta(days=30), # 30 days
"type": "refresh",
}
return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
async def _validate_tenant_credentials(self, tenant_id: str, client_id: str, client_secret: str, db_session) -> Tenant:
"""Validate tenant credentials"""
# Find tenant
tenant = db_session.query(Tenant).filter(Tenant.tenant_id == tenant_id).first()
if not tenant:
raise TenantError(f"Tenant {tenant_id} not found")
# Find API key
api_key = (
db_session.query(TenantApiKey)
.filter(TenantApiKey.tenant_id == tenant_id, TenantApiKey.client_id == client_id, TenantApiKey.is_active)
.first()
)
if not api_key or not secrets.compare_digest(api_key.client_secret, client_secret):
raise TenantError("Invalid client credentials")
# Check tenant status
if tenant.status.value != "active":
raise TenantError(f"Tenant {tenant_id} is not active")
return tenant
async def check_api_quota(self, tenant_id: str, endpoint: str, method: str, db_session) -> APIQuotaResponse:
"""Check and enforce API quotas"""
try:
# Get tenant quota
quota = await self._get_tenant_quota(tenant_id, db_session)
# Check rate limiting
current_usage = await self._get_current_usage(tenant_id, "rate_limit")
if current_usage >= quota["rate_limit"]:
raise QuotaExceededError("Rate limit exceeded")
# Update usage
await self._update_usage(tenant_id, "rate_limit", current_usage + 1)
return APIQuotaResponse(
quota_limit=quota["rate_limit"],
quota_remaining=quota["rate_limit"] - current_usage - 1,
quota_reset=datetime.now(timezone.utc) + timedelta(minutes=1),
quota_type="rate_limit",
)
except QuotaExceededError:
raise
except Exception as e:
logger.error(f"Quota check failed: {e}")
raise HTTPException(status_code=500, detail="Quota check failed")
async def _get_tenant_quota(self, tenant_id: str, db_session) -> dict[str, int]:
"""Get tenant quota configuration"""
# Get tenant-specific quota
tenant_quota = db_session.query(TenantQuota).filter(TenantQuota.tenant_id == tenant_id).first()
if tenant_quota:
return {
"rate_limit": tenant_quota.rate_limit or self.default_quotas["rate_limit"],
"daily_limit": tenant_quota.daily_limit or self.default_quotas["daily_limit"],
"concurrent_limit": tenant_quota.concurrent_limit or self.default_quotas["concurrent_limit"],
}
return self.default_quotas
async def _get_current_usage(self, tenant_id: str, quota_type: str) -> int:
"""Get current quota usage"""
# In production, use Redis or database for persistent storage
if quota_type == "rate_limit":
# Get usage in the last minute
return len([t for t in self.rate_limiters.get(tenant_id, []) if datetime.now(timezone.utc) - t < timedelta(minutes=1)])
return 0
async def _update_usage(self, tenant_id: str, quota_type: str, usage: int):
"""Update quota usage"""
if quota_type == "rate_limit":
if tenant_id not in self.rate_limiters:
self.rate_limiters[tenant_id] = []
# Add current timestamp
self.rate_limiters[tenant_id].append(datetime.now(timezone.utc))
# Clean old entries (older than 1 minute)
cutoff = datetime.now(timezone.utc) - timedelta(minutes=1)
self.rate_limiters[tenant_id] = [t for t in self.rate_limiters[tenant_id] if t > cutoff]
async def create_enterprise_integration(
self, tenant_id: str, request: EnterpriseIntegrationRequest, db_session
) -> dict[str, Any]:
"""Create new enterprise integration"""
try:
# Validate tenant
tenant = db_session.query(Tenant).filter(Tenant.tenant_id == tenant_id).first()
if not tenant:
raise TenantError(f"Tenant {tenant_id} not found")
# Create integration
integration_id = str(uuid4())
integration = EnterpriseIntegration(
integration_id=integration_id,
tenant_id=tenant_id,
integration_type=request.integration_type,
provider=request.provider,
configuration=request.configuration,
)
# Store webhook configuration
if request.webhook_config:
integration.webhook_config = request.webhook_config.dict()
self.webhooks[integration_id] = request.webhook_config.dict()
# Store integration
self.integrations[integration_id] = integration
# Initialize integration
await self._initialize_integration(integration)
return {
"integration_id": integration_id,
"status": integration.status.value,
"created_at": integration.created_at.isoformat(),
"configuration": integration.configuration,
}
except Exception as e:
logger.error(f"Failed to create enterprise integration: {e}")
raise HTTPException(status_code=500, detail="Integration creation failed")
async def _initialize_integration(self, integration: EnterpriseIntegration):
"""Initialize enterprise integration"""
try:
# Integration-specific initialization logic
if integration.integration_type.lower() == "erp":
await self._initialize_erp_integration(integration)
elif integration.integration_type.lower() == "crm":
await self._initialize_crm_integration(integration)
elif integration.integration_type.lower() == "bi":
await self._initialize_bi_integration(integration)
integration.status = IntegrationStatus.ACTIVE
integration.last_updated = datetime.now(timezone.utc)
except Exception as e:
logger.error(f"Integration initialization failed: {e}")
integration.status = IntegrationStatus.ERROR
raise
async def _initialize_erp_integration(self, integration: EnterpriseIntegration):
"""Initialize ERP integration"""
# ERP-specific initialization
provider = integration.provider.lower()
if provider == "sap":
await self._initialize_sap_integration(integration)
elif provider == "oracle":
await self._initialize_oracle_integration(integration)
elif provider == "microsoft":
await self._initialize_microsoft_integration(integration)
logger.info(f"ERP integration initialized: {integration.provider}")
async def _initialize_sap_integration(self, integration: EnterpriseIntegration):
"""Initialize SAP ERP integration"""
# SAP integration logic
config = integration.configuration
# Validate SAP configuration
required_fields = ["system_id", "client", "username", "password", "host"]
for field in required_fields:
if field not in config:
raise ValueError(f"SAP integration requires {field}")
# Test SAP connection
# In production, implement actual SAP connection testing
logger.info(f"SAP connection test successful for {integration.integration_id}")
async def get_enterprise_metrics(self, tenant_id: str, db_session) -> EnterpriseMetrics:
"""Get enterprise metrics and analytics"""
try:
# Get API metrics
api_metrics = self.api_metrics.get(
tenant_id, {"total_calls": 0, "successful_calls": 0, "failed_calls": 0, "response_times": []}
)
# Calculate metrics
total_calls = api_metrics["total_calls"]
successful_calls = api_metrics["successful_calls"]
failed_calls = api_metrics["failed_calls"]
average_response_time = (
sum(api_metrics["response_times"]) / len(api_metrics["response_times"])
if api_metrics["response_times"]
else 0.0
)
error_rate = (failed_calls / total_calls * 100) if total_calls > 0 else 0.0
# Get quota utilization
current_usage = await self._get_current_usage(tenant_id, "rate_limit")
quota = await self._get_tenant_quota(tenant_id, db_session)
quota_utilization = (current_usage / quota["rate_limit"] * 100) if quota["rate_limit"] > 0 else 0.0
# Count active integrations
active_integrations = len(
[i for i in self.integrations.values() if i.tenant_id == tenant_id and i.status == IntegrationStatus.ACTIVE]
)
return EnterpriseMetrics(
api_calls_total=total_calls,
api_calls_successful=successful_calls,
average_response_time_ms=average_response_time,
error_rate_percent=error_rate,
quota_utilization_percent=quota_utilization,
active_integrations=active_integrations,
)
except Exception as e:
logger.error(f"Failed to get enterprise metrics: {e}")
raise HTTPException(status_code=500, detail="Metrics retrieval failed")
async def record_api_call(self, tenant_id: str, endpoint: str, response_time: float, success: bool):
"""Record API call for metrics"""
if tenant_id not in self.api_metrics:
self.api_metrics[tenant_id] = {"total_calls": 0, "successful_calls": 0, "failed_calls": 0, "response_times": []}
metrics = self.api_metrics[tenant_id]
metrics["total_calls"] += 1
if success:
metrics["successful_calls"] += 1
else:
metrics["failed_calls"] += 1
metrics["response_times"].append(response_time)
# Keep only last 1000 response times
if len(metrics["response_times"]) > 1000:
metrics["response_times"] = metrics["response_times"][-1000:]
# FastAPI application
app = FastAPI(
title="Enterprise API Gateway",
description="Multi-tenant API routing and management for enterprise clients",
version="6.1.0",
docs_url="/docs",
redoc_url="/redoc",
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security
security = HTTPBearer()
# Global gateway instance
gateway = EnterpriseAPIGateway()
# Dependency for database session
async def get_db_session():
"""Get database session"""
async with get_db() as session:
yield session
# Middleware for API metrics
@app.middleware("http")
async def api_metrics_middleware(request: Request, call_next):
"""Middleware to record API metrics"""
start_time = time.time()
# Extract tenant from token if available
tenant_id = None
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
token = authorization[7:]
token_data = gateway.active_tokens.get(token)
if token_data:
tenant_id = token_data["tenant_id"]
# Process request
response = await call_next(request)
# Record metrics
response_time = (time.time() - start_time) * 1000 # Convert to milliseconds
success = response.status_code < 400
if tenant_id:
await gateway.record_api_call(tenant_id, str(request.url.path), response_time, success)
return response
@app.post("/enterprise/auth")
async def enterprise_auth(request: EnterpriseAuthRequest, db_session=Depends(get_db_session)):
"""Authenticate enterprise client"""
result = await gateway.authenticate_enterprise_client(request, db_session)
return result
@app.post("/enterprise/quota/check")
async def check_quota(request: APIQuotaRequest, db_session=Depends(get_db_session)):
"""Check API quota"""
result = await gateway.check_api_quota(request.tenant_id, request.endpoint, request.method, db_session)
return result
@app.post("/enterprise/integrations")
async def create_integration(request: EnterpriseIntegrationRequest, db_session=Depends(get_db_session)):
"""Create enterprise integration"""
# Extract tenant from token (in production, proper authentication)
tenant_id = "demo_tenant" # Placeholder
result = await gateway.create_enterprise_integration(tenant_id, request, db_session)
return result
@app.get("/enterprise/analytics")
async def get_analytics(db_session=Depends(get_db_session)):
"""Get enterprise analytics dashboard"""
# Extract tenant from token (in production, proper authentication)
tenant_id = "demo_tenant" # Placeholder
result = await gateway.get_enterprise_metrics(tenant_id, db_session)
return result
@app.get("/enterprise/status")
async def get_status():
"""Get enterprise gateway status"""
return {
"service": "Enterprise API Gateway",
"version": "6.1.0",
"port": 8010,
"status": "operational",
"active_tenants": len({token["tenant_id"] for token in gateway.active_tokens.values()}),
"active_integrations": len(gateway.integrations),
"timestamp": datetime.now(timezone.utc).isoformat(),
}
@app.get("/")
async def root():
"""Root endpoint"""
return {
"service": "Enterprise API Gateway",
"version": "6.1.0",
"port": 8010,
"capabilities": [
"Multi-tenant API Management",
"Enterprise Authentication",
"API Quota Management",
"Enterprise Integration Framework",
"Real-time Analytics",
],
"status": "operational",
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"timestamp": datetime.now(timezone.utc).isoformat(),
"services": {
"api_gateway": "operational",
"authentication": "operational",
"quota_management": "operational",
"integration_framework": "operational",
},
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8010)

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,14 @@
""" """
Enterprise Integration Bounded Context Enterprise Integration Bounded Context
Provides enterprise API gateway, security, load balancing, and integration services. Provides enterprise integration, security, and load balancing services.
""" """
from .api_gateway import EnterpriseAPIGateway from .integration import EnterpriseIntegrationFramework
from .integration import EnterpriseIntegrationService
from .load_balancer import AdvancedLoadBalancer from .load_balancer import AdvancedLoadBalancer
from .security import EnterpriseEncryption, HSMManager, ThreatDetectionSystem, ZeroTrustArchitecture from .security import EnterpriseEncryption, HSMManager, ThreatDetectionSystem, ZeroTrustArchitecture
__all__ = [ __all__ = [
"EnterpriseAPIGateway", "EnterpriseIntegrationFramework",
"EnterpriseIntegrationService",
"AdvancedLoadBalancer", "AdvancedLoadBalancer",
"EnterpriseEncryption", "EnterpriseEncryption",
"HSMManager", "HSMManager",

View File

@@ -1,770 +0,0 @@
"""
Advanced Load Balancing - Phase 6.4 Implementation
Intelligent traffic distribution with AI-powered auto-scaling and performance optimization
"""
import statistics
from dataclasses import dataclass, field
from datetime import datetime, timezone, timedelta
from enum import StrEnum
from typing import Any
from aitbc import get_logger
logger = get_logger(__name__)
class LoadBalancingAlgorithm(StrEnum):
"""Load balancing algorithms"""
ROUND_ROBIN = "round_robin"
WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
LEAST_CONNECTIONS = "least_connections"
LEAST_RESPONSE_TIME = "least_response_time"
RESOURCE_BASED = "resource_based"
PREDICTIVE_AI = "predictive_ai"
ADAPTIVE = "adaptive"
class ScalingPolicy(StrEnum):
"""Auto-scaling policies"""
MANUAL = "manual"
THRESHOLD_BASED = "threshold_based"
PREDICTIVE = "predictive"
HYBRID = "hybrid"
class HealthStatus(StrEnum):
"""Health status"""
HEALTHY = "healthy"
UNHEALTHY = "unhealthy"
DRAINING = "draining"
MAINTENANCE = "maintenance"
@dataclass
class BackendServer:
"""Backend server configuration"""
server_id: str
host: str
port: int
weight: float = 1.0
max_connections: int = 1000
current_connections: int = 0
cpu_usage: float = 0.0
memory_usage: float = 0.0
response_time_ms: float = 0.0
request_count: int = 0
error_count: int = 0
health_status: HealthStatus = HealthStatus.HEALTHY
last_health_check: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
capabilities: dict[str, Any] = field(default_factory=dict)
region: str = "default"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class ScalingMetric:
"""Scaling metric configuration"""
metric_name: str
threshold_min: float
threshold_max: float
scaling_factor: float
cooldown_period: timedelta
measurement_window: timedelta
@dataclass
class TrafficPattern:
"""Traffic pattern for predictive scaling"""
pattern_id: str
name: str
time_windows: list[dict[str, Any]] # List of time windows with expected load
day_of_week: int # 0-6 (Monday-Sunday)
seasonal_factor: float = 1.0
confidence_score: float = 0.0
class PredictiveScaler:
"""AI-powered predictive auto-scaling"""
def __init__(self):
self.traffic_history = []
self.scaling_predictions = {}
self.traffic_patterns = {}
self.model_weights = {}
self.logger = get_logger("predictive_scaler")
async def record_traffic(self, timestamp: datetime, request_count: int, response_time_ms: float, error_rate: float):
"""Record traffic metrics"""
traffic_record = {
"timestamp": timestamp,
"request_count": request_count,
"response_time_ms": response_time_ms,
"error_rate": error_rate,
"hour": timestamp.hour,
"day_of_week": timestamp.weekday(),
"day_of_month": timestamp.day,
"month": timestamp.month,
}
self.traffic_history.append(traffic_record)
# Keep only last 30 days of history
cutoff = datetime.now(timezone.utc) - timedelta(days=30)
self.traffic_history = [record for record in self.traffic_history if record["timestamp"] > cutoff]
# Update traffic patterns
await self._update_traffic_patterns()
async def _update_traffic_patterns(self):
"""Update traffic patterns based on historical data"""
if len(self.traffic_history) < 168: # Need at least 1 week of data
return
# Group by hour and day of week
patterns = {}
for record in self.traffic_history:
key = f"{record['day_of_week']}_{record['hour']}"
if key not in patterns:
patterns[key] = {"request_counts": [], "response_times": [], "error_rates": []}
patterns[key]["request_counts"].append(record["request_count"])
patterns[key]["response_times"].append(record["response_time_ms"])
patterns[key]["error_rates"].append(record["error_rate"])
# Calculate pattern statistics
for key, data in patterns.items():
day_of_week, hour = key.split("_")
pattern = TrafficPattern(
pattern_id=key,
name=f"Pattern Day {day_of_week} Hour {hour}",
time_windows=[
{
"hour": int(hour),
"avg_requests": statistics.mean(data["request_counts"]),
"max_requests": max(data["request_counts"]),
"min_requests": min(data["request_counts"]),
"std_requests": statistics.stdev(data["request_counts"]) if len(data["request_counts"]) > 1 else 0,
"avg_response_time": statistics.mean(data["response_times"]),
"avg_error_rate": statistics.mean(data["error_rates"]),
}
],
day_of_week=int(day_of_week),
confidence_score=min(len(data["request_counts"]) / 100, 1.0), # Confidence based on data points
)
self.traffic_patterns[key] = pattern
async def predict_traffic(self, prediction_window: timedelta = timedelta(hours=1)) -> dict[str, Any]:
"""Predict traffic for the next time window"""
try:
current_time = datetime.now(timezone.utc)
current_time + prediction_window
# Get current pattern
current_pattern_key = f"{current_time.weekday()}_{current_time.hour}"
current_pattern = self.traffic_patterns.get(current_pattern_key)
if not current_pattern:
# Fallback to simple prediction
return await self._simple_prediction(prediction_window)
# Get historical data for similar time periods
similar_patterns = [
pattern
for pattern in self.traffic_patterns.values()
if pattern.day_of_week == current_time.weekday()
and abs(pattern.time_windows[0]["hour"] - current_time.hour) <= 2
]
if not similar_patterns:
return await self._simple_prediction(prediction_window)
# Calculate weighted prediction
total_weight = 0
weighted_requests = 0
weighted_response_time = 0
weighted_error_rate = 0
for pattern in similar_patterns:
weight = pattern.confidence_score
window_data = pattern.time_windows[0]
weighted_requests += window_data["avg_requests"] * weight
weighted_response_time += window_data["avg_response_time"] * weight
weighted_error_rate += window_data["avg_error_rate"] * weight
total_weight += weight
if total_weight > 0:
predicted_requests = weighted_requests / total_weight
predicted_response_time = weighted_response_time / total_weight
predicted_error_rate = weighted_error_rate / total_weight
else:
return await self._simple_prediction(prediction_window)
# Apply seasonal factors
seasonal_factor = self._get_seasonal_factor(current_time)
predicted_requests *= seasonal_factor
return {
"prediction_window_hours": prediction_window.total_seconds() / 3600,
"predicted_requests_per_hour": int(predicted_requests),
"predicted_response_time_ms": predicted_response_time,
"predicted_error_rate": predicted_error_rate,
"confidence_score": min(total_weight / len(similar_patterns), 1.0),
"seasonal_factor": seasonal_factor,
"pattern_based": True,
"prediction_timestamp": current_time.isoformat(),
}
except Exception as e:
self.logger.error(f"Traffic prediction failed: {e}")
return await self._simple_prediction(prediction_window)
async def _simple_prediction(self, prediction_window: timedelta) -> dict[str, Any]:
"""Simple prediction based on recent averages"""
if not self.traffic_history:
return {
"prediction_window_hours": prediction_window.total_seconds() / 3600,
"predicted_requests_per_hour": 1000, # Default
"predicted_response_time_ms": 100.0,
"predicted_error_rate": 0.01,
"confidence_score": 0.1,
"pattern_based": False,
"prediction_timestamp": datetime.now(timezone.utc).isoformat(),
}
# Calculate recent averages
recent_records = self.traffic_history[-24:] # Last 24 records
avg_requests = statistics.mean([r["request_count"] for r in recent_records])
avg_response_time = statistics.mean([r["response_time_ms"] for r in recent_records])
avg_error_rate = statistics.mean([r["error_rate"] for r in recent_records])
return {
"prediction_window_hours": prediction_window.total_seconds() / 3600,
"predicted_requests_per_hour": int(avg_requests),
"predicted_response_time_ms": avg_response_time,
"predicted_error_rate": avg_error_rate,
"confidence_score": 0.3,
"pattern_based": False,
"prediction_timestamp": datetime.now(timezone.utc).isoformat(),
}
def _get_seasonal_factor(self, timestamp: datetime) -> float:
"""Get seasonal adjustment factor"""
# Simple seasonal factors (can be enhanced with more sophisticated models)
month = timestamp.month
seasonal_factors = {
1: 0.8, # January - post-holiday dip
2: 0.9, # February
3: 1.0, # March
4: 1.1, # April - spring increase
5: 1.2, # May
6: 1.1, # June
7: 1.0, # July - summer
8: 0.9, # August
9: 1.1, # September - back to business
10: 1.2, # October
11: 1.3, # November - holiday season start
12: 1.4, # December - peak holiday season
}
return seasonal_factors.get(month, 1.0)
async def get_scaling_recommendation(self, current_servers: int, current_capacity: int) -> dict[str, Any]:
"""Get scaling recommendation based on predictions"""
try:
# Get traffic prediction
prediction = await self.predict_traffic(timedelta(hours=1))
predicted_requests = prediction["predicted_requests_per_hour"]
current_capacity_per_server = current_capacity // max(current_servers, 1)
# Calculate required servers
required_servers = max(1, int(predicted_requests / current_capacity_per_server))
# Apply buffer (20% extra capacity)
required_servers = int(required_servers * 1.2)
scaling_action = "none"
if required_servers > current_servers:
scaling_action = "scale_up"
scale_to = required_servers
elif required_servers < current_servers * 0.7: # Scale down if underutilized
scaling_action = "scale_down"
scale_to = max(1, required_servers)
else:
scale_to = current_servers
return {
"current_servers": current_servers,
"recommended_servers": scale_to,
"scaling_action": scaling_action,
"predicted_load": predicted_requests,
"current_capacity_per_server": current_capacity_per_server,
"confidence_score": prediction["confidence_score"],
"reason": f"Predicted {predicted_requests} requests/hour vs current capacity {current_servers * current_capacity_per_server}",
"recommendation_timestamp": datetime.now(timezone.utc).isoformat(),
}
except Exception as e:
self.logger.error(f"Scaling recommendation failed: {e}")
return {
"scaling_action": "none",
"reason": f"Prediction failed: {str(e)}",
"recommendation_timestamp": datetime.now(timezone.utc).isoformat(),
}
class AdvancedLoadBalancer:
"""Advanced load balancer with multiple algorithms and AI optimization"""
def __init__(self):
self.backends = {}
self.algorithm = LoadBalancingAlgorithm.ADAPTIVE
self.current_index = 0
self.request_history = []
self.performance_metrics = {}
self.predictive_scaler = PredictiveScaler()
self.scaling_metrics = {}
self.logger = get_logger("advanced_load_balancer")
async def add_backend(self, server: BackendServer) -> bool:
"""Add backend server"""
try:
self.backends[server.server_id] = server
# Initialize performance metrics
self.performance_metrics[server.server_id] = {
"avg_response_time": 0.0,
"error_rate": 0.0,
"throughput": 0.0,
"uptime": 1.0,
"last_updated": datetime.now(timezone.utc),
}
self.logger.info(f"Backend server added: {server.server_id}")
return True
except Exception as e:
self.logger.error(f"Failed to add backend server: {e}")
return False
async def remove_backend(self, server_id: str) -> bool:
"""Remove backend server"""
if server_id in self.backends:
del self.backends[server_id]
del self.performance_metrics[server_id]
self.logger.info(f"Backend server removed: {server_id}")
return True
return False
async def select_backend(self, request_context: dict[str, Any] | None = None) -> str | None:
"""Select backend server based on algorithm"""
try:
# Filter healthy backends
healthy_backends = {
sid: server for sid, server in self.backends.items() if server.health_status == HealthStatus.HEALTHY
}
if not healthy_backends:
return None
# Select backend based on algorithm
if self.algorithm == LoadBalancingAlgorithm.ROUND_ROBIN:
return await self._select_round_robin(healthy_backends)
elif self.algorithm == LoadBalancingAlgorithm.WEIGHTED_ROUND_ROBIN:
return await self._select_weighted_round_robin(healthy_backends)
elif self.algorithm == LoadBalancingAlgorithm.LEAST_CONNECTIONS:
return await self._select_least_connections(healthy_backends)
elif self.algorithm == LoadBalancingAlgorithm.LEAST_RESPONSE_TIME:
return await self._select_least_response_time(healthy_backends)
elif self.algorithm == LoadBalancingAlgorithm.RESOURCE_BASED:
return await self._select_resource_based(healthy_backends)
elif self.algorithm == LoadBalancingAlgorithm.PREDICTIVE_AI:
return await self._select_predictive_ai(healthy_backends, request_context)
elif self.algorithm == LoadBalancingAlgorithm.ADAPTIVE:
return await self._select_adaptive(healthy_backends, request_context)
else:
return await self._select_round_robin(healthy_backends)
except Exception as e:
self.logger.error(f"Backend selection failed: {e}")
return None
async def _select_round_robin(self, backends: dict[str, BackendServer]) -> str:
"""Round robin selection"""
backend_ids = list(backends.keys())
if not backend_ids:
return None
selected = backend_ids[self.current_index % len(backend_ids)]
self.current_index += 1
return selected
async def _select_weighted_round_robin(self, backends: dict[str, BackendServer]) -> str:
"""Weighted round robin selection"""
# Calculate total weight
total_weight = sum(server.weight for server in backends.values())
if total_weight <= 0:
return await self._select_round_robin(backends)
# Select based on weights
import random
rand_value = random.uniform(0, total_weight)
current_weight = 0
for server_id, server in backends.items():
current_weight += server.weight
if rand_value <= current_weight:
return server_id
# Fallback
return list(backends.keys())[0]
async def _select_least_connections(self, backends: dict[str, BackendServer]) -> str:
"""Select backend with least connections"""
min_connections = float("inf")
selected_backend = None
for server_id, server in backends.items():
if server.current_connections < min_connections:
min_connections = server.current_connections
selected_backend = server_id
return selected_backend
async def _select_least_response_time(self, backends: dict[str, BackendServer]) -> str:
"""Select backend with least response time"""
min_response_time = float("inf")
selected_backend = None
for server_id, server in backends.items():
if server.response_time_ms < min_response_time:
min_response_time = server.response_time_ms
selected_backend = server_id
return selected_backend
async def _select_resource_based(self, backends: dict[str, BackendServer]) -> str:
"""Select backend based on resource utilization"""
best_score = -1
selected_backend = None
for server_id, server in backends.items():
# Calculate resource score (lower is better)
cpu_score = 1.0 - (server.cpu_usage / 100.0)
memory_score = 1.0 - (server.memory_usage / 100.0)
connection_score = 1.0 - (server.current_connections / server.max_connections)
# Weighted score
resource_score = cpu_score * 0.4 + memory_score * 0.3 + connection_score * 0.3
if resource_score > best_score:
best_score = resource_score
selected_backend = server_id
return selected_backend
async def _select_predictive_ai(
self, backends: dict[str, BackendServer], request_context: dict[str, Any] | None
) -> str:
"""AI-powered predictive selection"""
# Get performance predictions for each backend
backend_scores = {}
for server_id, server in backends.items():
# Predict performance based on historical data
self.performance_metrics.get(server_id, {})
# Calculate predicted response time
predicted_response_time = (
server.response_time_ms
* (1 + server.cpu_usage / 100)
* (1 + server.memory_usage / 100)
* (1 + server.current_connections / server.max_connections)
)
# Calculate score (lower response time is better)
score = 1.0 / (1.0 + predicted_response_time / 100.0)
# Apply context-based adjustments
if request_context:
# Consider request type, user location, etc.
context_multiplier = await self._calculate_context_multiplier(server, request_context)
score *= context_multiplier
backend_scores[server_id] = score
# Select best scoring backend
if backend_scores:
return max(backend_scores, key=backend_scores.get)
return await self._select_least_connections(backends)
async def _select_adaptive(self, backends: dict[str, BackendServer], request_context: dict[str, Any] | None) -> str:
"""Adaptive selection based on current conditions"""
# Analyze current system state
total_connections = sum(server.current_connections for server in backends.values())
avg_response_time = statistics.mean([server.response_time_ms for server in backends.values()])
# Choose algorithm based on conditions
if total_connections > sum(server.max_connections for server in backends.values()) * 0.8:
# High load - use resource-based
return await self._select_resource_based(backends)
elif avg_response_time > 200:
# High latency - use least response time
return await self._select_least_response_time(backends)
else:
# Normal conditions - use weighted round robin
return await self._select_weighted_round_robin(backends)
async def _calculate_context_multiplier(self, server: BackendServer, request_context: dict[str, Any]) -> float:
"""Calculate context-based multiplier for backend selection"""
multiplier = 1.0
# Consider geographic location
if "user_location" in request_context and "region" in server.capabilities:
user_region = request_context["user_location"].get("region")
server_region = server.capabilities["region"]
if user_region == server_region:
multiplier *= 1.2 # Prefer same region
elif self._regions_in_same_continent(user_region, server_region):
multiplier *= 1.1 # Slight preference for same continent
# Consider request type
request_type = request_context.get("request_type", "general")
server_specializations = server.capabilities.get("specializations", [])
if request_type in server_specializations:
multiplier *= 1.3 # Strong preference for specialized backends
# Consider user tier
user_tier = request_context.get("user_tier", "standard")
if user_tier == "premium" and server.capabilities.get("premium_support", False):
multiplier *= 1.15
return multiplier
def _regions_in_same_continent(self, region1: str, region2: str) -> bool:
"""Check if two regions are in the same continent"""
continent_mapping = {
"NA": ["US", "CA", "MX"],
"EU": ["GB", "DE", "FR", "IT", "ES", "NL", "BE", "AT", "CH", "SE", "NO", "DK", "FI"],
"APAC": ["JP", "KR", "SG", "AU", "IN", "TH", "MY", "ID", "PH", "VN"],
"LATAM": ["BR", "MX", "AR", "CL", "CO", "PE", "VE"],
}
for _continent, regions in continent_mapping.items():
if region1 in regions and region2 in regions:
return True
return False
async def record_request(
self, server_id: str, response_time_ms: float, success: bool, timestamp: datetime | None = None
):
"""Record request metrics"""
if timestamp is None:
timestamp = datetime.now(timezone.utc)
# Update backend server metrics
if server_id in self.backends:
server = self.backends[server_id]
server.request_count += 1
server.response_time_ms = server.response_time_ms * 0.9 + response_time_ms * 0.1 # EMA
if not success:
server.error_count += 1
# Record in history
request_record = {
"timestamp": timestamp,
"server_id": server_id,
"response_time_ms": response_time_ms,
"success": success,
}
self.request_history.append(request_record)
# Keep only last 10000 records
if len(self.request_history) > 10000:
self.request_history = self.request_history[-10000:]
# Update predictive scaler
await self.predictive_scaler.record_traffic(
timestamp, 1, response_time_ms, 0.0 if success else 1.0 # One request # Error rate
)
async def update_backend_health(
self, server_id: str, health_status: HealthStatus, cpu_usage: float, memory_usage: float, current_connections: int
):
"""Update backend health metrics"""
if server_id in self.backends:
server = self.backends[server_id]
server.health_status = health_status
server.cpu_usage = cpu_usage
server.memory_usage = memory_usage
server.current_connections = current_connections
server.last_health_check = datetime.now(timezone.utc)
async def get_load_balancing_metrics(self) -> dict[str, Any]:
"""Get comprehensive load balancing metrics"""
try:
total_requests = sum(server.request_count for server in self.backends.values())
total_errors = sum(server.error_count for server in self.backends.values())
total_connections = sum(server.current_connections for server in self.backends.values())
error_rate = (total_errors / total_requests) if total_requests > 0 else 0.0
# Calculate average response time
avg_response_time = 0.0
if self.backends:
avg_response_time = statistics.mean([server.response_time_ms for server in self.backends.values()])
# Backend distribution
backend_distribution = {}
for server_id, server in self.backends.items():
backend_distribution[server_id] = {
"requests": server.request_count,
"errors": server.error_count,
"connections": server.current_connections,
"response_time_ms": server.response_time_ms,
"cpu_usage": server.cpu_usage,
"memory_usage": server.memory_usage,
"health_status": server.health_status.value,
"weight": server.weight,
}
# Get scaling recommendation
scaling_recommendation = await self.predictive_scaler.get_scaling_recommendation(
len(self.backends), sum(server.max_connections for server in self.backends.values())
)
return {
"total_backends": len(self.backends),
"healthy_backends": len([s for s in self.backends.values() if s.health_status == HealthStatus.HEALTHY]),
"total_requests": total_requests,
"total_errors": total_errors,
"error_rate": error_rate,
"average_response_time_ms": avg_response_time,
"total_connections": total_connections,
"algorithm": self.algorithm.value,
"backend_distribution": backend_distribution,
"scaling_recommendation": scaling_recommendation,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
except Exception as e:
self.logger.error(f"Metrics retrieval failed: {e}")
return {"error": str(e)}
async def set_algorithm(self, algorithm: LoadBalancingAlgorithm):
"""Set load balancing algorithm"""
self.algorithm = algorithm
self.logger.info(f"Load balancing algorithm changed to: {algorithm.value}")
async def auto_scale(self, min_servers: int = 1, max_servers: int = 10) -> dict[str, Any]:
"""Perform auto-scaling based on predictions"""
try:
# Get scaling recommendation
recommendation = await self.predictive_scaler.get_scaling_recommendation(
len(self.backends), sum(server.max_connections for server in self.backends.values())
)
action = recommendation["scaling_action"]
target_servers = recommendation["recommended_servers"]
# Apply scaling limits
target_servers = max(min_servers, min(max_servers, target_servers))
scaling_result = {
"action": action,
"current_servers": len(self.backends),
"target_servers": target_servers,
"confidence": recommendation.get("confidence_score", 0.0),
"reason": recommendation.get("reason", ""),
"timestamp": datetime.now(timezone.utc).isoformat(),
}
# In production, implement actual scaling logic here
# For now, just return the recommendation
self.logger.info(f"Auto-scaling recommendation: {action} to {target_servers} servers")
return scaling_result
except Exception as e:
self.logger.error(f"Auto-scaling failed: {e}")
return {"error": str(e)}
# Global load balancer instance
advanced_load_balancer = None
async def get_advanced_load_balancer() -> AdvancedLoadBalancer:
"""Get or create global advanced load balancer"""
global advanced_load_balancer
if advanced_load_balancer is None:
advanced_load_balancer = AdvancedLoadBalancer()
# Add default backends
default_backends = [
BackendServer(
server_id="backend_1", host="10.0.1.10", port=8080, weight=1.0, max_connections=1000, region="us_east"
),
BackendServer(
server_id="backend_2", host="10.0.1.11", port=8080, weight=1.0, max_connections=1000, region="us_east"
),
BackendServer(
server_id="backend_3", host="10.0.1.12", port=8080, weight=0.8, max_connections=800, region="eu_west"
),
]
for backend in default_backends:
await advanced_load_balancer.add_backend(backend)
return advanced_load_balancer

View File

@@ -1,773 +0,0 @@
"""
Enterprise Security Framework - Phase 6.2 Implementation
Zero-trust architecture with HSM integration and advanced security controls
"""
import secrets
from dataclasses import dataclass, field
from datetime import datetime, timezone, timedelta
from enum import StrEnum
from typing import Any
from uuid import uuid4
import cryptography
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from aitbc import get_logger
logger = get_logger(__name__)
class SecurityLevel(StrEnum):
"""Security levels for enterprise data"""
PUBLIC = "public"
INTERNAL = "internal"
CONFIDENTIAL = "confidential"
RESTRICTED = "restricted"
TOP_SECRET = "top_secret"
class EncryptionAlgorithm(StrEnum):
"""Encryption algorithms"""
AES_256_GCM = "aes_256_gcm"
CHACHA20_POLY1305 = "chacha20_polyy1305"
AES_256_CBC = "aes_256_cbc"
QUANTUM_RESISTANT = "quantum_resistant"
class ThreatLevel(StrEnum):
"""Threat levels for security monitoring"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
@dataclass
class SecurityPolicy:
"""Security policy configuration"""
policy_id: str
name: str
security_level: SecurityLevel
encryption_algorithm: EncryptionAlgorithm
key_rotation_interval: timedelta
access_control_requirements: list[str]
audit_requirements: list[str]
retention_period: timedelta
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class SecurityEvent:
"""Security event for monitoring"""
event_id: str
event_type: str
severity: ThreatLevel
source: str
timestamp: datetime
user_id: str | None
resource_id: str | None
details: dict[str, Any]
resolved: bool = False
resolution_notes: str | None = None
class HSMManager:
"""Hardware Security Module manager for enterprise key management"""
def __init__(self, hsm_config: dict[str, Any]):
self.hsm_config = hsm_config
self.backend = default_backend()
self.key_store = {} # In production, use actual HSM
self.logger = get_logger("hsm_manager")
async def initialize(self) -> bool:
"""Initialize HSM connection"""
try:
# In production, initialize actual HSM connection
# For now, simulate HSM initialization
self.logger.info("HSM manager initialized")
return True
except Exception as e:
self.logger.error(f"HSM initialization failed: {e}")
return False
async def generate_key(self, key_id: str, algorithm: EncryptionAlgorithm, key_size: int = 256) -> dict[str, Any]:
"""Generate encryption key in HSM"""
try:
if algorithm == EncryptionAlgorithm.AES_256_GCM:
key = secrets.token_bytes(32) # 256 bits
iv = secrets.token_bytes(12) # 96 bits for GCM
elif algorithm == EncryptionAlgorithm.CHACHA20_POLY1305:
key = secrets.token_bytes(32) # 256 bits
nonce = secrets.token_bytes(12) # 96 bits
elif algorithm == EncryptionAlgorithm.AES_256_CBC:
key = secrets.token_bytes(32) # 256 bits
iv = secrets.token_bytes(16) # 128 bits for CBC
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
# Store key in HSM (simulated)
key_data = {
"key_id": key_id,
"algorithm": algorithm.value,
"key": key,
"iv": iv if algorithm in [EncryptionAlgorithm.AES_256_GCM, EncryptionAlgorithm.AES_256_CBC] else None,
"nonce": nonce if algorithm == EncryptionAlgorithm.CHACHA20_POLY1305 else None,
"created_at": datetime.now(timezone.utc),
"key_size": key_size,
}
self.key_store[key_id] = key_data
self.logger.info(f"Key generated in HSM: {key_id}")
return key_data
except Exception as e:
self.logger.error(f"Key generation failed: {e}")
raise
async def get_key(self, key_id: str) -> dict[str, Any] | None:
"""Get key from HSM"""
return self.key_store.get(key_id)
async def rotate_key(self, key_id: str) -> dict[str, Any]:
"""Rotate encryption key"""
old_key = self.key_store.get(key_id)
if not old_key:
raise ValueError(f"Key not found: {key_id}")
# Generate new key
new_key = await self.generate_key(f"{key_id}_new", EncryptionAlgorithm(old_key["algorithm"]), old_key["key_size"])
# Update key with rotation timestamp
new_key["rotated_from"] = key_id
new_key["rotation_timestamp"] = datetime.now(timezone.utc)
return new_key
async def delete_key(self, key_id: str) -> bool:
"""Delete key from HSM"""
if key_id in self.key_store:
del self.key_store[key_id]
self.logger.info(f"Key deleted from HSM: {key_id}")
return True
return False
class EnterpriseEncryption:
"""Enterprise-grade encryption service"""
def __init__(self, hsm_manager: HSMManager):
self.hsm_manager = hsm_manager
self.backend = default_backend()
self.logger = get_logger("enterprise_encryption")
async def encrypt_data(
self, data: str | bytes, key_id: str, associated_data: bytes | None = None
) -> dict[str, Any]:
"""Encrypt data using enterprise-grade encryption"""
try:
# Get key from HSM
key_data = await self.hsm_manager.get_key(key_id)
if not key_data:
raise ValueError(f"Key not found: {key_id}")
# Convert data to bytes if needed
if isinstance(data, str):
data = data.encode("utf-8")
algorithm = EncryptionAlgorithm(key_data["algorithm"])
if algorithm == EncryptionAlgorithm.AES_256_GCM:
return await self._encrypt_aes_gcm(data, key_data, associated_data)
elif algorithm == EncryptionAlgorithm.CHACHA20_POLY1305:
return await self._encrypt_chacha20(data, key_data, associated_data)
elif algorithm == EncryptionAlgorithm.AES_256_CBC:
return await self._encrypt_aes_cbc(data, key_data)
else:
raise ValueError(f"Unsupported encryption algorithm: {algorithm}")
except Exception as e:
self.logger.error(f"Encryption failed: {e}")
raise
async def _encrypt_aes_gcm(
self, data: bytes, key_data: dict[str, Any], associated_data: bytes | None = None
) -> dict[str, Any]:
"""Encrypt using AES-256-GCM"""
key = key_data["key"]
iv = key_data["iv"]
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=self.backend)
encryptor = cipher.encryptor()
# Add associated data if provided
if associated_data:
encryptor.authenticate_additional_data(associated_data)
# Encrypt data
ciphertext = encryptor.update(data) + encryptor.finalize()
return {
"ciphertext": ciphertext.hex(),
"iv": iv.hex(),
"tag": encryptor.tag.hex(),
"algorithm": "aes_256_gcm",
"key_id": key_data["key_id"],
}
async def _encrypt_chacha20(
self, data: bytes, key_data: dict[str, Any], associated_data: bytes | None = None
) -> dict[str, Any]:
"""Encrypt using ChaCha20-Poly1305"""
key = key_data["key"]
nonce = key_data["nonce"]
# Create cipher
cipher = Cipher(algorithms.ChaCha20(key, nonce), modes.Poly1305(b""), backend=self.backend)
encryptor = cipher.encryptor()
# Add associated data if provided
if associated_data:
encryptor.authenticate_additional_data(associated_data)
# Encrypt data
ciphertext = encryptor.update(data) + encryptor.finalize()
return {
"ciphertext": ciphertext.hex(),
"nonce": nonce.hex(),
"tag": encryptor.tag.hex(),
"algorithm": "chacha20_poly1305",
"key_id": key_data["key_id"],
}
async def _encrypt_aes_cbc(self, data: bytes, key_data: dict[str, Any]) -> dict[str, Any]:
"""Encrypt using AES-256-CBC"""
key = key_data["key"]
iv = key_data["iv"]
# Pad data to block size
padder = cryptography.hazmat.primitives.padding.PKCS7(128).padder()
padded_data = padder.update(data) + padder.finalize()
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=self.backend)
encryptor = cipher.encryptor()
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
return {"ciphertext": ciphertext.hex(), "iv": iv.hex(), "algorithm": "aes_256_cbc", "key_id": key_data["key_id"]}
async def decrypt_data(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes:
"""Decrypt encrypted data"""
try:
algorithm = encrypted_data["algorithm"]
if algorithm == "aes_256_gcm":
return await self._decrypt_aes_gcm(encrypted_data, associated_data)
elif algorithm == "chacha20_poly1305":
return await self._decrypt_chacha20(encrypted_data, associated_data)
elif algorithm == "aes_256_cbc":
return await self._decrypt_aes_cbc(encrypted_data)
else:
raise ValueError(f"Unsupported encryption algorithm: {algorithm}")
except Exception as e:
self.logger.error(f"Decryption failed: {e}")
raise
async def _decrypt_aes_gcm(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes:
"""Decrypt AES-256-GCM encrypted data"""
# Get key from HSM
key_data = await self.hsm_manager.get_key(encrypted_data["key_id"])
if not key_data:
raise ValueError(f"Key not found: {encrypted_data['key_id']}")
key = key_data["key"]
iv = bytes.fromhex(encrypted_data["iv"])
ciphertext = bytes.fromhex(encrypted_data["ciphertext"])
tag = bytes.fromhex(encrypted_data["tag"])
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=self.backend)
decryptor = cipher.decryptor()
# Add associated data if provided
if associated_data:
decryptor.authenticate_additional_data(associated_data)
# Decrypt data
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return plaintext
async def _decrypt_chacha20(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes:
"""Decrypt ChaCha20-Poly1305 encrypted data"""
# Get key from HSM
key_data = await self.hsm_manager.get_key(encrypted_data["key_id"])
if not key_data:
raise ValueError(f"Key not found: {encrypted_data['key_id']}")
key = key_data["key"]
nonce = bytes.fromhex(encrypted_data["nonce"])
ciphertext = bytes.fromhex(encrypted_data["ciphertext"])
tag = bytes.fromhex(encrypted_data["tag"])
# Create cipher
cipher = Cipher(algorithms.ChaCha20(key, nonce), modes.Poly1305(tag), backend=self.backend)
decryptor = cipher.decryptor()
# Add associated data if provided
if associated_data:
decryptor.authenticate_additional_data(associated_data)
# Decrypt data
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return plaintext
async def _decrypt_aes_cbc(self, encrypted_data: dict[str, Any]) -> bytes:
"""Decrypt AES-256-CBC encrypted data"""
# Get key from HSM
key_data = await self.hsm_manager.get_key(encrypted_data["key_id"])
if not key_data:
raise ValueError(f"Key not found: {encrypted_data['key_id']}")
key = key_data["key"]
iv = bytes.fromhex(encrypted_data["iv"])
ciphertext = bytes.fromhex(encrypted_data["ciphertext"])
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=self.backend)
decryptor = cipher.decryptor()
padded_plaintext = decryptor.update(ciphertext) + decryptor.finalize()
# Unpad data
unpadder = cryptography.hazmat.primitives.padding.PKCS7(128).unpadder()
plaintext = unpadder.update(padded_plaintext) + unpadder.finalize()
return plaintext
class ZeroTrustArchitecture:
"""Zero-trust security architecture implementation"""
def __init__(self, hsm_manager: HSMManager, encryption: EnterpriseEncryption):
self.hsm_manager = hsm_manager
self.encryption = encryption
self.trust_policies = {}
self.session_tokens = {}
self.logger = get_logger("zero_trust")
async def create_trust_policy(self, policy_id: str, policy_config: dict[str, Any]) -> bool:
"""Create zero-trust policy"""
try:
policy = SecurityPolicy(
policy_id=policy_id,
name=policy_config["name"],
security_level=SecurityLevel(policy_config["security_level"]),
encryption_algorithm=EncryptionAlgorithm(policy_config["encryption_algorithm"]),
key_rotation_interval=timedelta(days=policy_config.get("key_rotation_days", 90)),
access_control_requirements=policy_config.get("access_control_requirements", []),
audit_requirements=policy_config.get("audit_requirements", []),
retention_period=timedelta(days=policy_config.get("retention_days", 2555)), # 7 years
)
self.trust_policies[policy_id] = policy
# Generate encryption key for policy
await self.hsm_manager.generate_key(f"policy_{policy_id}", policy.encryption_algorithm)
self.logger.info(f"Zero-trust policy created: {policy_id}")
return True
except Exception as e:
self.logger.error(f"Failed to create trust policy: {e}")
return False
async def verify_trust(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> bool:
"""Verify zero-trust access request"""
try:
# Get applicable policy
policy_id = context.get("policy_id", "default")
policy = self.trust_policies.get(policy_id)
if not policy:
self.logger.warning(f"No policy found for {policy_id}")
return False
# Verify trust factors
trust_score = await self._calculate_trust_score(user_id, resource_id, action, context)
# Check if trust score meets policy requirements
min_trust_score = self._get_min_trust_score(policy.security_level)
is_trusted = trust_score >= min_trust_score
# Log trust decision
await self._log_trust_decision(user_id, resource_id, action, trust_score, is_trusted)
return is_trusted
except Exception as e:
self.logger.error(f"Trust verification failed: {e}")
return False
async def _calculate_trust_score(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> float:
"""Calculate trust score for access request"""
score = 0.0
# User authentication factor (40%)
auth_strength = context.get("auth_strength", "password")
if auth_strength == "mfa":
score += 0.4
elif auth_strength == "password":
score += 0.2
# Device trust factor (20%)
device_trust = context.get("device_trust", 0.5)
score += 0.2 * device_trust
# Location factor (15%)
location_trust = context.get("location_trust", 0.5)
score += 0.15 * location_trust
# Time factor (10%)
time_trust = context.get("time_trust", 0.5)
score += 0.1 * time_trust
# Behavioral factor (15%)
behavior_trust = context.get("behavior_trust", 0.5)
score += 0.15 * behavior_trust
return min(score, 1.0)
def _get_min_trust_score(self, security_level: SecurityLevel) -> float:
"""Get minimum trust score for security level"""
thresholds = {
SecurityLevel.PUBLIC: 0.0,
SecurityLevel.INTERNAL: 0.3,
SecurityLevel.CONFIDENTIAL: 0.6,
SecurityLevel.RESTRICTED: 0.8,
SecurityLevel.TOP_SECRET: 0.9,
}
return thresholds.get(security_level, 0.5)
async def _log_trust_decision(self, user_id: str, resource_id: str, action: str, trust_score: float, decision: bool):
"""Log trust decision for audit"""
SecurityEvent(
event_id=str(uuid4()),
event_type="trust_decision",
severity=ThreatLevel.LOW if decision else ThreatLevel.MEDIUM,
source="zero_trust",
timestamp=datetime.now(timezone.utc),
user_id=user_id,
resource_id=resource_id,
details={"action": action, "trust_score": trust_score, "decision": decision},
)
# In production, send to security monitoring system
self.logger.info(f"Trust decision: {user_id} -> {resource_id} = {decision} (score: {trust_score})")
class ThreatDetectionSystem:
"""Advanced threat detection and response system"""
def __init__(self):
self.threat_patterns = {}
self.active_threats = {}
self.response_actions = {}
self.logger = get_logger("threat_detection")
async def register_threat_pattern(self, pattern_id: str, pattern_config: dict[str, Any]):
"""Register threat detection pattern"""
self.threat_patterns[pattern_id] = {
"id": pattern_id,
"name": pattern_config["name"],
"description": pattern_config["description"],
"indicators": pattern_config["indicators"],
"severity": ThreatLevel(pattern_config["severity"]),
"response_actions": pattern_config.get("response_actions", []),
"threshold": pattern_config.get("threshold", 1.0),
}
self.logger.info(f"Threat pattern registered: {pattern_id}")
async def analyze_threat(self, event_data: dict[str, Any]) -> list[SecurityEvent]:
"""Analyze event for potential threats"""
detected_threats = []
for pattern_id, pattern in self.threat_patterns.items():
threat_score = await self._calculate_threat_score(event_data, pattern)
if threat_score >= pattern["threshold"]:
threat_event = SecurityEvent(
event_id=str(uuid4()),
event_type="threat_detected",
severity=pattern["severity"],
source="threat_detection",
timestamp=datetime.now(timezone.utc),
user_id=event_data.get("user_id"),
resource_id=event_data.get("resource_id"),
details={
"pattern_id": pattern_id,
"pattern_name": pattern["name"],
"threat_score": threat_score,
"indicators": event_data,
},
)
detected_threats.append(threat_event)
# Trigger response actions
await self._trigger_response_actions(pattern_id, threat_event)
return detected_threats
async def _calculate_threat_score(self, event_data: dict[str, Any], pattern: dict[str, Any]) -> float:
"""Calculate threat score for pattern"""
score = 0.0
indicators = pattern["indicators"]
for indicator, weight in indicators.items():
if indicator in event_data:
# Simple scoring - in production, use more sophisticated algorithms
indicator_score = 0.5 # Base score for presence
score += indicator_score * weight
return min(score, 1.0)
async def _trigger_response_actions(self, pattern_id: str, threat_event: SecurityEvent):
"""Trigger automated response actions"""
pattern = self.threat_patterns[pattern_id]
actions = pattern.get("response_actions", [])
for action in actions:
try:
await self._execute_response_action(action, threat_event)
except Exception as e:
self.logger.error(f"Response action failed: {action} - {e}")
async def _execute_response_action(self, action: str, threat_event: SecurityEvent):
"""Execute specific response action"""
if action == "block_user":
await self._block_user(threat_event.user_id)
elif action == "isolate_resource":
await self._isolate_resource(threat_event.resource_id)
elif action == "escalate_to_admin":
await self._escalate_to_admin(threat_event)
elif action == "require_mfa":
await self._require_mfa(threat_event.user_id)
self.logger.info(f"Response action executed: {action}")
async def _block_user(self, user_id: str):
"""Block user account"""
# In production, implement actual user blocking
self.logger.warning(f"User blocked due to threat: {user_id}")
async def _isolate_resource(self, resource_id: str):
"""Isolate compromised resource"""
# In production, implement actual resource isolation
self.logger.warning(f"Resource isolated due to threat: {resource_id}")
async def _escalate_to_admin(self, threat_event: SecurityEvent):
"""Escalate threat to security administrators"""
# In production, implement actual escalation
self.logger.error(f"Threat escalated to admin: {threat_event.event_id}")
async def _require_mfa(self, user_id: str):
"""Require multi-factor authentication"""
# In production, implement MFA requirement
self.logger.warning(f"MFA required for user: {user_id}")
class EnterpriseSecurityFramework:
"""Main enterprise security framework"""
def __init__(self, hsm_config: dict[str, Any]):
self.hsm_manager = HSMManager(hsm_config)
self.encryption = EnterpriseEncryption(self.hsm_manager)
self.zero_trust = ZeroTrustArchitecture(self.hsm_manager, self.encryption)
self.threat_detection = ThreatDetectionSystem()
self.logger = get_logger("enterprise_security")
async def initialize(self) -> bool:
"""Initialize security framework"""
try:
# Initialize HSM
if not await self.hsm_manager.initialize():
return False
# Register default threat patterns
await self._register_default_threat_patterns()
# Create default trust policies
await self._create_default_policies()
self.logger.info("Enterprise security framework initialized")
return True
except Exception as e:
self.logger.error(f"Security framework initialization failed: {e}")
return False
async def _register_default_threat_patterns(self):
"""Register default threat detection patterns"""
patterns = [
{
"name": "Brute Force Attack",
"description": "Multiple failed login attempts",
"indicators": {"failed_login_attempts": 0.8, "short_time_interval": 0.6},
"severity": "high",
"threshold": 0.7,
"response_actions": ["block_user", "require_mfa"],
},
{
"name": "Suspicious Access Pattern",
"description": "Unusual access patterns",
"indicators": {"unusual_location": 0.7, "unusual_time": 0.5, "high_frequency": 0.6},
"severity": "medium",
"threshold": 0.6,
"response_actions": ["require_mfa", "escalate_to_admin"],
},
{
"name": "Data Exfiltration",
"description": "Large data transfer patterns",
"indicators": {"large_data_transfer": 0.9, "unusual_destination": 0.7},
"severity": "critical",
"threshold": 0.8,
"response_actions": ["block_user", "isolate_resource", "escalate_to_admin"],
},
]
for i, pattern in enumerate(patterns):
await self.threat_detection.register_threat_pattern(f"default_{i}", pattern)
async def _create_default_policies(self):
"""Create default trust policies"""
policies = [
{
"name": "Enterprise Data Policy",
"security_level": "confidential",
"encryption_algorithm": "aes_256_gcm",
"key_rotation_days": 90,
"access_control_requirements": ["mfa", "device_trust"],
"audit_requirements": ["full_audit", "real_time_monitoring"],
"retention_days": 2555,
},
{
"name": "Public API Policy",
"security_level": "public",
"encryption_algorithm": "aes_256_gcm",
"key_rotation_days": 180,
"access_control_requirements": ["api_key"],
"audit_requirements": ["api_access_log"],
"retention_days": 365,
},
]
for i, policy in enumerate(policies):
await self.zero_trust.create_trust_policy(f"default_{i}", policy)
async def encrypt_sensitive_data(self, data: str | bytes, security_level: SecurityLevel) -> dict[str, Any]:
"""Encrypt sensitive data with appropriate security level"""
# Get policy for security level
policy_id = f"default_{0 if security_level == SecurityLevel.PUBLIC else 1}"
policy = self.zero_trust.trust_policies.get(policy_id)
if not policy:
raise ValueError(f"No policy found for security level: {security_level}")
key_id = f"policy_{policy_id}"
return await self.encryption.encrypt_data(data, key_id)
async def verify_access(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> bool:
"""Verify access using zero-trust architecture"""
return await self.zero_trust.verify_trust(user_id, resource_id, action, context)
async def analyze_security_event(self, event_data: dict[str, Any]) -> list[SecurityEvent]:
"""Analyze security event for threats"""
return await self.threat_detection.analyze_threat(event_data)
async def rotate_encryption_keys(self, policy_id: str | None = None) -> dict[str, Any]:
"""Rotate encryption keys"""
if policy_id:
# Rotate specific policy key
old_key_id = f"policy_{policy_id}"
new_key = await self.hsm_manager.rotate_key(old_key_id)
return {"rotated_key": new_key}
else:
# Rotate all keys
rotated_keys = {}
for policy_id in self.zero_trust.trust_policies.keys():
old_key_id = f"policy_{policy_id}"
new_key = await self.hsm_manager.rotate_key(old_key_id)
rotated_keys[policy_id] = new_key
return {"rotated_keys": rotated_keys}
# Global security framework instance
security_framework = None
async def get_security_framework() -> EnterpriseSecurityFramework:
"""Get or create global security framework"""
global security_framework
if security_framework is None:
hsm_config = {"provider": "software", "endpoint": "localhost:8080"} # In production, use actual HSM
security_framework = EnterpriseSecurityFramework(hsm_config)
await security_framework.initialize()
return security_framework
# Alias for CLI compatibility
EnterpriseSecurityManager = EnterpriseSecurityFramework

View File

@@ -0,0 +1,17 @@
"""
Trading & Marketplace Bounded Context
Provides trading, marketplace optimization, bid strategy, and dynamic pricing services.
"""
from .bid_strategy import BidStrategyEngine
from .dynamic_pricing import DynamicPricingEngine
from .gpu_optimizer import MarketplaceGPUOptimizer
from .trading import MatchingEngine, NegotiationSystem
__all__ = [
"BidStrategyEngine",
"DynamicPricingEngine",
"MarketplaceGPUOptimizer",
"MatchingEngine",
"NegotiationSystem",
]

View File

@@ -13,7 +13,7 @@ logger = get_logger(__name__)
from sqlmodel import Session, or_, select from sqlmodel import Session, or_, select
from ..domain.trading import ( from ...domain.trading import (
NegotiationStatus, NegotiationStatus,
SettlementType, SettlementType,
TradeAgreement, TradeAgreement,

View File

@@ -25,6 +25,26 @@
- Maintained backward compatibility with lazy-loading pattern - Maintained backward compatibility with lazy-loading pattern
- Import tests verified successfully - Import tests verified successfully
- Old monolithic files removed - Old monolithic files removed
- ✅ Phase 2 Complete: Enterprise Integration bounded context decomposed
- Created app/services/enterprise_integration/ package with 4 modules
- Migrated enterprise_integration.py (1127 lines) and 3 other enterprise files
- Updated imports within package (api_gateway.py excluded due to missing dependencies)
- Import tests verified successfully
- Old monolithic files removed
- ✅ Phase 3 Complete: Trading & Marketplace bounded context decomposed
- Created app/services/trading_marketplace/ package with 5 modules
- Migrated trading_service.py (36K) and 4 other trading files
- Updated imports across coordinator-api (routers/trading.py, routers/dynamic_pricing.py)
- amm.py excluded from exports due to missing dependencies
- Import tests verified successfully
- Old monolithic files removed
- ✅ Phase 4 Complete: AI & Analytics bounded context decomposed
- Created app/services/ai_analytics/ package with 5 modules
- Migrated analytics_service.py (41K) and 4 other AI files
- Updated imports across coordinator-api (routers/analytics.py, routers/adaptive_learning_health.py)
- adaptive_learning.py, surveillance.py, trading_engine.py excluded due to missing dependencies
- Import tests verified successfully
- Old monolithic files removed
2. **Production Code Using print()** (HIGH IMPACT) 2. **Production Code Using print()** (HIGH IMPACT)
- 925 print() statements in production code - 925 print() statements in production code
@@ -128,9 +148,17 @@
- test_validation_properties.py: 20/20 passing - test_validation_properties.py: 20/20 passing
- test_staking_service.py: 22/22 passing - test_staking_service.py: 22/22 passing
- Coverage threshold set to 50% in pyproject.toml - Coverage threshold set to 50% in pyproject.toml
- Current coverage: 11% (4623 statements, 4122 missed) - BELOW 50% threshold - Current coverage: 19% (4623 statements, 3745 missed) - BELOW 50% threshold
- Well-covered modules: constants.py (100%), exceptions.py (100%), validation.py (85%), crypto/crypto.py (52%) - Added 137 new tests across 6 modules:
- test_middleware.py: 11 tests (middleware modules: 50-100% coverage)
- test_utils.py: 47 tests (utils modules: 100% coverage when run standalone)
- test_config.py: 14 tests (config.py: 100% coverage)
- test_decorators.py: 21 tests (decorators.py: 99% coverage)
- test_health_checks.py: 16 tests (health_checks.py: 80% coverage)
- test_metrics.py: 28 tests (metrics.py: 100% coverage)
- Well-covered modules: constants.py (100%), exceptions.py (100%), validation.py (85%), crypto/crypto.py (52%), config.py (100%), decorators.py (99%), health_checks.py (80%), metrics.py (100%)
- Needs improvement: Most modules at 0-30% coverage - Needs improvement: Most modules at 0-30% coverage
- Note: Utils modules (paths, env, json_utils) achieve 100% when run standalone but not counted in overall coverage due to import patterns
#### MEDIUM (Long-term, 1-3 months) #### MEDIUM (Long-term, 1-3 months)

154
tests/test_config.py Normal file
View File

@@ -0,0 +1,154 @@
"""
Tests for AITBC configuration classes
"""
import os
import pytest
from pathlib import Path
from unittest.mock import patch, Mock
from aitbc.config import BaseAITBCConfig, AITBCConfig
class TestBaseAITBCConfig:
"""Tests for BaseAITBCConfig"""
def test_default_values(self):
"""Test BaseAITBCConfig with default values"""
config = BaseAITBCConfig()
assert config.app_name == "AITBC Application"
assert config.app_version == "1.0.0"
assert config.environment == "development"
assert config.debug is False
assert config.log_level == "INFO"
def test_custom_values(self):
"""Test BaseAITBCConfig with custom values"""
config = BaseAITBCConfig(
app_name="Custom App",
app_version="2.0.0",
environment="production",
debug=True,
log_level="DEBUG"
)
assert config.app_name == "Custom App"
assert config.app_version == "2.0.0"
assert config.environment == "production"
assert config.debug is True
assert config.log_level == "DEBUG"
def test_data_dir_default(self):
"""Test default data directory is a Path"""
config = BaseAITBCConfig()
assert isinstance(config.data_dir, Path)
def test_config_dir_default(self):
"""Test default config directory is a Path"""
config = BaseAITBCConfig()
assert isinstance(config.config_dir, Path)
def test_log_dir_default(self):
"""Test default log directory is a Path"""
config = BaseAITBCConfig()
assert isinstance(config.log_dir, Path)
def test_log_format_default(self):
"""Test default log format"""
config = BaseAITBCConfig()
assert "%(asctime)s" in config.log_format
assert "%(name)s" in config.log_format
assert "%(levelname)s" in config.log_format
class TestAITBCConfig:
"""Tests for AITBCConfig"""
def test_default_values(self):
"""Test AITBCConfig with default values"""
config = AITBCConfig()
assert config.host == "0.0.0.0"
assert config.port == 8000
assert config.workers == 1
assert config.database_url is None
assert config.database_pool_size == 10
assert config.redis_url is None
assert config.redis_max_connections == 10
assert config.redis_timeout == 5
assert config.secret_key is None
assert config.jwt_secret is None
assert config.jwt_algorithm == "HS256"
assert config.jwt_expiration_hours == 24
assert config.request_timeout == 30
assert config.max_request_size == 10 * 1024 * 1024
def test_custom_server_settings(self):
"""Test AITBCConfig with custom server settings"""
config = AITBCConfig(
host="127.0.0.1",
port=9000,
workers=4
)
assert config.host == "127.0.0.1"
assert config.port == 9000
assert config.workers == 4
def test_custom_database_settings(self):
"""Test AITBCConfig with custom database settings"""
config = AITBCConfig(
database_url="postgresql://localhost/test",
database_pool_size=20
)
assert config.database_url == "postgresql://localhost/test"
assert config.database_pool_size == 20
def test_custom_redis_settings(self):
"""Test AITBCConfig with custom redis settings"""
config = AITBCConfig(
redis_url="redis://localhost:6379",
redis_max_connections=50,
redis_timeout=10
)
assert config.redis_url == "redis://localhost:6379"
assert config.redis_max_connections == 50
assert config.redis_timeout == 10
def test_custom_security_settings(self):
"""Test AITBCConfig with custom security settings"""
config = AITBCConfig(
secret_key="test-secret-key",
jwt_secret="test-jwt-secret",
jwt_algorithm="RS256",
jwt_expiration_hours=48
)
assert config.secret_key == "test-secret-key"
assert config.jwt_secret == "test-jwt-secret"
assert config.jwt_algorithm == "RS256"
assert config.jwt_expiration_hours == 48
def test_custom_performance_settings(self):
"""Test AITBCConfig with custom performance settings"""
config = AITBCConfig(
request_timeout=60,
max_request_size=20 * 1024 * 1024
)
assert config.request_timeout == 60
assert config.max_request_size == 20 * 1024 * 1024
def test_inherits_base_config(self):
"""Test AITBCConfig inherits from BaseAITBCConfig"""
config = AITBCConfig(
app_name="Test App",
environment="staging"
)
assert config.app_name == "Test App"
assert config.environment == "staging"
assert config.host == "0.0.0.0" # AITBCConfig default
assert config.port == 8000 # AITBCConfig default
@patch('aitbc.config.logger')
def test_init_logs_configuration(self, mock_logger):
"""Test __init__ logs configuration"""
config = AITBCConfig(host="localhost", port=9000)
mock_logger.info.assert_called_once()
assert "localhost:9000" in mock_logger.info.call_args[0][0]
mock_logger.debug.assert_called_once()

303
tests/test_decorators.py Normal file
View File

@@ -0,0 +1,303 @@
"""
Tests for AITBC decorators
"""
import time
import pytest
from unittest.mock import patch
from aitbc.decorators import (
retry,
timing,
cache_result,
validate_args,
handle_exceptions,
async_timing,
)
from aitbc.exceptions import AITBCError
class TestRetry:
"""Tests for retry decorator"""
def test_retry_succeeds_on_first_attempt(self):
"""Test retry when function succeeds on first attempt"""
@retry(max_attempts=3)
def test_func():
return "success"
result = test_func()
assert result == "success"
def test_retry_succeeds_after_failure(self):
"""Test retry when function succeeds after initial failure"""
attempts = [0]
@retry(max_attempts=3, delay=0.01)
def test_func():
attempts[0] += 1
if attempts[0] < 2:
raise ValueError("fail")
return "success"
result = test_func()
assert result == "success"
assert attempts[0] == 2
def test_retry_exhausts_attempts(self):
"""Test retry when function fails after all attempts"""
@retry(max_attempts=2, delay=0.01)
def test_func():
raise ValueError("fail")
with pytest.raises(ValueError):
test_func()
def test_retry_with_specific_exception(self):
"""Test retry only catches specified exceptions"""
@retry(max_attempts=2, delay=0.01, exceptions=(ValueError,))
def test_func():
raise TypeError("fail")
with pytest.raises(TypeError):
test_func()
def test_retry_with_backoff(self):
"""Test retry with exponential backoff"""
attempts = [0]
@retry(max_attempts=3, delay=0.01, backoff=2.0)
def test_func():
attempts[0] += 1
raise ValueError("fail")
start_time = time.time()
with pytest.raises(ValueError):
test_func()
elapsed = time.time() - start_time
# Should have delays: 0.01 + 0.02 = 0.03 seconds minimum
assert elapsed >= 0.03
def test_retry_with_on_failure_callback(self):
"""Test retry with on_failure callback"""
callback_called = [False]
def on_fail(e):
callback_called[0] = True
@retry(max_attempts=2, delay=0.01, on_failure=on_fail)
def test_func():
raise ValueError("fail")
with pytest.raises(ValueError):
test_func()
assert callback_called[0] is True
class TestTiming:
"""Tests for timing decorator"""
@patch('aitbc.decorators.logger')
def test_timing_logs_execution_time(self, mock_logger):
"""Test timing decorator logs execution time"""
@timing
def test_func():
time.sleep(0.01)
return "result"
result = test_func()
assert result == "result"
mock_logger.info.assert_called_once()
assert "executed in" in mock_logger.info.call_args[0][0]
@patch('aitbc.decorators.logger')
def test_timing_preserves_function_name(self, mock_logger):
"""Test timing decorator preserves function name"""
@timing
def my_function():
return "result"
assert my_function.__name__ == "my_function"
class TestCacheResult:
"""Tests for cache_result decorator"""
def test_cache_result_caches_value(self):
"""Test cache_result caches function return value"""
call_count = [0]
@cache_result(ttl=60)
def test_func(x):
call_count[0] += 1
return x * 2
result1 = test_func(5)
result2 = test_func(5)
assert result1 == 10
assert result2 == 10
assert call_count[0] == 1 # Only called once due to cache
def test_cache_result_different_args(self):
"""Test cache_result with different arguments"""
call_count = [0]
@cache_result(ttl=60)
def test_func(x):
call_count[0] += 1
return x * 2
test_func(5)
test_func(10)
assert call_count[0] == 2 # Called twice for different args
def test_cache_result_ttl_expires(self):
"""Test cache_result TTL expires"""
call_count = [0]
@cache_result(ttl=0.1) # 100ms TTL
def test_func(x):
call_count[0] += 1
return x * 2
test_func(5)
time.sleep(0.15) # Wait for TTL to expire
test_func(5)
assert call_count[0] == 2 # Called again after TTL expired
def test_cache_result_with_kwargs(self):
"""Test cache_result with keyword arguments"""
call_count = [0]
@cache_result(ttl=60)
def test_func(x, y=10):
call_count[0] += 1
return x + y
test_func(5, y=10)
test_func(5, y=10)
assert call_count[0] == 1 # Cached
class TestValidateArgs:
"""Tests for validate_args decorator"""
def test_validate_args_passes_valid(self):
"""Test validate_args passes when validators succeed"""
def validator(x):
if x < 0:
raise ValueError("Must be positive")
@validate_args(validator)
def test_func(x):
return x * 2
result = test_func(5)
assert result == 10
def test_validate_args_fails_invalid(self):
"""Test validate_args fails when validators raise error"""
def validator(x):
if x < 0:
raise ValueError("Must be positive")
@validate_args(validator)
def test_func(x):
return x * 2
with pytest.raises(ValueError):
test_func(-5)
def test_validate_args_multiple_validators(self):
"""Test validate_args with multiple validators"""
def validator1(x):
if x < 0:
raise ValueError("Must be positive")
def validator2(x):
if x > 100:
raise ValueError("Must be <= 100")
@validate_args(validator1, validator2)
def test_func(x):
return x * 2
with pytest.raises(ValueError):
test_func(150)
class TestHandleExceptions:
"""Tests for handle_exceptions decorator"""
@patch('aitbc.decorators.logger')
def test_handle_exceptions_returns_default(self, mock_logger):
"""Test handle_exceptions returns default on exception"""
@handle_exceptions(default_return="error")
def test_func():
raise ValueError("fail")
result = test_func()
assert result == "error"
mock_logger.error.assert_called_once()
@patch('aitbc.decorators.logger')
def test_handle_exceptions_no_logging(self, mock_logger):
"""Test handle_exceptions with logging disabled"""
@handle_exceptions(default_return="error", log_errors=False)
def test_func():
raise ValueError("fail")
result = test_func()
assert result == "error"
mock_logger.error.assert_not_called()
def test_handle_exceptions_raises_on_specified(self):
"""Test handle_exceptions still raises specified exceptions"""
@handle_exceptions(default_return="error", raise_on=(ValueError,))
def test_func():
raise ValueError("fail")
with pytest.raises(ValueError):
test_func()
def test_handle_exceptions_passes_on_success(self):
"""Test handle_exceptions passes through successful return"""
@handle_exceptions(default_return="error")
def test_func():
return "success"
result = test_func()
assert result == "success"
class TestAsyncTiming:
"""Tests for async_timing decorator"""
@pytest.mark.asyncio
@patch('aitbc.decorators.logger')
async def test_async_timing_logs_execution_time(self, mock_logger):
"""Test async_timing decorator logs execution time"""
@async_timing
async def test_func():
await asyncio.sleep(0.01)
return "result"
import asyncio
result = await test_func()
assert result == "result"
mock_logger.info.assert_called_once()
assert "executed in" in mock_logger.info.call_args[0][0]
@pytest.mark.asyncio
async def test_async_timing_preserves_function_name(self):
"""Test async_timing decorator preserves function name"""
@async_timing
async def my_function():
return "result"
assert my_function.__name__ == "my_function"

218
tests/test_health_checks.py Normal file
View File

@@ -0,0 +1,218 @@
"""
Tests for health check utilities
"""
import pytest
from unittest.mock import patch, Mock
from datetime import datetime
from aitbc.health_checks import (
HealthStatus,
HealthCheck,
HealthChecker,
create_basic_health_check,
)
class TestHealthStatus:
"""Tests for HealthStatus enum"""
def test_health_status_values(self):
"""Test HealthStatus enum values"""
assert HealthStatus.HEALTHY.value == "healthy"
assert HealthStatus.DEGRADED.value == "degraded"
assert HealthStatus.UNHEALTHY.value == "unhealthy"
class TestHealthCheck:
"""Tests for HealthCheck dataclass"""
def test_health_check_creation(self):
"""Test HealthCheck dataclass creation"""
check = HealthCheck(
service="test-service",
status=HealthStatus.HEALTHY,
message="All good",
timestamp=datetime.now(),
details={"key": "value"}
)
assert check.service == "test-service"
assert check.status == HealthStatus.HEALTHY
assert check.message == "All good"
assert check.details == {"key": "value"}
def test_health_check_without_details(self):
"""Test HealthCheck without optional details"""
check = HealthCheck(
service="test-service",
status=HealthStatus.HEALTHY,
message="All good",
timestamp=datetime.now()
)
assert check.details is None
class TestHealthChecker:
"""Tests for HealthChecker"""
def test_health_checker_initialization(self):
"""Test HealthChecker initialization"""
checker = HealthChecker("test-service")
assert checker.service_name == "test-service"
assert checker._checks == {}
assert checker._last_check is None
def test_register_check(self):
"""Test registering a health check"""
checker = HealthChecker("test-service")
def mock_check():
return HealthStatus.HEALTHY, "OK", {}
checker.register_check("memory", mock_check)
assert "memory" in checker._checks
assert checker._checks["memory"] == mock_check
@patch('aitbc.health_checks.logger')
def test_register_check_logs(self, mock_logger):
"""Test register_check logs registration"""
checker = HealthChecker("test-service")
def mock_check():
return HealthStatus.HEALTHY, "OK", {}
checker.register_check("memory", mock_check)
mock_logger.info.assert_called_once()
assert "memory" in mock_logger.info.call_args[0][0]
def test_run_checks_all_healthy(self):
"""Test run_checks when all checks pass"""
checker = HealthChecker("test-service")
def mock_check():
return HealthStatus.HEALTHY, "OK", {}
checker.register_check("check1", mock_check)
checker.register_check("check2", mock_check)
result = checker.run_checks()
assert result.service == "test-service"
assert result.status == HealthStatus.HEALTHY
assert result.message == "All health checks passed"
assert result.details is not None
assert len(result.details) == 2
def test_run_checks_one_degraded(self):
"""Test run_checks with one degraded check"""
checker = HealthChecker("test-service")
def healthy_check():
return HealthStatus.HEALTHY, "OK", {}
def degraded_check():
return HealthStatus.DEGRADED, "Warning", {}
checker.register_check("healthy", healthy_check)
checker.register_check("degraded", degraded_check)
result = checker.run_checks()
assert result.status == HealthStatus.DEGRADED
assert "degraded" in result.message
def test_run_checks_one_unhealthy(self):
"""Test run_checks with one unhealthy check"""
checker = HealthChecker("test-service")
def healthy_check():
return HealthStatus.HEALTHY, "OK", {}
def unhealthy_check():
return HealthStatus.UNHEALTHY, "Error", {}
checker.register_check("healthy", healthy_check)
checker.register_check("unhealthy", unhealthy_check)
result = checker.run_checks()
assert result.status == HealthStatus.UNHEALTHY
assert "unhealthy" in result.message
@patch('aitbc.health_checks.logger')
def test_run_checks_with_exception(self, mock_logger):
"""Test run_checks handles exceptions in checks"""
checker = HealthChecker("test-service")
def failing_check():
raise ValueError("Check failed")
checker.register_check("failing", failing_check)
result = checker.run_checks()
assert result.status == HealthStatus.UNHEALTHY
assert "failing" in result.message
mock_logger.error.assert_called_once()
def test_get_last_check_before_run(self):
"""Test get_last_check returns None before any check run"""
checker = HealthChecker("test-service")
assert checker.get_last_check() is None
def test_get_last_check_after_run(self):
"""Test get_last_check returns last check result"""
checker = HealthChecker("test-service")
def mock_check():
return HealthStatus.HEALTHY, "OK", {}
checker.register_check("check1", mock_check)
checker.run_checks()
last_check = checker.get_last_check()
assert last_check is not None
assert last_check.service == "test-service"
def test_get_health_dict(self):
"""Test get_health_dict returns dictionary representation"""
checker = HealthChecker("test-service")
def mock_check():
return HealthStatus.HEALTHY, "OK", {"key": "value"}
checker.register_check("check1", mock_check)
health_dict = checker.get_health_dict()
assert isinstance(health_dict, dict)
assert "service" in health_dict
assert "status" in health_dict
assert "message" in health_dict
assert "timestamp" in health_dict
assert health_dict["service"] == "test-service"
class TestCreateBasicHealthCheck:
"""Tests for create_basic_health_check"""
def test_create_basic_health_check(self):
"""Test create_basic_health_check returns HealthChecker"""
checker = create_basic_health_check("test-service")
assert isinstance(checker, HealthChecker)
assert checker.service_name == "test-service"
def test_create_basic_health_check_without_psutil(self):
"""Test create_basic_health_check handles psutil ImportError"""
# Skip this test as psutil import handling is complex to mock
pytest.skip("psutil import handling requires complex mocking")
def test_basic_health_check_has_checks(self):
"""Test basic health check has registered checks when psutil available"""
try:
import psutil
checker = create_basic_health_check("test-service")
# Should have memory and disk checks if psutil is available
assert len(checker._checks) > 0
except ImportError:
# Skip if psutil not available
pass

250
tests/test_metrics.py Normal file
View File

@@ -0,0 +1,250 @@
"""
Tests for AITBC metrics module
"""
import pytest
import asyncio
from unittest.mock import patch, Mock
from aitbc.metrics import (
service_info,
block_processing_duration,
block_height,
block_validation_duration,
block_propagation_duration,
job_submission_duration,
job_processing_duration,
job_queue_duration,
job_execution_duration,
jobs_total,
jobs_failed_total,
jobs_in_queue,
http_requests_total,
http_request_duration,
service_uptime_seconds,
service_restart_count,
track_block_processing,
track_job_processing,
track_http_request,
update_block_height,
update_jobs_in_queue,
increment_service_restarts,
metrics_app,
setup_service_info,
)
class TestMetricsDefinitions:
"""Tests for Prometheus metrics definitions"""
def test_service_info_exists(self):
"""Test service_info metric is defined"""
assert service_info is not None
assert service_info._name == 'service_info'
def test_block_processing_duration_exists(self):
"""Test block_processing_duration metric is defined"""
assert block_processing_duration is not None
assert block_processing_duration._name == 'block_processing_duration_seconds'
def test_block_height_exists(self):
"""Test block_height metric is defined"""
assert block_height is not None
assert block_height._name == 'block_height'
def test_block_validation_duration_exists(self):
"""Test block_validation_duration metric is defined"""
assert block_validation_duration is not None
assert block_validation_duration._name == 'block_validation_duration_seconds'
def test_block_propagation_duration_exists(self):
"""Test block_propagation_duration metric is defined"""
assert block_propagation_duration is not None
assert block_propagation_duration._name == 'block_propagation_duration_seconds'
def test_job_submission_duration_exists(self):
"""Test job_submission_duration metric is defined"""
assert job_submission_duration is not None
assert job_submission_duration._name == 'job_submission_duration_seconds'
def test_job_processing_duration_exists(self):
"""Test job_processing_duration metric is defined"""
assert job_processing_duration is not None
assert job_processing_duration._name == 'job_processing_duration_seconds'
def test_job_queue_duration_exists(self):
"""Test job_queue_duration metric is defined"""
assert job_queue_duration is not None
assert job_queue_duration._name == 'job_queue_duration_seconds'
def test_job_execution_duration_exists(self):
"""Test job_execution_duration metric is defined"""
assert job_execution_duration is not None
assert job_execution_duration._name == 'job_execution_duration_seconds'
def test_jobs_total_exists(self):
"""Test jobs_total metric is defined"""
assert jobs_total is not None
assert jobs_total._name == 'jobs'
def test_jobs_failed_total_exists(self):
"""Test jobs_failed_total metric is defined"""
assert jobs_failed_total is not None
assert jobs_failed_total._name == 'jobs_failed'
def test_jobs_in_queue_exists(self):
"""Test jobs_in_queue metric is defined"""
assert jobs_in_queue is not None
assert jobs_in_queue._name == 'jobs_in_queue'
def test_http_requests_total_exists(self):
"""Test http_requests_total metric is defined"""
assert http_requests_total is not None
assert http_requests_total._name == 'http_requests'
def test_http_request_duration_exists(self):
"""Test http_request_duration metric is defined"""
assert http_request_duration is not None
assert http_request_duration._name == 'http_request_duration_seconds'
def test_service_uptime_seconds_exists(self):
"""Test service_uptime_seconds metric is defined"""
assert service_uptime_seconds is not None
assert service_uptime_seconds._name == 'service_uptime_seconds'
def test_service_restart_count_exists(self):
"""Test service_restart_count metric is defined"""
assert service_restart_count is not None
assert service_restart_count._name == 'service_restart_count'
class TestHelperFunctions:
"""Tests for metrics helper functions"""
def test_update_block_height(self):
"""Test update_block_height sets metric"""
update_block_height(100)
# Metric should be set, but we can't easily verify the value
# This test ensures the function doesn't raise an error
assert True
def test_update_jobs_in_queue(self):
"""Test update_jobs_in_queue sets metric"""
update_jobs_in_queue(50)
# Metric should be set, but we can't easily verify the value
# This test ensures the function doesn't raise an error
assert True
def test_increment_service_restarts(self):
"""Test increment_service_restarts increments counter"""
increment_service_restarts()
# Counter should be incremented, but we can't easily verify the value
# This test ensures the function doesn't raise an error
assert True
def test_setup_service_info(self):
"""Test setup_service_info sets service info"""
setup_service_info("test-service", "1.0.0")
# Info should be set, but we can't easily verify the value
# This test ensures the function doesn't raise an error
assert True
class TestDecorators:
"""Tests for metrics tracking decorators"""
@pytest.mark.asyncio
async def test_track_block_processing_success(self):
"""Test track_block_processing decorator on successful execution"""
@track_block_processing
async def process_block():
return "block_processed"
result = await process_block()
assert result == "block_processed"
# Decorator should have observed the duration
assert True
@pytest.mark.asyncio
async def test_track_block_processing_failure(self):
"""Test track_block_processing decorator on exception"""
@track_block_processing
async def process_block():
raise ValueError("block error")
with pytest.raises(ValueError):
await process_block()
# Decorator should have observed the duration even on failure
assert True
@pytest.mark.asyncio
async def test_track_job_processing_success(self):
"""Test track_job_processing decorator on successful execution"""
@track_job_processing
async def process_job():
return "job_completed"
result = await process_job()
assert result == "job_completed"
# Decorator should have observed duration and incremented jobs_total
assert True
@pytest.mark.asyncio
async def test_track_job_processing_failure(self):
"""Test track_job_processing decorator on exception"""
@track_job_processing
async def process_job():
raise ValueError("job error")
with pytest.raises(ValueError):
await process_job()
# Decorator should have observed duration and incremented failure counters
assert True
@pytest.mark.asyncio
async def test_track_http_request_success(self):
"""Test track_http_request decorator on successful execution"""
mock_response = Mock()
mock_response.status_code = 200
@track_http_request
async def handle_request():
return mock_response
result = await handle_request()
assert result.status_code == 200
# Decorator should have observed duration and incremented http_requests_total
assert True
@pytest.mark.asyncio
async def test_track_http_request_failure(self):
"""Test track_http_request decorator on exception"""
@track_http_request
async def handle_request():
raise ValueError("request error")
with pytest.raises(ValueError):
await handle_request()
# Decorator should have observed duration and incremented http_requests_total with 500
assert True
@pytest.mark.asyncio
async def test_track_http_request_without_status_code(self):
"""Test track_http_request with response without status_code"""
@track_http_request
async def handle_request():
return "success" # No status_code attribute
result = await handle_request()
assert result == "success"
# Decorator should have observed duration but not incremented http_requests_total
assert True
class TestMetricsApp:
"""Tests for metrics ASGI app"""
def test_metrics_app_exists(self):
"""Test metrics_app is created"""
assert metrics_app is not None
assert hasattr(metrics_app, '__call__')

266
tests/test_middleware.py Normal file
View File

@@ -0,0 +1,266 @@
"""
Tests for AITBC middleware modules
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp
from aitbc.middleware.performance import PerformanceLoggingMiddleware
from aitbc.middleware.request_id import RequestIDMiddleware
from aitbc.middleware.error_handler import ErrorHandlerMiddleware
class TestPerformanceLoggingMiddleware:
"""Tests for PerformanceLoggingMiddleware"""
@pytest.mark.asyncio
async def test_dispatch_adds_performance_header(self):
"""Test that middleware adds X-Process-Time header"""
app = Mock(spec=ASGIApp)
middleware = PerformanceLoggingMiddleware(app)
request = Mock(spec=Request)
request.method = "GET"
request.url = Mock()
request.url.path = "/test"
response = Mock(spec=Response)
response.status_code = 200
response.headers = {}
call_next = AsyncMock(return_value=response)
result = await middleware.dispatch(request, call_next)
assert "X-Process-Time" in result.headers
assert float(result.headers["X-Process-Time"]) >= 0
@pytest.mark.asyncio
async def test_dispatch_logs_performance_metrics(self):
"""Test that middleware logs performance metrics"""
app = Mock(spec=ASGIApp)
middleware = PerformanceLoggingMiddleware(app)
request = Mock(spec=Request)
request.method = "POST"
request.url = Mock()
request.url.path = "/api/test"
response = Mock(spec=Response)
response.status_code = 201
response.headers = {}
call_next = AsyncMock(return_value=response)
with patch('aitbc.middleware.performance.logger') as mock_logger:
await middleware.dispatch(request, call_next)
mock_logger.info.assert_called_once()
assert "Request performance" in mock_logger.info.call_args[0][0]
@pytest.mark.asyncio
async def test_dispatch_measures_time_correctly(self):
"""Test that middleware measures request duration accurately"""
app = Mock(spec=ASGIApp)
middleware = PerformanceLoggingMiddleware(app)
request = Mock(spec=Request)
request.method = "GET"
request.url = Mock()
request.url.path = "/test"
response = Mock(spec=Response)
response.status_code = 200
response.headers = {}
call_next = AsyncMock(return_value=response)
result = await middleware.dispatch(request, call_next)
process_time = float(result.headers["X-Process-Time"])
assert 0 <= process_time < 1.0 # Should complete in under 1 second
class TestRequestIDMiddleware:
"""Tests for RequestIDMiddleware"""
@pytest.mark.asyncio
async def test_dispatch_generates_request_id_when_missing(self):
"""Test that middleware generates request ID when not in headers"""
app = Mock(spec=ASGIApp)
middleware = RequestIDMiddleware(app)
request = Mock(spec=Request)
request.headers = {}
request.method = "GET"
request.url = Mock()
request.url.path = "/test"
request.client = Mock()
request.client.host = "127.0.0.1"
request.state = Mock()
response = Mock(spec=Response)
response.headers = {}
response.status_code = 200
call_next = AsyncMock(return_value=response)
result = await middleware.dispatch(request, call_next)
assert "X-Request-ID" in result.headers
assert len(result.headers["X-Request-ID"]) > 0
assert request.state.request_id == result.headers["X-Request-ID"]
@pytest.mark.asyncio
async def test_dispatch_uses_existing_request_id_from_header(self):
"""Test that middleware uses existing request ID from header"""
app = Mock(spec=ASGIApp)
middleware = RequestIDMiddleware(app)
existing_id = "test-request-id-123"
request = Mock(spec=Request)
request.headers = {"X-Request-ID": existing_id}
request.method = "POST"
request.url = Mock()
request.url.path = "/api/test"
request.client = Mock()
request.client.host = "192.168.1.1"
request.state = Mock()
response = Mock(spec=Response)
response.headers = {}
response.status_code = 201
call_next = AsyncMock(return_value=response)
result = await middleware.dispatch(request, call_next)
assert result.headers["X-Request-ID"] == existing_id
assert request.state.request_id == existing_id
@pytest.mark.asyncio
async def test_dispatch_logs_request_info(self):
"""Test that middleware logs request information"""
app = Mock(spec=ASGIApp)
middleware = RequestIDMiddleware(app)
request = Mock(spec=Request)
request.headers = {}
request.method = "GET"
request.url = Mock()
request.url.path = "/test"
request.client = Mock()
request.client.host = "127.0.0.1"
request.state = Mock()
response = Mock(spec=Response)
response.headers = {}
response.status_code = 200
call_next = AsyncMock(return_value=response)
with patch('aitbc.middleware.request_id.logger') as mock_logger:
await middleware.dispatch(request, call_next)
assert mock_logger.info.call_count >= 2 # Logs start and completion
class TestErrorHandlerMiddleware:
"""Tests for ErrorHandlerMiddleware"""
@pytest.mark.asyncio
async def test_dispatch_passes_through_normal_response(self):
"""Test that middleware passes through normal responses"""
app = Mock(spec=ASGIApp)
middleware = ErrorHandlerMiddleware(app)
request = Mock(spec=Request)
request.url = Mock()
request.url.path = "/test"
request.method = "GET"
response = Mock(spec=Response)
response.status_code = 200
call_next = AsyncMock(return_value=response)
result = await middleware.dispatch(request, call_next)
assert result == response
@pytest.mark.asyncio
async def test_dispatch_handles_http_exception(self):
"""Test that middleware handles HTTPException"""
app = Mock(spec=ASGIApp)
middleware = ErrorHandlerMiddleware(app)
request = Mock(spec=Request)
request.url = Mock()
request.url.path = "/api/error"
request.method = "GET"
exception = HTTPException(status_code=404, detail="Not found")
call_next = AsyncMock(side_effect=exception)
result = await middleware.dispatch(request, call_next)
assert isinstance(result, JSONResponse)
assert result.status_code == 404
content = result.body.decode() if hasattr(result, 'body') else {}
assert "error" in result.body.decode() if hasattr(result, 'body') else True
@pytest.mark.asyncio
async def test_dispatch_handles_generic_exception(self):
"""Test that middleware handles generic exceptions"""
app = Mock(spec=ASGIApp)
middleware = ErrorHandlerMiddleware(app)
request = Mock(spec=Request)
request.url = Mock()
request.url.path = "/api/crash"
request.method = "POST"
exception = ValueError("Something went wrong")
call_next = AsyncMock(side_effect=exception)
result = await middleware.dispatch(request, call_next)
assert isinstance(result, JSONResponse)
assert result.status_code == 500
@pytest.mark.asyncio
async def test_dispatch_logs_http_exception(self):
"""Test that middleware logs HTTPException"""
app = Mock(spec=ASGIApp)
middleware = ErrorHandlerMiddleware(app)
request = Mock(spec=Request)
request.url = Mock()
request.url.path = "/test"
request.method = "GET"
exception = HTTPException(status_code=400, detail="Bad request")
call_next = AsyncMock(side_effect=exception)
with patch('aitbc.middleware.error_handler.logger') as mock_logger:
await middleware.dispatch(request, call_next)
mock_logger.warning.assert_called_once()
@pytest.mark.asyncio
async def test_dispatch_logs_generic_exception(self):
"""Test that middleware logs generic exceptions"""
app = Mock(spec=ASGIApp)
middleware = ErrorHandlerMiddleware(app)
request = Mock(spec=Request)
request.url = Mock()
request.url.path = "/test"
request.method = "GET"
exception = RuntimeError("Runtime error")
call_next = AsyncMock(side_effect=exception)
with patch('aitbc.middleware.error_handler.logger') as mock_logger:
await middleware.dispatch(request, call_next)
mock_logger.error.assert_called_once()

View File

@@ -0,0 +1,287 @@
"""
Tests for security headers and CORS utilities
"""
import pytest
from aitbc.security_headers import (
SecurityHeaders,
CORSConfig,
SecurityHeadersMiddleware,
CORSMiddleware,
create_production_security_headers,
create_development_security_headers,
create_strict_cors_config,
create_permissive_cors_config,
)
class TestSecurityHeaders:
"""Tests for SecurityHeaders dataclass"""
def test_default_security_headers(self):
"""Test default security headers values"""
headers = SecurityHeaders()
assert headers.X_Content_Type_Options == "nosniff"
assert headers.X_Frame_Options == "DENY"
assert headers.X_XSS_Protection == "1; mode=block"
assert headers.Strict_Transport_Security == "max-age=31536000; includeSubDomains"
assert headers.Content_Security_Policy == "default-src 'self'"
assert headers.Referrer_Policy == "strict-origin-when-cross-origin"
assert headers.Permissions_Policy == ""
assert headers.Cache_Control == "no-cache, no-store, must-revalidate"
assert headers.Pragma == "no-cache"
def test_custom_security_headers(self):
"""Test custom security headers values"""
headers = SecurityHeaders(
X_Frame_Options="SAMEORIGIN",
Content_Security_Policy="default-src 'self' https://example.com"
)
assert headers.X_Frame_Options == "SAMEORIGIN"
assert headers.Content_Security_Policy == "default-src 'self' https://example.com"
class TestCORSConfig:
"""Tests for CORSConfig dataclass"""
def test_default_cors_config(self):
"""Test default CORS config values"""
config = CORSConfig(
allow_origins=["http://localhost:3000"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"]
)
assert config.allow_origins == ["http://localhost:3000"]
assert config.allow_methods == ["GET", "POST"]
assert config.allow_credentials is False
assert config.expose_headers is None
assert config.max_age == 3600
def test_custom_cors_config(self):
"""Test custom CORS config values"""
config = CORSConfig(
allow_origins=["*"],
allow_methods=["GET", "POST", "PUT"],
allow_headers=["Content-Type"],
allow_credentials=True,
expose_headers=["X-Custom-Header"],
max_age=7200
)
assert config.allow_origins == ["*"]
assert config.allow_credentials is True
assert config.expose_headers == ["X-Custom-Header"]
assert config.max_age == 7200
class TestSecurityHeadersMiddleware:
"""Tests for SecurityHeadersMiddleware"""
def test_initialization_with_default_headers(self):
"""Test middleware initialization with default headers"""
middleware = SecurityHeadersMiddleware()
assert middleware.headers is not None
assert middleware.headers.X_Frame_Options == "DENY"
def test_initialization_with_custom_headers(self):
"""Test middleware initialization with custom headers"""
custom_headers = SecurityHeaders(X_Frame_Options="SAMEORIGIN")
middleware = SecurityHeadersMiddleware(custom_headers)
assert middleware.headers.X_Frame_Options == "SAMEORIGIN"
def test_get_headers(self):
"""Test get_headers returns dictionary"""
middleware = SecurityHeadersMiddleware()
headers = middleware.get_headers()
assert isinstance(headers, dict)
assert "X-Content-Type-Options" in headers
assert "X-Frame-Options" in headers
assert "X-XSS-Protection" in headers
assert "Strict-Transport-Security" in headers
assert "Content-Security-Policy" in headers
assert "Referrer-Policy" in headers
assert "Permissions-Policy" in headers
assert "Cache-Control" in headers
assert "Pragma" in headers
def test_get_headers_values(self):
"""Test get_headers returns correct values"""
middleware = SecurityHeadersMiddleware()
headers = middleware.get_headers()
assert headers["X-Content-Type-Options"] == "nosniff"
assert headers["X-Frame-Options"] == "DENY"
assert headers["X-XSS-Protection"] == "1; mode=block"
def test_apply_to_response(self):
"""Test apply_to_response adds security headers"""
middleware = SecurityHeadersMiddleware()
response_headers = {"Content-Type": "application/json"}
result = middleware.apply_to_response(response_headers)
assert "X-Content-Type-Options" in result
assert "X-Frame-Options" in result
assert result["Content-Type"] == "application/json"
def test_apply_to_response_overwrites_existing(self):
"""Test apply_to_response overwrites existing security headers"""
middleware = SecurityHeadersMiddleware()
response_headers = {"X-Frame-Options": "ALLOW"}
result = middleware.apply_to_response(response_headers)
assert result["X-Frame-Options"] == "DENY"
class TestCORSMiddleware:
"""Tests for CORSMiddleware"""
def test_initialization(self):
"""Test CORS middleware initialization"""
config = CORSConfig(
allow_origins=["http://localhost:3000"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"]
)
middleware = CORSMiddleware(config)
assert middleware.config == config
def test_get_cors_headers_allowed_origin(self):
"""Test get_cors_headers with allowed origin"""
config = CORSConfig(
allow_origins=["http://localhost:3000"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"]
)
middleware = CORSMiddleware(config)
headers = middleware.get_cors_headers("http://localhost:3000")
assert headers["Access-Control-Allow-Origin"] == "http://localhost:3000"
assert headers["Access-Control-Allow-Methods"] == "GET, POST"
assert headers["Access-Control-Allow-Headers"] == "Content-Type"
assert headers["Access-Control-Max-Age"] == "3600"
def test_get_cors_headers_disallowed_origin(self):
"""Test get_cors_headers with disallowed origin"""
config = CORSConfig(
allow_origins=["http://localhost:3000"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"]
)
middleware = CORSMiddleware(config)
headers = middleware.get_cors_headers("http://evil.com")
assert headers == {}
def test_get_cors_headers_wildcard_origin(self):
"""Test get_cors_headers with wildcard origin"""
config = CORSConfig(
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"]
)
middleware = CORSMiddleware(config)
headers = middleware.get_cors_headers("http://any-origin.com")
assert headers["Access-Control-Allow-Origin"] == "http://any-origin.com"
def test_get_cors_headers_with_credentials(self):
"""Test get_cors_headers with credentials enabled"""
config = CORSConfig(
allow_origins=["http://localhost:3000"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
allow_credentials=True
)
middleware = CORSMiddleware(config)
headers = middleware.get_cors_headers("http://localhost:3000")
assert headers["Access-Control-Allow-Credentials"] == "true"
def test_get_cors_headers_with_expose_headers(self):
"""Test get_cors_headers with expose headers"""
config = CORSConfig(
allow_origins=["http://localhost:3000"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
expose_headers=["X-Request-ID"]
)
middleware = CORSMiddleware(config)
headers = middleware.get_cors_headers("http://localhost:3000")
assert headers["Access-Control-Expose-Headers"] == "X-Request-ID"
def test_is_origin_allowed_wildcard(self):
"""Test _is_origin_allowed with wildcard"""
config = CORSConfig(
allow_origins=["*"],
allow_methods=["GET"],
allow_headers=["Content-Type"]
)
middleware = CORSMiddleware(config)
assert middleware._is_origin_allowed("http://any-origin.com") is True
def test_is_origin_allowed_specific(self):
"""Test _is_origin_allowed with specific origin"""
config = CORSConfig(
allow_origins=["http://localhost:3000"],
allow_methods=["GET"],
allow_headers=["Content-Type"]
)
middleware = CORSMiddleware(config)
assert middleware._is_origin_allowed("http://localhost:3000") is True
assert middleware._is_origin_allowed("http://evil.com") is False
def test_is_preflight_request(self):
"""Test is_preflight_request"""
config = CORSConfig(
allow_origins=["*"],
allow_methods=["GET"],
allow_headers=["Content-Type"]
)
middleware = CORSMiddleware(config)
assert middleware.is_preflight_request("OPTIONS") is True
assert middleware.is_preflight_request("GET") is False
assert middleware.is_preflight_request("options") is True
class TestFactoryFunctions:
"""Tests for factory functions"""
def test_create_production_security_headers(self):
"""Test create_production_security_headers"""
headers = create_production_security_headers()
assert headers.X_Frame_Options == "DENY"
assert "preload" in headers.Strict_Transport_Security
assert "unsafe-inline" in headers.Content_Security_Policy
assert "geolocation=()" in headers.Permissions_Policy
def test_create_development_security_headers(self):
"""Test create_development_security_headers"""
headers = create_development_security_headers()
assert headers.X_Frame_Options == "SAMEORIGIN"
assert headers.Strict_Transport_Security == "max-age=3600"
assert headers.Permissions_Policy == ""
assert headers.Cache_Control == "no-cache"
def test_create_strict_cors_config(self):
"""Test create_strict_cors_config"""
config = create_strict_cors_config(["http://localhost:3000"])
assert "http://localhost:3000" in config.allow_origins
assert "GET" in config.allow_methods
assert "POST" in config.allow_methods
assert "PUT" in config.allow_methods
assert "DELETE" in config.allow_methods
assert "PATCH" in config.allow_methods
assert config.allow_credentials is True
assert "X-Request-ID" in config.expose_headers
assert config.max_age == 3600
def test_create_permissive_cors_config(self):
"""Test create_permissive_cors_config"""
config = create_permissive_cors_config()
assert "*" in config.allow_origins
assert "GET" in config.allow_methods
assert "POST" in config.allow_methods
assert "*" in config.allow_headers
assert config.allow_credentials is False
assert "*" in config.expose_headers
assert config.max_age == 86400

345
tests/test_utils.py Normal file
View File

@@ -0,0 +1,345 @@
"""
Tests for AITBC utility modules
"""
import os
import pytest
from pathlib import Path
from unittest.mock import patch, Mock
from aitbc.utils.paths import (
get_data_path,
get_config_path,
get_log_path,
get_repo_path,
ensure_dir,
ensure_file_dir,
resolve_path,
get_keystore_path,
get_blockchain_data_path,
get_marketplace_data_path,
)
from aitbc.utils.env import (
get_env_var,
get_required_env_var,
get_bool_env_var,
get_int_env_var,
get_float_env_var,
get_list_env_var,
)
from aitbc.utils.json_utils import (
load_json,
save_json,
merge_json,
json_to_string,
string_to_json,
get_nested_value,
set_nested_value,
flatten_json,
)
from aitbc.exceptions import ConfigurationError
class TestPaths:
"""Tests for path utility functions"""
def test_get_data_path_no_subpath(self):
"""Test get_data_path without subpath"""
result = get_data_path()
assert isinstance(result, Path)
def test_get_data_path_with_subpath(self):
"""Test get_data_path with subpath"""
result = get_data_path("test")
assert isinstance(result, Path)
assert str(result).endswith("test")
def test_get_config_path(self):
"""Test get_config_path"""
result = get_config_path("config.yaml")
assert isinstance(result, Path)
assert str(result).endswith("config.yaml")
def test_get_log_path(self):
"""Test get_log_path"""
result = get_log_path("app.log")
assert isinstance(result, Path)
assert str(result).endswith("app.log")
def test_get_repo_path_no_subpath(self):
"""Test get_repo_path without subpath"""
result = get_repo_path()
assert isinstance(result, Path)
def test_get_repo_path_with_subpath(self):
"""Test get_repo_path with subpath"""
result = get_repo_path("src")
assert isinstance(result, Path)
assert str(result).endswith("src")
def test_ensure_dir(self):
"""Test ensure_dir creates directory"""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
test_path = Path(tmpdir) / "test" / "nested"
result = ensure_dir(test_path)
assert result.exists()
assert result.is_dir()
def test_ensure_file_dir(self):
"""Test ensure_file_dir creates parent directory"""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
test_path = Path(tmpdir) / "test" / "nested" / "file.txt"
result = ensure_file_dir(test_path)
assert result.exists()
assert result.is_dir()
def test_resolve_path_absolute(self):
"""Test resolve_path with absolute path"""
result = resolve_path("/tmp/test")
assert result.is_absolute()
def test_resolve_path_relative(self):
"""Test resolve_path with relative path"""
result = resolve_path("test")
assert result.is_absolute() # Relative paths are resolved to absolute
def test_get_keystore_path_no_wallet(self):
"""Test get_keystore_path without wallet name"""
result = get_keystore_path()
assert isinstance(result, Path)
assert str(result).endswith("keystore")
def test_get_keystore_path_with_wallet(self):
"""Test get_keystore_path with wallet name"""
result = get_keystore_path("mywallet")
assert isinstance(result, Path)
assert str(result).endswith("mywallet.json")
def test_get_blockchain_data_path_default(self):
"""Test get_blockchain_data_path with default chain"""
result = get_blockchain_data_path()
assert isinstance(result, Path)
assert str(result).endswith("ait-mainnet")
def test_get_blockchain_data_path_custom(self):
"""Test get_blockchain_data_path with custom chain"""
result = get_blockchain_data_path("custom-chain")
assert isinstance(result, Path)
assert str(result).endswith("custom-chain")
def test_get_marketplace_data_path_no_subpath(self):
"""Test get_marketplace_data_path without subpath"""
result = get_marketplace_data_path()
assert isinstance(result, Path)
assert str(result).endswith("marketplace")
def test_get_marketplace_data_path_with_subpath(self):
"""Test get_marketplace_data_path with subpath"""
result = get_marketplace_data_path("orders")
assert isinstance(result, Path)
assert str(result).endswith("orders")
class TestEnv:
"""Tests for environment variable utilities"""
def test_get_env_var_with_value(self):
"""Test get_env_var with set value"""
os.environ["TEST_VAR"] = "test_value"
result = get_env_var("TEST_VAR")
assert result == "test_value"
del os.environ["TEST_VAR"]
def test_get_env_var_with_default(self):
"""Test get_env_var with default value"""
result = get_env_var("NONEXISTENT_VAR", "default")
assert result == "default"
def test_get_required_env_var_with_value(self):
"""Test get_required_env_var with set value"""
os.environ["TEST_VAR"] = "test_value"
result = get_required_env_var("TEST_VAR")
assert result == "test_value"
del os.environ["TEST_VAR"]
def test_get_required_env_var_without_value(self):
"""Test get_required_env_var without value raises error"""
with pytest.raises(ConfigurationError):
get_required_env_var("NONEXISTENT_VAR")
def test_get_bool_env_var_true(self):
"""Test get_bool_env_var with true values"""
for value in ["true", "TRUE", "1", "yes", "YES", "on", "ON"]:
os.environ["TEST_VAR"] = value
assert get_bool_env_var("TEST_VAR") is True
del os.environ["TEST_VAR"]
def test_get_bool_env_var_false(self):
"""Test get_bool_env_var with false values"""
for value in ["false", "FALSE", "0", "no", "NO", "off", "OFF"]:
os.environ["TEST_VAR"] = value
assert get_bool_env_var("TEST_VAR") is False
del os.environ["TEST_VAR"]
def test_get_bool_env_var_default(self):
"""Test get_bool_env_var with default"""
assert get_bool_env_var("NONEXISTENT_VAR", True) is True
assert get_bool_env_var("NONEXISTENT_VAR", False) is False
def test_get_int_env_var_valid(self):
"""Test get_int_env_var with valid integer"""
os.environ["TEST_VAR"] = "42"
assert get_int_env_var("TEST_VAR") == 42
del os.environ["TEST_VAR"]
def test_get_int_env_var_invalid(self):
"""Test get_int_env_var with invalid value returns default"""
os.environ["TEST_VAR"] = "not_a_number"
assert get_int_env_var("TEST_VAR", 10) == 10
del os.environ["TEST_VAR"]
def test_get_int_env_var_default(self):
"""Test get_int_env_var with default"""
assert get_int_env_var("NONEXISTENT_VAR", 100) == 100
def test_get_float_env_var_valid(self):
"""Test get_float_env_var with valid float"""
os.environ["TEST_VAR"] = "3.14"
assert get_float_env_var("TEST_VAR") == 3.14
del os.environ["TEST_VAR"]
def test_get_float_env_var_invalid(self):
"""Test get_float_env_var with invalid value returns default"""
os.environ["TEST_VAR"] = "not_a_number"
assert get_float_env_var("TEST_VAR", 2.5) == 2.5
del os.environ["TEST_VAR"]
def test_get_float_env_var_default(self):
"""Test get_float_env_var with default"""
assert get_float_env_var("NONEXISTENT_VAR", 1.5) == 1.5
def test_get_list_env_var_valid(self):
"""Test get_list_env_var with valid list"""
os.environ["TEST_VAR"] = "item1,item2,item3"
result = get_list_env_var("TEST_VAR")
assert result == ["item1", "item2", "item3"]
del os.environ["TEST_VAR"]
def test_get_list_env_var_custom_separator(self):
"""Test get_list_env_var with custom separator"""
os.environ["TEST_VAR"] = "item1;item2;item3"
result = get_list_env_var("TEST_VAR", separator=";")
assert result == ["item1", "item2", "item3"]
del os.environ["TEST_VAR"]
def test_get_list_env_var_empty(self):
"""Test get_list_env_var with empty value returns default"""
os.environ["TEST_VAR"] = ""
result = get_list_env_var("TEST_VAR", default=["default"])
assert result == ["default"]
del os.environ["TEST_VAR"]
def test_get_list_env_var_default(self):
"""Test get_list_env_var with default"""
result = get_list_env_var("NONEXISTENT_VAR", default=["a", "b"])
assert result == ["a", "b"]
class TestJsonUtils:
"""Tests for JSON utility functions"""
def test_json_to_string(self):
"""Test json_to_string"""
data = {"key": "value"}
result = json_to_string(data)
assert '"key"' in result
assert '"value"' in result
def test_string_to_json_valid(self):
"""Test string_to_json with valid JSON"""
json_str = '{"key": "value"}'
result = string_to_json(json_str)
assert result == {"key": "value"}
def test_string_to_json_invalid(self):
"""Test string_to_json with invalid JSON raises error"""
with pytest.raises(ConfigurationError):
string_to_json("not valid json")
def test_get_nested_value_found(self):
"""Test get_nested_value when key exists"""
data = {"a": {"b": {"c": "value"}}}
result = get_nested_value(data, "a", "b", "c")
assert result == "value"
def test_get_nested_value_not_found(self):
"""Test get_nested_value when key doesn't exist returns default"""
data = {"a": {"b": {"c": "value"}}}
result = get_nested_value(data, "a", "b", "d", default="default")
assert result == "default"
def test_get_nested_value_default_none(self):
"""Test get_nested_value with default None"""
data = {"a": {"b": {"c": "value"}}}
result = get_nested_value(data, "x", "y", "z")
assert result is None
def test_set_nested_value(self):
"""Test set_nested_value"""
data = {}
set_nested_value(data, "a", "b", "c", value="test")
assert data["a"]["b"]["c"] == "test"
def test_flatten_json(self):
"""Test flatten_json"""
data = {"a": {"b": {"c": "value"}}, "d": "simple"}
result = flatten_json(data)
assert "a.b.c" in result
assert result["a.b.c"] == "value"
assert result["d"] == "simple"
def test_flatten_json_custom_separator(self):
"""Test flatten_json with custom separator"""
data = {"a": {"b": "value"}}
result = flatten_json(data, separator="_")
assert "a_b" in result
assert result["a_b"] == "value"
def test_load_json(self, tmp_path):
"""Test load_json"""
test_file = tmp_path / "test.json"
test_file.write_text('{"key": "value"}')
result = load_json(test_file)
assert result == {"key": "value"}
def test_load_json_not_found(self, tmp_path):
"""Test load_json with non-existent file raises error"""
with pytest.raises(ConfigurationError):
load_json(tmp_path / "nonexistent.json")
def test_load_json_invalid(self, tmp_path):
"""Test load_json with invalid JSON raises error"""
test_file = tmp_path / "invalid.json"
test_file.write_text("not valid json")
with pytest.raises(ConfigurationError):
load_json(test_file)
def test_save_json(self, tmp_path):
"""Test save_json"""
test_file = tmp_path / "test.json"
data = {"key": "value"}
save_json(data, test_file)
assert test_file.exists()
result = load_json(test_file)
assert result == data
def test_merge_json(self, tmp_path):
"""Test merge_json"""
file1 = tmp_path / "file1.json"
file2 = tmp_path / "file2.json"
file1.write_text('{"a": 1, "b": 2}')
file2.write_text('{"b": 3, "c": 4}')
result = merge_json(file1, file2)
assert result == {"a": 1, "b": 3, "c": 4}