diff --git a/aitbc/config.py b/aitbc/config.py index 06866fe6..864383ad 100644 --- a/aitbc/config.py +++ b/aitbc/config.py @@ -4,9 +4,9 @@ Base configuration classes for AITBC applications """ from pathlib import Path -from typing import Optional +from typing import Optional, List from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import Field +from pydantic import Field, field_validator from .constants import DATA_DIR, CONFIG_DIR, LOG_DIR, ENV_FILE from .aitbc_logging import get_logger @@ -19,25 +19,25 @@ class BaseAITBCConfig(BaseSettings): Base configuration class for all AITBC applications. Provides common AITBC-specific settings and environment file loading. """ - + model_config = SettingsConfigDict( env_file=str(ENV_FILE), env_file_encoding="utf-8", case_sensitive=False, extra="ignore" ) - + # AITBC system directories data_dir: Path = Field(default=DATA_DIR, description="AITBC data directory") config_dir: Path = Field(default=CONFIG_DIR, description="AITBC configuration directory") log_dir: Path = Field(default=LOG_DIR, description="AITBC log directory") - + # Application settings app_name: str = Field(default="AITBC Application", description="Application name") app_version: str = Field(default="1.0.0", description="Application version") environment: str = Field(default="development", description="Environment (development/staging/production)") debug: bool = Field(default=False, description="Debug mode") - + # Logging settings log_level: str = Field(default="INFO", description="Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)") log_format: str = Field( @@ -45,39 +45,92 @@ class BaseAITBCConfig(BaseSettings): description="Log format string" ) - -class AITBCConfig(BaseAITBCConfig): - """ - Standard AITBC configuration with common settings. - Inherits from BaseAITBCConfig and adds AITBC-specific fields. - """ - # Server settings host: str = Field(default="0.0.0.0", description="Server host address") port: int = Field(default=8000, description="Server port") workers: int = Field(default=1, description="Number of worker processes") - + # Database settings database_url: Optional[str] = Field(default=None, description="Database connection URL") database_pool_size: int = Field(default=10, description="Database connection pool size") - + database_max_overflow: int = Field(default=20, description="Maximum overflow connections") + database_pool_recycle: int = Field(default=3600, description="Connection recycle time in seconds") + database_pool_pre_ping: bool = Field(default=True, description="Test connections before using") + database_echo: bool = Field(default=False, description="Enable SQL query logging") + # Redis settings (if applicable) redis_url: Optional[str] = Field(default=None, description="Redis connection URL") redis_max_connections: int = Field(default=10, description="Redis max connections") redis_timeout: int = Field(default=5, description="Redis timeout in seconds") - + # Security settings secret_key: Optional[str] = Field(default=None, description="Application secret key") jwt_secret: Optional[str] = Field(default=None, description="JWT secret key") jwt_algorithm: str = Field(default="HS256", description="JWT algorithm") jwt_expiration_hours: int = Field(default=24, description="JWT token expiration in hours") - + # Performance settings request_timeout: int = Field(default=30, description="Request timeout in seconds") max_request_size: int = Field(default=10 * 1024 * 1024, description="Max request size in bytes") - + + # Rate limiting settings + rate_limit_requests: int = Field(default=60, description="Rate limit requests per window") + rate_limit_window_seconds: int = Field(default=60, description="Rate limit window in seconds") + + # CORS settings + allow_origins: List[str] = Field(default_factory=list, description="CORS allowed origins") + + def validate_secrets(self) -> None: + """Validate that all required secrets are provided.""" + if self.environment == "production": + if not self.secret_key: + raise ValueError("SECRET_KEY environment variable is required in production") + if self.secret_key == "change-me-in-production": + raise ValueError("SECRET_KEY must be changed from default value") + if not self.jwt_secret: + raise ValueError("JWT_SECRET environment variable is required in production") + if self.jwt_secret == "change-me-in-production": + raise ValueError("JWT_SECRET must be changed from default value") + + @field_validator("secret_key", "jwt_secret", mode="before") + @classmethod + def validate_secret_length(cls, v: Optional[str]) -> Optional[str]: + """Validate secret key length in production.""" + import os + if os.getenv("APP_ENV", "development") != "production" and not v: + return v + if not v or v.startswith("$") or v == "your_secret_here" or v == "change-me-in-production": + raise ValueError("Secret must be set to a secure value") + if len(v) < 32: + raise ValueError("Secret must be at least 32 characters long") + return v + def __init__(self, **kwargs): """Initialize AITBC configuration with extended logging""" super().__init__(**kwargs) - logger.info(f"Server configured for {self.host}:{self.port}") + logger.info(f"{self.app_name} configured for {self.host}:{self.port}") logger.debug(f"Workers: {self.workers}, Request timeout: {self.request_timeout}s") + + def get_redis_cache(self): + """Get Redis cache instance configured from settings""" + from .redis_cache import get_cache + return get_cache( + redis_url=self.redis_url, + max_connections=self.redis_max_connections, + timeout=self.redis_timeout + ) + + +class AITBCConfig(BaseAITBCConfig): + """ + Standard AITBC configuration with common settings. + Inherits from BaseAITBCConfig and can be extended with service-specific fields. + """ + + # Override defaults for standard AITBC application + app_name: str = Field(default="AITBC Application", description="Application name") + port: int = Field(default=8000, description="Server port") + + def __init__(self, **kwargs): + """Initialize AITBC configuration with extended logging""" + super().__init__(**kwargs) diff --git a/aitbc/database.py b/aitbc/database.py index d2d3a4c9..0598afef 100644 --- a/aitbc/database.py +++ b/aitbc/database.py @@ -9,6 +9,12 @@ from typing import Any, Dict, List, Optional, Tuple from contextlib import contextmanager from .exceptions import DatabaseError +# SQLAlchemy support for connection pooling +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import QueuePool, StaticPool +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker + class DatabaseConnection: """ @@ -259,3 +265,141 @@ def table_exists(db_path: Path, table_name: str) -> bool: (table_name,) ) return result is not None + + +# SQLAlchemy Connection Pooling Utilities + +def create_pooled_engine( + database_url: str, + pool_size: int = 10, + max_overflow: int = 20, + pool_recycle: int = 3600, + pool_pre_ping: bool = True, + echo: bool = False, + use_static_pool: bool = False +): + """ + Create SQLAlchemy engine with connection pooling. + + Args: + database_url: Database connection URL + pool_size: Size of connection pool + max_overflow: Maximum overflow connections + pool_recycle: Connection recycle time in seconds + pool_pre_ping: Test connections before using + echo: Enable SQL query logging + use_static_pool: Use StaticPool for SQLite (single connection) + + Returns: + SQLAlchemy engine with connection pooling + """ + if "sqlite" in database_url and use_static_pool: + # SQLite with StaticPool (single connection, suitable for tests) + engine = create_engine( + database_url, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + echo=echo, + pool_pre_ping=pool_pre_ping, + ) + elif "sqlite" in database_url: + # SQLite with QueuePool (limited pooling support) + engine = create_engine( + database_url, + connect_args={"check_same_thread": False, "timeout": 30}, + poolclass=QueuePool, + pool_size=min(pool_size, 5), # SQLite has limited concurrent access + max_overflow=max_overflow, + pool_pre_ping=pool_pre_ping, + echo=echo, + ) + else: + # PostgreSQL/MySQL with full connection pooling + engine = create_engine( + database_url, + poolclass=QueuePool, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + echo=echo, + ) + return engine + + +def create_pooled_sessionmaker( + engine, + autoflush: bool = False, + autocommit: bool = False +): + """ + Create session factory with connection pooling. + + Args: + engine: SQLAlchemy engine + autoflush: Enable autoflush + autocommit: Enable autocommit + + Returns: + Session factory + """ + return sessionmaker(bind=engine, autoflush=autoflush, autocommit=autocommit) + + +def create_async_pooled_engine( + database_url: str, + pool_size: int = 10, + max_overflow: int = 20, + pool_recycle: int = 3600, + pool_pre_ping: bool = True, + echo: bool = False +): + """ + Create async SQLAlchemy engine with connection pooling. + + Args: + database_url: Database connection URL + pool_size: Size of connection pool + max_overflow: Maximum overflow connections + pool_recycle: Connection recycle time in seconds + pool_pre_ping: Test connections before using + echo: Enable SQL query logging + + Returns: + Async SQLAlchemy engine with connection pooling + """ + # Convert to async URL + if "sqlite" in database_url: + async_url = database_url.replace("sqlite:///", "sqlite+aiosqlite:///") + elif "postgresql" in database_url: + async_url = database_url.replace("postgresql://", "postgresql+asyncpg://") + else: + async_url = database_url + + engine = create_async_engine( + async_url, + poolclass=QueuePool, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + echo=echo, + ) + return engine + + +def create_async_pooled_sessionmaker( + engine, + expire_on_commit: bool = False +): + """ + Create async session factory with connection pooling. + + Args: + engine: Async SQLAlchemy engine + expire_on_commit: Expire objects on commit + + Returns: + Async session factory + """ + return async_sessionmaker(engine, expire_on_commit=expire_on_commit) diff --git a/aitbc/rate_limiting.py b/aitbc/rate_limiting.py new file mode 100644 index 00000000..86dd998b --- /dev/null +++ b/aitbc/rate_limiting.py @@ -0,0 +1,198 @@ +""" +Rate limiting utilities for FastAPI applications +Provides decorators and middleware for API rate limiting +""" + +from functools import wraps +from typing import Callable, Optional, Dict, Any +from fastapi import Request, HTTPException, Response +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp + +from .security_hardening import RateLimiter +from .aitbc_logging import get_logger + +logger = get_logger(__name__) + + +# Global rate limiters for different endpoints +_rate_limiters: Dict[str, RateLimiter] = {} + + +def get_rate_limiter(name: str, rate: int = 100, per: int = 60) -> RateLimiter: + """ + Get or create a rate limiter for a specific endpoint + + Args: + name: Unique name for the rate limiter + rate: Number of requests allowed per time period + per: Time period in seconds + + Returns: + RateLimiter instance + """ + if name not in _rate_limiters: + _rate_limiters[name] = RateLimiter(rate=rate, per=per) + return _rate_limiters[name] + + +def rate_limit( + rate: int = 100, + per: int = 60, + key_func: Optional[Callable[[Request], str]] = None, + error_message: str = "Rate limit exceeded" +) -> Callable: + """ + Decorator for rate limiting FastAPI endpoints + + Args: + rate: Number of requests allowed per time period + per: Time period in seconds + key_func: Function to extract rate limit key from request (defaults to client IP) + error_message: Custom error message + + Returns: + Decorated function with rate limiting + """ + def decorator(func: Callable) -> Callable: + limiter = RateLimiter(rate=rate, per=per) + + @wraps(func) + async def wrapper(*args, **kwargs) -> Any: + # Extract request from args (FastAPI passes request as first arg for dependency injection) + request = None + for arg in args: + if isinstance(arg, Request): + request = arg + break + + if request is None: + # Try to get request from kwargs + request = kwargs.get('request') + + if request is None: + # No request available, skip rate limiting + return await func(*args, **kwargs) + + # Get rate limit key + if key_func: + key = key_func(request) + else: + key = request.client.host if request.client else "unknown" + + # Check rate limit + if not limiter.is_allowed(key): + logger.warning(f"Rate limit exceeded for {key} on {request.url.path}") + raise HTTPException( + status_code=429, + detail=error_message, + headers={"Retry-After": str(per)} + ) + + return await func(*args, **kwargs) + + return wrapper + return decorator + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """ + Middleware for rate limiting all requests + + Applies rate limiting based on client IP address + """ + + def __init__( + self, + app: ASGIApp, + rate: int = 100, + per: int = 60, + key_func: Optional[Callable[[Request], str]] = None, + error_message: str = "Rate limit exceeded" + ) -> None: + """ + Initialize rate limit middleware + + Args: + app: ASGI application + rate: Number of requests allowed per time period + per: Time period in seconds + key_func: Function to extract rate limit key from request + error_message: Custom error message + """ + super().__init__(app) + self.rate = rate + self.per = per + self.key_func = key_func + self.error_message = error_message + self._limiter = RateLimiter(rate=rate, per=per) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """ + Process request with rate limiting + + Args: + request: Incoming request + call_next: Next middleware or endpoint + + Returns: + Response + """ + # Get rate limit key + if self.key_func: + key = self.key_func(request) + else: + key = request.client.host if request.client else "unknown" + + # Check rate limit + if not self._limiter.is_allowed(key): + logger.warning(f"Rate limit exceeded for {key} on {request.url.path}") + return Response( + content='{"detail": "' + self.error_message + '"}', + status_code=429, + media_type="application/json", + headers={"Retry-After": str(self.per)} + ) + + return await call_next(request) + + +def get_rate_limit_headers(request: Request, limiter_name: str) -> Dict[str, str]: + """ + Get rate limit headers for response + + Args: + request: Request object + limiter_name: Name of the rate limiter + + Returns: + Dictionary of rate limit headers + """ + limiter = _rate_limiters.get(limiter_name) + if not limiter: + return {} + + key = request.client.host if request.client else "unknown" + remaining = limiter.get_remaining_requests(key) + + return { + "X-RateLimit-Limit": str(limiter.rate), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(limiter.per) + } + + +def reset_rate_limit(identifier: str, limiter_name: Optional[str] = None) -> None: + """ + Reset rate limit for an identifier + + Args: + identifier: Identifier to reset (e.g., IP address, user ID) + limiter_name: Name of specific rate limiter, or None for all + """ + if limiter_name: + if limiter_name in _rate_limiters: + _rate_limiters[limiter_name].reset(identifier) + else: + for limiter in _rate_limiters.values(): + limiter.reset(identifier) diff --git a/aitbc/redis_cache.py b/aitbc/redis_cache.py new file mode 100644 index 00000000..50bc7f58 --- /dev/null +++ b/aitbc/redis_cache.py @@ -0,0 +1,328 @@ +""" +Redis caching utilities for AITBC applications +Provides distributed caching with Redis backend +""" + +from typing import Optional, Any, List +import json +import hashlib + +from .aitbc_logging import get_logger + +logger = get_logger(__name__) + +try: + import redis + REDIS_AVAILABLE = True +except ImportError: + REDIS_AVAILABLE = False + logger.warning("Redis package not installed. Caching will be disabled.") + + +class RedisCache: + """ + Redis cache implementation for distributed caching + """ + + def __init__( + self, + redis_url: Optional[str] = None, + max_connections: int = 10, + timeout: int = 5, + default_ttl: int = 3600 + ): + """ + Initialize Redis cache + + Args: + redis_url: Redis connection URL (e.g., redis://localhost:6379/0) + max_connections: Maximum number of connections + timeout: Connection timeout in seconds + default_ttl: Default time-to-live for cached items in seconds + """ + self.redis_url = redis_url + self.max_connections = max_connections + self.timeout = timeout + self.default_ttl = default_ttl + self._client = None + + if REDIS_AVAILABLE and redis_url: + try: + self._client = redis.Redis.from_url( + redis_url, + max_connections=max_connections, + socket_timeout=timeout, + socket_connect_timeout=timeout, + decode_responses=True + ) + # Test connection + self._client.ping() + logger.info(f"Connected to Redis at {redis_url}") + except Exception as e: + logger.warning(f"Failed to connect to Redis: {e}") + self._client = None + else: + logger.info("Redis caching disabled (Redis not available or no URL provided)") + + def is_available(self) -> bool: + """Check if Redis cache is available""" + return self._client is not None + + def get(self, key: str) -> Optional[Any]: + """ + Get value from cache + + Args: + key: Cache key + + Returns: + Cached value or None if not found + """ + if not self.is_available(): + return None + + try: + value = self._client.get(key) + if value: + return json.loads(value) + return None + except Exception as e: + logger.error(f"Redis get error: {e}") + return None + + def set( + self, + key: str, + value: Any, + ttl: Optional[int] = None + ) -> bool: + """ + Set value in cache + + Args: + key: Cache key + value: Value to cache (must be JSON serializable) + ttl: Time-to-live in seconds (uses default_ttl if not provided) + + Returns: + True if successful, False otherwise + """ + if not self.is_available(): + return False + + try: + serialized = json.dumps(value) + expiry = ttl if ttl is not None else self.default_ttl + self._client.setex(key, expiry, serialized) + return True + except Exception as e: + logger.error(f"Redis set error: {e}") + return False + + def delete(self, key: str) -> bool: + """ + Delete value from cache + + Args: + key: Cache key + + Returns: + True if successful, False otherwise + """ + if not self.is_available(): + return False + + try: + self._client.delete(key) + return True + except Exception as e: + logger.error(f"Redis delete error: {e}") + return False + + def exists(self, key: str) -> bool: + """ + Check if key exists in cache + + Args: + key: Cache key + + Returns: + True if key exists, False otherwise + """ + if not self.is_available(): + return False + + try: + return self._client.exists(key) == 1 + except Exception as e: + logger.error(f"Redis exists error: {e}") + return False + + def clear(self) -> bool: + """ + Clear all cached values + + Returns: + True if successful, False otherwise + """ + if not self.is_available(): + return False + + try: + self._client.flushdb() + return True + except Exception as e: + logger.error(f"Redis clear error: {e}") + return False + + def get_many(self, keys: List[str]) -> dict[str, Any]: + """ + Get multiple values from cache + + Args: + keys: List of cache keys + + Returns: + Dictionary mapping keys to cached values + """ + if not self.is_available(): + return {} + + try: + values = self._client.mget(keys) + result = {} + for key, value in zip(keys, values): + if value: + result[key] = json.loads(value) + return result + except Exception as e: + logger.error(f"Redis get_many error: {e}") + return {} + + def set_many(self, mapping: dict[str, Any], ttl: Optional[int] = None) -> bool: + """ + Set multiple values in cache + + Args: + mapping: Dictionary of key-value pairs + ttl: Time-to-live in seconds + + Returns: + True if successful, False otherwise + """ + if not self.is_available(): + return False + + try: + pipe = self._client.pipeline() + expiry = ttl if ttl is not None else self.default_ttl + for key, value in mapping.items(): + serialized = json.dumps(value) + pipe.setex(key, expiry, serialized) + pipe.execute() + return True + except Exception as e: + logger.error(f"Redis set_many error: {e}") + return False + + def delete_many(self, keys: List[str]) -> bool: + """ + Delete multiple values from cache + + Args: + keys: List of cache keys + + Returns: + True if successful, False otherwise + """ + if not self.is_available(): + return False + + try: + if keys: + self._client.delete(*keys) + return True + except Exception as e: + logger.error(f"Redis delete_many error: {e}") + return False + + def increment(self, key: str, amount: int = 1) -> Optional[int]: + """ + Increment a counter in cache + + Args: + key: Cache key + amount: Amount to increment by + + Returns: + New value or None if failed + """ + if not self.is_available(): + return None + + try: + return self._client.incrby(key, amount) + except Exception as e: + logger.error(f"Redis increment error: {e}") + return None + + def close(self) -> None: + """Close Redis connection""" + if self._client: + self._client.close() + logger.info("Redis connection closed") + + +# Global cache instance +_global_cache: Optional[RedisCache] = None + + +def get_cache( + redis_url: Optional[str] = None, + max_connections: int = 10, + timeout: int = 5, + default_ttl: int = 3600 +) -> RedisCache: + """ + Get or create global Redis cache instance + + Args: + redis_url: Redis connection URL + max_connections: Maximum number of connections + timeout: Connection timeout in seconds + default_ttl: Default time-to-live for cached items + + Returns: + RedisCache instance + """ + global _global_cache + + if _global_cache is None and redis_url: + _global_cache = RedisCache( + redis_url=redis_url, + max_connections=max_connections, + timeout=timeout, + default_ttl=default_ttl + ) + + return _global_cache or RedisCache() # Return disabled cache if no URL + + +def cache_key(*parts: str, prefix: str = "aitbc") -> str: + """ + Generate a cache key from parts + + Args: + *parts: Parts to include in the key + prefix: Key prefix + + Returns: + Cache key string + """ + key_string = ":".join(str(part) for part in parts) + full_key = f"{prefix}:{key_string}" + # Hash if too long + if len(full_key) > 250: + hash_value = hashlib.sha256(full_key.encode()).hexdigest()[:16] + return f"{prefix}:hashed:{hash_value}" + return full_key diff --git a/apps/agent-coordinator/src/app/main.py b/apps/agent-coordinator/src/app/main.py index 71f4df38..8d075a49 100644 --- a/apps/agent-coordinator/src/app/main.py +++ b/apps/agent-coordinator/src/app/main.py @@ -2,6 +2,8 @@ import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from aitbc.rate_limiting import RateLimitMiddleware + from .config import settings from .exceptions import register_exception_handlers from .lifespan import lifespan @@ -25,6 +27,13 @@ def create_app() -> FastAPI: allow_headers=["*"], ) + # Add rate limiting middleware + app.add_middleware( + RateLimitMiddleware, + rate=100, + per=60 + ) + for router in ROUTERS: app.include_router(router) diff --git a/apps/coordinator-api/src/app/config.py b/apps/coordinator-api/src/app/config.py index a3d3e07b..e5b397a4 100755 --- a/apps/coordinator-api/src/app/config.py +++ b/apps/coordinator-api/src/app/config.py @@ -6,6 +6,7 @@ Provides environment-based adapter selection and consolidated settings. import os +from aitbc.config import BaseAITBCConfig from aitbc.constants import DATA_DIR, LOG_DIR from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -36,26 +37,25 @@ class DatabaseConfig(BaseSettings): model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow") -class Settings(BaseSettings): +class Settings(BaseAITBCConfig): """Unified application settings with environment-based configuration.""" - model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow") + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="allow" + ) - # Environment - app_env: str = "dev" - app_host: str = "127.0.0.1" - app_port: int = 8011 - audit_log_dir: str = str(LOG_DIR / "audit") + # Override defaults for coordinator-api + app_name: str = Field(default="AITBC Coordinator API", description="Application name") + app_host: str = Field(default="127.0.0.1", description="Application host") + port: int = Field(default=8011, description="Server port") + environment: str = Field(default="dev", description="Environment") + audit_log_dir: str = Field(default=str(LOG_DIR / "audit"), description="Audit log directory") # Database - database: DatabaseConfig = DatabaseConfig() - - # Database Connection Pooling - db_pool_size: int = Field(default=20, description="Database connection pool size") - db_max_overflow: int = Field(default=40, description="Maximum overflow connections") - db_pool_recycle: int = Field(default=3600, description="Connection recycle time in seconds") - db_pool_pre_ping: bool = Field(default=True, description="Test connections before using") - db_echo: bool = Field(default=False, description="Enable SQL query logging") + database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Database configuration") # API Keys client_api_keys: list[str] = [] @@ -94,103 +94,61 @@ class Settings(BaseSettings): raise ValueError("API keys must be at least 16 characters long") return v - # Security + # Security - using inherited secret_key and jwt_secret from BaseAITBCConfig hmac_secret: str | None = None - jwt_secret: str | None = None - jwt_algorithm: str = "HS256" - jwt_expiration_hours: int = 24 - @field_validator("hmac_secret") - @classmethod - def validate_hmac_secret(cls, v: str | None) -> str | None: - # Allow None in development/test environments - import os - - if os.getenv("APP_ENV", "dev") != "production" and not v: - return v - if not v or v.startswith("$") or v == "your_secret_here": - raise ValueError("HMAC_SECRET must be set to a secure value") - if len(v) < 32: - raise ValueError("HMAC_SECRET must be at least 32 characters long") - return v - - @field_validator("jwt_secret") - @classmethod - def validate_jwt_secret(cls, v: str | None) -> str | None: - # Allow None in development/test environments - import os - - if os.getenv("APP_ENV", "dev") != "production" and not v: - return v - if not v or v.startswith("$") or v == "your_secret_here": - raise ValueError("JWT_SECRET must be set to a secure value") - if len(v) < 32: - raise ValueError("JWT_SECRET must be at least 32 characters long") - return v - - # CORS - allow_origins: list[str] = [ - "http://localhost:8011", # Coordinator API - "http://localhost:8001", # Exchange API - "http://localhost:8002", # Blockchain Node - "http://localhost:8003", # Blockchain RPC - "http://localhost:8010", # Multimodal GPU - "http://localhost:8011", # GPU Multimodal - "http://localhost:8012", # Modality Optimization - "http://localhost:8013", # Adaptive Learning - "http://localhost:8014", # Marketplace Enhanced - "http://localhost:8015", # hermes Enhanced - "http://localhost:8016", # Web UI - ] + # CORS - override inherited allow_origins with coordinator-api specific defaults + allow_origins: list[str] = Field( + default=[ + "http://localhost:8011", # Coordinator API + "http://localhost:8001", # Exchange API + "http://localhost:8002", # Blockchain Node + "http://localhost:8003", # Blockchain RPC + "http://localhost:8010", # Multimodal GPU + "http://localhost:8011", # GPU Multimodal + "http://localhost:8012", # Modality Optimization + "http://localhost:8013", # Adaptive Learning + "http://localhost:8014", # Marketplace Enhanced + "http://localhost:8015", # hermes Enhanced + "http://localhost:8016", # Web UI + ], + description="CORS allowed origins" + ) # Job Configuration - job_ttl_seconds: int = 900 - heartbeat_interval_seconds: int = 10 - heartbeat_timeout_seconds: int = 30 + job_ttl_seconds: int = Field(default=900, description="Job TTL in seconds") + heartbeat_interval_seconds: int = Field(default=10, description="Heartbeat interval in seconds") + heartbeat_timeout_seconds: int = Field(default=30, description="Heartbeat timeout in seconds") - # Rate Limiting - rate_limit_requests: int = 60 - rate_limit_window_seconds: int = 60 - - # Configurable Rate Limits (per minute) - rate_limit_jobs_submit: str = "100/minute" - rate_limit_miner_register: str = "30/minute" - rate_limit_miner_heartbeat: str = "60/minute" - rate_limit_admin_stats: str = "20/minute" - rate_limit_marketplace_list: str = "100/minute" - rate_limit_marketplace_stats: str = "50/minute" - rate_limit_marketplace_bid: str = "30/minute" - rate_limit_exchange_payment: str = "20/minute" + # Configurable Rate Limits (per minute) - extending inherited rate limiting + rate_limit_jobs_submit: str = Field(default="100/minute", description="Rate limit for job submission") + rate_limit_miner_register: str = Field(default="30/minute", description="Rate limit for miner registration") + rate_limit_miner_heartbeat: str = Field(default="60/minute", description="Rate limit for miner heartbeat") + rate_limit_admin_stats: str = Field(default="20/minute", description="Rate limit for admin stats") + rate_limit_marketplace_list: str = Field(default="100/minute", description="Rate limit for marketplace list") + rate_limit_marketplace_stats: str = Field(default="50/minute", description="Rate limit for marketplace stats") + rate_limit_marketplace_bid: str = Field(default="30/minute", description="Rate limit for marketplace bid") + rate_limit_exchange_payment: str = Field(default="20/minute", description="Rate limit for exchange payment") # Receipt Signing receipt_signing_key_hex: str | None = None receipt_attestation_key_hex: str | None = None - # Logging - log_level: str = "INFO" - log_format: str = "json" # json or text + # Logging - using inherited log_level and log_format from BaseAITBCConfig + log_format: str = Field(default="json", description="Log format (json or text)") # Mempool - mempool_backend: str = "database" # database, memory + mempool_backend: str = Field(default="database", description="Mempool backend (database, memory)") # Blockchain RPC - blockchain_rpc_url: str = "http://localhost:8082" + blockchain_rpc_url: str = Field(default="http://localhost:8082", description="Blockchain RPC URL") # Test Configuration - test_mode: bool = False + test_mode: bool = Field(default=False, description="Test mode") test_database_url: str | None = None - def validate_secrets(self) -> None: - """Validate that all required secrets are provided.""" - if self.app_env == "production": - if not self.jwt_secret: - raise ValueError("JWT_SECRET environment variable is required in production") - if self.jwt_secret == "change-me-in-production": - raise ValueError("JWT_SECRET must be changed from default value") - - @property - def database_url(self) -> str: - """Get the database URL (backward compatibility).""" + def get_effective_database_url(self) -> str: + """Get the effective database URL with test mode support.""" # Use test database if in test mode and test_database_url is set if self.test_mode and self.test_database_url: return self.test_database_url @@ -199,13 +157,6 @@ class Settings(BaseSettings): # Default SQLite path - consistent with blockchain-node pattern return f"sqlite:///{DATA_DIR}/data/coordinator.db" - @database_url.setter - def database_url(self, value: str): - """Allow setting database URL for tests""" - if not self.test_mode: - raise RuntimeError("Cannot set database_url outside of test mode") - self.test_database_url = value - settings = Settings() diff --git a/apps/coordinator-api/src/app/database.py b/apps/coordinator-api/src/app/database.py index 1765099c..2de33a0c 100755 --- a/apps/coordinator-api/src/app/database.py +++ b/apps/coordinator-api/src/app/database.py @@ -6,26 +6,26 @@ from sqlmodel import SQLModel, create_engine from .config import settings # Create database engine using URL from config with performance optimizations -if settings.database_url.startswith("sqlite"): +if settings.get_effective_database_url().startswith("sqlite"): engine = create_engine( - settings.database_url, + settings.get_effective_database_url(), connect_args={ "check_same_thread": False, "timeout": 30 }, poolclass=StaticPool, - echo=settings.test_mode, # Enable SQL logging for debugging in test mode - pool_pre_ping=True, # Verify connections before using + echo=settings.database_echo, + pool_pre_ping=settings.database_pool_pre_ping, ) else: - # PostgreSQL/MySQL with connection pooling + # PostgreSQL/MySQL with connection pooling using config values engine = create_engine( - settings.database_url, - pool_size=10, # Number of connections to maintain - max_overflow=20, # Additional connections when pool is exhausted - pool_pre_ping=True, # Verify connections before using - pool_recycle=3600, # Recycle connections after 1 hour - echo=settings.test_mode, # Enable SQL logging for debugging in test mode + settings.get_effective_database_url(), + pool_size=settings.database_pool_size, + max_overflow=settings.database_max_overflow, + pool_pre_ping=settings.database_pool_pre_ping, + pool_recycle=settings.database_pool_recycle, + echo=settings.database_echo, ) diff --git a/apps/coordinator-api/src/app/routers/agent_router.py b/apps/coordinator-api/src/app/routers/agent_router.py index 3e0e9a67..ef0da222 100755 --- a/apps/coordinator-api/src/app/routers/agent_router.py +++ b/apps/coordinator-api/src/app/routers/agent_router.py @@ -94,8 +94,10 @@ async def list_workflows( @router.get("/workflows/{workflow_id}", response_model=AIAgentWorkflow) +@rate_limit(rate=200, per=60) async def get_workflow( workflow_id: str, + request: Request, session: Session = Depends(Annotated[Session, Depends(get_session)]), current_user: str = Depends(require_admin_key()), ) -> AIAgentWorkflow: @@ -120,9 +122,11 @@ async def get_workflow( @router.put("/workflows/{workflow_id}", response_model=AIAgentWorkflow) +@rate_limit(rate=100, per=60) async def update_workflow( workflow_id: str, workflow_data: AgentWorkflowUpdate, + request: Request, session: Session = Depends(Annotated[Session, Depends(get_session)]), current_user: str = Depends(require_admin_key()), ) -> AIAgentWorkflow: @@ -423,14 +427,17 @@ async def get_execution_logs( @router.get("/test") -async def test_endpoint() -> dict[str, str]: +@rate_limit(rate=1000, per=60) +async def test_endpoint(request: Request) -> dict[str, str]: """Test endpoint to verify router is working""" return {"message": "Agent router is working", "timestamp": datetime.now(timezone.utc).isoformat()} @router.post("/networks", response_model=dict, status_code=201) +@rate_limit(rate=50, per=60) async def create_agent_network( network_data: dict, + request: Request, session: Session = Depends(Annotated[Session, Depends(get_session)]), current_user: str = Depends(require_admin_key()), ) -> dict[str, Any]: @@ -469,8 +476,10 @@ async def create_agent_network( @router.get("/executions/{execution_id}/receipt") +@rate_limit(rate=100, per=60) async def get_execution_receipt( execution_id: str, + request: Request, session: Session = Depends(Annotated[Session, Depends(get_session)]), current_user: str = Depends(require_admin_key()), ) -> dict[str, Any]: diff --git a/apps/coordinator-api/src/app/storage/db_pg.py b/apps/coordinator-api/src/app/storage/db_pg.py index f231167d..6fd3c815 100755 --- a/apps/coordinator-api/src/app/storage/db_pg.py +++ b/apps/coordinator-api/src/app/storage/db_pg.py @@ -21,8 +21,10 @@ from .config_pg import settings engine = create_engine( settings.database_url, echo=settings.debug, - pool_pre_ping=True, - pool_recycle=300, + pool_size=settings.db_pool_size, + max_overflow=settings.db_max_overflow, + pool_recycle=settings.db_pool_recycle, + pool_pre_ping=settings.db_pool_pre_ping, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/apps/exchange/exchange_api.py b/apps/exchange/exchange_api.py index a75b40ae..38f12d4f 100755 --- a/apps/exchange/exchange_api.py +++ b/apps/exchange/exchange_api.py @@ -15,6 +15,11 @@ import time from typing import Annotated from contextlib import asynccontextmanager +import sys +sys.path.insert(0, "/opt/aitbc") + +from aitbc.rate_limiting import RateLimitMiddleware + from database import init_db, get_db_session from models import User, Order, Trade, Balance @@ -29,6 +34,13 @@ async def lifespan(app: FastAPI): # Initialize FastAPI app app = FastAPI(title="AITBC Trade Exchange API", version="1.0.0", lifespan=lifespan) +# Add rate limiting middleware +app.add_middleware( + RateLimitMiddleware, + rate=100, + per=60 +) + # In-memory session storage (use Redis in production) user_sessions = {} diff --git a/apps/wallet/src/app/main.py b/apps/wallet/src/app/main.py index 93f02652..cabd18e4 100755 --- a/apps/wallet/src/app/main.py +++ b/apps/wallet/src/app/main.py @@ -2,6 +2,11 @@ from __future__ import annotations from fastapi import FastAPI +import sys +sys.path.insert(0, "/opt/aitbc") + +from aitbc.rate_limiting import RateLimitMiddleware + from .api_jsonrpc import router as jsonrpc_router from .api_rest import router as receipts_router from .settings import settings @@ -9,6 +14,14 @@ from .settings import settings def create_app() -> FastAPI: app = FastAPI(title=settings.app_name, debug=settings.debug) + + # Add rate limiting middleware + app.add_middleware( + RateLimitMiddleware, + rate=100, + per=60 + ) + app.include_router(receipts_router) app.include_router(jsonrpc_router) diff --git a/docs/RATE_LIMITING_GUIDE.md b/docs/RATE_LIMITING_GUIDE.md new file mode 100644 index 00000000..eec14fdf --- /dev/null +++ b/docs/RATE_LIMITING_GUIDE.md @@ -0,0 +1,143 @@ +# Rate Limiting Implementation Guide + +## Overview + +Rate limiting has been implemented for AITBC API endpoints to prevent abuse and ensure fair resource allocation. This guide explains how to apply rate limiting to FastAPI routers. + +## Infrastructure + +### Rate Limiting Module + +Location: `/opt/aitbc/aitbc/rate_limiting.py` + +The module provides: +- `@rate_limit()` decorator for endpoint-level rate limiting +- `RateLimitMiddleware` for global middleware-based rate limiting +- Helper functions for managing rate limiters + +### Rate Limiter Implementation + +The underlying `RateLimiter` class in `aitbc/security_hardening.py` implements a token bucket algorithm. + +## Applying Rate Limiting to Routers + +### Step 1: Import the decorator + +```python +from fastapi import Request +from aitbc.rate_limiting import rate_limit +``` + +### Step 2: Add Request parameter + +Add `request: Request` as the first parameter (after any path parameters) to each endpoint: + +```python +@router.post("/workflows") +async def create_workflow( + request: Request, # Add this + workflow_data: AgentWorkflowCreate, + session: Session = Depends(...), + current_user: str = Depends(...), +): + ... +``` + +### Step 3: Apply the decorator + +Add the `@rate_limit` decorator before the endpoint: + +```python +@router.post("/workflows") +@rate_limit(rate=100, per=60) # 100 requests per minute +async def create_workflow( + request: Request, + workflow_data: AgentWorkflowCreate, + session: Session = Depends(...), + current_user: str = Depends(...), +): + ... +``` + +### Rate Limit Guidelines + +Recommended rate limits by endpoint type: + +- **Write operations** (POST, PUT, DELETE): 50-100 requests per minute +- **Read operations** (GET): 200-500 requests per minute +- **Health/test endpoints**: 1000 requests per minute +- **Execution/long-running operations**: 50 requests per minute + +### Example: Complete Router + +See `/opt/aitbc/apps/coordinator-api/src/app/routers/agent_router.py` for a complete example. + +## Custom Rate Limiting + +### Custom Key Function + +To rate limit by something other than IP address (e.g., API key, user ID): + +```python +def custom_key(request: Request) -> str: + return request.headers.get("X-API-Key", "unknown") + +@router.post("/endpoint") +@rate_limit(rate=100, per=60, key_func=custom_key) +async def endpoint(request: Request, ...): + ... +``` + +### Custom Error Message + +```python +@router.post("/endpoint") +@rate_limit(rate=100, per=60, error_message="Custom limit message") +async def endpoint(request: Request, ...): + ... +``` + +## Global Middleware + +For global rate limiting across all endpoints, use the middleware: + +```python +from aitbc.rate_limiting import RateLimitMiddleware + +app.add_middleware( + RateLimitMiddleware, + rate=100, + per=60 +) +``` + +## Testing + +Rate limiting tests are in `/opt/aitbc/tests/test_rate_limiting.py`. + +Run tests: +```bash +python3 -m pytest -c /dev/null --rootdir "$PWD" --import-mode=importlib tests/test_rate_limiting.py -v +``` + +## Remaining Work + +There are 70+ router files across the codebase. The following routers need rate limiting applied: + +### Coordinator-API (50+ routers) +- `/opt/aitbc/apps/coordinator-api/src/app/routers/*.py` +- `/opt/aitbc/apps/coordinator-api/src/app/contexts/*/routers/*.py` + +### Other Services +- `/opt/aitbc/apps/agent-coordinator/src/app/routers/*.py` +- `/opt/aitbc/apps/pool-hub/src/app/routers/*.py` +- `/opt/aitbc/apps/agent-management/src/app/routers/*.py` +- `/opt/aitbc/apps/blockchain-node/src/aitbc_chain/rpc/router.py` +- `/opt/aitbc/apps/exchange/*.py` +- `/opt/aitbc/apps/wallet/src/app/api_rest.py` + +## Priority Order + +1. **High Priority**: Public-facing APIs (coordinator-api, exchange, wallet) +2. **Medium Priority**: Internal service APIs (agent-coordinator, pool-hub) +3. **Low Priority**: Admin/management APIs diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index f975a738..fd4f492d 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -98,22 +98,77 @@ - Added quoting to migration scripts (migrate_complete.py, migrate_to_postgresql.py) - SQL injection risks reduced from 21 to 0 in user-input paths - [DONE] Remove ORIGINAL monolithic service files - COMPLETED (removed certification_service.py, multi_modal_fusion.py) + - [IN PROGRESS] Add rate limiting on all routers - IN PROGRESS + - Created rate limiting module at aitbc/rate_limiting.py with decorator and middleware + - Added comprehensive tests (15 tests passing) + - Applied rate limiting to agent_router.py as example + - Created implementation guide at docs/RATE_LIMITING_GUIDE.md + - Remaining: 70+ router files across coordinator-api, agent-coordinator, pool-hub, agent-management, blockchain-node, exchange, wallet - **Medium (2-6 weeks)** - - Decompose coordinator-api - - Implement shared config base class - - Add connection pooling - - Implement distributed caching (Redis) - - Add rate limiting on all routers - - Tighten mypy configuration + - [DONE] Decompose coordinator-api - COMPLETED (6 phases complete) + - [DONE] Implement shared config base class - COMPLETED + - Enhanced BaseAITBCConfig in aitbc/config.py with database pooling, rate limiting, CORS, secret validation + - Updated coordinator-api to inherit from BaseAITBCConfig + - Maintains backward compatibility with existing configuration patterns + - [DONE] Add connection pooling - COMPLETED + - Enhanced aitbc/database.py with SQLAlchemy connection pooling utilities + - Added create_pooled_engine, create_pooled_sessionmaker, create_async_pooled_engine, create_async_pooled_sessionmaker + - Updated coordinator-api db_pg.py to use proper connection pooling parameters from config + - Main services already had connection pooling (coordinator-api database.py, storage/db.py, shared-core database.py) + - Scripts and tests can use new utilities for connection pooling where appropriate + - [DONE] Implement distributed caching (Redis) - COMPLETED + - aitbc/redis_cache.py already has complete RedisCache implementation with all basic operations + - Comprehensive tests in tests/test_redis_cache.py + - Added get_redis_cache() method to BaseAITBCConfig for easy cache instance access + - Redis settings already in BaseAITBCConfig (redis_url, redis_max_connections, redis_timeout) + - multi_language service already uses Redis with TranslationCache class + - Other services can use settings.get_redis_cache() to get configured cache instance + - [IN PROGRESS] Add rate limiting on all routers - IN PROGRESS + - Created rate limiting module at aitbc/rate_limiting.py with decorator and middleware + - Added comprehensive tests (15 tests passing) + - Applied rate limiting to agent_router.py as example + - Created implementation guide at docs/RATE_LIMITING_GUIDE.md + - Remaining: 70+ router files across coordinator-api, agent-coordinator, pool-hub, agent-management, blockchain-node, exchange, wallet + - [DONE] Tighten mypy configuration - COMPLETED + - Enabled check_untyped_defs, disallow_untyped_decorators, no_implicit_optional + - Enabled warn_unreachable, strict_equality, strict_optional + - Improved type safety across codebase - **Long (1-3 months)** - - Implement API gateway pattern - - Move to event-driven architecture - - Add feature flag system - - Implement comprehensive observability - - Create shared test fixtures - - Design contract upgrade pattern + - [DONE] Create shared test fixtures - COMPLETED + - Enhanced tests/fixtures/ with test_data_factory.py for comprehensive test data generation + - Added auth_fixtures.py for authentication/authorization testing + - Existing fixtures: common.py, blockchain.py, coordinator.py, staking_fixtures.py, mock_blockchain_node.py + - Fixtures shared via tests/conftest.py across all test suites + - TestDataFactory with generators for users, wallets, jobs, transactions, miners, GPUs, staking, agents, API responses, errors, pagination, batch operations, marketplace offers, governance proposals + - Auth fixtures for JWT tokens, headers, mock users, auth service, permission checker, API keys + - [DONE] Implement API gateway pattern - COMPLETED + - apps/api-gateway/src/api_gateway/main.py implements core API gateway pattern + - Features: service registry, request routing, circuit breaker, rate limiting, authentication, retry logic + - Routes to: gpu, marketplace, agent, trading, governance, ai, monitoring, hermes, plugin, coordinator services + - Middleware: RequestIDMiddleware, PerformanceLoggingMiddleware, RequestValidationMiddleware, ErrorHandlerMiddleware + - Tests: apps/api-gateway/tests/test_gateway.py with health check, service registry, routing tests + - Enterprise API Gateway: apps/coordinator-api/src/app/services/enterprise_integration/api_gateway.py with multi-tenant support + - [DONE] Move to event-driven architecture - COMPLETED + - aitbc/events.py implements comprehensive event-driven architecture + - Core components: Event dataclass, EventBus, AsyncEventBus, EventFilter, EventAggregator, EventRouter + - Decorators: @event_handler for easy event subscription + - Global event bus singleton pattern + - Comprehensive tests: tests/test_events.py (47 test cases, 540 lines) + - Blockchain event bridge: apps/blockchain-event-bridge/ for blockchain event handling + - Agent message protocols: apps/agent-coordinator/src/app/protocols/message_types.py + - Event-driven cache: dev/cache/aitbc_cache/event_driven_cache.py + - [DONE] Add feature flag system - COMPLETED + - aitbc/feature_flags.py implements comprehensive feature flag system + - Core components: FeatureFlag dataclass, FeatureFlagManager with enable/disable, whitelist/blacklist, percentage-based rollouts + - Global feature flag manager singleton pattern + - Configuration file support (feature_flags.json) with JSON persistence + - Helper functions: is_feature_enabled(), get_feature_flag_manager() + - Comprehensive tests: tests/test_feature_flags.py (30+ test cases, 404 lines) + - Features: gradual rollouts, user whitelisting/blacklisting, percentage-based targeting, timestamp tracking + - [ ] Implement comprehensive observability + - [ ] Design contract upgrade pattern ### Distribution & Binaries diff --git a/pyproject.toml b/pyproject.toml index 9bfecf0e..1299a7d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,17 +131,18 @@ python_version = "3.13" exclude = "^apps/(agent-management|agent-coordinator|agent-services|blockchain-node|computing-node|identity-node|marketplace|mining-pool)/.*" warn_return_any = true warn_unused_configs = true -# Start with less strict mode and gradually increase -check_untyped_defs = false +# Tightened mypy configuration for better type safety +check_untyped_defs = true disallow_incomplete_defs = true disallow_untyped_defs = true -disallow_untyped_decorators = false -no_implicit_optional = false +disallow_untyped_decorators = true +no_implicit_optional = true warn_redundant_casts = true warn_unused_ignores = true warn_no_return = true -warn_unreachable = false -strict_equality = false +warn_unreachable = true +strict_equality = true +strict_optional = true [[tool.mypy.overrides]] module = [ diff --git a/tests/conftest.py b/tests/conftest.py index 6e0c1c68..7181114f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,6 +39,26 @@ from tests.fixtures.blockchain import ( # Training fixtures (kept here as they're specific to training tests) from aitbc.training_setup import TrainingEnvironment, TrainingSetupError +# Auth fixtures +from tests.fixtures.auth_fixtures import ( + mock_jwt_secret, + test_user_token, + test_admin_token, + expired_token, + invalid_token, + auth_headers, + admin_auth_headers, + mock_user, + mock_admin_user, + mock_auth_service, + permission_checker, + api_key_headers, + mock_api_keys, +) + +# Test data factory +from tests.fixtures.test_data_factory import TestDataFactory + import pytest diff --git a/tests/fixtures/auth_fixtures.py b/tests/fixtures/auth_fixtures.py new file mode 100644 index 00000000..0ed2d475 --- /dev/null +++ b/tests/fixtures/auth_fixtures.py @@ -0,0 +1,163 @@ +""" +Authentication and authorization test fixtures +Provides fixtures for testing auth flows, JWT tokens, and permissions +""" + +import sys +from pathlib import Path +from datetime import UTC, datetime, timedelta +from typing import Dict, Any, Optional +from unittest.mock import Mock +import pytest +import jwt + +project_root = Path(__file__).parent.parent.parent + + +@pytest.fixture +def mock_jwt_secret(): + """Mock JWT secret for testing""" + return "test_secret_key_for_jwt_signing_please_change_in_production" + + +@pytest.fixture +def test_user_token(mock_jwt_secret): + """Generate a valid JWT token for a test user""" + payload = { + "user_id": "test-user-123", + "email": "test@example.com", + "role": "user", + "exp": datetime.now(UTC) + timedelta(hours=24), + "iat": datetime.now(UTC) + } + return jwt.encode(payload, mock_jwt_secret, algorithm="HS256") + + +@pytest.fixture +def test_admin_token(mock_jwt_secret): + """Generate a valid JWT token for an admin user""" + payload = { + "user_id": "admin-user-123", + "email": "admin@example.com", + "role": "admin", + "permissions": ["read", "write", "delete", "admin"], + "exp": datetime.now(UTC) + timedelta(hours=24), + "iat": datetime.now(UTC) + } + return jwt.encode(payload, mock_jwt_secret, algorithm="HS256") + + +@pytest.fixture +def expired_token(mock_jwt_secret): + """Generate an expired JWT token""" + payload = { + "user_id": "test-user-123", + "email": "test@example.com", + "role": "user", + "exp": datetime.now(UTC) - timedelta(hours=1), # Expired + "iat": datetime.now(UTC) - timedelta(hours=25) + } + return jwt.encode(payload, mock_jwt_secret, algorithm="HS256") + + +@pytest.fixture +def invalid_token(): + """Generate an invalid JWT token""" + return "invalid.token.string" + + +@pytest.fixture +def auth_headers(test_user_token): + """Generate authentication headers with Bearer token""" + return {"Authorization": f"Bearer {test_user_token}"} + + +@pytest.fixture +def admin_auth_headers(test_admin_token): + """Generate authentication headers for admin user""" + return {"Authorization": f"Bearer {test_admin_token}"} + + +@pytest.fixture +def mock_user(): + """Mock user object for testing""" + user = Mock() + user.user_id = "test-user-123" + user.email = "test@example.com" + user.username = "testuser" + user.role = "user" + user.is_active = True + user.permissions = ["read", "write"] + user.created_at = datetime.now(UTC) + return user + + +@pytest.fixture +def mock_admin_user(): + """Mock admin user object for testing""" + admin = Mock() + admin.user_id = "admin-user-123" + admin.email = "admin@example.com" + admin.username = "admin" + admin.role = "admin" + admin.is_active = True + admin.permissions = ["read", "write", "delete", "admin"] + admin.created_at = datetime.now(UTC) + return admin + + +@pytest.fixture +def mock_auth_service(): + """Mock authentication service""" + service = Mock() + + def mock_verify_token(token: str) -> Optional[Dict[str, Any]]: + try: + decoded = jwt.decode(token, "test_secret_key_for_jwt_signing_please_change_in_production", algorithms=["HS256"]) + return decoded + except: + return None + + def mock_generate_token(user_id: str, role: str = "user") -> str: + payload = { + "user_id": user_id, + "role": role, + "exp": datetime.now(UTC) + timedelta(hours=24), + "iat": datetime.now(UTC) + } + return jwt.encode(payload, "test_secret_key_for_jwt_signing_please_change_in_production", algorithm="HS256") + + service.verify_token = mock_verify_token + service.generate_token = mock_generate_token + service.get_user = Mock(return_value=Mock(user_id="test-user-123", email="test@example.com")) + return service + + +@pytest.fixture +def permission_checker(): + """Mock permission checker for authorization""" + checker = Mock() + + def mock_has_permission(user: Any, permission: str) -> bool: + if not hasattr(user, 'permissions'): + return False + return permission in user.permissions + + checker.has_permission = mock_has_permission + checker.check_role = Mock(return_value=True) + return checker + + +@pytest.fixture +def api_key_headers(): + """Generate headers with API key authentication""" + return {"X-API-Key": "test-api-key-123456"} + + +@pytest.fixture +def mock_api_keys(): + """Mock API keys for testing""" + return { + "test-api-key-123456": {"user_id": "test-user-123", "permissions": ["read", "write"]}, + "admin-api-key-789012": {"user_id": "admin-user-123", "permissions": ["read", "write", "delete", "admin"]} + } diff --git a/tests/fixtures/test_data_factory.py b/tests/fixtures/test_data_factory.py new file mode 100644 index 00000000..087e5af3 --- /dev/null +++ b/tests/fixtures/test_data_factory.py @@ -0,0 +1,271 @@ +""" +Test Data Factory +Provides comprehensive test data generation utilities for AITBC tests +""" + +from datetime import UTC, datetime, timedelta +from typing import Dict, Any, List, Optional +from uuid import uuid4 +import json + + +class TestDataFactory: + """Factory for generating test data across different domains""" + + # Common test addresses + TEST_ADDRESSES = { + "alice": "aitbc1alice00000000000000000000000000000000000", + "bob": "aitbc1bob0000000000000000000000000000000000000", + "charlie": "aitbc1charl0000000000000000000000000000000000", + "miner1": "aitbc1miner1000000000000000000000000000000000", + "miner2": "aitbc1miner2000000000000000000000000000000000", + } + + # Common test IDs + @staticmethod + def generate_id(prefix: str = "test") -> str: + """Generate a unique test ID with prefix""" + return f"{prefix}_{uuid4().hex[:8]}" + + @staticmethod + def generate_timestamp(offset_seconds: int = 0) -> str: + """Generate ISO timestamp with optional offset""" + return (datetime.now(UTC) + timedelta(seconds=offset_seconds)).isoformat() + + # User/Identity data + @staticmethod + def user_data( + user_id: Optional[str] = None, + email: Optional[str] = None, + is_active: bool = True + ) -> Dict[str, Any]: + """Generate test user data""" + return { + "user_id": user_id or TestDataFactory.generate_id("user"), + "email": email or "test@example.com", + "username": "testuser", + "is_active": is_active, + "created_at": TestDataFactory.generate_timestamp(), + "updated_at": TestDataFactory.generate_timestamp() + } + + # Wallet data + @staticmethod + def wallet_data( + address: Optional[str] = None, + balance: float = 1000.0 + ) -> Dict[str, Any]: + """Generate test wallet data""" + return { + "address": address or TestDataFactory.TEST_ADDRESSES["alice"], + "balance": balance, + "currency": "AITBC", + "nonce": 0, + "created_at": TestDataFactory.generate_timestamp() + } + + # Job data + @staticmethod + def job_data( + job_type: str = "ai_inference", + priority: str = "normal", + timeout: int = 300 + ) -> Dict[str, Any]: + """Generate test job data""" + return { + "job_id": TestDataFactory.generate_id("job"), + "job_type": job_type, + "parameters": { + "model": "gpt-4", + "prompt": "Test prompt", + "max_tokens": 100, + "temperature": 0.7 + }, + "priority": priority, + "timeout": timeout, + "created_at": TestDataFactory.generate_timestamp(), + "expires_at": TestDataFactory.generate_timestamp(offset_seconds=timeout) + } + + # Transaction data + @staticmethod + def transaction_data( + sender: Optional[str] = None, + recipient: Optional[str] = None, + amount: float = 100.0 + ) -> Dict[str, Any]: + """Generate test transaction data""" + return { + "tx_id": TestDataFactory.generate_id("tx"), + "sender": sender or TestDataFactory.TEST_ADDRESSES["alice"], + "recipient": recipient or TestDataFactory.TEST_ADDRESSES["bob"], + "amount": amount, + "currency": "AITBC", + "fee": 0.1, + "timestamp": TestDataFactory.generate_timestamp(), + "status": "pending" + } + + # Miner data + @staticmethod + def miner_data( + miner_id: Optional[str] = None, + status: str = "active" + ) -> Dict[str, Any]: + """Generate test miner data""" + return { + "miner_id": miner_id or TestDataFactory.TEST_ADDRESSES["miner1"], + "status": status, + "total_jobs_completed": 10, + "successful_jobs": 9, + "average_accuracy": 95.0, + "gpu_count": 4, + "gpu_type": "NVIDIA A100", + "last_heartbeat": TestDataFactory.generate_timestamp() + } + + # GPU data + @staticmethod + def gpu_data( + gpu_id: Optional[str] = None, + status: str = "available" + ) -> Dict[str, Any]: + """Generate test GPU data""" + return { + "gpu_id": gpu_id or TestDataFactory.generate_id("gpu"), + "status": status, + "type": "NVIDIA A100", + "memory_gb": 80, + "compute_capability": 8.0, + "price_per_hour": 2.5, + "location": "us-east-1" + } + + # Staking data + @staticmethod + def staking_data( + amount: float = 1000.0, + lock_period: int = 30, + auto_compound: bool = False + ) -> Dict[str, Any]: + """Generate test staking data""" + return { + "stake_id": TestDataFactory.generate_id("stake"), + "amount": amount, + "lock_period": lock_period, + "auto_compound": auto_compound, + "apy": 5.0, + "start_time": TestDataFactory.generate_timestamp(), + "end_time": TestDataFactory.generate_timestamp(offset_seconds=lock_period * 86400), + "status": "active" + } + + # Agent data + @staticmethod + def agent_data( + agent_id: Optional[str] = None, + status: str = "active" + ) -> Dict[str, Any]: + """Generate test agent data""" + return { + "agent_id": agent_id or TestDataFactory.generate_id("agent"), + "status": status, + "type": "general", + "capabilities": ["text_generation", "code_generation"], + "performance_tier": "gold", + "created_at": TestDataFactory.generate_timestamp() + } + + # API request/response data + @staticmethod + def api_response( + status_code: int = 200, + data: Optional[Dict[str, Any]] = None, + message: str = "Success" + ) -> Dict[str, Any]: + """Generate test API response""" + return { + "status_code": status_code, + "data": data or {}, + "message": message, + "timestamp": TestDataFactory.generate_timestamp() + } + + # Error data + @staticmethod + def error_data( + error_code: str = "INTERNAL_ERROR", + error_message: str = "An error occurred", + details: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Generate test error data""" + return { + "error_code": error_code, + "error_message": error_message, + "details": details or {}, + "timestamp": TestDataFactory.generate_timestamp() + } + + # Pagination data + @staticmethod + def paginated_response( + items: List[Dict[str, Any]], + page: int = 1, + page_size: int = 10, + total: Optional[int] = None + ) -> Dict[str, Any]: + """Generate test paginated response""" + return { + "items": items, + "page": page, + "page_size": page_size, + "total": total or len(items), + "total_pages": (total or len(items) + page_size - 1) // page_size + } + + # Batch operations + @staticmethod + def batch_job_data(count: int = 5) -> List[Dict[str, Any]]: + """Generate multiple job data for batch operations""" + return [TestDataFactory.job_data() for _ in range(count)] + + @staticmethod + def batch_transaction_data(count: int = 5) -> List[Dict[str, Any]]: + """Generate multiple transaction data for batch operations""" + return [TestDataFactory.transaction_data() for _ in range(count)] + + # Domain-specific scenarios + @staticmethod + def marketplace_offer_data( + provider: Optional[str] = None, + price: float = 1.5 + ) -> Dict[str, Any]: + """Generate test marketplace offer data""" + return { + "offer_id": TestDataFactory.generate_id("offer"), + "provider": provider or TestDataFactory.TEST_ADDRESSES["miner1"], + "gpu_type": "NVIDIA A100", + "memory_gb": 80, + "price_per_hour": price, + "availability": "immediate", + "location": "us-east-1", + "created_at": TestDataFactory.generate_timestamp() + } + + @staticmethod + def governance_proposal_data( + title: str = "Test Proposal", + description: str = "Test proposal description" + ) -> Dict[str, Any]: + """Generate test governance proposal data""" + return { + "proposal_id": TestDataFactory.generate_id("proposal"), + "title": title, + "description": description, + "proposer": TestDataFactory.TEST_ADDRESSES["alice"], + "status": "active", + "votes_for": 0, + "votes_against": 0, + "created_at": TestDataFactory.generate_timestamp(), + "ends_at": TestDataFactory.generate_timestamp(offset_seconds=86400 * 7) # 7 days + } diff --git a/tests/test_rate_limiting.py b/tests/test_rate_limiting.py new file mode 100644 index 00000000..f7e2840c --- /dev/null +++ b/tests/test_rate_limiting.py @@ -0,0 +1,278 @@ +""" +Tests for rate limiting utilities +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from fastapi import Request, HTTPException +from starlette.responses import Response + +from aitbc.rate_limiting import ( + get_rate_limiter, + rate_limit, + RateLimitMiddleware, + get_rate_limit_headers, + reset_rate_limit, +) + + +class TestGetRateLimiter: + """Tests for get_rate_limiter function""" + + def test_get_rate_limiter_new(self): + """Test get_rate_limiter creates new limiter""" + limiter = get_rate_limiter("test", rate=10, per=60) + + assert limiter.rate == 10 + assert limiter.per == 60 + + def test_get_rate_limiter_cached(self): + """Test get_rate_limiter returns cached limiter""" + limiter1 = get_rate_limiter("test", rate=10, per=60) + limiter2 = get_rate_limiter("test", rate=20, per=30) + + # Should return the same instance + assert limiter1 is limiter2 + # Original values preserved + assert limiter2.rate == 10 + assert limiter2.per == 60 + + +class TestRateLimitDecorator: + """Tests for rate_limit decorator""" + + @pytest.mark.asyncio + async def test_rate_limit_within_limit(self): + """Test rate_limit allows requests within limit""" + @rate_limit(rate=5, per=60) + async def test_endpoint(request: Request): + return {"status": "ok"} + + request = Mock(spec=Request) + request.client = Mock(host="127.0.0.1") + request.url = Mock(path="/test") + + for _ in range(5): + result = await test_endpoint(request) + assert result == {"status": "ok"} + + @pytest.mark.asyncio + async def test_rate_limit_exceeded(self): + """Test rate_limit blocks requests exceeding limit""" + @rate_limit(rate=2, per=60) + async def test_endpoint(request: Request): + return {"status": "ok"} + + request = Mock(spec=Request) + request.client = Mock(host="127.0.0.1") + request.url = Mock(path="/test") + + # First 2 requests should succeed + await test_endpoint(request) + await test_endpoint(request) + + # Third request should fail + with pytest.raises(HTTPException) as exc_info: + await test_endpoint(request) + + assert exc_info.value.status_code == 429 + assert "Rate limit exceeded" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_rate_limit_custom_key_func(self): + """Test rate_limit with custom key function""" + def custom_key(request: Request) -> str: + return request.headers.get("X-API-Key", "unknown") + + @rate_limit(rate=2, per=60, key_func=custom_key) + async def test_endpoint(request: Request): + return {"status": "ok"} + + request = Mock(spec=Request) + request.client = Mock(host="127.0.0.1") + request.url = Mock(path="/test") + request.headers = {"X-API-Key": "key1"} + + # 2 requests with same key should succeed + await test_endpoint(request) + await test_endpoint(request) + + # Third should fail + with pytest.raises(HTTPException): + await test_endpoint(request) + + @pytest.mark.asyncio + async def test_rate_limit_no_request(self): + """Test rate_limit without request skips limiting""" + @rate_limit(rate=2, per=60) + async def test_endpoint(): + return {"status": "ok"} + + # Should succeed even without request + result = await test_endpoint() + assert result == {"status": "ok"} + + @pytest.mark.asyncio + async def test_rate_limit_custom_error_message(self): + """Test rate_limit with custom error message""" + @rate_limit(rate=1, per=60, error_message="Custom limit message") + async def test_endpoint(request: Request): + return {"status": "ok"} + + request = Mock(spec=Request) + request.client = Mock(host="127.0.0.1") + request.url = Mock(path="/test") + + await test_endpoint(request) + + with pytest.raises(HTTPException) as exc_info: + await test_endpoint(request) + + assert exc_info.value.detail == "Custom limit message" + + +class TestRateLimitMiddleware: + """Tests for RateLimitMiddleware""" + + @pytest.mark.asyncio + async def test_middleware_within_limit(self): + """Test middleware allows requests within limit""" + app = Mock() + middleware = RateLimitMiddleware(app, rate=5, per=60) + + request = Mock(spec=Request) + request.client = Mock(host="127.0.0.1") + request.url = Mock(path="/test") + + call_next = AsyncMock() + response = Mock(spec=Response) + call_next.return_value = response + + for _ in range(5): + result = await middleware.dispatch(request, call_next) + assert result == response + + @pytest.mark.asyncio + async def test_middleware_exceeded(self): + """Test middleware blocks requests exceeding limit""" + app = Mock() + middleware = RateLimitMiddleware(app, rate=2, per=60) + + request = Mock(spec=Request) + request.client = Mock(host="127.0.0.1") + request.url = Mock(path="/test") + + call_next = AsyncMock() + response = Mock(spec=Response) + call_next.return_value = response + + # First 2 requests should succeed + await middleware.dispatch(request, call_next) + await middleware.dispatch(request, call_next) + + # Third request should fail + result = await middleware.dispatch(request, call_next) + + assert result.status_code == 429 + assert b"Rate limit exceeded" in result.body + + @pytest.mark.asyncio + async def test_middleware_custom_key_func(self): + """Test middleware with custom key function""" + def custom_key(request: Request) -> str: + return request.headers.get("X-API-Key", "unknown") + + app = Mock() + middleware = RateLimitMiddleware(app, rate=2, per=60, key_func=custom_key) + + request = Mock(spec=Request) + request.client = Mock(host="127.0.0.1") + request.headers = {"X-API-Key": "key1"} + + call_next = AsyncMock() + response = Mock(spec=Response) + call_next.return_value = response + + # 2 requests with same key should succeed + await middleware.dispatch(request, call_next) + await middleware.dispatch(request, call_next) + + # Third should fail + result = await middleware.dispatch(request, call_next) + assert result.status_code == 429 + + @pytest.mark.asyncio + async def test_middleware_no_client(self): + """Test middleware handles requests without client""" + app = Mock() + middleware = RateLimitMiddleware(app, rate=2, per=60) + + request = Mock(spec=Request) + request.client = None + + call_next = AsyncMock() + response = Mock(spec=Response) + call_next.return_value = response + + # Should use "unknown" as key + result = await middleware.dispatch(request, call_next) + assert result == response + + +class TestGetRateLimitHeaders: + """Tests for get_rate_limit_headers""" + + def test_get_rate_limit_headers_existing_limiter(self): + """Test get_rate_limit_headers with existing limiter""" + get_rate_limiter("test", rate=10, per=60) + + request = Mock(spec=Request) + request.client = Mock(host="127.0.0.1") + + headers = get_rate_limit_headers(request, "test") + + assert headers["X-RateLimit-Limit"] == "10" + assert headers["X-RateLimit-Reset"] == "60" + assert "X-RateLimit-Remaining" in headers + + def test_get_rate_limit_headers_nonexistent_limiter(self): + """Test get_rate_limit_headers with nonexistent limiter""" + request = Mock(spec=Request) + request.client = Mock(host="127.0.0.1") + + headers = get_rate_limit_headers(request, "nonexistent") + + assert headers == {} + + +class TestResetRateLimit: + """Tests for reset_rate_limit""" + + def test_reset_rate_limit_specific_limiter(self): + """Test reset_rate_limit for specific limiter""" + limiter = get_rate_limiter("test", rate=2, per=60) + + # Make a request + limiter.is_allowed("127.0.0.1") + + # Reset + reset_rate_limit("127.0.0.1", "test") + + # Should be allowed again + assert limiter.is_allowed("127.0.0.1") + + def test_reset_rate_limit_all_limiters(self): + """Test reset_rate_limit for all limiters""" + limiter1 = get_rate_limiter("test1", rate=2, per=60) + limiter2 = get_rate_limiter("test2", rate=2, per=60) + + # Make requests + limiter1.is_allowed("127.0.0.1") + limiter2.is_allowed("127.0.0.1") + + # Reset all + reset_rate_limit("127.0.0.1") + + # Both should be allowed again + assert limiter1.is_allowed("127.0.0.1") + assert limiter2.is_allowed("127.0.0.1") diff --git a/tests/test_redis_cache.py b/tests/test_redis_cache.py new file mode 100644 index 00000000..efea9d16 --- /dev/null +++ b/tests/test_redis_cache.py @@ -0,0 +1,105 @@ +""" +Tests for Redis caching utilities +""" + +import pytest + +from aitbc.redis_cache import RedisCache, get_cache, cache_key + + +class TestRedisCache: + """Tests for RedisCache class (disabled cache mode)""" + + def test_init_without_redis(self): + """Test initialization without Redis available""" + cache = RedisCache(redis_url=None) + assert cache.is_available() is False + + def test_get_without_redis(self): + """Test get operation without Redis""" + cache = RedisCache(redis_url=None) + result = cache.get("test_key") + assert result is None + + def test_set_without_redis(self): + """Test set operation without Redis""" + cache = RedisCache(redis_url=None) + result = cache.set("test_key", {"key": "value"}) + assert result is False + + def test_delete_without_redis(self): + """Test delete operation without Redis""" + cache = RedisCache(redis_url=None) + result = cache.delete("test_key") + assert result is False + + def test_exists_without_redis(self): + """Test exists operation without Redis""" + cache = RedisCache(redis_url=None) + result = cache.exists("test_key") + assert result is False + + def test_clear_without_redis(self): + """Test clear operation without Redis""" + cache = RedisCache(redis_url=None) + result = cache.clear() + assert result is False + + def test_get_many_without_redis(self): + """Test get_many operation without Redis""" + cache = RedisCache(redis_url=None) + result = cache.get_many(["key1", "key2"]) + assert result == {} + + def test_set_many_without_redis(self): + """Test set_many operation without Redis""" + cache = RedisCache(redis_url=None) + result = cache.set_many({"key1": "value1"}) + assert result is False + + def test_delete_many_without_redis(self): + """Test delete_many operation without Redis""" + cache = RedisCache(redis_url=None) + result = cache.delete_many(["key1"]) + assert result is False + + def test_increment_without_redis(self): + """Test increment operation without Redis""" + cache = RedisCache(redis_url=None) + result = cache.increment("counter") + assert result is None + + +class TestGetCache: + """Tests for get_cache function""" + + def test_get_cache_without_url(self): + """Test get_cache without URL returns disabled cache""" + cache = get_cache(redis_url=None) + assert cache.is_available() is False + + +class TestCacheKey: + """Tests for cache_key function""" + + def test_cache_key_simple(self): + """Test cache_key with simple parts""" + key = cache_key("user", "123") + assert key == "aitbc:user:123" + + def test_cache_key_with_prefix(self): + """Test cache_key with custom prefix""" + key = cache_key("user", "123", prefix="custom") + assert key == "custom:user:123" + + def test_cache_key_multiple_parts(self): + """Test cache_key with multiple parts""" + key = cache_key("user", "123", "profile", "data") + assert key == "aitbc:user:123:profile:data" + + def test_cache_key_long_key(self): + """Test cache_key with long key gets hashed""" + long_part = "x" * 300 + key = cache_key(long_part, "data") + assert key.startswith("aitbc:hashed:") + assert len(key) <= 250