From c87806b68bf1b9a93a06ff2623368188ce5d5b0e Mon Sep 17 00:00:00 2001 From: aitbc Date: Tue, 12 May 2026 18:10:58 +0200 Subject: [PATCH] refactor: reorganize services into bounded contexts and implement async database support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Moved services to bounded context packages: - adaptive_learning.py → ai_analytics/adaptive_learning.py - analytics_service.py → ai_analytics/analytics.py - dynamic_pricing_engine.py → trading_marketplace/dynamic_pricing.py - trading_service.py → trading_marketplace/trading.py - Implemented async database module (database_async.py): - Added async SQLAlchemy engine with connection pooling - Added --- .../coordinator-api/src/app/database_async.py | 118 +- apps/coordinator-api/src/app/main.py | 16 + .../app/routers/adaptive_learning_health.py | 2 +- .../src/app/routers/analytics.py | 2 +- .../src/app/routers/dynamic_pricing.py | 2 +- .../src/app/routers/trading.py | 2 +- .../src/app/services/ai_analytics/__init__.py | 14 + .../{ => ai_analytics}/adaptive_learning.py | 2 +- .../{ => ai_analytics}/advanced_learning.py | 0 .../analytics.py} | 2 +- .../surveillance.py} | 0 .../trading_engine.py} | 0 .../app/services/enterprise_api_gateway.py | 608 --------- .../app/services/enterprise_integration.py | 1127 ----------------- .../enterprise_integration/__init__.py | 8 +- .../app/services/enterprise_load_balancer.py | 770 ----------- .../src/app/services/enterprise_security.py | 773 ----------- .../services/trading_marketplace/__init__.py | 17 + .../amm.py} | 0 .../bid_strategy.py} | 0 .../dynamic_pricing.py} | 0 .../gpu_optimizer.py} | 0 .../trading.py} | 2 +- docs/ROADMAP.md | 32 +- tests/test_config.py | 154 +++ tests/test_decorators.py | 303 +++++ tests/test_health_checks.py | 218 ++++ tests/test_metrics.py | 250 ++++ tests/test_middleware.py | 266 ++++ tests/test_security_headers.py | 287 +++++ tests/test_utils.py | 345 +++++ 31 files changed, 2027 insertions(+), 3293 deletions(-) create mode 100644 apps/coordinator-api/src/app/services/ai_analytics/__init__.py rename apps/coordinator-api/src/app/services/{ => ai_analytics}/adaptive_learning.py (99%) rename apps/coordinator-api/src/app/services/{ => ai_analytics}/advanced_learning.py (100%) rename apps/coordinator-api/src/app/services/{analytics_service.py => ai_analytics/analytics.py} (99%) rename apps/coordinator-api/src/app/services/{ai_surveillance.py => ai_analytics/surveillance.py} (100%) rename apps/coordinator-api/src/app/services/{ai_trading_engine.py => ai_analytics/trading_engine.py} (100%) delete mode 100755 apps/coordinator-api/src/app/services/enterprise_api_gateway.py delete mode 100755 apps/coordinator-api/src/app/services/enterprise_integration.py delete mode 100755 apps/coordinator-api/src/app/services/enterprise_load_balancer.py delete mode 100755 apps/coordinator-api/src/app/services/enterprise_security.py create mode 100644 apps/coordinator-api/src/app/services/trading_marketplace/__init__.py rename apps/coordinator-api/src/app/services/{amm_service.py => trading_marketplace/amm.py} (100%) rename apps/coordinator-api/src/app/services/{bid_strategy_engine.py => trading_marketplace/bid_strategy.py} (100%) rename apps/coordinator-api/src/app/services/{dynamic_pricing_engine.py => trading_marketplace/dynamic_pricing.py} (100%) rename apps/coordinator-api/src/app/services/{marketplace_gpu_optimizer.py => trading_marketplace/gpu_optimizer.py} (100%) rename apps/coordinator-api/src/app/services/{trading_service.py => trading_marketplace/trading.py} (99%) create mode 100644 tests/test_config.py create mode 100644 tests/test_decorators.py create mode 100644 tests/test_health_checks.py create mode 100644 tests/test_metrics.py create mode 100644 tests/test_middleware.py create mode 100644 tests/test_security_headers.py create mode 100644 tests/test_utils.py diff --git a/apps/coordinator-api/src/app/database_async.py b/apps/coordinator-api/src/app/database_async.py index 48f84b00..378e5eef 100644 --- a/apps/coordinator-api/src/app/database_async.py +++ b/apps/coordinator-api/src/app/database_async.py @@ -1 +1,117 @@ -[{'adapter}': 'def _build_async_url(url: str', 'driver': 'str) -> str:', 'Convert sync URL to async URL."': 'sqlite:///path.db -> sqlite+aiosqlite:///path.db\n # postgresql:// -> postgresql+asyncpg://\n parts = url.split(', ', 1)\n if len(parts) == 2:\n return f': 'parts[0]'}, {'driver}': {'f': 'url'}, 'sqlite': 'For SQLite', 'connect_args={"check_same_thread': False}, {}] \ No newline at end of file +"""Async database module with connection pooling for Coordinator API.""" + +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from .config import settings +import logging + +logger = logging.getLogger(__name__) + +# Global async engine and session factory +_async_engine = None +_async_session_factory = None + + +def _build_async_url(url: str) -> str: + """Convert sync database URL to async URL. + + Examples: + sqlite:///path.db -> sqlite+aiosqlite:///path.db + postgresql://user:pass@host/db -> postgresql+asyncpg://user:pass@host/db + """ + # Handle URL with query parameters + if '?' in url: + base, params = url.split('?', 1) + if base.startswith('sqlite:'): + return f"{base}+aiosqlite://?{params}" + elif base.startswith('postgresql:'): + return f"{base}+asyncpg://?{params}" + else: + return f"{base}+aiosqlite://?{params}" # fallback + else: + if url.startswith('sqlite:'): + return url.replace('sqlite:', 'sqlite+aiosqlite:') + elif url.startswith('postgresql:'): + return url.replace('postgresql:', 'postgresql+asyncpg:') + else: + return url.replace(':', '+aiosqlite:') # fallback + + +def init_async_db() -> None: + """Initialize async database engine and session factory.""" + global _async_engine, _async_session_factory + + if _async_engine is not None: + logger.warning("Async database already initialized") + return + + try: + # Build async URL from sync settings + sync_url = str(settings.database.url) + async_url = _build_async_url(sync_url) + + logger.info(f"Initializing async database connection: {async_url.split('://')[0]}://...") + + # Create async engine with pooling + _async_engine = create_async_engine( + async_url, + echo=settings.database.echo if hasattr(settings.database, 'echo') else False, + pool_size=getattr(settings.database, 'pool_size', 5), + max_overflow=getattr(settings.database, 'max_overflow', 10), + pool_pre_ping=getattr(settings.database, 'pool_pre_ping', True), + pool_recycle=getattr(settings.database, 'pool_recycle', 3600), + ) + + # Create session factory + _async_session_factory = sessionmaker( + _async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + logger.info("Async database initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize async database: {e}") + raise + + +def get_async_db() -> AsyncSession: + """Dependency to get async database session. + + Yields: + AsyncSession: Database session that closes automatically + """ + if _async_session_factory is None: + raise RuntimeError("Async database not initialized. Call init_async_db() first.") + + async def _get_async_db() -> AsyncSession: + async with _async_session_factory() as session: + try: + yield session + finally: + await session.close() + + return _get_async_db() + + +def get_sync_engine(): + """Get synchronous engine for backward compatibility. + + Returns: + Engine: SQLAlchemy synchronous engine + """ + from .database import engine + return engine + + +async def close_async_db() -> None: + """Close async database connections.""" + global _async_engine, _async_session_factory + + if _async_engine is not None: + logger.info("Closing async database connections...") + await _async_engine.dispose() + _async_engine = None + _async_session_factory = None + logger.info("Async database connections closed") \ No newline at end of file diff --git a/apps/coordinator-api/src/app/main.py b/apps/coordinator-api/src/app/main.py index 14b7fd7d..6cc1247d 100755 --- a/apps/coordinator-api/src/app/main.py +++ b/apps/coordinator-api/src/app/main.py @@ -114,6 +114,7 @@ logger = get_logger(__name__) from contextlib import asynccontextmanager from .storage.db import init_db +from .database_async import init_async_db, close_async_db @asynccontextmanager @@ -130,6 +131,14 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: logger.warning(f"Database initialization failed (non-fatal): {e}") # Continue startup even if init_db fails + # Initialize async database + try: + init_async_db() + logger.info("Async database initialized successfully") + except Exception as e: + logger.warning(f"Async database initialization failed (non-fatal): {e}") + # Continue startup even if async init fails + # Warmup database connections logger.info("Warming up database connections...") try: @@ -227,6 +236,13 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: except Exception as e: logger.warning(f"Error closing database connections: {e}") + # Close async database connections + try: + await close_async_db() + logger.info("Async database connections closed successfully") + except Exception as e: + logger.warning(f"Error closing async database connections: {e}") + # Cleanup rate limiting state logger.info("Cleaning up rate limiting state...") diff --git a/apps/coordinator-api/src/app/routers/adaptive_learning_health.py b/apps/coordinator-api/src/app/routers/adaptive_learning_health.py index 2ce7ba25..d55e8544 100755 --- a/apps/coordinator-api/src/app/routers/adaptive_learning_health.py +++ b/apps/coordinator-api/src/app/routers/adaptive_learning_health.py @@ -14,7 +14,7 @@ from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from aitbc import get_logger -from ..services.adaptive_learning import AdaptiveLearningService +from ..services.ai_analytics.adaptive_learning import AdaptiveLearningService from ..storage import get_session logger = get_logger(__name__) diff --git a/apps/coordinator-api/src/app/routers/analytics.py b/apps/coordinator-api/src/app/routers/analytics.py index 0b1771b4..107a2c85 100755 --- a/apps/coordinator-api/src/app/routers/analytics.py +++ b/apps/coordinator-api/src/app/routers/analytics.py @@ -27,7 +27,7 @@ from ..domain.analytics import ( MetricType, ReportType, ) -from ..services.analytics_service import MarketplaceAnalytics +from ..services.ai_analytics.analytics import MarketplaceAnalytics from ..storage import get_session router = APIRouter(prefix="/v1/analytics", tags=["analytics"]) diff --git a/apps/coordinator-api/src/app/routers/dynamic_pricing.py b/apps/coordinator-api/src/app/routers/dynamic_pricing.py index 3c612f03..24d82ee7 100755 --- a/apps/coordinator-api/src/app/routers/dynamic_pricing.py +++ b/apps/coordinator-api/src/app/routers/dynamic_pricing.py @@ -23,7 +23,7 @@ from ..schemas.pricing import ( PricingStrategyRequest, PricingStrategyResponse, ) -from ..services.dynamic_pricing_engine import DynamicPricingEngine, PriceConstraints, PricingStrategy, ResourceType +from ..services.trading_marketplace.dynamic_pricing import DynamicPricingEngine, PriceConstraints, PricingStrategy, ResourceType from ..services.market_data_collector import MarketDataCollector router = APIRouter(prefix="/v1/pricing", tags=["dynamic-pricing"]) diff --git a/apps/coordinator-api/src/app/routers/trading.py b/apps/coordinator-api/src/app/routers/trading.py index 7b9a3c9b..6c5e6640 100755 --- a/apps/coordinator-api/src/app/routers/trading.py +++ b/apps/coordinator-api/src/app/routers/trading.py @@ -28,7 +28,7 @@ from ..domain.trading import ( TradeStatus, TradeType, ) -from ..services.trading_service import P2PTradingProtocol +from ..services.trading_marketplace.trading import P2PTradingProtocol from ..storage import get_session router = APIRouter(prefix="/v1/trading", tags=["trading"]) diff --git a/apps/coordinator-api/src/app/services/ai_analytics/__init__.py b/apps/coordinator-api/src/app/services/ai_analytics/__init__.py new file mode 100644 index 00000000..0bc6f324 --- /dev/null +++ b/apps/coordinator-api/src/app/services/ai_analytics/__init__.py @@ -0,0 +1,14 @@ +""" +AI & Analytics Bounded Context +Provides analytics and advanced learning services. +""" + +from .advanced_learning import AdvancedLearningService +from .analytics import AnalyticsEngine, DashboardManager, MarketplaceAnalytics + +__all__ = [ + "AdvancedLearningService", + "AnalyticsEngine", + "DashboardManager", + "MarketplaceAnalytics", +] diff --git a/apps/coordinator-api/src/app/services/adaptive_learning.py b/apps/coordinator-api/src/app/services/ai_analytics/adaptive_learning.py similarity index 99% rename from apps/coordinator-api/src/app/services/adaptive_learning.py rename to apps/coordinator-api/src/app/services/ai_analytics/adaptive_learning.py index 8f085a88..f550950c 100755 --- a/apps/coordinator-api/src/app/services/adaptive_learning.py +++ b/apps/coordinator-api/src/app/services/ai_analytics/adaptive_learning.py @@ -17,7 +17,7 @@ from typing import Any import numpy as np -from ..storage import get_session +from ...storage import get_session class LearningAlgorithm(StrEnum): diff --git a/apps/coordinator-api/src/app/services/advanced_learning.py b/apps/coordinator-api/src/app/services/ai_analytics/advanced_learning.py similarity index 100% rename from apps/coordinator-api/src/app/services/advanced_learning.py rename to apps/coordinator-api/src/app/services/ai_analytics/advanced_learning.py diff --git a/apps/coordinator-api/src/app/services/analytics_service.py b/apps/coordinator-api/src/app/services/ai_analytics/analytics.py similarity index 99% rename from apps/coordinator-api/src/app/services/analytics_service.py rename to apps/coordinator-api/src/app/services/ai_analytics/analytics.py index 5436af6f..13f63a29 100755 --- a/apps/coordinator-api/src/app/services/analytics_service.py +++ b/apps/coordinator-api/src/app/services/ai_analytics/analytics.py @@ -13,7 +13,7 @@ logger = get_logger(__name__) from sqlmodel import Session, and_, select -from ..domain.analytics import ( +from ...domain.analytics import ( AnalyticsAlert, AnalyticsPeriod, DashboardConfig, diff --git a/apps/coordinator-api/src/app/services/ai_surveillance.py b/apps/coordinator-api/src/app/services/ai_analytics/surveillance.py similarity index 100% rename from apps/coordinator-api/src/app/services/ai_surveillance.py rename to apps/coordinator-api/src/app/services/ai_analytics/surveillance.py diff --git a/apps/coordinator-api/src/app/services/ai_trading_engine.py b/apps/coordinator-api/src/app/services/ai_analytics/trading_engine.py similarity index 100% rename from apps/coordinator-api/src/app/services/ai_trading_engine.py rename to apps/coordinator-api/src/app/services/ai_analytics/trading_engine.py diff --git a/apps/coordinator-api/src/app/services/enterprise_api_gateway.py b/apps/coordinator-api/src/app/services/enterprise_api_gateway.py deleted file mode 100755 index e7faa712..00000000 --- a/apps/coordinator-api/src/app/services/enterprise_api_gateway.py +++ /dev/null @@ -1,608 +0,0 @@ -""" -Enterprise API Gateway - Phase 6.1 Implementation -Multi-tenant API routing and management for enterprise clients -Port: 8010 -""" - -import secrets -import time -from datetime import datetime, timezone, timedelta -from enum import StrEnum -from typing import Any -from uuid import uuid4 - -import jwt -from fastapi import Depends, FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.security import HTTPBearer -from pydantic import BaseModel, Field - -from aitbc import get_logger - -logger = get_logger(__name__) - -from ..domain.multitenant import Tenant, TenantApiKey, TenantQuota -from ..exceptions import QuotaExceededError, TenantError -from ..storage.db import get_db - - -# Pydantic models for API requests/responses -class EnterpriseAuthRequest(BaseModel): - tenant_id: str = Field(..., description="Enterprise tenant identifier") - client_id: str = Field(..., description="Enterprise client ID") - client_secret: str = Field(..., description="Enterprise client secret") - auth_method: str = Field(default="client_credentials", description="Authentication method") - scopes: list[str] | None = Field(default=None, description="Requested scopes") - - -class EnterpriseAuthResponse(BaseModel): - access_token: str = Field(..., description="Access token for enterprise API") - token_type: str = Field(default="Bearer", description="Token type") - expires_in: int = Field(..., description="Token expiration in seconds") - refresh_token: str | None = Field(None, description="Refresh token") - scopes: list[str] = Field(..., description="Granted scopes") - tenant_info: dict[str, Any] = Field(..., description="Tenant information") - - -class APIQuotaRequest(BaseModel): - tenant_id: str = Field(..., description="Enterprise tenant identifier") - endpoint: str = Field(..., description="API endpoint") - method: str = Field(..., description="HTTP method") - quota_type: str = Field(default="rate_limit", description="Quota type") - - -class APIQuotaResponse(BaseModel): - quota_limit: int = Field(..., description="Quota limit") - quota_remaining: int = Field(..., description="Remaining quota") - quota_reset: datetime = Field(..., description="Quota reset time") - quota_type: str = Field(..., description="Quota type") - - -class WebhookConfig(BaseModel): - url: str = Field(..., description="Webhook URL") - events: list[str] = Field(..., description="Events to subscribe to") - secret: str | None = Field(None, description="Webhook secret") - active: bool = Field(default=True, description="Webhook active status") - retry_policy: dict[str, Any] | None = Field(None, description="Retry policy") - - -class EnterpriseIntegrationRequest(BaseModel): - integration_type: str = Field(..., description="Integration type (ERP, CRM, etc.)") - provider: str = Field(..., description="Integration provider") - configuration: dict[str, Any] = Field(..., description="Integration configuration") - credentials: dict[str, str] | None = Field(None, description="Integration credentials") - webhook_config: WebhookConfig | None = Field(None, description="Webhook configuration") - - -class EnterpriseMetrics(BaseModel): - api_calls_total: int = Field(..., description="Total API calls") - api_calls_successful: int = Field(..., description="Successful API calls") - average_response_time_ms: float = Field(..., description="Average response time") - error_rate_percent: float = Field(..., description="Error rate percentage") - quota_utilization_percent: float = Field(..., description="Quota utilization") - active_integrations: int = Field(..., description="Active integrations count") - - -class IntegrationStatus(StrEnum): - ACTIVE = "active" - INACTIVE = "inactive" - ERROR = "error" - PENDING = "pending" - - -class EnterpriseIntegration: - """Enterprise integration configuration and management""" - - def __init__( - self, integration_id: str, tenant_id: str, integration_type: str, provider: str, configuration: dict[str, Any] - ): - self.integration_id = integration_id - self.tenant_id = tenant_id - self.integration_type = integration_type - self.provider = provider - self.configuration = configuration - self.status = IntegrationStatus.PENDING - self.created_at = datetime.now(timezone.utc) - self.last_updated = datetime.now(timezone.utc) - self.webhook_config = None - self.metrics = {"api_calls": 0, "errors": 0, "last_call": None} - - -class EnterpriseAPIGateway: - """Enterprise API Gateway with multi-tenant support""" - - def __init__(self): - self.tenant_service = None # Will be initialized with database session - self.active_tokens = {} # In-memory token storage (in production, use Redis) - self.rate_limiters = {} # Per-tenant rate limiters - self.webhooks = {} # Webhook configurations - self.integrations = {} # Enterprise integrations - self.api_metrics = {} # API performance metrics - - # Default quotas - self.default_quotas = { - "rate_limit": 1000, # requests per minute - "daily_limit": 50000, # requests per day - "concurrent_limit": 100, # concurrent requests - } - - # JWT configuration - self.jwt_secret = secrets.token_urlsafe(64) - self.jwt_algorithm = "HS256" - self.token_expiry = 3600 # 1 hour - - async def authenticate_enterprise_client(self, request: EnterpriseAuthRequest, db_session) -> EnterpriseAuthResponse: - """Authenticate enterprise client and issue access token""" - - try: - # Validate tenant and client credentials - tenant = await self._validate_tenant_credentials( - request.tenant_id, request.client_id, request.client_secret, db_session - ) - - # Generate access token - access_token = self._generate_access_token( - tenant_id=request.tenant_id, client_id=request.client_id, scopes=request.scopes or ["enterprise_api"] - ) - - # Generate refresh token - refresh_token = self._generate_refresh_token(request.tenant_id, request.client_id) - - # Store token - self.active_tokens[access_token] = { - "tenant_id": request.tenant_id, - "client_id": request.client_id, - "scopes": request.scopes or ["enterprise_api"], - "expires_at": datetime.now(timezone.utc) + timedelta(seconds=self.token_expiry), - "refresh_token": refresh_token, - } - - return EnterpriseAuthResponse( - access_token=access_token, - token_type="Bearer", - expires_in=self.token_expiry, - refresh_token=refresh_token, - scopes=request.scopes or ["enterprise_api"], - tenant_info={ - "tenant_id": tenant.tenant_id, - "name": tenant.name, - "plan": tenant.plan, - "status": tenant.status.value, - "created_at": tenant.created_at.isoformat(), - }, - ) - - except Exception as e: - logger.error(f"Enterprise authentication failed: {e}") - raise HTTPException(status_code=401, detail="Authentication failed") - - def _generate_access_token(self, tenant_id: str, client_id: str, scopes: list[str]) -> str: - """Generate JWT access token""" - - payload = { - "sub": f"{tenant_id}:{client_id}", - "scopes": scopes, - "iat": datetime.now(timezone.utc), - "exp": datetime.now(timezone.utc) + timedelta(seconds=self.token_expiry), - "type": "access", - } - - return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm) - - def _generate_refresh_token(self, tenant_id: str, client_id: str) -> str: - """Generate refresh token""" - - payload = { - "sub": f"{tenant_id}:{client_id}", - "iat": datetime.now(timezone.utc), - "exp": datetime.now(timezone.utc) + timedelta(days=30), # 30 days - "type": "refresh", - } - - return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm) - - async def _validate_tenant_credentials(self, tenant_id: str, client_id: str, client_secret: str, db_session) -> Tenant: - """Validate tenant credentials""" - - # Find tenant - tenant = db_session.query(Tenant).filter(Tenant.tenant_id == tenant_id).first() - if not tenant: - raise TenantError(f"Tenant {tenant_id} not found") - - # Find API key - api_key = ( - db_session.query(TenantApiKey) - .filter(TenantApiKey.tenant_id == tenant_id, TenantApiKey.client_id == client_id, TenantApiKey.is_active) - .first() - ) - - if not api_key or not secrets.compare_digest(api_key.client_secret, client_secret): - raise TenantError("Invalid client credentials") - - # Check tenant status - if tenant.status.value != "active": - raise TenantError(f"Tenant {tenant_id} is not active") - - return tenant - - async def check_api_quota(self, tenant_id: str, endpoint: str, method: str, db_session) -> APIQuotaResponse: - """Check and enforce API quotas""" - - try: - # Get tenant quota - quota = await self._get_tenant_quota(tenant_id, db_session) - - # Check rate limiting - current_usage = await self._get_current_usage(tenant_id, "rate_limit") - - if current_usage >= quota["rate_limit"]: - raise QuotaExceededError("Rate limit exceeded") - - # Update usage - await self._update_usage(tenant_id, "rate_limit", current_usage + 1) - - return APIQuotaResponse( - quota_limit=quota["rate_limit"], - quota_remaining=quota["rate_limit"] - current_usage - 1, - quota_reset=datetime.now(timezone.utc) + timedelta(minutes=1), - quota_type="rate_limit", - ) - - except QuotaExceededError: - raise - except Exception as e: - logger.error(f"Quota check failed: {e}") - raise HTTPException(status_code=500, detail="Quota check failed") - - async def _get_tenant_quota(self, tenant_id: str, db_session) -> dict[str, int]: - """Get tenant quota configuration""" - - # Get tenant-specific quota - tenant_quota = db_session.query(TenantQuota).filter(TenantQuota.tenant_id == tenant_id).first() - - if tenant_quota: - return { - "rate_limit": tenant_quota.rate_limit or self.default_quotas["rate_limit"], - "daily_limit": tenant_quota.daily_limit or self.default_quotas["daily_limit"], - "concurrent_limit": tenant_quota.concurrent_limit or self.default_quotas["concurrent_limit"], - } - - return self.default_quotas - - async def _get_current_usage(self, tenant_id: str, quota_type: str) -> int: - """Get current quota usage""" - - # In production, use Redis or database for persistent storage - - if quota_type == "rate_limit": - # Get usage in the last minute - return len([t for t in self.rate_limiters.get(tenant_id, []) if datetime.now(timezone.utc) - t < timedelta(minutes=1)]) - - return 0 - - async def _update_usage(self, tenant_id: str, quota_type: str, usage: int): - """Update quota usage""" - - if quota_type == "rate_limit": - if tenant_id not in self.rate_limiters: - self.rate_limiters[tenant_id] = [] - - # Add current timestamp - self.rate_limiters[tenant_id].append(datetime.now(timezone.utc)) - - # Clean old entries (older than 1 minute) - cutoff = datetime.now(timezone.utc) - timedelta(minutes=1) - self.rate_limiters[tenant_id] = [t for t in self.rate_limiters[tenant_id] if t > cutoff] - - async def create_enterprise_integration( - self, tenant_id: str, request: EnterpriseIntegrationRequest, db_session - ) -> dict[str, Any]: - """Create new enterprise integration""" - - try: - # Validate tenant - tenant = db_session.query(Tenant).filter(Tenant.tenant_id == tenant_id).first() - if not tenant: - raise TenantError(f"Tenant {tenant_id} not found") - - # Create integration - integration_id = str(uuid4()) - integration = EnterpriseIntegration( - integration_id=integration_id, - tenant_id=tenant_id, - integration_type=request.integration_type, - provider=request.provider, - configuration=request.configuration, - ) - - # Store webhook configuration - if request.webhook_config: - integration.webhook_config = request.webhook_config.dict() - self.webhooks[integration_id] = request.webhook_config.dict() - - # Store integration - self.integrations[integration_id] = integration - - # Initialize integration - await self._initialize_integration(integration) - - return { - "integration_id": integration_id, - "status": integration.status.value, - "created_at": integration.created_at.isoformat(), - "configuration": integration.configuration, - } - - except Exception as e: - logger.error(f"Failed to create enterprise integration: {e}") - raise HTTPException(status_code=500, detail="Integration creation failed") - - async def _initialize_integration(self, integration: EnterpriseIntegration): - """Initialize enterprise integration""" - - try: - # Integration-specific initialization logic - if integration.integration_type.lower() == "erp": - await self._initialize_erp_integration(integration) - elif integration.integration_type.lower() == "crm": - await self._initialize_crm_integration(integration) - elif integration.integration_type.lower() == "bi": - await self._initialize_bi_integration(integration) - - integration.status = IntegrationStatus.ACTIVE - integration.last_updated = datetime.now(timezone.utc) - - except Exception as e: - logger.error(f"Integration initialization failed: {e}") - integration.status = IntegrationStatus.ERROR - raise - - async def _initialize_erp_integration(self, integration: EnterpriseIntegration): - """Initialize ERP integration""" - - # ERP-specific initialization - provider = integration.provider.lower() - - if provider == "sap": - await self._initialize_sap_integration(integration) - elif provider == "oracle": - await self._initialize_oracle_integration(integration) - elif provider == "microsoft": - await self._initialize_microsoft_integration(integration) - - logger.info(f"ERP integration initialized: {integration.provider}") - - async def _initialize_sap_integration(self, integration: EnterpriseIntegration): - """Initialize SAP ERP integration""" - - # SAP integration logic - config = integration.configuration - - # Validate SAP configuration - required_fields = ["system_id", "client", "username", "password", "host"] - for field in required_fields: - if field not in config: - raise ValueError(f"SAP integration requires {field}") - - # Test SAP connection - # In production, implement actual SAP connection testing - logger.info(f"SAP connection test successful for {integration.integration_id}") - - async def get_enterprise_metrics(self, tenant_id: str, db_session) -> EnterpriseMetrics: - """Get enterprise metrics and analytics""" - - try: - # Get API metrics - api_metrics = self.api_metrics.get( - tenant_id, {"total_calls": 0, "successful_calls": 0, "failed_calls": 0, "response_times": []} - ) - - # Calculate metrics - total_calls = api_metrics["total_calls"] - successful_calls = api_metrics["successful_calls"] - failed_calls = api_metrics["failed_calls"] - - average_response_time = ( - sum(api_metrics["response_times"]) / len(api_metrics["response_times"]) - if api_metrics["response_times"] - else 0.0 - ) - - error_rate = (failed_calls / total_calls * 100) if total_calls > 0 else 0.0 - - # Get quota utilization - current_usage = await self._get_current_usage(tenant_id, "rate_limit") - quota = await self._get_tenant_quota(tenant_id, db_session) - quota_utilization = (current_usage / quota["rate_limit"] * 100) if quota["rate_limit"] > 0 else 0.0 - - # Count active integrations - active_integrations = len( - [i for i in self.integrations.values() if i.tenant_id == tenant_id and i.status == IntegrationStatus.ACTIVE] - ) - - return EnterpriseMetrics( - api_calls_total=total_calls, - api_calls_successful=successful_calls, - average_response_time_ms=average_response_time, - error_rate_percent=error_rate, - quota_utilization_percent=quota_utilization, - active_integrations=active_integrations, - ) - - except Exception as e: - logger.error(f"Failed to get enterprise metrics: {e}") - raise HTTPException(status_code=500, detail="Metrics retrieval failed") - - async def record_api_call(self, tenant_id: str, endpoint: str, response_time: float, success: bool): - """Record API call for metrics""" - - if tenant_id not in self.api_metrics: - self.api_metrics[tenant_id] = {"total_calls": 0, "successful_calls": 0, "failed_calls": 0, "response_times": []} - - metrics = self.api_metrics[tenant_id] - metrics["total_calls"] += 1 - - if success: - metrics["successful_calls"] += 1 - else: - metrics["failed_calls"] += 1 - - metrics["response_times"].append(response_time) - - # Keep only last 1000 response times - if len(metrics["response_times"]) > 1000: - metrics["response_times"] = metrics["response_times"][-1000:] - - -# FastAPI application -app = FastAPI( - title="Enterprise API Gateway", - description="Multi-tenant API routing and management for enterprise clients", - version="6.1.0", - docs_url="/docs", - redoc_url="/redoc", -) - -# CORS middleware -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Security -security = HTTPBearer() - -# Global gateway instance -gateway = EnterpriseAPIGateway() - - -# Dependency for database session -async def get_db_session(): - """Get database session""" - - async with get_db() as session: - yield session - - -# Middleware for API metrics -@app.middleware("http") -async def api_metrics_middleware(request: Request, call_next): - """Middleware to record API metrics""" - - start_time = time.time() - - # Extract tenant from token if available - tenant_id = None - authorization = request.headers.get("authorization") - if authorization and authorization.startswith("Bearer "): - token = authorization[7:] - token_data = gateway.active_tokens.get(token) - if token_data: - tenant_id = token_data["tenant_id"] - - # Process request - response = await call_next(request) - - # Record metrics - response_time = (time.time() - start_time) * 1000 # Convert to milliseconds - success = response.status_code < 400 - - if tenant_id: - await gateway.record_api_call(tenant_id, str(request.url.path), response_time, success) - - return response - - -@app.post("/enterprise/auth") -async def enterprise_auth(request: EnterpriseAuthRequest, db_session=Depends(get_db_session)): - """Authenticate enterprise client""" - - result = await gateway.authenticate_enterprise_client(request, db_session) - return result - - -@app.post("/enterprise/quota/check") -async def check_quota(request: APIQuotaRequest, db_session=Depends(get_db_session)): - """Check API quota""" - - result = await gateway.check_api_quota(request.tenant_id, request.endpoint, request.method, db_session) - return result - - -@app.post("/enterprise/integrations") -async def create_integration(request: EnterpriseIntegrationRequest, db_session=Depends(get_db_session)): - """Create enterprise integration""" - - # Extract tenant from token (in production, proper authentication) - tenant_id = "demo_tenant" # Placeholder - - result = await gateway.create_enterprise_integration(tenant_id, request, db_session) - return result - - -@app.get("/enterprise/analytics") -async def get_analytics(db_session=Depends(get_db_session)): - """Get enterprise analytics dashboard""" - - # Extract tenant from token (in production, proper authentication) - tenant_id = "demo_tenant" # Placeholder - - result = await gateway.get_enterprise_metrics(tenant_id, db_session) - return result - - -@app.get("/enterprise/status") -async def get_status(): - """Get enterprise gateway status""" - - return { - "service": "Enterprise API Gateway", - "version": "6.1.0", - "port": 8010, - "status": "operational", - "active_tenants": len({token["tenant_id"] for token in gateway.active_tokens.values()}), - "active_integrations": len(gateway.integrations), - "timestamp": datetime.now(timezone.utc).isoformat(), - } - - -@app.get("/") -async def root(): - """Root endpoint""" - return { - "service": "Enterprise API Gateway", - "version": "6.1.0", - "port": 8010, - "capabilities": [ - "Multi-tenant API Management", - "Enterprise Authentication", - "API Quota Management", - "Enterprise Integration Framework", - "Real-time Analytics", - ], - "status": "operational", - } - - -@app.get("/health") -async def health_check(): - """Health check endpoint""" - return { - "status": "healthy", - "timestamp": datetime.now(timezone.utc).isoformat(), - "services": { - "api_gateway": "operational", - "authentication": "operational", - "quota_management": "operational", - "integration_framework": "operational", - }, - } - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8010) diff --git a/apps/coordinator-api/src/app/services/enterprise_integration.py b/apps/coordinator-api/src/app/services/enterprise_integration.py deleted file mode 100755 index 4faf59b6..00000000 --- a/apps/coordinator-api/src/app/services/enterprise_integration.py +++ /dev/null @@ -1,1127 +0,0 @@ -""" -Enterprise Integration Framework - Phase 6.1 Implementation -ERP, CRM, and business system connectors for enterprise clients -""" - -import asyncio -import json -import xml.etree.ElementTree as ET -from dataclasses import dataclass, field -from datetime import datetime, timezone, timedelta -from enum import Enum -from typing import Any, Dict, List, Optional, Union -from uuid import uuid4 - -import aiohttp -from pydantic import BaseModel, Field, validator - -from aitbc import get_logger - -logger = get_logger(__name__) - - - -class IntegrationType(str, Enum): - """Enterprise integration types""" - ERP = "erp" - CRM = "crm" - BI = "bi" - HR = "hr" - FINANCE = "finance" - CUSTOM = "custom" - -class IntegrationProvider(str, Enum): - """Supported integration providers""" - SAP = "sap" - ORACLE = "oracle" - MICROSOFT = "microsoft" - SALESFORCE = "salesforce" - HUBSPOT = "hubspot" - TABLEAU = "tableau" - POWERBI = "powerbi" - WORKDAY = "workday" - -class DataFormat(str, Enum): - """Data exchange formats""" - JSON = "json" - XML = "xml" - CSV = "csv" - ODATA = "odata" - SOAP = "soap" - REST = "rest" - -@dataclass -class IntegrationConfig: - """Integration configuration""" - integration_id: str - tenant_id: str - integration_type: IntegrationType - provider: IntegrationProvider - endpoint_url: str - authentication: Dict[str, str] - data_format: DataFormat - mapping_rules: Dict[str, Any] = field(default_factory=dict) - retry_policy: Dict[str, Any] = field(default_factory=dict) - rate_limits: Dict[str, int] = field(default_factory=dict) - webhook_config: Optional[Dict[str, Any]] = None - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - last_sync: Optional[datetime] = None - status: str = "active" - -class IntegrationRequest(BaseModel): - """Integration request model""" - integration_id: str = Field(..., description="Integration identifier") - operation: str = Field(..., description="Operation to perform") - data: Dict[str, Any] = Field(..., description="Request data") - parameters: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters") - -class IntegrationResponse(BaseModel): - """Integration response model""" - success: bool = Field(..., description="Operation success status") - data: Optional[Dict[str, Any]] = Field(None, description="Response data") - error: Optional[str] = Field(None, description="Error message") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Response metadata") - -class ERPIntegration: - """Base ERP integration class""" - - def __init__(self, config: IntegrationConfig): - self.config = config - self.session = None - self.logger = get_logger(f"erp.{config.provider.value}") - - async def initialize(self): - """Initialize ERP connection (generic mock implementation)""" - try: - # Create generic HTTP session - self.session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=30) - ) - self.logger.info(f"Generic ERP connection initialized for {self.config.integration_id}") - return True - except Exception as e: - self.logger.error(f"ERP initialization failed: {e}") - raise - - async def test_connection(self) -> bool: - """Test ERP connection (generic mock implementation)""" - try: - # Generic connection test - always returns True for mock - self.logger.info(f"Generic ERP connection test passed for {self.config.integration_id}") - return True - except Exception as e: - self.logger.error(f"ERP connection test failed: {e}") - return False - - async def sync_data(self, data_type: str, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync data from ERP (generic mock implementation)""" - try: - # Generic sync - returns mock data - mock_data = { - "data_type": data_type, - "records": [], - "count": 0, - "timestamp": datetime.now(timezone.utc).isoformat() - } - return IntegrationResponse( - success=True, - data=mock_data, - metadata={"sync_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"ERP data sync failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def push_data(self, data_type: str, data: Dict[str, Any]) -> IntegrationResponse: - """Push data to ERP (generic mock implementation)""" - try: - # Generic push - returns success - return IntegrationResponse( - success=True, - data={"data_type": data_type, "pushed": True}, - metadata={"push_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"ERP data push failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def close(self): - """Close ERP connection""" - if self.session: - await self.session.close() - -class SAPIntegration(ERPIntegration): - """SAP ERP integration""" - - def __init__(self, config: IntegrationConfig): - super().__init__(config) - self.system_id = config.authentication.get("system_id") - self.client = config.authentication.get("client") - self.username = config.authentication.get("username") - self.password = config.authentication.get("password") - self.language = config.authentication.get("language", "EN") - - async def initialize(self): - """Initialize SAP connection""" - try: - # Create HTTP session for SAP web services - self.session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=30), - auth=aiohttp.BasicAuth(self.username, self.password) - ) - - # Test connection - if await self.test_connection(): - self.logger.info(f"SAP connection established for {self.config.integration_id}") - return True - else: - raise Exception("SAP connection test failed") - - except Exception as e: - self.logger.error(f"SAP initialization failed: {e}") - raise - - async def test_connection(self) -> bool: - """Test SAP connection""" - try: - # SAP system info endpoint - url = f"{self.config.endpoint_url}/sap/bc/ping" - - async with self.session.get(url) as response: - if response.status == 200: - return True - else: - self.logger.error(f"SAP ping failed: {response.status}") - return False - - except Exception as e: - self.logger.error(f"SAP connection test failed: {e}") - return False - - async def sync_data(self, data_type: str, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync data from SAP""" - - try: - if data_type == "customers": - return await self._sync_customers(filters) - elif data_type == "orders": - return await self._sync_orders(filters) - elif data_type == "products": - return await self._sync_products(filters) - else: - return IntegrationResponse( - success=False, - error=f"Unsupported data type: {data_type}" - ) - - except Exception as e: - self.logger.error(f"SAP data sync failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def _sync_customers(self, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync customer data from SAP""" - - try: - # SAP BAPI customer list endpoint - url = f"{self.config.endpoint_url}/sap/bc/sap/rfc/customer_list" - - params = { - "client": self.client, - "language": self.language - } - - if filters: - params.update(filters) - - async with self.session.get(url, params=params) as response: - if response.status == 200: - data = await response.json() - - # Apply mapping rules - mapped_data = self._apply_mapping_rules(data, "customers") - - return IntegrationResponse( - success=True, - data=mapped_data, - metadata={ - "records_count": len(mapped_data.get("customers", [])), - "sync_time": datetime.now(timezone.utc).isoformat() - } - ) - else: - error_text = await response.text() - return IntegrationResponse( - success=False, - error=f"SAP API error: {response.status} - {error_text}" - ) - - except Exception as e: - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def _sync_orders(self, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync order data from SAP""" - - try: - # SAP sales order endpoint - url = f"{self.config.endpoint_url}/sap/bc/sap/rfc/sales_orders" - - params = { - "client": self.client, - "language": self.language - } - - if filters: - params.update(filters) - - async with self.session.get(url, params=params) as response: - if response.status == 200: - data = await response.json() - - # Apply mapping rules - mapped_data = self._apply_mapping_rules(data, "orders") - - return IntegrationResponse( - success=True, - data=mapped_data, - metadata={ - "records_count": len(mapped_data.get("orders", [])), - "sync_time": datetime.now(timezone.utc).isoformat() - } - ) - else: - error_text = await response.text() - return IntegrationResponse( - success=False, - error=f"SAP API error: {response.status} - {error_text}" - ) - - except Exception as e: - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def _sync_products(self, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync product data from SAP""" - - try: - # SAP material master endpoint - url = f"{self.config.endpoint_url}/sap/bc/sap/rfc/material_master" - - params = { - "client": self.client, - "language": self.language - } - - if filters: - params.update(filters) - - async with self.session.get(url, params=params) as response: - if response.status == 200: - data = await response.json() - - # Apply mapping rules - mapped_data = self._apply_mapping_rules(data, "products") - - return IntegrationResponse( - success=True, - data=mapped_data, - metadata={ - "records_count": len(mapped_data.get("products", [])), - "sync_time": datetime.now(timezone.utc).isoformat() - } - ) - else: - error_text = await response.text() - return IntegrationResponse( - success=False, - error=f"SAP API error: {response.status} - {error_text}" - ) - - except Exception as e: - return IntegrationResponse( - success=False, - error=str(e) - ) - - def _apply_mapping_rules(self, data: Dict[str, Any], data_type: str) -> Dict[str, Any]: - """Apply mapping rules to transform data""" - - mapping_rules = self.config.mapping_rules.get(data_type, {}) - mapped_data = {} - - # Apply field mappings - for sap_field, aitbc_field in mapping_rules.get("field_mappings", {}).items(): - if sap_field in data: - mapped_data[aitbc_field] = data[sap_field] - - # Apply transformations - transformations = mapping_rules.get("transformations", {}) - for field, transform in transformations.items(): - if field in mapped_data: - # Apply transformation logic - if transform["type"] == "date_format": - # Date format transformation - mapped_data[field] = self._transform_date(mapped_data[field], transform["format"]) - elif transform["type"] == "numeric": - # Numeric transformation - mapped_data[field] = self._transform_numeric(mapped_data[field], transform) - - return {data_type: mapped_data} - - def _transform_date(self, date_value: str, format_str: str) -> str: - """Transform date format""" - try: - # Parse SAP date format and convert to target format - # SAP typically uses YYYYMMDD format - if len(date_value) == 8 and date_value.isdigit(): - year = date_value[:4] - month = date_value[4:6] - day = date_value[6:8] - return f"{year}-{month}-{day}" - return date_value - except (ValueError, IndexError, AttributeError, TypeError): - return date_value - - def _transform_numeric(self, value: str, transform: Dict[str, Any]) -> Union[str, int, float]: - """Transform numeric values""" - try: - if transform.get("type") == "decimal": - return float(value) / (10 ** transform.get("scale", 2)) - elif transform.get("type") == "integer": - return int(float(value)) - return value - except Exception: - return value - -class OracleIntegration(ERPIntegration): - """Oracle ERP integration""" - - def __init__(self, config: IntegrationConfig): - super().__init__(config) - self.service_name = config.authentication.get("service_name") - self.username = config.authentication.get("username") - self.password = config.authentication.get("password") - - async def initialize(self): - """Initialize Oracle connection""" - try: - # Create HTTP session for Oracle REST APIs - self.session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=30), - auth=aiohttp.BasicAuth(self.username, self.password) - ) - - # Test connection - if await self.test_connection(): - self.logger.info(f"Oracle connection established for {self.config.integration_id}") - return True - else: - raise Exception("Oracle connection test failed") - - except Exception as e: - self.logger.error(f"Oracle initialization failed: {e}") - raise - - async def test_connection(self) -> bool: - """Test Oracle connection""" - try: - # Oracle Fusion Cloud REST API endpoint - url = f"{self.config.endpoint_url}/fscmRestApi/resources/latest/version" - - async with self.session.get(url) as response: - if response.status == 200: - return True - else: - self.logger.error(f"Oracle version check failed: {response.status}") - return False - - except Exception as e: - self.logger.error(f"Oracle connection test failed: {e}") - return False - - async def sync_data(self, data_type: str, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync data from Oracle""" - - try: - if data_type == "customers": - return await self._sync_customers(filters) - elif data_type == "orders": - return await self._sync_orders(filters) - elif data_type == "products": - return await self._sync_products(filters) - else: - return IntegrationResponse( - success=False, - error=f"Unsupported data type: {data_type}" - ) - - except Exception as e: - self.logger.error(f"Oracle data sync failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def _sync_customers(self, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync customer data from Oracle""" - - try: - # Oracle Fusion Cloud Customer endpoint - url = f"{self.config.endpoint_url}/fscmRestApi/resources/latest/customerAccounts" - - params = {} - if filters: - params.update(filters) - - async with self.session.get(url, params=params) as response: - if response.status == 200: - data = await response.json() - - # Apply mapping rules - mapped_data = self._apply_mapping_rules(data, "customers") - - return IntegrationResponse( - success=True, - data=mapped_data, - metadata={ - "records_count": len(mapped_data.get("customers", [])), - "sync_time": datetime.now(timezone.utc).isoformat() - } - ) - else: - error_text = await response.text() - return IntegrationResponse( - success=False, - error=f"Oracle API error: {response.status} - {error_text}" - ) - - except Exception as e: - return IntegrationResponse( - success=False, - error=str(e) - ) - - def _apply_mapping_rules(self, data: Dict[str, Any], data_type: str) -> Dict[str, Any]: - """Apply mapping rules to transform data""" - - mapping_rules = self.config.mapping_rules.get(data_type, {}) - mapped_data = {} - - # Apply field mappings - for oracle_field, aitbc_field in mapping_rules.get("field_mappings", {}).items(): - if oracle_field in data: - mapped_data[aitbc_field] = data[oracle_field] - - return {data_type: mapped_data} - -class CRMIntegration: - """Base CRM integration class""" - - def __init__(self, config: IntegrationConfig): - self.config = config - self.session = None - self.logger = get_logger(f"crm.{config.provider.value}") - - async def initialize(self): - """Initialize CRM connection (generic mock implementation)""" - try: - # Create generic HTTP session - self.session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=30) - ) - self.logger.info(f"Generic CRM connection initialized for {self.config.integration_id}") - return True - except Exception as e: - self.logger.error(f"CRM initialization failed: {e}") - raise - - async def test_connection(self) -> bool: - """Test CRM connection (generic mock implementation)""" - try: - # Generic connection test - always returns True for mock - self.logger.info(f"Generic CRM connection test passed for {self.config.integration_id}") - return True - except Exception as e: - self.logger.error(f"CRM connection test failed: {e}") - return False - - async def sync_contacts(self, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync contacts from CRM (generic mock implementation)""" - try: - mock_data = { - "contacts": [], - "count": 0, - "timestamp": datetime.now(timezone.utc).isoformat() - } - return IntegrationResponse( - success=True, - data=mock_data, - metadata={"sync_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"CRM contact sync failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def sync_opportunities(self, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync opportunities from CRM (generic mock implementation)""" - try: - mock_data = { - "opportunities": [], - "count": 0, - "timestamp": datetime.now(timezone.utc).isoformat() - } - return IntegrationResponse( - success=True, - data=mock_data, - metadata={"sync_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"CRM opportunity sync failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def create_lead(self, lead_data: Dict[str, Any]) -> IntegrationResponse: - """Create lead in CRM (generic mock implementation)""" - try: - return IntegrationResponse( - success=True, - data={"lead_id": str(uuid4()), "created": True}, - metadata={"create_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"CRM lead creation failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def close(self): - """Close CRM connection""" - if self.session: - await self.session.close() - -class SalesforceIntegration(CRMIntegration): - """Salesforce CRM integration""" - - def __init__(self, config: IntegrationConfig): - super().__init__(config) - self.client_id = config.authentication.get("client_id") - self.client_secret = config.authentication.get("client_secret") - self.username = config.authentication.get("username") - self.password = config.authentication.get("password") - self.security_token = config.authentication.get("security_token") - self.access_token = None - - async def initialize(self): - """Initialize Salesforce connection""" - try: - # Create HTTP session - self.session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=30) - ) - - # Authenticate with Salesforce - if await self._authenticate(): - self.logger.info(f"Salesforce connection established for {self.config.integration_id}") - return True - else: - raise Exception("Salesforce authentication failed") - - except Exception as e: - self.logger.error(f"Salesforce initialization failed: {e}") - raise - - async def _authenticate(self) -> bool: - """Authenticate with Salesforce""" - - try: - # Salesforce OAuth2 endpoint - url = f"{self.config.endpoint_url}/services/oauth2/token" - - data = { - "grant_type": "password", - "client_id": self.client_id, - "client_secret": self.client_secret, - "username": self.username, - "password": f"{self.password}{self.security_token}" - } - - async with self.session.post(url, data=data) as response: - if response.status == 200: - token_data = await response.json() - self.access_token = token_data["access_token"] - return True - else: - error_text = await response.text() - self.logger.error(f"Salesforce authentication failed: {error_text}") - return False - - except Exception as e: - self.logger.error(f"Salesforce authentication error: {e}") - return False - - async def test_connection(self) -> bool: - """Test Salesforce connection""" - - try: - if not self.access_token: - return False - - # Salesforce identity endpoint - url = f"{self.config.endpoint_url}/services/oauth2/userinfo" - - headers = { - "Authorization": f"Bearer {self.access_token}" - } - - async with self.session.get(url, headers=headers) as response: - return response.status == 200 - - except Exception as e: - self.logger.error(f"Salesforce connection test failed: {e}") - return False - - async def sync_contacts(self, filters: Optional[Dict] = None) -> IntegrationResponse: - """Sync contacts from Salesforce""" - - try: - if not self.access_token: - return IntegrationResponse( - success=False, - error="Not authenticated" - ) - - # Salesforce contacts endpoint - url = f"{self.config.endpoint_url}/services/data/v52.0/sobjects/Contact" - - headers = { - "Authorization": f"Bearer {self.access_token}", - "Content-Type": "application/json" - } - - params = {} - if filters: - params.update(filters) - - async with self.session.get(url, headers=headers, params=params) as response: - if response.status == 200: - data = await response.json() - - # Apply mapping rules - mapped_data = self._apply_mapping_rules(data, "contacts") - - return IntegrationResponse( - success=True, - data=mapped_data, - metadata={ - "records_count": len(data.get("records", [])), - "sync_time": datetime.now(timezone.utc).isoformat() - } - ) - else: - error_text = await response.text() - return IntegrationResponse( - success=False, - error=f"Salesforce API error: {response.status} - {error_text}" - ) - - except Exception as e: - self.logger.error(f"Salesforce contacts sync failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - def _apply_mapping_rules(self, data: Dict[str, Any], data_type: str) -> Dict[str, Any]: - """Apply mapping rules to transform data""" - - mapping_rules = self.config.mapping_rules.get(data_type, {}) - mapped_data = {} - - # Apply field mappings - for salesforce_field, aitbc_field in mapping_rules.get("field_mappings", {}).items(): - if salesforce_field in data: - mapped_data[aitbc_field] = data[salesforce_field] - - return {data_type: mapped_data} - -class BillingIntegration: - """Base billing integration class""" - - def __init__(self, config: IntegrationConfig): - self.config = config - self.session = None - self.logger = get_logger(f"billing.{config.provider.value}") - - async def initialize(self): - """Initialize billing connection (generic mock implementation)""" - try: - # Create generic HTTP session - self.session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=30) - ) - self.logger.info(f"Generic billing connection initialized for {self.config.integration_id}") - return True - except Exception as e: - self.logger.error(f"Billing initialization failed: {e}") - raise - - async def test_connection(self) -> bool: - """Test billing connection (generic mock implementation)""" - try: - # Generic connection test - always returns True for mock - self.logger.info(f"Generic billing connection test passed for {self.config.integration_id}") - return True - except Exception as e: - self.logger.error(f"Billing connection test failed: {e}") - return False - - async def generate_invoice(self, billing_data: Dict[str, Any]) -> IntegrationResponse: - """Generate invoice (generic mock implementation)""" - try: - return IntegrationResponse( - success=True, - data={"invoice_id": str(uuid4()), "status": "generated"}, - metadata={"billing_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"Invoice generation failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def process_payment(self, payment_data: Dict[str, Any]) -> IntegrationResponse: - """Process payment (generic mock implementation)""" - try: - return IntegrationResponse( - success=True, - data={"payment_id": str(uuid4()), "status": "processed"}, - metadata={"payment_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"Payment processing failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def track_usage(self, usage_data: Dict[str, Any]) -> IntegrationResponse: - """Track usage (generic mock implementation)""" - try: - return IntegrationResponse( - success=True, - data={"usage_id": str(uuid4()), "tracked": True}, - metadata={"tracking_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"Usage tracking failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def close(self): - """Close billing connection""" - if self.session: - await self.session.close() - -class ComplianceIntegration: - """Base compliance integration class""" - - def __init__(self, config: IntegrationConfig): - self.config = config - self.session = None - self.logger = get_logger(f"compliance.{config.provider.value}") - - async def initialize(self): - """Initialize compliance connection (generic mock implementation)""" - try: - # Create generic HTTP session - self.session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=30) - ) - self.logger.info(f"Generic compliance connection initialized for {self.config.integration_id}") - return True - except Exception as e: - self.logger.error(f"Compliance initialization failed: {e}") - raise - - async def test_connection(self) -> bool: - """Test compliance connection (generic mock implementation)""" - try: - # Generic connection test - always returns True for mock - self.logger.info(f"Generic compliance connection test passed for {self.config.integration_id}") - return True - except Exception as e: - self.logger.error(f"Compliance connection test failed: {e}") - return False - - async def log_audit(self, audit_data: Dict[str, Any]) -> IntegrationResponse: - """Log audit event (generic mock implementation)""" - try: - return IntegrationResponse( - success=True, - data={"audit_id": str(uuid4()), "logged": True}, - metadata={"audit_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"Audit logging failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def enforce_policy(self, policy_data: Dict[str, Any]) -> IntegrationResponse: - """Enforce compliance policy (generic mock implementation)""" - try: - return IntegrationResponse( - success=True, - data={"policy_id": str(uuid4()), "enforced": True}, - metadata={"policy_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"Policy enforcement failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def generate_report(self, report_data: Dict[str, Any]) -> IntegrationResponse: - """Generate compliance report (generic mock implementation)""" - try: - return IntegrationResponse( - success=True, - data={"report_id": str(uuid4()), "generated": True}, - metadata={"report_type": "generic_mock"} - ) - except Exception as e: - self.logger.error(f"Report generation failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def close(self): - """Close compliance connection""" - if self.session: - await self.session.close() - -class EnterpriseIntegrationFramework: - """Enterprise integration framework manager""" - - def __init__(self): - self.integrations = {} # Active integrations - self.logger = logger - - async def create_integration(self, config: IntegrationConfig) -> bool: - """Create and initialize enterprise integration""" - - try: - # Create integration instance based on type and provider - integration = await self._create_integration_instance(config) - - # Initialize integration - await integration.initialize() - - # Store integration - self.integrations[config.integration_id] = integration - - self.logger.info(f"Enterprise integration created: {config.integration_id}") - return True - - except Exception as e: - self.logger.error(f"Failed to create integration {config.integration_id}: {e}") - return False - - async def _create_integration_instance(self, config: IntegrationConfig): - """Create integration instance based on configuration""" - - if config.integration_type == IntegrationType.ERP: - if config.provider == IntegrationProvider.SAP: - return SAPIntegration(config) - elif config.provider == IntegrationProvider.ORACLE: - return OracleIntegration(config) - else: - raise ValueError(f"Unsupported ERP provider: {config.provider}") - - elif config.integration_type == IntegrationType.CRM: - if config.provider == IntegrationProvider.SALESFORCE: - return SalesforceIntegration(config) - else: - raise ValueError(f"Unsupported CRM provider: {config.provider}") - - else: - raise ValueError(f"Unsupported integration type: {config.integration_type}") - - async def execute_integration_request(self, request: IntegrationRequest) -> IntegrationResponse: - """Execute integration request""" - - try: - integration = self.integrations.get(request.integration_id) - if not integration: - return IntegrationResponse( - success=False, - error=f"Integration not found: {request.integration_id}" - ) - - # Execute operation based on integration type - if isinstance(integration, ERPIntegration): - if request.operation == "sync_data": - data_type = request.parameters.get("data_type", "customers") - filters = request.parameters.get("filters") - return await integration.sync_data(data_type, filters) - elif request.operation == "push_data": - data_type = request.parameters.get("data_type", "customers") - return await integration.push_data(data_type, request.data) - - elif isinstance(integration, CRMIntegration): - if request.operation == "sync_contacts": - filters = request.parameters.get("filters") - return await integration.sync_contacts(filters) - elif request.operation == "sync_opportunities": - filters = request.parameters.get("filters") - return await integration.sync_opportunities(filters) - elif request.operation == "create_lead": - return await integration.create_lead(request.data) - - return IntegrationResponse( - success=False, - error=f"Unsupported operation: {request.operation}" - ) - - except Exception as e: - self.logger.error(f"Integration request failed: {e}") - return IntegrationResponse( - success=False, - error=str(e) - ) - - async def test_integration(self, integration_id: str) -> bool: - """Test integration connection""" - - integration = self.integrations.get(integration_id) - if not integration: - return False - - return await integration.test_connection() - - async def get_integration_status(self, integration_id: str) -> Dict[str, Any]: - """Get integration status""" - - integration = self.integrations.get(integration_id) - if not integration: - return {"status": "not_found"} - - return { - "integration_id": integration_id, - "integration_type": integration.config.integration_type.value, - "provider": integration.config.provider.value, - "endpoint_url": integration.config.endpoint_url, - "status": "active", - "last_test": datetime.now(timezone.utc).isoformat() - } - - async def close_integration(self, integration_id: str): - """Close integration connection""" - - integration = self.integrations.get(integration_id) - if integration: - await integration.close() - del self.integrations[integration_id] - self.logger.info(f"Integration closed: {integration_id}") - - async def close_all_integrations(self): - """Close all integration connections""" - - for integration_id in list(self.integrations.keys()): - await self.close_integration(integration_id) - -# Global integration framework instance -integration_framework = EnterpriseIntegrationFramework() - -# CLI Interface Functions -def create_tenant(name: str, domain: str) -> str: - """Create a new tenant""" - return api_gateway.create_tenant(name, domain) - -def get_tenant_info(tenant_id: str) -> Optional[Dict[str, Any]]: - """Get tenant information""" - tenant = api_gateway.get_tenant(tenant_id) - if tenant: - return { - "tenant_id": tenant.tenant_id, - "name": tenant.name, - "domain": tenant.domain, - "status": tenant.status.value, - "created_at": tenant.created_at.isoformat(), - "features": tenant.features - } - return None - -def generate_api_key(tenant_id: str) -> str: - """Generate API key for tenant""" - return security_manager.generate_api_key(tenant_id) - -def register_integration(tenant_id: str, name: str, integration_type: str, config: Dict[str, Any]) -> str: - """Register third-party integration""" - return integration_framework.register_integration(tenant_id, name, IntegrationType(integration_type), config) - -def get_system_status() -> Dict[str, Any]: - """Get enterprise integration system status""" - return { - "tenants": len(api_gateway.tenants), - "endpoints": len(api_gateway.endpoints), - "integrations": len(api_gateway.integrations), - "security_events": len(api_gateway.security_events), - "system_health": "operational" - } - -def list_tenants() -> List[Dict[str, Any]]: - """List all tenants""" - return [ - { - "tenant_id": tenant.tenant_id, - "name": tenant.name, - "domain": tenant.domain, - "status": tenant.status.value, - "features": tenant.features - } - for tenant in api_gateway.tenants.values() - ] - -def list_integrations(tenant_id: Optional[str] = None) -> List[Dict[str, Any]]: - """List integrations""" - integrations = api_gateway.integrations.values() - if tenant_id: - integrations = [i for i in integrations if i.tenant_id == tenant_id] - - return [ - { - "integration_id": i.integration_id, - "name": i.name, - "type": i.type.value, - "tenant_id": i.tenant_id, - "status": i.status, - "created_at": i.created_at.isoformat() - } - for i in integrations - ] diff --git a/apps/coordinator-api/src/app/services/enterprise_integration/__init__.py b/apps/coordinator-api/src/app/services/enterprise_integration/__init__.py index 77b647a6..9c749f5e 100644 --- a/apps/coordinator-api/src/app/services/enterprise_integration/__init__.py +++ b/apps/coordinator-api/src/app/services/enterprise_integration/__init__.py @@ -1,16 +1,14 @@ """ Enterprise Integration Bounded Context -Provides enterprise API gateway, security, load balancing, and integration services. +Provides enterprise integration, security, and load balancing services. """ -from .api_gateway import EnterpriseAPIGateway -from .integration import EnterpriseIntegrationService +from .integration import EnterpriseIntegrationFramework from .load_balancer import AdvancedLoadBalancer from .security import EnterpriseEncryption, HSMManager, ThreatDetectionSystem, ZeroTrustArchitecture __all__ = [ - "EnterpriseAPIGateway", - "EnterpriseIntegrationService", + "EnterpriseIntegrationFramework", "AdvancedLoadBalancer", "EnterpriseEncryption", "HSMManager", diff --git a/apps/coordinator-api/src/app/services/enterprise_load_balancer.py b/apps/coordinator-api/src/app/services/enterprise_load_balancer.py deleted file mode 100755 index bcdaa775..00000000 --- a/apps/coordinator-api/src/app/services/enterprise_load_balancer.py +++ /dev/null @@ -1,770 +0,0 @@ -""" -Advanced Load Balancing - Phase 6.4 Implementation -Intelligent traffic distribution with AI-powered auto-scaling and performance optimization -""" - -import statistics -from dataclasses import dataclass, field -from datetime import datetime, timezone, timedelta -from enum import StrEnum -from typing import Any - -from aitbc import get_logger - -logger = get_logger(__name__) - - -class LoadBalancingAlgorithm(StrEnum): - """Load balancing algorithms""" - - ROUND_ROBIN = "round_robin" - WEIGHTED_ROUND_ROBIN = "weighted_round_robin" - LEAST_CONNECTIONS = "least_connections" - LEAST_RESPONSE_TIME = "least_response_time" - RESOURCE_BASED = "resource_based" - PREDICTIVE_AI = "predictive_ai" - ADAPTIVE = "adaptive" - - -class ScalingPolicy(StrEnum): - """Auto-scaling policies""" - - MANUAL = "manual" - THRESHOLD_BASED = "threshold_based" - PREDICTIVE = "predictive" - HYBRID = "hybrid" - - -class HealthStatus(StrEnum): - """Health status""" - - HEALTHY = "healthy" - UNHEALTHY = "unhealthy" - DRAINING = "draining" - MAINTENANCE = "maintenance" - - -@dataclass -class BackendServer: - """Backend server configuration""" - - server_id: str - host: str - port: int - weight: float = 1.0 - max_connections: int = 1000 - current_connections: int = 0 - cpu_usage: float = 0.0 - memory_usage: float = 0.0 - response_time_ms: float = 0.0 - request_count: int = 0 - error_count: int = 0 - health_status: HealthStatus = HealthStatus.HEALTHY - last_health_check: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - capabilities: dict[str, Any] = field(default_factory=dict) - region: str = "default" - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - - -@dataclass -class ScalingMetric: - """Scaling metric configuration""" - - metric_name: str - threshold_min: float - threshold_max: float - scaling_factor: float - cooldown_period: timedelta - measurement_window: timedelta - - -@dataclass -class TrafficPattern: - """Traffic pattern for predictive scaling""" - - pattern_id: str - name: str - time_windows: list[dict[str, Any]] # List of time windows with expected load - day_of_week: int # 0-6 (Monday-Sunday) - seasonal_factor: float = 1.0 - confidence_score: float = 0.0 - - -class PredictiveScaler: - """AI-powered predictive auto-scaling""" - - def __init__(self): - self.traffic_history = [] - self.scaling_predictions = {} - self.traffic_patterns = {} - self.model_weights = {} - self.logger = get_logger("predictive_scaler") - - async def record_traffic(self, timestamp: datetime, request_count: int, response_time_ms: float, error_rate: float): - """Record traffic metrics""" - - traffic_record = { - "timestamp": timestamp, - "request_count": request_count, - "response_time_ms": response_time_ms, - "error_rate": error_rate, - "hour": timestamp.hour, - "day_of_week": timestamp.weekday(), - "day_of_month": timestamp.day, - "month": timestamp.month, - } - - self.traffic_history.append(traffic_record) - - # Keep only last 30 days of history - cutoff = datetime.now(timezone.utc) - timedelta(days=30) - self.traffic_history = [record for record in self.traffic_history if record["timestamp"] > cutoff] - - # Update traffic patterns - await self._update_traffic_patterns() - - async def _update_traffic_patterns(self): - """Update traffic patterns based on historical data""" - - if len(self.traffic_history) < 168: # Need at least 1 week of data - return - - # Group by hour and day of week - patterns = {} - - for record in self.traffic_history: - key = f"{record['day_of_week']}_{record['hour']}" - - if key not in patterns: - patterns[key] = {"request_counts": [], "response_times": [], "error_rates": []} - - patterns[key]["request_counts"].append(record["request_count"]) - patterns[key]["response_times"].append(record["response_time_ms"]) - patterns[key]["error_rates"].append(record["error_rate"]) - - # Calculate pattern statistics - for key, data in patterns.items(): - day_of_week, hour = key.split("_") - - pattern = TrafficPattern( - pattern_id=key, - name=f"Pattern Day {day_of_week} Hour {hour}", - time_windows=[ - { - "hour": int(hour), - "avg_requests": statistics.mean(data["request_counts"]), - "max_requests": max(data["request_counts"]), - "min_requests": min(data["request_counts"]), - "std_requests": statistics.stdev(data["request_counts"]) if len(data["request_counts"]) > 1 else 0, - "avg_response_time": statistics.mean(data["response_times"]), - "avg_error_rate": statistics.mean(data["error_rates"]), - } - ], - day_of_week=int(day_of_week), - confidence_score=min(len(data["request_counts"]) / 100, 1.0), # Confidence based on data points - ) - - self.traffic_patterns[key] = pattern - - async def predict_traffic(self, prediction_window: timedelta = timedelta(hours=1)) -> dict[str, Any]: - """Predict traffic for the next time window""" - - try: - current_time = datetime.now(timezone.utc) - current_time + prediction_window - - # Get current pattern - current_pattern_key = f"{current_time.weekday()}_{current_time.hour}" - current_pattern = self.traffic_patterns.get(current_pattern_key) - - if not current_pattern: - # Fallback to simple prediction - return await self._simple_prediction(prediction_window) - - # Get historical data for similar time periods - similar_patterns = [ - pattern - for pattern in self.traffic_patterns.values() - if pattern.day_of_week == current_time.weekday() - and abs(pattern.time_windows[0]["hour"] - current_time.hour) <= 2 - ] - - if not similar_patterns: - return await self._simple_prediction(prediction_window) - - # Calculate weighted prediction - total_weight = 0 - weighted_requests = 0 - weighted_response_time = 0 - weighted_error_rate = 0 - - for pattern in similar_patterns: - weight = pattern.confidence_score - window_data = pattern.time_windows[0] - - weighted_requests += window_data["avg_requests"] * weight - weighted_response_time += window_data["avg_response_time"] * weight - weighted_error_rate += window_data["avg_error_rate"] * weight - total_weight += weight - - if total_weight > 0: - predicted_requests = weighted_requests / total_weight - predicted_response_time = weighted_response_time / total_weight - predicted_error_rate = weighted_error_rate / total_weight - else: - return await self._simple_prediction(prediction_window) - - # Apply seasonal factors - seasonal_factor = self._get_seasonal_factor(current_time) - predicted_requests *= seasonal_factor - - return { - "prediction_window_hours": prediction_window.total_seconds() / 3600, - "predicted_requests_per_hour": int(predicted_requests), - "predicted_response_time_ms": predicted_response_time, - "predicted_error_rate": predicted_error_rate, - "confidence_score": min(total_weight / len(similar_patterns), 1.0), - "seasonal_factor": seasonal_factor, - "pattern_based": True, - "prediction_timestamp": current_time.isoformat(), - } - - except Exception as e: - self.logger.error(f"Traffic prediction failed: {e}") - return await self._simple_prediction(prediction_window) - - async def _simple_prediction(self, prediction_window: timedelta) -> dict[str, Any]: - """Simple prediction based on recent averages""" - - if not self.traffic_history: - return { - "prediction_window_hours": prediction_window.total_seconds() / 3600, - "predicted_requests_per_hour": 1000, # Default - "predicted_response_time_ms": 100.0, - "predicted_error_rate": 0.01, - "confidence_score": 0.1, - "pattern_based": False, - "prediction_timestamp": datetime.now(timezone.utc).isoformat(), - } - - # Calculate recent averages - recent_records = self.traffic_history[-24:] # Last 24 records - - avg_requests = statistics.mean([r["request_count"] for r in recent_records]) - avg_response_time = statistics.mean([r["response_time_ms"] for r in recent_records]) - avg_error_rate = statistics.mean([r["error_rate"] for r in recent_records]) - - return { - "prediction_window_hours": prediction_window.total_seconds() / 3600, - "predicted_requests_per_hour": int(avg_requests), - "predicted_response_time_ms": avg_response_time, - "predicted_error_rate": avg_error_rate, - "confidence_score": 0.3, - "pattern_based": False, - "prediction_timestamp": datetime.now(timezone.utc).isoformat(), - } - - def _get_seasonal_factor(self, timestamp: datetime) -> float: - """Get seasonal adjustment factor""" - - # Simple seasonal factors (can be enhanced with more sophisticated models) - month = timestamp.month - - seasonal_factors = { - 1: 0.8, # January - post-holiday dip - 2: 0.9, # February - 3: 1.0, # March - 4: 1.1, # April - spring increase - 5: 1.2, # May - 6: 1.1, # June - 7: 1.0, # July - summer - 8: 0.9, # August - 9: 1.1, # September - back to business - 10: 1.2, # October - 11: 1.3, # November - holiday season start - 12: 1.4, # December - peak holiday season - } - - return seasonal_factors.get(month, 1.0) - - async def get_scaling_recommendation(self, current_servers: int, current_capacity: int) -> dict[str, Any]: - """Get scaling recommendation based on predictions""" - - try: - # Get traffic prediction - prediction = await self.predict_traffic(timedelta(hours=1)) - - predicted_requests = prediction["predicted_requests_per_hour"] - current_capacity_per_server = current_capacity // max(current_servers, 1) - - # Calculate required servers - required_servers = max(1, int(predicted_requests / current_capacity_per_server)) - - # Apply buffer (20% extra capacity) - required_servers = int(required_servers * 1.2) - - scaling_action = "none" - if required_servers > current_servers: - scaling_action = "scale_up" - scale_to = required_servers - elif required_servers < current_servers * 0.7: # Scale down if underutilized - scaling_action = "scale_down" - scale_to = max(1, required_servers) - else: - scale_to = current_servers - - return { - "current_servers": current_servers, - "recommended_servers": scale_to, - "scaling_action": scaling_action, - "predicted_load": predicted_requests, - "current_capacity_per_server": current_capacity_per_server, - "confidence_score": prediction["confidence_score"], - "reason": f"Predicted {predicted_requests} requests/hour vs current capacity {current_servers * current_capacity_per_server}", - "recommendation_timestamp": datetime.now(timezone.utc).isoformat(), - } - - except Exception as e: - self.logger.error(f"Scaling recommendation failed: {e}") - return { - "scaling_action": "none", - "reason": f"Prediction failed: {str(e)}", - "recommendation_timestamp": datetime.now(timezone.utc).isoformat(), - } - - -class AdvancedLoadBalancer: - """Advanced load balancer with multiple algorithms and AI optimization""" - - def __init__(self): - self.backends = {} - self.algorithm = LoadBalancingAlgorithm.ADAPTIVE - self.current_index = 0 - self.request_history = [] - self.performance_metrics = {} - self.predictive_scaler = PredictiveScaler() - self.scaling_metrics = {} - self.logger = get_logger("advanced_load_balancer") - - async def add_backend(self, server: BackendServer) -> bool: - """Add backend server""" - - try: - self.backends[server.server_id] = server - - # Initialize performance metrics - self.performance_metrics[server.server_id] = { - "avg_response_time": 0.0, - "error_rate": 0.0, - "throughput": 0.0, - "uptime": 1.0, - "last_updated": datetime.now(timezone.utc), - } - - self.logger.info(f"Backend server added: {server.server_id}") - return True - - except Exception as e: - self.logger.error(f"Failed to add backend server: {e}") - return False - - async def remove_backend(self, server_id: str) -> bool: - """Remove backend server""" - - if server_id in self.backends: - del self.backends[server_id] - del self.performance_metrics[server_id] - - self.logger.info(f"Backend server removed: {server_id}") - return True - - return False - - async def select_backend(self, request_context: dict[str, Any] | None = None) -> str | None: - """Select backend server based on algorithm""" - - try: - # Filter healthy backends - healthy_backends = { - sid: server for sid, server in self.backends.items() if server.health_status == HealthStatus.HEALTHY - } - - if not healthy_backends: - return None - - # Select backend based on algorithm - if self.algorithm == LoadBalancingAlgorithm.ROUND_ROBIN: - return await self._select_round_robin(healthy_backends) - elif self.algorithm == LoadBalancingAlgorithm.WEIGHTED_ROUND_ROBIN: - return await self._select_weighted_round_robin(healthy_backends) - elif self.algorithm == LoadBalancingAlgorithm.LEAST_CONNECTIONS: - return await self._select_least_connections(healthy_backends) - elif self.algorithm == LoadBalancingAlgorithm.LEAST_RESPONSE_TIME: - return await self._select_least_response_time(healthy_backends) - elif self.algorithm == LoadBalancingAlgorithm.RESOURCE_BASED: - return await self._select_resource_based(healthy_backends) - elif self.algorithm == LoadBalancingAlgorithm.PREDICTIVE_AI: - return await self._select_predictive_ai(healthy_backends, request_context) - elif self.algorithm == LoadBalancingAlgorithm.ADAPTIVE: - return await self._select_adaptive(healthy_backends, request_context) - else: - return await self._select_round_robin(healthy_backends) - - except Exception as e: - self.logger.error(f"Backend selection failed: {e}") - return None - - async def _select_round_robin(self, backends: dict[str, BackendServer]) -> str: - """Round robin selection""" - - backend_ids = list(backends.keys()) - - if not backend_ids: - return None - - selected = backend_ids[self.current_index % len(backend_ids)] - self.current_index += 1 - - return selected - - async def _select_weighted_round_robin(self, backends: dict[str, BackendServer]) -> str: - """Weighted round robin selection""" - - # Calculate total weight - total_weight = sum(server.weight for server in backends.values()) - - if total_weight <= 0: - return await self._select_round_robin(backends) - - # Select based on weights - import random - - rand_value = random.uniform(0, total_weight) - - current_weight = 0 - for server_id, server in backends.items(): - current_weight += server.weight - if rand_value <= current_weight: - return server_id - - # Fallback - return list(backends.keys())[0] - - async def _select_least_connections(self, backends: dict[str, BackendServer]) -> str: - """Select backend with least connections""" - - min_connections = float("inf") - selected_backend = None - - for server_id, server in backends.items(): - if server.current_connections < min_connections: - min_connections = server.current_connections - selected_backend = server_id - - return selected_backend - - async def _select_least_response_time(self, backends: dict[str, BackendServer]) -> str: - """Select backend with least response time""" - - min_response_time = float("inf") - selected_backend = None - - for server_id, server in backends.items(): - if server.response_time_ms < min_response_time: - min_response_time = server.response_time_ms - selected_backend = server_id - - return selected_backend - - async def _select_resource_based(self, backends: dict[str, BackendServer]) -> str: - """Select backend based on resource utilization""" - - best_score = -1 - selected_backend = None - - for server_id, server in backends.items(): - # Calculate resource score (lower is better) - cpu_score = 1.0 - (server.cpu_usage / 100.0) - memory_score = 1.0 - (server.memory_usage / 100.0) - connection_score = 1.0 - (server.current_connections / server.max_connections) - - # Weighted score - resource_score = cpu_score * 0.4 + memory_score * 0.3 + connection_score * 0.3 - - if resource_score > best_score: - best_score = resource_score - selected_backend = server_id - - return selected_backend - - async def _select_predictive_ai( - self, backends: dict[str, BackendServer], request_context: dict[str, Any] | None - ) -> str: - """AI-powered predictive selection""" - - # Get performance predictions for each backend - backend_scores = {} - - for server_id, server in backends.items(): - # Predict performance based on historical data - self.performance_metrics.get(server_id, {}) - - # Calculate predicted response time - predicted_response_time = ( - server.response_time_ms - * (1 + server.cpu_usage / 100) - * (1 + server.memory_usage / 100) - * (1 + server.current_connections / server.max_connections) - ) - - # Calculate score (lower response time is better) - score = 1.0 / (1.0 + predicted_response_time / 100.0) - - # Apply context-based adjustments - if request_context: - # Consider request type, user location, etc. - context_multiplier = await self._calculate_context_multiplier(server, request_context) - score *= context_multiplier - - backend_scores[server_id] = score - - # Select best scoring backend - if backend_scores: - return max(backend_scores, key=backend_scores.get) - - return await self._select_least_connections(backends) - - async def _select_adaptive(self, backends: dict[str, BackendServer], request_context: dict[str, Any] | None) -> str: - """Adaptive selection based on current conditions""" - - # Analyze current system state - total_connections = sum(server.current_connections for server in backends.values()) - avg_response_time = statistics.mean([server.response_time_ms for server in backends.values()]) - - # Choose algorithm based on conditions - if total_connections > sum(server.max_connections for server in backends.values()) * 0.8: - # High load - use resource-based - return await self._select_resource_based(backends) - elif avg_response_time > 200: - # High latency - use least response time - return await self._select_least_response_time(backends) - else: - # Normal conditions - use weighted round robin - return await self._select_weighted_round_robin(backends) - - async def _calculate_context_multiplier(self, server: BackendServer, request_context: dict[str, Any]) -> float: - """Calculate context-based multiplier for backend selection""" - - multiplier = 1.0 - - # Consider geographic location - if "user_location" in request_context and "region" in server.capabilities: - user_region = request_context["user_location"].get("region") - server_region = server.capabilities["region"] - - if user_region == server_region: - multiplier *= 1.2 # Prefer same region - elif self._regions_in_same_continent(user_region, server_region): - multiplier *= 1.1 # Slight preference for same continent - - # Consider request type - request_type = request_context.get("request_type", "general") - server_specializations = server.capabilities.get("specializations", []) - - if request_type in server_specializations: - multiplier *= 1.3 # Strong preference for specialized backends - - # Consider user tier - user_tier = request_context.get("user_tier", "standard") - if user_tier == "premium" and server.capabilities.get("premium_support", False): - multiplier *= 1.15 - - return multiplier - - def _regions_in_same_continent(self, region1: str, region2: str) -> bool: - """Check if two regions are in the same continent""" - - continent_mapping = { - "NA": ["US", "CA", "MX"], - "EU": ["GB", "DE", "FR", "IT", "ES", "NL", "BE", "AT", "CH", "SE", "NO", "DK", "FI"], - "APAC": ["JP", "KR", "SG", "AU", "IN", "TH", "MY", "ID", "PH", "VN"], - "LATAM": ["BR", "MX", "AR", "CL", "CO", "PE", "VE"], - } - - for _continent, regions in continent_mapping.items(): - if region1 in regions and region2 in regions: - return True - - return False - - async def record_request( - self, server_id: str, response_time_ms: float, success: bool, timestamp: datetime | None = None - ): - """Record request metrics""" - - if timestamp is None: - timestamp = datetime.now(timezone.utc) - - # Update backend server metrics - if server_id in self.backends: - server = self.backends[server_id] - server.request_count += 1 - server.response_time_ms = server.response_time_ms * 0.9 + response_time_ms * 0.1 # EMA - - if not success: - server.error_count += 1 - - # Record in history - request_record = { - "timestamp": timestamp, - "server_id": server_id, - "response_time_ms": response_time_ms, - "success": success, - } - - self.request_history.append(request_record) - - # Keep only last 10000 records - if len(self.request_history) > 10000: - self.request_history = self.request_history[-10000:] - - # Update predictive scaler - await self.predictive_scaler.record_traffic( - timestamp, 1, response_time_ms, 0.0 if success else 1.0 # One request # Error rate - ) - - async def update_backend_health( - self, server_id: str, health_status: HealthStatus, cpu_usage: float, memory_usage: float, current_connections: int - ): - """Update backend health metrics""" - - if server_id in self.backends: - server = self.backends[server_id] - server.health_status = health_status - server.cpu_usage = cpu_usage - server.memory_usage = memory_usage - server.current_connections = current_connections - server.last_health_check = datetime.now(timezone.utc) - - async def get_load_balancing_metrics(self) -> dict[str, Any]: - """Get comprehensive load balancing metrics""" - - try: - total_requests = sum(server.request_count for server in self.backends.values()) - total_errors = sum(server.error_count for server in self.backends.values()) - total_connections = sum(server.current_connections for server in self.backends.values()) - - error_rate = (total_errors / total_requests) if total_requests > 0 else 0.0 - - # Calculate average response time - avg_response_time = 0.0 - if self.backends: - avg_response_time = statistics.mean([server.response_time_ms for server in self.backends.values()]) - - # Backend distribution - backend_distribution = {} - for server_id, server in self.backends.items(): - backend_distribution[server_id] = { - "requests": server.request_count, - "errors": server.error_count, - "connections": server.current_connections, - "response_time_ms": server.response_time_ms, - "cpu_usage": server.cpu_usage, - "memory_usage": server.memory_usage, - "health_status": server.health_status.value, - "weight": server.weight, - } - - # Get scaling recommendation - scaling_recommendation = await self.predictive_scaler.get_scaling_recommendation( - len(self.backends), sum(server.max_connections for server in self.backends.values()) - ) - - return { - "total_backends": len(self.backends), - "healthy_backends": len([s for s in self.backends.values() if s.health_status == HealthStatus.HEALTHY]), - "total_requests": total_requests, - "total_errors": total_errors, - "error_rate": error_rate, - "average_response_time_ms": avg_response_time, - "total_connections": total_connections, - "algorithm": self.algorithm.value, - "backend_distribution": backend_distribution, - "scaling_recommendation": scaling_recommendation, - "timestamp": datetime.now(timezone.utc).isoformat(), - } - - except Exception as e: - self.logger.error(f"Metrics retrieval failed: {e}") - return {"error": str(e)} - - async def set_algorithm(self, algorithm: LoadBalancingAlgorithm): - """Set load balancing algorithm""" - - self.algorithm = algorithm - self.logger.info(f"Load balancing algorithm changed to: {algorithm.value}") - - async def auto_scale(self, min_servers: int = 1, max_servers: int = 10) -> dict[str, Any]: - """Perform auto-scaling based on predictions""" - - try: - # Get scaling recommendation - recommendation = await self.predictive_scaler.get_scaling_recommendation( - len(self.backends), sum(server.max_connections for server in self.backends.values()) - ) - - action = recommendation["scaling_action"] - target_servers = recommendation["recommended_servers"] - - # Apply scaling limits - target_servers = max(min_servers, min(max_servers, target_servers)) - - scaling_result = { - "action": action, - "current_servers": len(self.backends), - "target_servers": target_servers, - "confidence": recommendation.get("confidence_score", 0.0), - "reason": recommendation.get("reason", ""), - "timestamp": datetime.now(timezone.utc).isoformat(), - } - - # In production, implement actual scaling logic here - # For now, just return the recommendation - - self.logger.info(f"Auto-scaling recommendation: {action} to {target_servers} servers") - - return scaling_result - - except Exception as e: - self.logger.error(f"Auto-scaling failed: {e}") - return {"error": str(e)} - - -# Global load balancer instance -advanced_load_balancer = None - - -async def get_advanced_load_balancer() -> AdvancedLoadBalancer: - """Get or create global advanced load balancer""" - - global advanced_load_balancer - if advanced_load_balancer is None: - advanced_load_balancer = AdvancedLoadBalancer() - - # Add default backends - default_backends = [ - BackendServer( - server_id="backend_1", host="10.0.1.10", port=8080, weight=1.0, max_connections=1000, region="us_east" - ), - BackendServer( - server_id="backend_2", host="10.0.1.11", port=8080, weight=1.0, max_connections=1000, region="us_east" - ), - BackendServer( - server_id="backend_3", host="10.0.1.12", port=8080, weight=0.8, max_connections=800, region="eu_west" - ), - ] - - for backend in default_backends: - await advanced_load_balancer.add_backend(backend) - - return advanced_load_balancer diff --git a/apps/coordinator-api/src/app/services/enterprise_security.py b/apps/coordinator-api/src/app/services/enterprise_security.py deleted file mode 100755 index 258beb32..00000000 --- a/apps/coordinator-api/src/app/services/enterprise_security.py +++ /dev/null @@ -1,773 +0,0 @@ -""" -Enterprise Security Framework - Phase 6.2 Implementation -Zero-trust architecture with HSM integration and advanced security controls -""" - -import secrets -from dataclasses import dataclass, field -from datetime import datetime, timezone, timedelta -from enum import StrEnum -from typing import Any -from uuid import uuid4 - -import cryptography -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes - -from aitbc import get_logger - -logger = get_logger(__name__) - - -class SecurityLevel(StrEnum): - """Security levels for enterprise data""" - - PUBLIC = "public" - INTERNAL = "internal" - CONFIDENTIAL = "confidential" - RESTRICTED = "restricted" - TOP_SECRET = "top_secret" - - -class EncryptionAlgorithm(StrEnum): - """Encryption algorithms""" - - AES_256_GCM = "aes_256_gcm" - CHACHA20_POLY1305 = "chacha20_polyy1305" - AES_256_CBC = "aes_256_cbc" - QUANTUM_RESISTANT = "quantum_resistant" - - -class ThreatLevel(StrEnum): - """Threat levels for security monitoring""" - - LOW = "low" - MEDIUM = "medium" - HIGH = "high" - CRITICAL = "critical" - - -@dataclass -class SecurityPolicy: - """Security policy configuration""" - - policy_id: str - name: str - security_level: SecurityLevel - encryption_algorithm: EncryptionAlgorithm - key_rotation_interval: timedelta - access_control_requirements: list[str] - audit_requirements: list[str] - retention_period: timedelta - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - - -@dataclass -class SecurityEvent: - """Security event for monitoring""" - - event_id: str - event_type: str - severity: ThreatLevel - source: str - timestamp: datetime - user_id: str | None - resource_id: str | None - details: dict[str, Any] - resolved: bool = False - resolution_notes: str | None = None - - -class HSMManager: - """Hardware Security Module manager for enterprise key management""" - - def __init__(self, hsm_config: dict[str, Any]): - self.hsm_config = hsm_config - self.backend = default_backend() - self.key_store = {} # In production, use actual HSM - self.logger = get_logger("hsm_manager") - - async def initialize(self) -> bool: - """Initialize HSM connection""" - try: - # In production, initialize actual HSM connection - # For now, simulate HSM initialization - self.logger.info("HSM manager initialized") - return True - except Exception as e: - self.logger.error(f"HSM initialization failed: {e}") - return False - - async def generate_key(self, key_id: str, algorithm: EncryptionAlgorithm, key_size: int = 256) -> dict[str, Any]: - """Generate encryption key in HSM""" - - try: - if algorithm == EncryptionAlgorithm.AES_256_GCM: - key = secrets.token_bytes(32) # 256 bits - iv = secrets.token_bytes(12) # 96 bits for GCM - elif algorithm == EncryptionAlgorithm.CHACHA20_POLY1305: - key = secrets.token_bytes(32) # 256 bits - nonce = secrets.token_bytes(12) # 96 bits - elif algorithm == EncryptionAlgorithm.AES_256_CBC: - key = secrets.token_bytes(32) # 256 bits - iv = secrets.token_bytes(16) # 128 bits for CBC - else: - raise ValueError(f"Unsupported algorithm: {algorithm}") - - # Store key in HSM (simulated) - key_data = { - "key_id": key_id, - "algorithm": algorithm.value, - "key": key, - "iv": iv if algorithm in [EncryptionAlgorithm.AES_256_GCM, EncryptionAlgorithm.AES_256_CBC] else None, - "nonce": nonce if algorithm == EncryptionAlgorithm.CHACHA20_POLY1305 else None, - "created_at": datetime.now(timezone.utc), - "key_size": key_size, - } - - self.key_store[key_id] = key_data - - self.logger.info(f"Key generated in HSM: {key_id}") - return key_data - - except Exception as e: - self.logger.error(f"Key generation failed: {e}") - raise - - async def get_key(self, key_id: str) -> dict[str, Any] | None: - """Get key from HSM""" - return self.key_store.get(key_id) - - async def rotate_key(self, key_id: str) -> dict[str, Any]: - """Rotate encryption key""" - - old_key = self.key_store.get(key_id) - if not old_key: - raise ValueError(f"Key not found: {key_id}") - - # Generate new key - new_key = await self.generate_key(f"{key_id}_new", EncryptionAlgorithm(old_key["algorithm"]), old_key["key_size"]) - - # Update key with rotation timestamp - new_key["rotated_from"] = key_id - new_key["rotation_timestamp"] = datetime.now(timezone.utc) - - return new_key - - async def delete_key(self, key_id: str) -> bool: - """Delete key from HSM""" - if key_id in self.key_store: - del self.key_store[key_id] - self.logger.info(f"Key deleted from HSM: {key_id}") - return True - return False - - -class EnterpriseEncryption: - """Enterprise-grade encryption service""" - - def __init__(self, hsm_manager: HSMManager): - self.hsm_manager = hsm_manager - self.backend = default_backend() - self.logger = get_logger("enterprise_encryption") - - async def encrypt_data( - self, data: str | bytes, key_id: str, associated_data: bytes | None = None - ) -> dict[str, Any]: - """Encrypt data using enterprise-grade encryption""" - - try: - # Get key from HSM - key_data = await self.hsm_manager.get_key(key_id) - if not key_data: - raise ValueError(f"Key not found: {key_id}") - - # Convert data to bytes if needed - if isinstance(data, str): - data = data.encode("utf-8") - - algorithm = EncryptionAlgorithm(key_data["algorithm"]) - - if algorithm == EncryptionAlgorithm.AES_256_GCM: - return await self._encrypt_aes_gcm(data, key_data, associated_data) - elif algorithm == EncryptionAlgorithm.CHACHA20_POLY1305: - return await self._encrypt_chacha20(data, key_data, associated_data) - elif algorithm == EncryptionAlgorithm.AES_256_CBC: - return await self._encrypt_aes_cbc(data, key_data) - else: - raise ValueError(f"Unsupported encryption algorithm: {algorithm}") - - except Exception as e: - self.logger.error(f"Encryption failed: {e}") - raise - - async def _encrypt_aes_gcm( - self, data: bytes, key_data: dict[str, Any], associated_data: bytes | None = None - ) -> dict[str, Any]: - """Encrypt using AES-256-GCM""" - - key = key_data["key"] - iv = key_data["iv"] - - # Create cipher - cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=self.backend) - - encryptor = cipher.encryptor() - - # Add associated data if provided - if associated_data: - encryptor.authenticate_additional_data(associated_data) - - # Encrypt data - ciphertext = encryptor.update(data) + encryptor.finalize() - - return { - "ciphertext": ciphertext.hex(), - "iv": iv.hex(), - "tag": encryptor.tag.hex(), - "algorithm": "aes_256_gcm", - "key_id": key_data["key_id"], - } - - async def _encrypt_chacha20( - self, data: bytes, key_data: dict[str, Any], associated_data: bytes | None = None - ) -> dict[str, Any]: - """Encrypt using ChaCha20-Poly1305""" - - key = key_data["key"] - nonce = key_data["nonce"] - - # Create cipher - cipher = Cipher(algorithms.ChaCha20(key, nonce), modes.Poly1305(b""), backend=self.backend) - - encryptor = cipher.encryptor() - - # Add associated data if provided - if associated_data: - encryptor.authenticate_additional_data(associated_data) - - # Encrypt data - ciphertext = encryptor.update(data) + encryptor.finalize() - - return { - "ciphertext": ciphertext.hex(), - "nonce": nonce.hex(), - "tag": encryptor.tag.hex(), - "algorithm": "chacha20_poly1305", - "key_id": key_data["key_id"], - } - - async def _encrypt_aes_cbc(self, data: bytes, key_data: dict[str, Any]) -> dict[str, Any]: - """Encrypt using AES-256-CBC""" - - key = key_data["key"] - iv = key_data["iv"] - - # Pad data to block size - padder = cryptography.hazmat.primitives.padding.PKCS7(128).padder() - padded_data = padder.update(data) + padder.finalize() - - # Create cipher - cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=self.backend) - - encryptor = cipher.encryptor() - ciphertext = encryptor.update(padded_data) + encryptor.finalize() - - return {"ciphertext": ciphertext.hex(), "iv": iv.hex(), "algorithm": "aes_256_cbc", "key_id": key_data["key_id"]} - - async def decrypt_data(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes: - """Decrypt encrypted data""" - - try: - algorithm = encrypted_data["algorithm"] - - if algorithm == "aes_256_gcm": - return await self._decrypt_aes_gcm(encrypted_data, associated_data) - elif algorithm == "chacha20_poly1305": - return await self._decrypt_chacha20(encrypted_data, associated_data) - elif algorithm == "aes_256_cbc": - return await self._decrypt_aes_cbc(encrypted_data) - else: - raise ValueError(f"Unsupported encryption algorithm: {algorithm}") - - except Exception as e: - self.logger.error(f"Decryption failed: {e}") - raise - - async def _decrypt_aes_gcm(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes: - """Decrypt AES-256-GCM encrypted data""" - - # Get key from HSM - key_data = await self.hsm_manager.get_key(encrypted_data["key_id"]) - if not key_data: - raise ValueError(f"Key not found: {encrypted_data['key_id']}") - - key = key_data["key"] - iv = bytes.fromhex(encrypted_data["iv"]) - ciphertext = bytes.fromhex(encrypted_data["ciphertext"]) - tag = bytes.fromhex(encrypted_data["tag"]) - - # Create cipher - cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=self.backend) - - decryptor = cipher.decryptor() - - # Add associated data if provided - if associated_data: - decryptor.authenticate_additional_data(associated_data) - - # Decrypt data - plaintext = decryptor.update(ciphertext) + decryptor.finalize() - - return plaintext - - async def _decrypt_chacha20(self, encrypted_data: dict[str, Any], associated_data: bytes | None = None) -> bytes: - """Decrypt ChaCha20-Poly1305 encrypted data""" - - # Get key from HSM - key_data = await self.hsm_manager.get_key(encrypted_data["key_id"]) - if not key_data: - raise ValueError(f"Key not found: {encrypted_data['key_id']}") - - key = key_data["key"] - nonce = bytes.fromhex(encrypted_data["nonce"]) - ciphertext = bytes.fromhex(encrypted_data["ciphertext"]) - tag = bytes.fromhex(encrypted_data["tag"]) - - # Create cipher - cipher = Cipher(algorithms.ChaCha20(key, nonce), modes.Poly1305(tag), backend=self.backend) - - decryptor = cipher.decryptor() - - # Add associated data if provided - if associated_data: - decryptor.authenticate_additional_data(associated_data) - - # Decrypt data - plaintext = decryptor.update(ciphertext) + decryptor.finalize() - - return plaintext - - async def _decrypt_aes_cbc(self, encrypted_data: dict[str, Any]) -> bytes: - """Decrypt AES-256-CBC encrypted data""" - - # Get key from HSM - key_data = await self.hsm_manager.get_key(encrypted_data["key_id"]) - if not key_data: - raise ValueError(f"Key not found: {encrypted_data['key_id']}") - - key = key_data["key"] - iv = bytes.fromhex(encrypted_data["iv"]) - ciphertext = bytes.fromhex(encrypted_data["ciphertext"]) - - # Create cipher - cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=self.backend) - - decryptor = cipher.decryptor() - padded_plaintext = decryptor.update(ciphertext) + decryptor.finalize() - - # Unpad data - unpadder = cryptography.hazmat.primitives.padding.PKCS7(128).unpadder() - plaintext = unpadder.update(padded_plaintext) + unpadder.finalize() - - return plaintext - - -class ZeroTrustArchitecture: - """Zero-trust security architecture implementation""" - - def __init__(self, hsm_manager: HSMManager, encryption: EnterpriseEncryption): - self.hsm_manager = hsm_manager - self.encryption = encryption - self.trust_policies = {} - self.session_tokens = {} - self.logger = get_logger("zero_trust") - - async def create_trust_policy(self, policy_id: str, policy_config: dict[str, Any]) -> bool: - """Create zero-trust policy""" - - try: - policy = SecurityPolicy( - policy_id=policy_id, - name=policy_config["name"], - security_level=SecurityLevel(policy_config["security_level"]), - encryption_algorithm=EncryptionAlgorithm(policy_config["encryption_algorithm"]), - key_rotation_interval=timedelta(days=policy_config.get("key_rotation_days", 90)), - access_control_requirements=policy_config.get("access_control_requirements", []), - audit_requirements=policy_config.get("audit_requirements", []), - retention_period=timedelta(days=policy_config.get("retention_days", 2555)), # 7 years - ) - - self.trust_policies[policy_id] = policy - - # Generate encryption key for policy - await self.hsm_manager.generate_key(f"policy_{policy_id}", policy.encryption_algorithm) - - self.logger.info(f"Zero-trust policy created: {policy_id}") - return True - - except Exception as e: - self.logger.error(f"Failed to create trust policy: {e}") - return False - - async def verify_trust(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> bool: - """Verify zero-trust access request""" - - try: - # Get applicable policy - policy_id = context.get("policy_id", "default") - policy = self.trust_policies.get(policy_id) - - if not policy: - self.logger.warning(f"No policy found for {policy_id}") - return False - - # Verify trust factors - trust_score = await self._calculate_trust_score(user_id, resource_id, action, context) - - # Check if trust score meets policy requirements - min_trust_score = self._get_min_trust_score(policy.security_level) - - is_trusted = trust_score >= min_trust_score - - # Log trust decision - await self._log_trust_decision(user_id, resource_id, action, trust_score, is_trusted) - - return is_trusted - - except Exception as e: - self.logger.error(f"Trust verification failed: {e}") - return False - - async def _calculate_trust_score(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> float: - """Calculate trust score for access request""" - - score = 0.0 - - # User authentication factor (40%) - auth_strength = context.get("auth_strength", "password") - if auth_strength == "mfa": - score += 0.4 - elif auth_strength == "password": - score += 0.2 - - # Device trust factor (20%) - device_trust = context.get("device_trust", 0.5) - score += 0.2 * device_trust - - # Location factor (15%) - location_trust = context.get("location_trust", 0.5) - score += 0.15 * location_trust - - # Time factor (10%) - time_trust = context.get("time_trust", 0.5) - score += 0.1 * time_trust - - # Behavioral factor (15%) - behavior_trust = context.get("behavior_trust", 0.5) - score += 0.15 * behavior_trust - - return min(score, 1.0) - - def _get_min_trust_score(self, security_level: SecurityLevel) -> float: - """Get minimum trust score for security level""" - - thresholds = { - SecurityLevel.PUBLIC: 0.0, - SecurityLevel.INTERNAL: 0.3, - SecurityLevel.CONFIDENTIAL: 0.6, - SecurityLevel.RESTRICTED: 0.8, - SecurityLevel.TOP_SECRET: 0.9, - } - - return thresholds.get(security_level, 0.5) - - async def _log_trust_decision(self, user_id: str, resource_id: str, action: str, trust_score: float, decision: bool): - """Log trust decision for audit""" - - SecurityEvent( - event_id=str(uuid4()), - event_type="trust_decision", - severity=ThreatLevel.LOW if decision else ThreatLevel.MEDIUM, - source="zero_trust", - timestamp=datetime.now(timezone.utc), - user_id=user_id, - resource_id=resource_id, - details={"action": action, "trust_score": trust_score, "decision": decision}, - ) - - # In production, send to security monitoring system - self.logger.info(f"Trust decision: {user_id} -> {resource_id} = {decision} (score: {trust_score})") - - -class ThreatDetectionSystem: - """Advanced threat detection and response system""" - - def __init__(self): - self.threat_patterns = {} - self.active_threats = {} - self.response_actions = {} - self.logger = get_logger("threat_detection") - - async def register_threat_pattern(self, pattern_id: str, pattern_config: dict[str, Any]): - """Register threat detection pattern""" - - self.threat_patterns[pattern_id] = { - "id": pattern_id, - "name": pattern_config["name"], - "description": pattern_config["description"], - "indicators": pattern_config["indicators"], - "severity": ThreatLevel(pattern_config["severity"]), - "response_actions": pattern_config.get("response_actions", []), - "threshold": pattern_config.get("threshold", 1.0), - } - - self.logger.info(f"Threat pattern registered: {pattern_id}") - - async def analyze_threat(self, event_data: dict[str, Any]) -> list[SecurityEvent]: - """Analyze event for potential threats""" - - detected_threats = [] - - for pattern_id, pattern in self.threat_patterns.items(): - threat_score = await self._calculate_threat_score(event_data, pattern) - - if threat_score >= pattern["threshold"]: - threat_event = SecurityEvent( - event_id=str(uuid4()), - event_type="threat_detected", - severity=pattern["severity"], - source="threat_detection", - timestamp=datetime.now(timezone.utc), - user_id=event_data.get("user_id"), - resource_id=event_data.get("resource_id"), - details={ - "pattern_id": pattern_id, - "pattern_name": pattern["name"], - "threat_score": threat_score, - "indicators": event_data, - }, - ) - - detected_threats.append(threat_event) - - # Trigger response actions - await self._trigger_response_actions(pattern_id, threat_event) - - return detected_threats - - async def _calculate_threat_score(self, event_data: dict[str, Any], pattern: dict[str, Any]) -> float: - """Calculate threat score for pattern""" - - score = 0.0 - indicators = pattern["indicators"] - - for indicator, weight in indicators.items(): - if indicator in event_data: - # Simple scoring - in production, use more sophisticated algorithms - indicator_score = 0.5 # Base score for presence - score += indicator_score * weight - - return min(score, 1.0) - - async def _trigger_response_actions(self, pattern_id: str, threat_event: SecurityEvent): - """Trigger automated response actions""" - - pattern = self.threat_patterns[pattern_id] - actions = pattern.get("response_actions", []) - - for action in actions: - try: - await self._execute_response_action(action, threat_event) - except Exception as e: - self.logger.error(f"Response action failed: {action} - {e}") - - async def _execute_response_action(self, action: str, threat_event: SecurityEvent): - """Execute specific response action""" - - if action == "block_user": - await self._block_user(threat_event.user_id) - elif action == "isolate_resource": - await self._isolate_resource(threat_event.resource_id) - elif action == "escalate_to_admin": - await self._escalate_to_admin(threat_event) - elif action == "require_mfa": - await self._require_mfa(threat_event.user_id) - - self.logger.info(f"Response action executed: {action}") - - async def _block_user(self, user_id: str): - """Block user account""" - # In production, implement actual user blocking - self.logger.warning(f"User blocked due to threat: {user_id}") - - async def _isolate_resource(self, resource_id: str): - """Isolate compromised resource""" - # In production, implement actual resource isolation - self.logger.warning(f"Resource isolated due to threat: {resource_id}") - - async def _escalate_to_admin(self, threat_event: SecurityEvent): - """Escalate threat to security administrators""" - # In production, implement actual escalation - self.logger.error(f"Threat escalated to admin: {threat_event.event_id}") - - async def _require_mfa(self, user_id: str): - """Require multi-factor authentication""" - # In production, implement MFA requirement - self.logger.warning(f"MFA required for user: {user_id}") - - -class EnterpriseSecurityFramework: - """Main enterprise security framework""" - - def __init__(self, hsm_config: dict[str, Any]): - self.hsm_manager = HSMManager(hsm_config) - self.encryption = EnterpriseEncryption(self.hsm_manager) - self.zero_trust = ZeroTrustArchitecture(self.hsm_manager, self.encryption) - self.threat_detection = ThreatDetectionSystem() - self.logger = get_logger("enterprise_security") - - async def initialize(self) -> bool: - """Initialize security framework""" - - try: - # Initialize HSM - if not await self.hsm_manager.initialize(): - return False - - # Register default threat patterns - await self._register_default_threat_patterns() - - # Create default trust policies - await self._create_default_policies() - - self.logger.info("Enterprise security framework initialized") - return True - - except Exception as e: - self.logger.error(f"Security framework initialization failed: {e}") - return False - - async def _register_default_threat_patterns(self): - """Register default threat detection patterns""" - - patterns = [ - { - "name": "Brute Force Attack", - "description": "Multiple failed login attempts", - "indicators": {"failed_login_attempts": 0.8, "short_time_interval": 0.6}, - "severity": "high", - "threshold": 0.7, - "response_actions": ["block_user", "require_mfa"], - }, - { - "name": "Suspicious Access Pattern", - "description": "Unusual access patterns", - "indicators": {"unusual_location": 0.7, "unusual_time": 0.5, "high_frequency": 0.6}, - "severity": "medium", - "threshold": 0.6, - "response_actions": ["require_mfa", "escalate_to_admin"], - }, - { - "name": "Data Exfiltration", - "description": "Large data transfer patterns", - "indicators": {"large_data_transfer": 0.9, "unusual_destination": 0.7}, - "severity": "critical", - "threshold": 0.8, - "response_actions": ["block_user", "isolate_resource", "escalate_to_admin"], - }, - ] - - for i, pattern in enumerate(patterns): - await self.threat_detection.register_threat_pattern(f"default_{i}", pattern) - - async def _create_default_policies(self): - """Create default trust policies""" - - policies = [ - { - "name": "Enterprise Data Policy", - "security_level": "confidential", - "encryption_algorithm": "aes_256_gcm", - "key_rotation_days": 90, - "access_control_requirements": ["mfa", "device_trust"], - "audit_requirements": ["full_audit", "real_time_monitoring"], - "retention_days": 2555, - }, - { - "name": "Public API Policy", - "security_level": "public", - "encryption_algorithm": "aes_256_gcm", - "key_rotation_days": 180, - "access_control_requirements": ["api_key"], - "audit_requirements": ["api_access_log"], - "retention_days": 365, - }, - ] - - for i, policy in enumerate(policies): - await self.zero_trust.create_trust_policy(f"default_{i}", policy) - - async def encrypt_sensitive_data(self, data: str | bytes, security_level: SecurityLevel) -> dict[str, Any]: - """Encrypt sensitive data with appropriate security level""" - - # Get policy for security level - policy_id = f"default_{0 if security_level == SecurityLevel.PUBLIC else 1}" - policy = self.zero_trust.trust_policies.get(policy_id) - - if not policy: - raise ValueError(f"No policy found for security level: {security_level}") - - key_id = f"policy_{policy_id}" - - return await self.encryption.encrypt_data(data, key_id) - - async def verify_access(self, user_id: str, resource_id: str, action: str, context: dict[str, Any]) -> bool: - """Verify access using zero-trust architecture""" - - return await self.zero_trust.verify_trust(user_id, resource_id, action, context) - - async def analyze_security_event(self, event_data: dict[str, Any]) -> list[SecurityEvent]: - """Analyze security event for threats""" - - return await self.threat_detection.analyze_threat(event_data) - - async def rotate_encryption_keys(self, policy_id: str | None = None) -> dict[str, Any]: - """Rotate encryption keys""" - - if policy_id: - # Rotate specific policy key - old_key_id = f"policy_{policy_id}" - new_key = await self.hsm_manager.rotate_key(old_key_id) - return {"rotated_key": new_key} - else: - # Rotate all keys - rotated_keys = {} - for policy_id in self.zero_trust.trust_policies.keys(): - old_key_id = f"policy_{policy_id}" - new_key = await self.hsm_manager.rotate_key(old_key_id) - rotated_keys[policy_id] = new_key - - return {"rotated_keys": rotated_keys} - - -# Global security framework instance -security_framework = None - - -async def get_security_framework() -> EnterpriseSecurityFramework: - """Get or create global security framework""" - - global security_framework - if security_framework is None: - hsm_config = {"provider": "software", "endpoint": "localhost:8080"} # In production, use actual HSM - - security_framework = EnterpriseSecurityFramework(hsm_config) - await security_framework.initialize() - - return security_framework - - -# Alias for CLI compatibility -EnterpriseSecurityManager = EnterpriseSecurityFramework diff --git a/apps/coordinator-api/src/app/services/trading_marketplace/__init__.py b/apps/coordinator-api/src/app/services/trading_marketplace/__init__.py new file mode 100644 index 00000000..531d0f63 --- /dev/null +++ b/apps/coordinator-api/src/app/services/trading_marketplace/__init__.py @@ -0,0 +1,17 @@ +""" +Trading & Marketplace Bounded Context +Provides trading, marketplace optimization, bid strategy, and dynamic pricing services. +""" + +from .bid_strategy import BidStrategyEngine +from .dynamic_pricing import DynamicPricingEngine +from .gpu_optimizer import MarketplaceGPUOptimizer +from .trading import MatchingEngine, NegotiationSystem + +__all__ = [ + "BidStrategyEngine", + "DynamicPricingEngine", + "MarketplaceGPUOptimizer", + "MatchingEngine", + "NegotiationSystem", +] diff --git a/apps/coordinator-api/src/app/services/amm_service.py b/apps/coordinator-api/src/app/services/trading_marketplace/amm.py similarity index 100% rename from apps/coordinator-api/src/app/services/amm_service.py rename to apps/coordinator-api/src/app/services/trading_marketplace/amm.py diff --git a/apps/coordinator-api/src/app/services/bid_strategy_engine.py b/apps/coordinator-api/src/app/services/trading_marketplace/bid_strategy.py similarity index 100% rename from apps/coordinator-api/src/app/services/bid_strategy_engine.py rename to apps/coordinator-api/src/app/services/trading_marketplace/bid_strategy.py diff --git a/apps/coordinator-api/src/app/services/dynamic_pricing_engine.py b/apps/coordinator-api/src/app/services/trading_marketplace/dynamic_pricing.py similarity index 100% rename from apps/coordinator-api/src/app/services/dynamic_pricing_engine.py rename to apps/coordinator-api/src/app/services/trading_marketplace/dynamic_pricing.py diff --git a/apps/coordinator-api/src/app/services/marketplace_gpu_optimizer.py b/apps/coordinator-api/src/app/services/trading_marketplace/gpu_optimizer.py similarity index 100% rename from apps/coordinator-api/src/app/services/marketplace_gpu_optimizer.py rename to apps/coordinator-api/src/app/services/trading_marketplace/gpu_optimizer.py diff --git a/apps/coordinator-api/src/app/services/trading_service.py b/apps/coordinator-api/src/app/services/trading_marketplace/trading.py similarity index 99% rename from apps/coordinator-api/src/app/services/trading_service.py rename to apps/coordinator-api/src/app/services/trading_marketplace/trading.py index 5feb5c1d..4f50c9d6 100755 --- a/apps/coordinator-api/src/app/services/trading_service.py +++ b/apps/coordinator-api/src/app/services/trading_marketplace/trading.py @@ -13,7 +13,7 @@ logger = get_logger(__name__) from sqlmodel import Session, or_, select -from ..domain.trading import ( +from ...domain.trading import ( NegotiationStatus, SettlementType, TradeAgreement, diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index 1bd57f5a..f3d32de8 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -25,6 +25,26 @@ - Maintained backward compatibility with lazy-loading pattern - Import tests verified successfully - Old monolithic files removed + - ✅ Phase 2 Complete: Enterprise Integration bounded context decomposed + - Created app/services/enterprise_integration/ package with 4 modules + - Migrated enterprise_integration.py (1127 lines) and 3 other enterprise files + - Updated imports within package (api_gateway.py excluded due to missing dependencies) + - Import tests verified successfully + - Old monolithic files removed + - ✅ Phase 3 Complete: Trading & Marketplace bounded context decomposed + - Created app/services/trading_marketplace/ package with 5 modules + - Migrated trading_service.py (36K) and 4 other trading files + - Updated imports across coordinator-api (routers/trading.py, routers/dynamic_pricing.py) + - amm.py excluded from exports due to missing dependencies + - Import tests verified successfully + - Old monolithic files removed + - ✅ Phase 4 Complete: AI & Analytics bounded context decomposed + - Created app/services/ai_analytics/ package with 5 modules + - Migrated analytics_service.py (41K) and 4 other AI files + - Updated imports across coordinator-api (routers/analytics.py, routers/adaptive_learning_health.py) + - adaptive_learning.py, surveillance.py, trading_engine.py excluded due to missing dependencies + - Import tests verified successfully + - Old monolithic files removed 2. **Production Code Using print()** (HIGH IMPACT) - 925 print() statements in production code @@ -128,9 +148,17 @@ - test_validation_properties.py: 20/20 passing - test_staking_service.py: 22/22 passing - Coverage threshold set to 50% in pyproject.toml - - Current coverage: 11% (4623 statements, 4122 missed) - BELOW 50% threshold - - Well-covered modules: constants.py (100%), exceptions.py (100%), validation.py (85%), crypto/crypto.py (52%) + - Current coverage: 19% (4623 statements, 3745 missed) - BELOW 50% threshold + - Added 137 new tests across 6 modules: + - test_middleware.py: 11 tests (middleware modules: 50-100% coverage) + - test_utils.py: 47 tests (utils modules: 100% coverage when run standalone) + - test_config.py: 14 tests (config.py: 100% coverage) + - test_decorators.py: 21 tests (decorators.py: 99% coverage) + - test_health_checks.py: 16 tests (health_checks.py: 80% coverage) + - test_metrics.py: 28 tests (metrics.py: 100% coverage) + - Well-covered modules: constants.py (100%), exceptions.py (100%), validation.py (85%), crypto/crypto.py (52%), config.py (100%), decorators.py (99%), health_checks.py (80%), metrics.py (100%) - Needs improvement: Most modules at 0-30% coverage + - Note: Utils modules (paths, env, json_utils) achieve 100% when run standalone but not counted in overall coverage due to import patterns #### MEDIUM (Long-term, 1-3 months) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..482ff451 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,154 @@ +""" +Tests for AITBC configuration classes +""" + +import os +import pytest +from pathlib import Path +from unittest.mock import patch, Mock + +from aitbc.config import BaseAITBCConfig, AITBCConfig + + +class TestBaseAITBCConfig: + """Tests for BaseAITBCConfig""" + + def test_default_values(self): + """Test BaseAITBCConfig with default values""" + config = BaseAITBCConfig() + assert config.app_name == "AITBC Application" + assert config.app_version == "1.0.0" + assert config.environment == "development" + assert config.debug is False + assert config.log_level == "INFO" + + def test_custom_values(self): + """Test BaseAITBCConfig with custom values""" + config = BaseAITBCConfig( + app_name="Custom App", + app_version="2.0.0", + environment="production", + debug=True, + log_level="DEBUG" + ) + assert config.app_name == "Custom App" + assert config.app_version == "2.0.0" + assert config.environment == "production" + assert config.debug is True + assert config.log_level == "DEBUG" + + def test_data_dir_default(self): + """Test default data directory is a Path""" + config = BaseAITBCConfig() + assert isinstance(config.data_dir, Path) + + def test_config_dir_default(self): + """Test default config directory is a Path""" + config = BaseAITBCConfig() + assert isinstance(config.config_dir, Path) + + def test_log_dir_default(self): + """Test default log directory is a Path""" + config = BaseAITBCConfig() + assert isinstance(config.log_dir, Path) + + def test_log_format_default(self): + """Test default log format""" + config = BaseAITBCConfig() + assert "%(asctime)s" in config.log_format + assert "%(name)s" in config.log_format + assert "%(levelname)s" in config.log_format + + +class TestAITBCConfig: + """Tests for AITBCConfig""" + + def test_default_values(self): + """Test AITBCConfig with default values""" + config = AITBCConfig() + assert config.host == "0.0.0.0" + assert config.port == 8000 + assert config.workers == 1 + assert config.database_url is None + assert config.database_pool_size == 10 + assert config.redis_url is None + assert config.redis_max_connections == 10 + assert config.redis_timeout == 5 + assert config.secret_key is None + assert config.jwt_secret is None + assert config.jwt_algorithm == "HS256" + assert config.jwt_expiration_hours == 24 + assert config.request_timeout == 30 + assert config.max_request_size == 10 * 1024 * 1024 + + def test_custom_server_settings(self): + """Test AITBCConfig with custom server settings""" + config = AITBCConfig( + host="127.0.0.1", + port=9000, + workers=4 + ) + assert config.host == "127.0.0.1" + assert config.port == 9000 + assert config.workers == 4 + + def test_custom_database_settings(self): + """Test AITBCConfig with custom database settings""" + config = AITBCConfig( + database_url="postgresql://localhost/test", + database_pool_size=20 + ) + assert config.database_url == "postgresql://localhost/test" + assert config.database_pool_size == 20 + + def test_custom_redis_settings(self): + """Test AITBCConfig with custom redis settings""" + config = AITBCConfig( + redis_url="redis://localhost:6379", + redis_max_connections=50, + redis_timeout=10 + ) + assert config.redis_url == "redis://localhost:6379" + assert config.redis_max_connections == 50 + assert config.redis_timeout == 10 + + def test_custom_security_settings(self): + """Test AITBCConfig with custom security settings""" + config = AITBCConfig( + secret_key="test-secret-key", + jwt_secret="test-jwt-secret", + jwt_algorithm="RS256", + jwt_expiration_hours=48 + ) + assert config.secret_key == "test-secret-key" + assert config.jwt_secret == "test-jwt-secret" + assert config.jwt_algorithm == "RS256" + assert config.jwt_expiration_hours == 48 + + def test_custom_performance_settings(self): + """Test AITBCConfig with custom performance settings""" + config = AITBCConfig( + request_timeout=60, + max_request_size=20 * 1024 * 1024 + ) + assert config.request_timeout == 60 + assert config.max_request_size == 20 * 1024 * 1024 + + def test_inherits_base_config(self): + """Test AITBCConfig inherits from BaseAITBCConfig""" + config = AITBCConfig( + app_name="Test App", + environment="staging" + ) + assert config.app_name == "Test App" + assert config.environment == "staging" + assert config.host == "0.0.0.0" # AITBCConfig default + assert config.port == 8000 # AITBCConfig default + + @patch('aitbc.config.logger') + def test_init_logs_configuration(self, mock_logger): + """Test __init__ logs configuration""" + config = AITBCConfig(host="localhost", port=9000) + mock_logger.info.assert_called_once() + assert "localhost:9000" in mock_logger.info.call_args[0][0] + mock_logger.debug.assert_called_once() diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 00000000..7527dffe --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,303 @@ +""" +Tests for AITBC decorators +""" + +import time +import pytest +from unittest.mock import patch +from aitbc.decorators import ( + retry, + timing, + cache_result, + validate_args, + handle_exceptions, + async_timing, +) +from aitbc.exceptions import AITBCError + + +class TestRetry: + """Tests for retry decorator""" + + def test_retry_succeeds_on_first_attempt(self): + """Test retry when function succeeds on first attempt""" + @retry(max_attempts=3) + def test_func(): + return "success" + + result = test_func() + assert result == "success" + + def test_retry_succeeds_after_failure(self): + """Test retry when function succeeds after initial failure""" + attempts = [0] + + @retry(max_attempts=3, delay=0.01) + def test_func(): + attempts[0] += 1 + if attempts[0] < 2: + raise ValueError("fail") + return "success" + + result = test_func() + assert result == "success" + assert attempts[0] == 2 + + def test_retry_exhausts_attempts(self): + """Test retry when function fails after all attempts""" + @retry(max_attempts=2, delay=0.01) + def test_func(): + raise ValueError("fail") + + with pytest.raises(ValueError): + test_func() + + def test_retry_with_specific_exception(self): + """Test retry only catches specified exceptions""" + @retry(max_attempts=2, delay=0.01, exceptions=(ValueError,)) + def test_func(): + raise TypeError("fail") + + with pytest.raises(TypeError): + test_func() + + def test_retry_with_backoff(self): + """Test retry with exponential backoff""" + attempts = [0] + + @retry(max_attempts=3, delay=0.01, backoff=2.0) + def test_func(): + attempts[0] += 1 + raise ValueError("fail") + + start_time = time.time() + with pytest.raises(ValueError): + test_func() + elapsed = time.time() - start_time + + # Should have delays: 0.01 + 0.02 = 0.03 seconds minimum + assert elapsed >= 0.03 + + def test_retry_with_on_failure_callback(self): + """Test retry with on_failure callback""" + callback_called = [False] + + def on_fail(e): + callback_called[0] = True + + @retry(max_attempts=2, delay=0.01, on_failure=on_fail) + def test_func(): + raise ValueError("fail") + + with pytest.raises(ValueError): + test_func() + + assert callback_called[0] is True + + +class TestTiming: + """Tests for timing decorator""" + + @patch('aitbc.decorators.logger') + def test_timing_logs_execution_time(self, mock_logger): + """Test timing decorator logs execution time""" + @timing + def test_func(): + time.sleep(0.01) + return "result" + + result = test_func() + assert result == "result" + mock_logger.info.assert_called_once() + assert "executed in" in mock_logger.info.call_args[0][0] + + @patch('aitbc.decorators.logger') + def test_timing_preserves_function_name(self, mock_logger): + """Test timing decorator preserves function name""" + @timing + def my_function(): + return "result" + + assert my_function.__name__ == "my_function" + + +class TestCacheResult: + """Tests for cache_result decorator""" + + def test_cache_result_caches_value(self): + """Test cache_result caches function return value""" + call_count = [0] + + @cache_result(ttl=60) + def test_func(x): + call_count[0] += 1 + return x * 2 + + result1 = test_func(5) + result2 = test_func(5) + + assert result1 == 10 + assert result2 == 10 + assert call_count[0] == 1 # Only called once due to cache + + def test_cache_result_different_args(self): + """Test cache_result with different arguments""" + call_count = [0] + + @cache_result(ttl=60) + def test_func(x): + call_count[0] += 1 + return x * 2 + + test_func(5) + test_func(10) + + assert call_count[0] == 2 # Called twice for different args + + def test_cache_result_ttl_expires(self): + """Test cache_result TTL expires""" + call_count = [0] + + @cache_result(ttl=0.1) # 100ms TTL + def test_func(x): + call_count[0] += 1 + return x * 2 + + test_func(5) + time.sleep(0.15) # Wait for TTL to expire + test_func(5) + + assert call_count[0] == 2 # Called again after TTL expired + + def test_cache_result_with_kwargs(self): + """Test cache_result with keyword arguments""" + call_count = [0] + + @cache_result(ttl=60) + def test_func(x, y=10): + call_count[0] += 1 + return x + y + + test_func(5, y=10) + test_func(5, y=10) + + assert call_count[0] == 1 # Cached + + +class TestValidateArgs: + """Tests for validate_args decorator""" + + def test_validate_args_passes_valid(self): + """Test validate_args passes when validators succeed""" + def validator(x): + if x < 0: + raise ValueError("Must be positive") + + @validate_args(validator) + def test_func(x): + return x * 2 + + result = test_func(5) + assert result == 10 + + def test_validate_args_fails_invalid(self): + """Test validate_args fails when validators raise error""" + def validator(x): + if x < 0: + raise ValueError("Must be positive") + + @validate_args(validator) + def test_func(x): + return x * 2 + + with pytest.raises(ValueError): + test_func(-5) + + def test_validate_args_multiple_validators(self): + """Test validate_args with multiple validators""" + def validator1(x): + if x < 0: + raise ValueError("Must be positive") + + def validator2(x): + if x > 100: + raise ValueError("Must be <= 100") + + @validate_args(validator1, validator2) + def test_func(x): + return x * 2 + + with pytest.raises(ValueError): + test_func(150) + + +class TestHandleExceptions: + """Tests for handle_exceptions decorator""" + + @patch('aitbc.decorators.logger') + def test_handle_exceptions_returns_default(self, mock_logger): + """Test handle_exceptions returns default on exception""" + @handle_exceptions(default_return="error") + def test_func(): + raise ValueError("fail") + + result = test_func() + assert result == "error" + mock_logger.error.assert_called_once() + + @patch('aitbc.decorators.logger') + def test_handle_exceptions_no_logging(self, mock_logger): + """Test handle_exceptions with logging disabled""" + @handle_exceptions(default_return="error", log_errors=False) + def test_func(): + raise ValueError("fail") + + result = test_func() + assert result == "error" + mock_logger.error.assert_not_called() + + def test_handle_exceptions_raises_on_specified(self): + """Test handle_exceptions still raises specified exceptions""" + @handle_exceptions(default_return="error", raise_on=(ValueError,)) + def test_func(): + raise ValueError("fail") + + with pytest.raises(ValueError): + test_func() + + def test_handle_exceptions_passes_on_success(self): + """Test handle_exceptions passes through successful return""" + @handle_exceptions(default_return="error") + def test_func(): + return "success" + + result = test_func() + assert result == "success" + + +class TestAsyncTiming: + """Tests for async_timing decorator""" + + @pytest.mark.asyncio + @patch('aitbc.decorators.logger') + async def test_async_timing_logs_execution_time(self, mock_logger): + """Test async_timing decorator logs execution time""" + @async_timing + async def test_func(): + await asyncio.sleep(0.01) + return "result" + + import asyncio + result = await test_func() + assert result == "result" + mock_logger.info.assert_called_once() + assert "executed in" in mock_logger.info.call_args[0][0] + + @pytest.mark.asyncio + async def test_async_timing_preserves_function_name(self): + """Test async_timing decorator preserves function name""" + @async_timing + async def my_function(): + return "result" + + assert my_function.__name__ == "my_function" diff --git a/tests/test_health_checks.py b/tests/test_health_checks.py new file mode 100644 index 00000000..8471bedd --- /dev/null +++ b/tests/test_health_checks.py @@ -0,0 +1,218 @@ +""" +Tests for health check utilities +""" + +import pytest +from unittest.mock import patch, Mock +from datetime import datetime + +from aitbc.health_checks import ( + HealthStatus, + HealthCheck, + HealthChecker, + create_basic_health_check, +) + + +class TestHealthStatus: + """Tests for HealthStatus enum""" + + def test_health_status_values(self): + """Test HealthStatus enum values""" + assert HealthStatus.HEALTHY.value == "healthy" + assert HealthStatus.DEGRADED.value == "degraded" + assert HealthStatus.UNHEALTHY.value == "unhealthy" + + +class TestHealthCheck: + """Tests for HealthCheck dataclass""" + + def test_health_check_creation(self): + """Test HealthCheck dataclass creation""" + check = HealthCheck( + service="test-service", + status=HealthStatus.HEALTHY, + message="All good", + timestamp=datetime.now(), + details={"key": "value"} + ) + assert check.service == "test-service" + assert check.status == HealthStatus.HEALTHY + assert check.message == "All good" + assert check.details == {"key": "value"} + + def test_health_check_without_details(self): + """Test HealthCheck without optional details""" + check = HealthCheck( + service="test-service", + status=HealthStatus.HEALTHY, + message="All good", + timestamp=datetime.now() + ) + assert check.details is None + + +class TestHealthChecker: + """Tests for HealthChecker""" + + def test_health_checker_initialization(self): + """Test HealthChecker initialization""" + checker = HealthChecker("test-service") + assert checker.service_name == "test-service" + assert checker._checks == {} + assert checker._last_check is None + + def test_register_check(self): + """Test registering a health check""" + checker = HealthChecker("test-service") + + def mock_check(): + return HealthStatus.HEALTHY, "OK", {} + + checker.register_check("memory", mock_check) + assert "memory" in checker._checks + assert checker._checks["memory"] == mock_check + + @patch('aitbc.health_checks.logger') + def test_register_check_logs(self, mock_logger): + """Test register_check logs registration""" + checker = HealthChecker("test-service") + + def mock_check(): + return HealthStatus.HEALTHY, "OK", {} + + checker.register_check("memory", mock_check) + mock_logger.info.assert_called_once() + assert "memory" in mock_logger.info.call_args[0][0] + + def test_run_checks_all_healthy(self): + """Test run_checks when all checks pass""" + checker = HealthChecker("test-service") + + def mock_check(): + return HealthStatus.HEALTHY, "OK", {} + + checker.register_check("check1", mock_check) + checker.register_check("check2", mock_check) + + result = checker.run_checks() + + assert result.service == "test-service" + assert result.status == HealthStatus.HEALTHY + assert result.message == "All health checks passed" + assert result.details is not None + assert len(result.details) == 2 + + def test_run_checks_one_degraded(self): + """Test run_checks with one degraded check""" + checker = HealthChecker("test-service") + + def healthy_check(): + return HealthStatus.HEALTHY, "OK", {} + + def degraded_check(): + return HealthStatus.DEGRADED, "Warning", {} + + checker.register_check("healthy", healthy_check) + checker.register_check("degraded", degraded_check) + + result = checker.run_checks() + + assert result.status == HealthStatus.DEGRADED + assert "degraded" in result.message + + def test_run_checks_one_unhealthy(self): + """Test run_checks with one unhealthy check""" + checker = HealthChecker("test-service") + + def healthy_check(): + return HealthStatus.HEALTHY, "OK", {} + + def unhealthy_check(): + return HealthStatus.UNHEALTHY, "Error", {} + + checker.register_check("healthy", healthy_check) + checker.register_check("unhealthy", unhealthy_check) + + result = checker.run_checks() + + assert result.status == HealthStatus.UNHEALTHY + assert "unhealthy" in result.message + + @patch('aitbc.health_checks.logger') + def test_run_checks_with_exception(self, mock_logger): + """Test run_checks handles exceptions in checks""" + checker = HealthChecker("test-service") + + def failing_check(): + raise ValueError("Check failed") + + checker.register_check("failing", failing_check) + + result = checker.run_checks() + + assert result.status == HealthStatus.UNHEALTHY + assert "failing" in result.message + mock_logger.error.assert_called_once() + + def test_get_last_check_before_run(self): + """Test get_last_check returns None before any check run""" + checker = HealthChecker("test-service") + assert checker.get_last_check() is None + + def test_get_last_check_after_run(self): + """Test get_last_check returns last check result""" + checker = HealthChecker("test-service") + + def mock_check(): + return HealthStatus.HEALTHY, "OK", {} + + checker.register_check("check1", mock_check) + checker.run_checks() + + last_check = checker.get_last_check() + assert last_check is not None + assert last_check.service == "test-service" + + def test_get_health_dict(self): + """Test get_health_dict returns dictionary representation""" + checker = HealthChecker("test-service") + + def mock_check(): + return HealthStatus.HEALTHY, "OK", {"key": "value"} + + checker.register_check("check1", mock_check) + health_dict = checker.get_health_dict() + + assert isinstance(health_dict, dict) + assert "service" in health_dict + assert "status" in health_dict + assert "message" in health_dict + assert "timestamp" in health_dict + assert health_dict["service"] == "test-service" + + +class TestCreateBasicHealthCheck: + """Tests for create_basic_health_check""" + + def test_create_basic_health_check(self): + """Test create_basic_health_check returns HealthChecker""" + checker = create_basic_health_check("test-service") + assert isinstance(checker, HealthChecker) + assert checker.service_name == "test-service" + + def test_create_basic_health_check_without_psutil(self): + """Test create_basic_health_check handles psutil ImportError""" + # Skip this test as psutil import handling is complex to mock + pytest.skip("psutil import handling requires complex mocking") + + def test_basic_health_check_has_checks(self): + """Test basic health check has registered checks when psutil available""" + try: + import psutil + checker = create_basic_health_check("test-service") + # Should have memory and disk checks if psutil is available + assert len(checker._checks) > 0 + except ImportError: + # Skip if psutil not available + pass diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 00000000..69057244 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,250 @@ +""" +Tests for AITBC metrics module +""" + +import pytest +import asyncio +from unittest.mock import patch, Mock + +from aitbc.metrics import ( + service_info, + block_processing_duration, + block_height, + block_validation_duration, + block_propagation_duration, + job_submission_duration, + job_processing_duration, + job_queue_duration, + job_execution_duration, + jobs_total, + jobs_failed_total, + jobs_in_queue, + http_requests_total, + http_request_duration, + service_uptime_seconds, + service_restart_count, + track_block_processing, + track_job_processing, + track_http_request, + update_block_height, + update_jobs_in_queue, + increment_service_restarts, + metrics_app, + setup_service_info, +) + + +class TestMetricsDefinitions: + """Tests for Prometheus metrics definitions""" + + def test_service_info_exists(self): + """Test service_info metric is defined""" + assert service_info is not None + assert service_info._name == 'service_info' + + def test_block_processing_duration_exists(self): + """Test block_processing_duration metric is defined""" + assert block_processing_duration is not None + assert block_processing_duration._name == 'block_processing_duration_seconds' + + def test_block_height_exists(self): + """Test block_height metric is defined""" + assert block_height is not None + assert block_height._name == 'block_height' + + def test_block_validation_duration_exists(self): + """Test block_validation_duration metric is defined""" + assert block_validation_duration is not None + assert block_validation_duration._name == 'block_validation_duration_seconds' + + def test_block_propagation_duration_exists(self): + """Test block_propagation_duration metric is defined""" + assert block_propagation_duration is not None + assert block_propagation_duration._name == 'block_propagation_duration_seconds' + + def test_job_submission_duration_exists(self): + """Test job_submission_duration metric is defined""" + assert job_submission_duration is not None + assert job_submission_duration._name == 'job_submission_duration_seconds' + + def test_job_processing_duration_exists(self): + """Test job_processing_duration metric is defined""" + assert job_processing_duration is not None + assert job_processing_duration._name == 'job_processing_duration_seconds' + + def test_job_queue_duration_exists(self): + """Test job_queue_duration metric is defined""" + assert job_queue_duration is not None + assert job_queue_duration._name == 'job_queue_duration_seconds' + + def test_job_execution_duration_exists(self): + """Test job_execution_duration metric is defined""" + assert job_execution_duration is not None + assert job_execution_duration._name == 'job_execution_duration_seconds' + + def test_jobs_total_exists(self): + """Test jobs_total metric is defined""" + assert jobs_total is not None + assert jobs_total._name == 'jobs' + + def test_jobs_failed_total_exists(self): + """Test jobs_failed_total metric is defined""" + assert jobs_failed_total is not None + assert jobs_failed_total._name == 'jobs_failed' + + def test_jobs_in_queue_exists(self): + """Test jobs_in_queue metric is defined""" + assert jobs_in_queue is not None + assert jobs_in_queue._name == 'jobs_in_queue' + + def test_http_requests_total_exists(self): + """Test http_requests_total metric is defined""" + assert http_requests_total is not None + assert http_requests_total._name == 'http_requests' + + def test_http_request_duration_exists(self): + """Test http_request_duration metric is defined""" + assert http_request_duration is not None + assert http_request_duration._name == 'http_request_duration_seconds' + + def test_service_uptime_seconds_exists(self): + """Test service_uptime_seconds metric is defined""" + assert service_uptime_seconds is not None + assert service_uptime_seconds._name == 'service_uptime_seconds' + + def test_service_restart_count_exists(self): + """Test service_restart_count metric is defined""" + assert service_restart_count is not None + assert service_restart_count._name == 'service_restart_count' + + +class TestHelperFunctions: + """Tests for metrics helper functions""" + + def test_update_block_height(self): + """Test update_block_height sets metric""" + update_block_height(100) + # Metric should be set, but we can't easily verify the value + # This test ensures the function doesn't raise an error + assert True + + def test_update_jobs_in_queue(self): + """Test update_jobs_in_queue sets metric""" + update_jobs_in_queue(50) + # Metric should be set, but we can't easily verify the value + # This test ensures the function doesn't raise an error + assert True + + def test_increment_service_restarts(self): + """Test increment_service_restarts increments counter""" + increment_service_restarts() + # Counter should be incremented, but we can't easily verify the value + # This test ensures the function doesn't raise an error + assert True + + def test_setup_service_info(self): + """Test setup_service_info sets service info""" + setup_service_info("test-service", "1.0.0") + # Info should be set, but we can't easily verify the value + # This test ensures the function doesn't raise an error + assert True + + +class TestDecorators: + """Tests for metrics tracking decorators""" + + @pytest.mark.asyncio + async def test_track_block_processing_success(self): + """Test track_block_processing decorator on successful execution""" + @track_block_processing + async def process_block(): + return "block_processed" + + result = await process_block() + assert result == "block_processed" + # Decorator should have observed the duration + assert True + + @pytest.mark.asyncio + async def test_track_block_processing_failure(self): + """Test track_block_processing decorator on exception""" + @track_block_processing + async def process_block(): + raise ValueError("block error") + + with pytest.raises(ValueError): + await process_block() + # Decorator should have observed the duration even on failure + assert True + + @pytest.mark.asyncio + async def test_track_job_processing_success(self): + """Test track_job_processing decorator on successful execution""" + @track_job_processing + async def process_job(): + return "job_completed" + + result = await process_job() + assert result == "job_completed" + # Decorator should have observed duration and incremented jobs_total + assert True + + @pytest.mark.asyncio + async def test_track_job_processing_failure(self): + """Test track_job_processing decorator on exception""" + @track_job_processing + async def process_job(): + raise ValueError("job error") + + with pytest.raises(ValueError): + await process_job() + # Decorator should have observed duration and incremented failure counters + assert True + + @pytest.mark.asyncio + async def test_track_http_request_success(self): + """Test track_http_request decorator on successful execution""" + mock_response = Mock() + mock_response.status_code = 200 + + @track_http_request + async def handle_request(): + return mock_response + + result = await handle_request() + assert result.status_code == 200 + # Decorator should have observed duration and incremented http_requests_total + assert True + + @pytest.mark.asyncio + async def test_track_http_request_failure(self): + """Test track_http_request decorator on exception""" + @track_http_request + async def handle_request(): + raise ValueError("request error") + + with pytest.raises(ValueError): + await handle_request() + # Decorator should have observed duration and incremented http_requests_total with 500 + assert True + + @pytest.mark.asyncio + async def test_track_http_request_without_status_code(self): + """Test track_http_request with response without status_code""" + @track_http_request + async def handle_request(): + return "success" # No status_code attribute + + result = await handle_request() + assert result == "success" + # Decorator should have observed duration but not incremented http_requests_total + assert True + + +class TestMetricsApp: + """Tests for metrics ASGI app""" + + def test_metrics_app_exists(self): + """Test metrics_app is created""" + assert metrics_app is not None + assert hasattr(metrics_app, '__call__') diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 00000000..60985f67 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,266 @@ +""" +Tests for AITBC middleware modules +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +from fastapi import Request, Response, HTTPException +from fastapi.responses import JSONResponse +from starlette.types import ASGIApp + +from aitbc.middleware.performance import PerformanceLoggingMiddleware +from aitbc.middleware.request_id import RequestIDMiddleware +from aitbc.middleware.error_handler import ErrorHandlerMiddleware + + +class TestPerformanceLoggingMiddleware: + """Tests for PerformanceLoggingMiddleware""" + + @pytest.mark.asyncio + async def test_dispatch_adds_performance_header(self): + """Test that middleware adds X-Process-Time header""" + app = Mock(spec=ASGIApp) + middleware = PerformanceLoggingMiddleware(app) + + request = Mock(spec=Request) + request.method = "GET" + request.url = Mock() + request.url.path = "/test" + + response = Mock(spec=Response) + response.status_code = 200 + response.headers = {} + + call_next = AsyncMock(return_value=response) + + result = await middleware.dispatch(request, call_next) + + assert "X-Process-Time" in result.headers + assert float(result.headers["X-Process-Time"]) >= 0 + + @pytest.mark.asyncio + async def test_dispatch_logs_performance_metrics(self): + """Test that middleware logs performance metrics""" + app = Mock(spec=ASGIApp) + middleware = PerformanceLoggingMiddleware(app) + + request = Mock(spec=Request) + request.method = "POST" + request.url = Mock() + request.url.path = "/api/test" + + response = Mock(spec=Response) + response.status_code = 201 + response.headers = {} + + call_next = AsyncMock(return_value=response) + + with patch('aitbc.middleware.performance.logger') as mock_logger: + await middleware.dispatch(request, call_next) + mock_logger.info.assert_called_once() + assert "Request performance" in mock_logger.info.call_args[0][0] + + @pytest.mark.asyncio + async def test_dispatch_measures_time_correctly(self): + """Test that middleware measures request duration accurately""" + app = Mock(spec=ASGIApp) + middleware = PerformanceLoggingMiddleware(app) + + request = Mock(spec=Request) + request.method = "GET" + request.url = Mock() + request.url.path = "/test" + + response = Mock(spec=Response) + response.status_code = 200 + response.headers = {} + + call_next = AsyncMock(return_value=response) + + result = await middleware.dispatch(request, call_next) + + process_time = float(result.headers["X-Process-Time"]) + assert 0 <= process_time < 1.0 # Should complete in under 1 second + + +class TestRequestIDMiddleware: + """Tests for RequestIDMiddleware""" + + @pytest.mark.asyncio + async def test_dispatch_generates_request_id_when_missing(self): + """Test that middleware generates request ID when not in headers""" + app = Mock(spec=ASGIApp) + middleware = RequestIDMiddleware(app) + + request = Mock(spec=Request) + request.headers = {} + request.method = "GET" + request.url = Mock() + request.url.path = "/test" + request.client = Mock() + request.client.host = "127.0.0.1" + request.state = Mock() + + response = Mock(spec=Response) + response.headers = {} + response.status_code = 200 + + call_next = AsyncMock(return_value=response) + + result = await middleware.dispatch(request, call_next) + + assert "X-Request-ID" in result.headers + assert len(result.headers["X-Request-ID"]) > 0 + assert request.state.request_id == result.headers["X-Request-ID"] + + @pytest.mark.asyncio + async def test_dispatch_uses_existing_request_id_from_header(self): + """Test that middleware uses existing request ID from header""" + app = Mock(spec=ASGIApp) + middleware = RequestIDMiddleware(app) + + existing_id = "test-request-id-123" + request = Mock(spec=Request) + request.headers = {"X-Request-ID": existing_id} + request.method = "POST" + request.url = Mock() + request.url.path = "/api/test" + request.client = Mock() + request.client.host = "192.168.1.1" + request.state = Mock() + + response = Mock(spec=Response) + response.headers = {} + response.status_code = 201 + + call_next = AsyncMock(return_value=response) + + result = await middleware.dispatch(request, call_next) + + assert result.headers["X-Request-ID"] == existing_id + assert request.state.request_id == existing_id + + @pytest.mark.asyncio + async def test_dispatch_logs_request_info(self): + """Test that middleware logs request information""" + app = Mock(spec=ASGIApp) + middleware = RequestIDMiddleware(app) + + request = Mock(spec=Request) + request.headers = {} + request.method = "GET" + request.url = Mock() + request.url.path = "/test" + request.client = Mock() + request.client.host = "127.0.0.1" + request.state = Mock() + + response = Mock(spec=Response) + response.headers = {} + response.status_code = 200 + + call_next = AsyncMock(return_value=response) + + with patch('aitbc.middleware.request_id.logger') as mock_logger: + await middleware.dispatch(request, call_next) + assert mock_logger.info.call_count >= 2 # Logs start and completion + + +class TestErrorHandlerMiddleware: + """Tests for ErrorHandlerMiddleware""" + + @pytest.mark.asyncio + async def test_dispatch_passes_through_normal_response(self): + """Test that middleware passes through normal responses""" + app = Mock(spec=ASGIApp) + middleware = ErrorHandlerMiddleware(app) + + request = Mock(spec=Request) + request.url = Mock() + request.url.path = "/test" + request.method = "GET" + + response = Mock(spec=Response) + response.status_code = 200 + + call_next = AsyncMock(return_value=response) + + result = await middleware.dispatch(request, call_next) + + assert result == response + + @pytest.mark.asyncio + async def test_dispatch_handles_http_exception(self): + """Test that middleware handles HTTPException""" + app = Mock(spec=ASGIApp) + middleware = ErrorHandlerMiddleware(app) + + request = Mock(spec=Request) + request.url = Mock() + request.url.path = "/api/error" + request.method = "GET" + + exception = HTTPException(status_code=404, detail="Not found") + call_next = AsyncMock(side_effect=exception) + + result = await middleware.dispatch(request, call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 404 + content = result.body.decode() if hasattr(result, 'body') else {} + assert "error" in result.body.decode() if hasattr(result, 'body') else True + + @pytest.mark.asyncio + async def test_dispatch_handles_generic_exception(self): + """Test that middleware handles generic exceptions""" + app = Mock(spec=ASGIApp) + middleware = ErrorHandlerMiddleware(app) + + request = Mock(spec=Request) + request.url = Mock() + request.url.path = "/api/crash" + request.method = "POST" + + exception = ValueError("Something went wrong") + call_next = AsyncMock(side_effect=exception) + + result = await middleware.dispatch(request, call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + @pytest.mark.asyncio + async def test_dispatch_logs_http_exception(self): + """Test that middleware logs HTTPException""" + app = Mock(spec=ASGIApp) + middleware = ErrorHandlerMiddleware(app) + + request = Mock(spec=Request) + request.url = Mock() + request.url.path = "/test" + request.method = "GET" + + exception = HTTPException(status_code=400, detail="Bad request") + call_next = AsyncMock(side_effect=exception) + + with patch('aitbc.middleware.error_handler.logger') as mock_logger: + await middleware.dispatch(request, call_next) + mock_logger.warning.assert_called_once() + + @pytest.mark.asyncio + async def test_dispatch_logs_generic_exception(self): + """Test that middleware logs generic exceptions""" + app = Mock(spec=ASGIApp) + middleware = ErrorHandlerMiddleware(app) + + request = Mock(spec=Request) + request.url = Mock() + request.url.path = "/test" + request.method = "GET" + + exception = RuntimeError("Runtime error") + call_next = AsyncMock(side_effect=exception) + + with patch('aitbc.middleware.error_handler.logger') as mock_logger: + await middleware.dispatch(request, call_next) + mock_logger.error.assert_called_once() diff --git a/tests/test_security_headers.py b/tests/test_security_headers.py new file mode 100644 index 00000000..13db436c --- /dev/null +++ b/tests/test_security_headers.py @@ -0,0 +1,287 @@ +""" +Tests for security headers and CORS utilities +""" + +import pytest + +from aitbc.security_headers import ( + SecurityHeaders, + CORSConfig, + SecurityHeadersMiddleware, + CORSMiddleware, + create_production_security_headers, + create_development_security_headers, + create_strict_cors_config, + create_permissive_cors_config, +) + + +class TestSecurityHeaders: + """Tests for SecurityHeaders dataclass""" + + def test_default_security_headers(self): + """Test default security headers values""" + headers = SecurityHeaders() + assert headers.X_Content_Type_Options == "nosniff" + assert headers.X_Frame_Options == "DENY" + assert headers.X_XSS_Protection == "1; mode=block" + assert headers.Strict_Transport_Security == "max-age=31536000; includeSubDomains" + assert headers.Content_Security_Policy == "default-src 'self'" + assert headers.Referrer_Policy == "strict-origin-when-cross-origin" + assert headers.Permissions_Policy == "" + assert headers.Cache_Control == "no-cache, no-store, must-revalidate" + assert headers.Pragma == "no-cache" + + def test_custom_security_headers(self): + """Test custom security headers values""" + headers = SecurityHeaders( + X_Frame_Options="SAMEORIGIN", + Content_Security_Policy="default-src 'self' https://example.com" + ) + assert headers.X_Frame_Options == "SAMEORIGIN" + assert headers.Content_Security_Policy == "default-src 'self' https://example.com" + + +class TestCORSConfig: + """Tests for CORSConfig dataclass""" + + def test_default_cors_config(self): + """Test default CORS config values""" + config = CORSConfig( + allow_origins=["http://localhost:3000"], + allow_methods=["GET", "POST"], + allow_headers=["Content-Type"] + ) + assert config.allow_origins == ["http://localhost:3000"] + assert config.allow_methods == ["GET", "POST"] + assert config.allow_credentials is False + assert config.expose_headers is None + assert config.max_age == 3600 + + def test_custom_cors_config(self): + """Test custom CORS config values""" + config = CORSConfig( + allow_origins=["*"], + allow_methods=["GET", "POST", "PUT"], + allow_headers=["Content-Type"], + allow_credentials=True, + expose_headers=["X-Custom-Header"], + max_age=7200 + ) + assert config.allow_origins == ["*"] + assert config.allow_credentials is True + assert config.expose_headers == ["X-Custom-Header"] + assert config.max_age == 7200 + + +class TestSecurityHeadersMiddleware: + """Tests for SecurityHeadersMiddleware""" + + def test_initialization_with_default_headers(self): + """Test middleware initialization with default headers""" + middleware = SecurityHeadersMiddleware() + assert middleware.headers is not None + assert middleware.headers.X_Frame_Options == "DENY" + + def test_initialization_with_custom_headers(self): + """Test middleware initialization with custom headers""" + custom_headers = SecurityHeaders(X_Frame_Options="SAMEORIGIN") + middleware = SecurityHeadersMiddleware(custom_headers) + assert middleware.headers.X_Frame_Options == "SAMEORIGIN" + + def test_get_headers(self): + """Test get_headers returns dictionary""" + middleware = SecurityHeadersMiddleware() + headers = middleware.get_headers() + assert isinstance(headers, dict) + assert "X-Content-Type-Options" in headers + assert "X-Frame-Options" in headers + assert "X-XSS-Protection" in headers + assert "Strict-Transport-Security" in headers + assert "Content-Security-Policy" in headers + assert "Referrer-Policy" in headers + assert "Permissions-Policy" in headers + assert "Cache-Control" in headers + assert "Pragma" in headers + + def test_get_headers_values(self): + """Test get_headers returns correct values""" + middleware = SecurityHeadersMiddleware() + headers = middleware.get_headers() + assert headers["X-Content-Type-Options"] == "nosniff" + assert headers["X-Frame-Options"] == "DENY" + assert headers["X-XSS-Protection"] == "1; mode=block" + + def test_apply_to_response(self): + """Test apply_to_response adds security headers""" + middleware = SecurityHeadersMiddleware() + response_headers = {"Content-Type": "application/json"} + result = middleware.apply_to_response(response_headers) + + assert "X-Content-Type-Options" in result + assert "X-Frame-Options" in result + assert result["Content-Type"] == "application/json" + + def test_apply_to_response_overwrites_existing(self): + """Test apply_to_response overwrites existing security headers""" + middleware = SecurityHeadersMiddleware() + response_headers = {"X-Frame-Options": "ALLOW"} + result = middleware.apply_to_response(response_headers) + + assert result["X-Frame-Options"] == "DENY" + + +class TestCORSMiddleware: + """Tests for CORSMiddleware""" + + def test_initialization(self): + """Test CORS middleware initialization""" + config = CORSConfig( + allow_origins=["http://localhost:3000"], + allow_methods=["GET", "POST"], + allow_headers=["Content-Type"] + ) + middleware = CORSMiddleware(config) + assert middleware.config == config + + def test_get_cors_headers_allowed_origin(self): + """Test get_cors_headers with allowed origin""" + config = CORSConfig( + allow_origins=["http://localhost:3000"], + allow_methods=["GET", "POST"], + allow_headers=["Content-Type"] + ) + middleware = CORSMiddleware(config) + headers = middleware.get_cors_headers("http://localhost:3000") + + assert headers["Access-Control-Allow-Origin"] == "http://localhost:3000" + assert headers["Access-Control-Allow-Methods"] == "GET, POST" + assert headers["Access-Control-Allow-Headers"] == "Content-Type" + assert headers["Access-Control-Max-Age"] == "3600" + + def test_get_cors_headers_disallowed_origin(self): + """Test get_cors_headers with disallowed origin""" + config = CORSConfig( + allow_origins=["http://localhost:3000"], + allow_methods=["GET", "POST"], + allow_headers=["Content-Type"] + ) + middleware = CORSMiddleware(config) + headers = middleware.get_cors_headers("http://evil.com") + + assert headers == {} + + def test_get_cors_headers_wildcard_origin(self): + """Test get_cors_headers with wildcard origin""" + config = CORSConfig( + allow_origins=["*"], + allow_methods=["GET", "POST"], + allow_headers=["Content-Type"] + ) + middleware = CORSMiddleware(config) + headers = middleware.get_cors_headers("http://any-origin.com") + + assert headers["Access-Control-Allow-Origin"] == "http://any-origin.com" + + def test_get_cors_headers_with_credentials(self): + """Test get_cors_headers with credentials enabled""" + config = CORSConfig( + allow_origins=["http://localhost:3000"], + allow_methods=["GET", "POST"], + allow_headers=["Content-Type"], + allow_credentials=True + ) + middleware = CORSMiddleware(config) + headers = middleware.get_cors_headers("http://localhost:3000") + + assert headers["Access-Control-Allow-Credentials"] == "true" + + def test_get_cors_headers_with_expose_headers(self): + """Test get_cors_headers with expose headers""" + config = CORSConfig( + allow_origins=["http://localhost:3000"], + allow_methods=["GET", "POST"], + allow_headers=["Content-Type"], + expose_headers=["X-Request-ID"] + ) + middleware = CORSMiddleware(config) + headers = middleware.get_cors_headers("http://localhost:3000") + + assert headers["Access-Control-Expose-Headers"] == "X-Request-ID" + + def test_is_origin_allowed_wildcard(self): + """Test _is_origin_allowed with wildcard""" + config = CORSConfig( + allow_origins=["*"], + allow_methods=["GET"], + allow_headers=["Content-Type"] + ) + middleware = CORSMiddleware(config) + assert middleware._is_origin_allowed("http://any-origin.com") is True + + def test_is_origin_allowed_specific(self): + """Test _is_origin_allowed with specific origin""" + config = CORSConfig( + allow_origins=["http://localhost:3000"], + allow_methods=["GET"], + allow_headers=["Content-Type"] + ) + middleware = CORSMiddleware(config) + assert middleware._is_origin_allowed("http://localhost:3000") is True + assert middleware._is_origin_allowed("http://evil.com") is False + + def test_is_preflight_request(self): + """Test is_preflight_request""" + config = CORSConfig( + allow_origins=["*"], + allow_methods=["GET"], + allow_headers=["Content-Type"] + ) + middleware = CORSMiddleware(config) + assert middleware.is_preflight_request("OPTIONS") is True + assert middleware.is_preflight_request("GET") is False + assert middleware.is_preflight_request("options") is True + + +class TestFactoryFunctions: + """Tests for factory functions""" + + def test_create_production_security_headers(self): + """Test create_production_security_headers""" + headers = create_production_security_headers() + assert headers.X_Frame_Options == "DENY" + assert "preload" in headers.Strict_Transport_Security + assert "unsafe-inline" in headers.Content_Security_Policy + assert "geolocation=()" in headers.Permissions_Policy + + def test_create_development_security_headers(self): + """Test create_development_security_headers""" + headers = create_development_security_headers() + assert headers.X_Frame_Options == "SAMEORIGIN" + assert headers.Strict_Transport_Security == "max-age=3600" + assert headers.Permissions_Policy == "" + assert headers.Cache_Control == "no-cache" + + def test_create_strict_cors_config(self): + """Test create_strict_cors_config""" + config = create_strict_cors_config(["http://localhost:3000"]) + assert "http://localhost:3000" in config.allow_origins + assert "GET" in config.allow_methods + assert "POST" in config.allow_methods + assert "PUT" in config.allow_methods + assert "DELETE" in config.allow_methods + assert "PATCH" in config.allow_methods + assert config.allow_credentials is True + assert "X-Request-ID" in config.expose_headers + assert config.max_age == 3600 + + def test_create_permissive_cors_config(self): + """Test create_permissive_cors_config""" + config = create_permissive_cors_config() + assert "*" in config.allow_origins + assert "GET" in config.allow_methods + assert "POST" in config.allow_methods + assert "*" in config.allow_headers + assert config.allow_credentials is False + assert "*" in config.expose_headers + assert config.max_age == 86400 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..23110e73 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,345 @@ +""" +Tests for AITBC utility modules +""" + +import os +import pytest +from pathlib import Path +from unittest.mock import patch, Mock + +from aitbc.utils.paths import ( + get_data_path, + get_config_path, + get_log_path, + get_repo_path, + ensure_dir, + ensure_file_dir, + resolve_path, + get_keystore_path, + get_blockchain_data_path, + get_marketplace_data_path, +) +from aitbc.utils.env import ( + get_env_var, + get_required_env_var, + get_bool_env_var, + get_int_env_var, + get_float_env_var, + get_list_env_var, +) +from aitbc.utils.json_utils import ( + load_json, + save_json, + merge_json, + json_to_string, + string_to_json, + get_nested_value, + set_nested_value, + flatten_json, +) +from aitbc.exceptions import ConfigurationError + + +class TestPaths: + """Tests for path utility functions""" + + def test_get_data_path_no_subpath(self): + """Test get_data_path without subpath""" + result = get_data_path() + assert isinstance(result, Path) + + def test_get_data_path_with_subpath(self): + """Test get_data_path with subpath""" + result = get_data_path("test") + assert isinstance(result, Path) + assert str(result).endswith("test") + + def test_get_config_path(self): + """Test get_config_path""" + result = get_config_path("config.yaml") + assert isinstance(result, Path) + assert str(result).endswith("config.yaml") + + def test_get_log_path(self): + """Test get_log_path""" + result = get_log_path("app.log") + assert isinstance(result, Path) + assert str(result).endswith("app.log") + + def test_get_repo_path_no_subpath(self): + """Test get_repo_path without subpath""" + result = get_repo_path() + assert isinstance(result, Path) + + def test_get_repo_path_with_subpath(self): + """Test get_repo_path with subpath""" + result = get_repo_path("src") + assert isinstance(result, Path) + assert str(result).endswith("src") + + def test_ensure_dir(self): + """Test ensure_dir creates directory""" + import tempfile + with tempfile.TemporaryDirectory() as tmpdir: + test_path = Path(tmpdir) / "test" / "nested" + result = ensure_dir(test_path) + assert result.exists() + assert result.is_dir() + + def test_ensure_file_dir(self): + """Test ensure_file_dir creates parent directory""" + import tempfile + with tempfile.TemporaryDirectory() as tmpdir: + test_path = Path(tmpdir) / "test" / "nested" / "file.txt" + result = ensure_file_dir(test_path) + assert result.exists() + assert result.is_dir() + + def test_resolve_path_absolute(self): + """Test resolve_path with absolute path""" + result = resolve_path("/tmp/test") + assert result.is_absolute() + + def test_resolve_path_relative(self): + """Test resolve_path with relative path""" + result = resolve_path("test") + assert result.is_absolute() # Relative paths are resolved to absolute + + def test_get_keystore_path_no_wallet(self): + """Test get_keystore_path without wallet name""" + result = get_keystore_path() + assert isinstance(result, Path) + assert str(result).endswith("keystore") + + def test_get_keystore_path_with_wallet(self): + """Test get_keystore_path with wallet name""" + result = get_keystore_path("mywallet") + assert isinstance(result, Path) + assert str(result).endswith("mywallet.json") + + def test_get_blockchain_data_path_default(self): + """Test get_blockchain_data_path with default chain""" + result = get_blockchain_data_path() + assert isinstance(result, Path) + assert str(result).endswith("ait-mainnet") + + def test_get_blockchain_data_path_custom(self): + """Test get_blockchain_data_path with custom chain""" + result = get_blockchain_data_path("custom-chain") + assert isinstance(result, Path) + assert str(result).endswith("custom-chain") + + def test_get_marketplace_data_path_no_subpath(self): + """Test get_marketplace_data_path without subpath""" + result = get_marketplace_data_path() + assert isinstance(result, Path) + assert str(result).endswith("marketplace") + + def test_get_marketplace_data_path_with_subpath(self): + """Test get_marketplace_data_path with subpath""" + result = get_marketplace_data_path("orders") + assert isinstance(result, Path) + assert str(result).endswith("orders") + + +class TestEnv: + """Tests for environment variable utilities""" + + def test_get_env_var_with_value(self): + """Test get_env_var with set value""" + os.environ["TEST_VAR"] = "test_value" + result = get_env_var("TEST_VAR") + assert result == "test_value" + del os.environ["TEST_VAR"] + + def test_get_env_var_with_default(self): + """Test get_env_var with default value""" + result = get_env_var("NONEXISTENT_VAR", "default") + assert result == "default" + + def test_get_required_env_var_with_value(self): + """Test get_required_env_var with set value""" + os.environ["TEST_VAR"] = "test_value" + result = get_required_env_var("TEST_VAR") + assert result == "test_value" + del os.environ["TEST_VAR"] + + def test_get_required_env_var_without_value(self): + """Test get_required_env_var without value raises error""" + with pytest.raises(ConfigurationError): + get_required_env_var("NONEXISTENT_VAR") + + def test_get_bool_env_var_true(self): + """Test get_bool_env_var with true values""" + for value in ["true", "TRUE", "1", "yes", "YES", "on", "ON"]: + os.environ["TEST_VAR"] = value + assert get_bool_env_var("TEST_VAR") is True + del os.environ["TEST_VAR"] + + def test_get_bool_env_var_false(self): + """Test get_bool_env_var with false values""" + for value in ["false", "FALSE", "0", "no", "NO", "off", "OFF"]: + os.environ["TEST_VAR"] = value + assert get_bool_env_var("TEST_VAR") is False + del os.environ["TEST_VAR"] + + def test_get_bool_env_var_default(self): + """Test get_bool_env_var with default""" + assert get_bool_env_var("NONEXISTENT_VAR", True) is True + assert get_bool_env_var("NONEXISTENT_VAR", False) is False + + def test_get_int_env_var_valid(self): + """Test get_int_env_var with valid integer""" + os.environ["TEST_VAR"] = "42" + assert get_int_env_var("TEST_VAR") == 42 + del os.environ["TEST_VAR"] + + def test_get_int_env_var_invalid(self): + """Test get_int_env_var with invalid value returns default""" + os.environ["TEST_VAR"] = "not_a_number" + assert get_int_env_var("TEST_VAR", 10) == 10 + del os.environ["TEST_VAR"] + + def test_get_int_env_var_default(self): + """Test get_int_env_var with default""" + assert get_int_env_var("NONEXISTENT_VAR", 100) == 100 + + def test_get_float_env_var_valid(self): + """Test get_float_env_var with valid float""" + os.environ["TEST_VAR"] = "3.14" + assert get_float_env_var("TEST_VAR") == 3.14 + del os.environ["TEST_VAR"] + + def test_get_float_env_var_invalid(self): + """Test get_float_env_var with invalid value returns default""" + os.environ["TEST_VAR"] = "not_a_number" + assert get_float_env_var("TEST_VAR", 2.5) == 2.5 + del os.environ["TEST_VAR"] + + def test_get_float_env_var_default(self): + """Test get_float_env_var with default""" + assert get_float_env_var("NONEXISTENT_VAR", 1.5) == 1.5 + + def test_get_list_env_var_valid(self): + """Test get_list_env_var with valid list""" + os.environ["TEST_VAR"] = "item1,item2,item3" + result = get_list_env_var("TEST_VAR") + assert result == ["item1", "item2", "item3"] + del os.environ["TEST_VAR"] + + def test_get_list_env_var_custom_separator(self): + """Test get_list_env_var with custom separator""" + os.environ["TEST_VAR"] = "item1;item2;item3" + result = get_list_env_var("TEST_VAR", separator=";") + assert result == ["item1", "item2", "item3"] + del os.environ["TEST_VAR"] + + def test_get_list_env_var_empty(self): + """Test get_list_env_var with empty value returns default""" + os.environ["TEST_VAR"] = "" + result = get_list_env_var("TEST_VAR", default=["default"]) + assert result == ["default"] + del os.environ["TEST_VAR"] + + def test_get_list_env_var_default(self): + """Test get_list_env_var with default""" + result = get_list_env_var("NONEXISTENT_VAR", default=["a", "b"]) + assert result == ["a", "b"] + + +class TestJsonUtils: + """Tests for JSON utility functions""" + + def test_json_to_string(self): + """Test json_to_string""" + data = {"key": "value"} + result = json_to_string(data) + assert '"key"' in result + assert '"value"' in result + + def test_string_to_json_valid(self): + """Test string_to_json with valid JSON""" + json_str = '{"key": "value"}' + result = string_to_json(json_str) + assert result == {"key": "value"} + + def test_string_to_json_invalid(self): + """Test string_to_json with invalid JSON raises error""" + with pytest.raises(ConfigurationError): + string_to_json("not valid json") + + def test_get_nested_value_found(self): + """Test get_nested_value when key exists""" + data = {"a": {"b": {"c": "value"}}} + result = get_nested_value(data, "a", "b", "c") + assert result == "value" + + def test_get_nested_value_not_found(self): + """Test get_nested_value when key doesn't exist returns default""" + data = {"a": {"b": {"c": "value"}}} + result = get_nested_value(data, "a", "b", "d", default="default") + assert result == "default" + + def test_get_nested_value_default_none(self): + """Test get_nested_value with default None""" + data = {"a": {"b": {"c": "value"}}} + result = get_nested_value(data, "x", "y", "z") + assert result is None + + def test_set_nested_value(self): + """Test set_nested_value""" + data = {} + set_nested_value(data, "a", "b", "c", value="test") + assert data["a"]["b"]["c"] == "test" + + def test_flatten_json(self): + """Test flatten_json""" + data = {"a": {"b": {"c": "value"}}, "d": "simple"} + result = flatten_json(data) + assert "a.b.c" in result + assert result["a.b.c"] == "value" + assert result["d"] == "simple" + + def test_flatten_json_custom_separator(self): + """Test flatten_json with custom separator""" + data = {"a": {"b": "value"}} + result = flatten_json(data, separator="_") + assert "a_b" in result + assert result["a_b"] == "value" + + def test_load_json(self, tmp_path): + """Test load_json""" + test_file = tmp_path / "test.json" + test_file.write_text('{"key": "value"}') + result = load_json(test_file) + assert result == {"key": "value"} + + def test_load_json_not_found(self, tmp_path): + """Test load_json with non-existent file raises error""" + with pytest.raises(ConfigurationError): + load_json(tmp_path / "nonexistent.json") + + def test_load_json_invalid(self, tmp_path): + """Test load_json with invalid JSON raises error""" + test_file = tmp_path / "invalid.json" + test_file.write_text("not valid json") + with pytest.raises(ConfigurationError): + load_json(test_file) + + def test_save_json(self, tmp_path): + """Test save_json""" + test_file = tmp_path / "test.json" + data = {"key": "value"} + save_json(data, test_file) + assert test_file.exists() + result = load_json(test_file) + assert result == data + + def test_merge_json(self, tmp_path): + """Test merge_json""" + file1 = tmp_path / "file1.json" + file2 = tmp_path / "file2.json" + file1.write_text('{"a": 1, "b": 2}') + file2.write_text('{"b": 3, "c": 4}') + result = merge_json(file1, file2) + assert result == {"a": 1, "b": 3, "c": 4}