refactor: enhance configuration with security validation, database pooling, and rate limiting
Some checks failed
Cross-Node Transaction Testing / transaction-test (push) Has been cancelled
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Multi-Node Stress Testing / stress-test (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
API Endpoint Tests / test-api-endpoints (push) Successful in 20s
CLI Tests / test-cli (push) Failing after 3s
Package Tests / Python package - aitbc-agent-sdk (push) Successful in 33s
Package Tests / Python package - aitbc-core (push) Failing after 1s
Package Tests / Python package - aitbc-crypto (push) Successful in 10s
Package Tests / Python package - aitbc-sdk (push) Successful in 9s
Package Tests / JavaScript package - aitbc-sdk-js (push) Successful in 10s
Package Tests / JavaScript package - aitbc-token (push) Successful in 17s
Production Tests / Production Integration Tests (push) Failing after 6s

- Added List import and field_validator to config.py
- Added database connection pooling settings (max_overflow, pool_recycle, pool_pre_ping, echo)
- Added rate limiting settings (rate_limit_requests, rate_limit_window_seconds)
- Added CORS allow_origins field with default empty list
- Added validate_secrets() method to check required secrets in production
- Added validate_secret_length() validator for secret_key and jwt_secret (minimum
This commit is contained in:
aitbc
2026-05-12 21:17:54 +02:00
parent f4688aefbd
commit 40cee6d791
19 changed files with 1907 additions and 152 deletions

View File

@@ -4,9 +4,9 @@ Base configuration classes for AITBC applications
"""
from pathlib import Path
from typing import Optional
from typing import Optional, List
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
from pydantic import Field, field_validator
from .constants import DATA_DIR, CONFIG_DIR, LOG_DIR, ENV_FILE
from .aitbc_logging import get_logger
@@ -19,25 +19,25 @@ class BaseAITBCConfig(BaseSettings):
Base configuration class for all AITBC applications.
Provides common AITBC-specific settings and environment file loading.
"""
model_config = SettingsConfigDict(
env_file=str(ENV_FILE),
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore"
)
# AITBC system directories
data_dir: Path = Field(default=DATA_DIR, description="AITBC data directory")
config_dir: Path = Field(default=CONFIG_DIR, description="AITBC configuration directory")
log_dir: Path = Field(default=LOG_DIR, description="AITBC log directory")
# Application settings
app_name: str = Field(default="AITBC Application", description="Application name")
app_version: str = Field(default="1.0.0", description="Application version")
environment: str = Field(default="development", description="Environment (development/staging/production)")
debug: bool = Field(default=False, description="Debug mode")
# Logging settings
log_level: str = Field(default="INFO", description="Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)")
log_format: str = Field(
@@ -45,39 +45,92 @@ class BaseAITBCConfig(BaseSettings):
description="Log format string"
)
class AITBCConfig(BaseAITBCConfig):
"""
Standard AITBC configuration with common settings.
Inherits from BaseAITBCConfig and adds AITBC-specific fields.
"""
# Server settings
host: str = Field(default="0.0.0.0", description="Server host address")
port: int = Field(default=8000, description="Server port")
workers: int = Field(default=1, description="Number of worker processes")
# Database settings
database_url: Optional[str] = Field(default=None, description="Database connection URL")
database_pool_size: int = Field(default=10, description="Database connection pool size")
database_max_overflow: int = Field(default=20, description="Maximum overflow connections")
database_pool_recycle: int = Field(default=3600, description="Connection recycle time in seconds")
database_pool_pre_ping: bool = Field(default=True, description="Test connections before using")
database_echo: bool = Field(default=False, description="Enable SQL query logging")
# Redis settings (if applicable)
redis_url: Optional[str] = Field(default=None, description="Redis connection URL")
redis_max_connections: int = Field(default=10, description="Redis max connections")
redis_timeout: int = Field(default=5, description="Redis timeout in seconds")
# Security settings
secret_key: Optional[str] = Field(default=None, description="Application secret key")
jwt_secret: Optional[str] = Field(default=None, description="JWT secret key")
jwt_algorithm: str = Field(default="HS256", description="JWT algorithm")
jwt_expiration_hours: int = Field(default=24, description="JWT token expiration in hours")
# Performance settings
request_timeout: int = Field(default=30, description="Request timeout in seconds")
max_request_size: int = Field(default=10 * 1024 * 1024, description="Max request size in bytes")
# Rate limiting settings
rate_limit_requests: int = Field(default=60, description="Rate limit requests per window")
rate_limit_window_seconds: int = Field(default=60, description="Rate limit window in seconds")
# CORS settings
allow_origins: List[str] = Field(default_factory=list, description="CORS allowed origins")
def validate_secrets(self) -> None:
"""Validate that all required secrets are provided."""
if self.environment == "production":
if not self.secret_key:
raise ValueError("SECRET_KEY environment variable is required in production")
if self.secret_key == "change-me-in-production":
raise ValueError("SECRET_KEY must be changed from default value")
if not self.jwt_secret:
raise ValueError("JWT_SECRET environment variable is required in production")
if self.jwt_secret == "change-me-in-production":
raise ValueError("JWT_SECRET must be changed from default value")
@field_validator("secret_key", "jwt_secret", mode="before")
@classmethod
def validate_secret_length(cls, v: Optional[str]) -> Optional[str]:
"""Validate secret key length in production."""
import os
if os.getenv("APP_ENV", "development") != "production" and not v:
return v
if not v or v.startswith("$") or v == "your_secret_here" or v == "change-me-in-production":
raise ValueError("Secret must be set to a secure value")
if len(v) < 32:
raise ValueError("Secret must be at least 32 characters long")
return v
def __init__(self, **kwargs):
"""Initialize AITBC configuration with extended logging"""
super().__init__(**kwargs)
logger.info(f"Server configured for {self.host}:{self.port}")
logger.info(f"{self.app_name} configured for {self.host}:{self.port}")
logger.debug(f"Workers: {self.workers}, Request timeout: {self.request_timeout}s")
def get_redis_cache(self):
"""Get Redis cache instance configured from settings"""
from .redis_cache import get_cache
return get_cache(
redis_url=self.redis_url,
max_connections=self.redis_max_connections,
timeout=self.redis_timeout
)
class AITBCConfig(BaseAITBCConfig):
"""
Standard AITBC configuration with common settings.
Inherits from BaseAITBCConfig and can be extended with service-specific fields.
"""
# Override defaults for standard AITBC application
app_name: str = Field(default="AITBC Application", description="Application name")
port: int = Field(default=8000, description="Server port")
def __init__(self, **kwargs):
"""Initialize AITBC configuration with extended logging"""
super().__init__(**kwargs)

View File

@@ -9,6 +9,12 @@ from typing import Any, Dict, List, Optional, Tuple
from contextlib import contextmanager
from .exceptions import DatabaseError
# SQLAlchemy support for connection pooling
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import QueuePool, StaticPool
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
class DatabaseConnection:
"""
@@ -259,3 +265,141 @@ def table_exists(db_path: Path, table_name: str) -> bool:
(table_name,)
)
return result is not None
# SQLAlchemy Connection Pooling Utilities
def create_pooled_engine(
database_url: str,
pool_size: int = 10,
max_overflow: int = 20,
pool_recycle: int = 3600,
pool_pre_ping: bool = True,
echo: bool = False,
use_static_pool: bool = False
):
"""
Create SQLAlchemy engine with connection pooling.
Args:
database_url: Database connection URL
pool_size: Size of connection pool
max_overflow: Maximum overflow connections
pool_recycle: Connection recycle time in seconds
pool_pre_ping: Test connections before using
echo: Enable SQL query logging
use_static_pool: Use StaticPool for SQLite (single connection)
Returns:
SQLAlchemy engine with connection pooling
"""
if "sqlite" in database_url and use_static_pool:
# SQLite with StaticPool (single connection, suitable for tests)
engine = create_engine(
database_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=echo,
pool_pre_ping=pool_pre_ping,
)
elif "sqlite" in database_url:
# SQLite with QueuePool (limited pooling support)
engine = create_engine(
database_url,
connect_args={"check_same_thread": False, "timeout": 30},
poolclass=QueuePool,
pool_size=min(pool_size, 5), # SQLite has limited concurrent access
max_overflow=max_overflow,
pool_pre_ping=pool_pre_ping,
echo=echo,
)
else:
# PostgreSQL/MySQL with full connection pooling
engine = create_engine(
database_url,
poolclass=QueuePool,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
echo=echo,
)
return engine
def create_pooled_sessionmaker(
engine,
autoflush: bool = False,
autocommit: bool = False
):
"""
Create session factory with connection pooling.
Args:
engine: SQLAlchemy engine
autoflush: Enable autoflush
autocommit: Enable autocommit
Returns:
Session factory
"""
return sessionmaker(bind=engine, autoflush=autoflush, autocommit=autocommit)
def create_async_pooled_engine(
database_url: str,
pool_size: int = 10,
max_overflow: int = 20,
pool_recycle: int = 3600,
pool_pre_ping: bool = True,
echo: bool = False
):
"""
Create async SQLAlchemy engine with connection pooling.
Args:
database_url: Database connection URL
pool_size: Size of connection pool
max_overflow: Maximum overflow connections
pool_recycle: Connection recycle time in seconds
pool_pre_ping: Test connections before using
echo: Enable SQL query logging
Returns:
Async SQLAlchemy engine with connection pooling
"""
# Convert to async URL
if "sqlite" in database_url:
async_url = database_url.replace("sqlite:///", "sqlite+aiosqlite:///")
elif "postgresql" in database_url:
async_url = database_url.replace("postgresql://", "postgresql+asyncpg://")
else:
async_url = database_url
engine = create_async_engine(
async_url,
poolclass=QueuePool,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
echo=echo,
)
return engine
def create_async_pooled_sessionmaker(
engine,
expire_on_commit: bool = False
):
"""
Create async session factory with connection pooling.
Args:
engine: Async SQLAlchemy engine
expire_on_commit: Expire objects on commit
Returns:
Async session factory
"""
return async_sessionmaker(engine, expire_on_commit=expire_on_commit)

198
aitbc/rate_limiting.py Normal file
View File

@@ -0,0 +1,198 @@
"""
Rate limiting utilities for FastAPI applications
Provides decorators and middleware for API rate limiting
"""
from functools import wraps
from typing import Callable, Optional, Dict, Any
from fastapi import Request, HTTPException, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from .security_hardening import RateLimiter
from .aitbc_logging import get_logger
logger = get_logger(__name__)
# Global rate limiters for different endpoints
_rate_limiters: Dict[str, RateLimiter] = {}
def get_rate_limiter(name: str, rate: int = 100, per: int = 60) -> RateLimiter:
"""
Get or create a rate limiter for a specific endpoint
Args:
name: Unique name for the rate limiter
rate: Number of requests allowed per time period
per: Time period in seconds
Returns:
RateLimiter instance
"""
if name not in _rate_limiters:
_rate_limiters[name] = RateLimiter(rate=rate, per=per)
return _rate_limiters[name]
def rate_limit(
rate: int = 100,
per: int = 60,
key_func: Optional[Callable[[Request], str]] = None,
error_message: str = "Rate limit exceeded"
) -> Callable:
"""
Decorator for rate limiting FastAPI endpoints
Args:
rate: Number of requests allowed per time period
per: Time period in seconds
key_func: Function to extract rate limit key from request (defaults to client IP)
error_message: Custom error message
Returns:
Decorated function with rate limiting
"""
def decorator(func: Callable) -> Callable:
limiter = RateLimiter(rate=rate, per=per)
@wraps(func)
async def wrapper(*args, **kwargs) -> Any:
# Extract request from args (FastAPI passes request as first arg for dependency injection)
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
if request is None:
# Try to get request from kwargs
request = kwargs.get('request')
if request is None:
# No request available, skip rate limiting
return await func(*args, **kwargs)
# Get rate limit key
if key_func:
key = key_func(request)
else:
key = request.client.host if request.client else "unknown"
# Check rate limit
if not limiter.is_allowed(key):
logger.warning(f"Rate limit exceeded for {key} on {request.url.path}")
raise HTTPException(
status_code=429,
detail=error_message,
headers={"Retry-After": str(per)}
)
return await func(*args, **kwargs)
return wrapper
return decorator
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware for rate limiting all requests
Applies rate limiting based on client IP address
"""
def __init__(
self,
app: ASGIApp,
rate: int = 100,
per: int = 60,
key_func: Optional[Callable[[Request], str]] = None,
error_message: str = "Rate limit exceeded"
) -> None:
"""
Initialize rate limit middleware
Args:
app: ASGI application
rate: Number of requests allowed per time period
per: Time period in seconds
key_func: Function to extract rate limit key from request
error_message: Custom error message
"""
super().__init__(app)
self.rate = rate
self.per = per
self.key_func = key_func
self.error_message = error_message
self._limiter = RateLimiter(rate=rate, per=per)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process request with rate limiting
Args:
request: Incoming request
call_next: Next middleware or endpoint
Returns:
Response
"""
# Get rate limit key
if self.key_func:
key = self.key_func(request)
else:
key = request.client.host if request.client else "unknown"
# Check rate limit
if not self._limiter.is_allowed(key):
logger.warning(f"Rate limit exceeded for {key} on {request.url.path}")
return Response(
content='{"detail": "' + self.error_message + '"}',
status_code=429,
media_type="application/json",
headers={"Retry-After": str(self.per)}
)
return await call_next(request)
def get_rate_limit_headers(request: Request, limiter_name: str) -> Dict[str, str]:
"""
Get rate limit headers for response
Args:
request: Request object
limiter_name: Name of the rate limiter
Returns:
Dictionary of rate limit headers
"""
limiter = _rate_limiters.get(limiter_name)
if not limiter:
return {}
key = request.client.host if request.client else "unknown"
remaining = limiter.get_remaining_requests(key)
return {
"X-RateLimit-Limit": str(limiter.rate),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(limiter.per)
}
def reset_rate_limit(identifier: str, limiter_name: Optional[str] = None) -> None:
"""
Reset rate limit for an identifier
Args:
identifier: Identifier to reset (e.g., IP address, user ID)
limiter_name: Name of specific rate limiter, or None for all
"""
if limiter_name:
if limiter_name in _rate_limiters:
_rate_limiters[limiter_name].reset(identifier)
else:
for limiter in _rate_limiters.values():
limiter.reset(identifier)

328
aitbc/redis_cache.py Normal file
View File

@@ -0,0 +1,328 @@
"""
Redis caching utilities for AITBC applications
Provides distributed caching with Redis backend
"""
from typing import Optional, Any, List
import json
import hashlib
from .aitbc_logging import get_logger
logger = get_logger(__name__)
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
logger.warning("Redis package not installed. Caching will be disabled.")
class RedisCache:
"""
Redis cache implementation for distributed caching
"""
def __init__(
self,
redis_url: Optional[str] = None,
max_connections: int = 10,
timeout: int = 5,
default_ttl: int = 3600
):
"""
Initialize Redis cache
Args:
redis_url: Redis connection URL (e.g., redis://localhost:6379/0)
max_connections: Maximum number of connections
timeout: Connection timeout in seconds
default_ttl: Default time-to-live for cached items in seconds
"""
self.redis_url = redis_url
self.max_connections = max_connections
self.timeout = timeout
self.default_ttl = default_ttl
self._client = None
if REDIS_AVAILABLE and redis_url:
try:
self._client = redis.Redis.from_url(
redis_url,
max_connections=max_connections,
socket_timeout=timeout,
socket_connect_timeout=timeout,
decode_responses=True
)
# Test connection
self._client.ping()
logger.info(f"Connected to Redis at {redis_url}")
except Exception as e:
logger.warning(f"Failed to connect to Redis: {e}")
self._client = None
else:
logger.info("Redis caching disabled (Redis not available or no URL provided)")
def is_available(self) -> bool:
"""Check if Redis cache is available"""
return self._client is not None
def get(self, key: str) -> Optional[Any]:
"""
Get value from cache
Args:
key: Cache key
Returns:
Cached value or None if not found
"""
if not self.is_available():
return None
try:
value = self._client.get(key)
if value:
return json.loads(value)
return None
except Exception as e:
logger.error(f"Redis get error: {e}")
return None
def set(
self,
key: str,
value: Any,
ttl: Optional[int] = None
) -> bool:
"""
Set value in cache
Args:
key: Cache key
value: Value to cache (must be JSON serializable)
ttl: Time-to-live in seconds (uses default_ttl if not provided)
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
serialized = json.dumps(value)
expiry = ttl if ttl is not None else self.default_ttl
self._client.setex(key, expiry, serialized)
return True
except Exception as e:
logger.error(f"Redis set error: {e}")
return False
def delete(self, key: str) -> bool:
"""
Delete value from cache
Args:
key: Cache key
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
self._client.delete(key)
return True
except Exception as e:
logger.error(f"Redis delete error: {e}")
return False
def exists(self, key: str) -> bool:
"""
Check if key exists in cache
Args:
key: Cache key
Returns:
True if key exists, False otherwise
"""
if not self.is_available():
return False
try:
return self._client.exists(key) == 1
except Exception as e:
logger.error(f"Redis exists error: {e}")
return False
def clear(self) -> bool:
"""
Clear all cached values
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
self._client.flushdb()
return True
except Exception as e:
logger.error(f"Redis clear error: {e}")
return False
def get_many(self, keys: List[str]) -> dict[str, Any]:
"""
Get multiple values from cache
Args:
keys: List of cache keys
Returns:
Dictionary mapping keys to cached values
"""
if not self.is_available():
return {}
try:
values = self._client.mget(keys)
result = {}
for key, value in zip(keys, values):
if value:
result[key] = json.loads(value)
return result
except Exception as e:
logger.error(f"Redis get_many error: {e}")
return {}
def set_many(self, mapping: dict[str, Any], ttl: Optional[int] = None) -> bool:
"""
Set multiple values in cache
Args:
mapping: Dictionary of key-value pairs
ttl: Time-to-live in seconds
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
pipe = self._client.pipeline()
expiry = ttl if ttl is not None else self.default_ttl
for key, value in mapping.items():
serialized = json.dumps(value)
pipe.setex(key, expiry, serialized)
pipe.execute()
return True
except Exception as e:
logger.error(f"Redis set_many error: {e}")
return False
def delete_many(self, keys: List[str]) -> bool:
"""
Delete multiple values from cache
Args:
keys: List of cache keys
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
if keys:
self._client.delete(*keys)
return True
except Exception as e:
logger.error(f"Redis delete_many error: {e}")
return False
def increment(self, key: str, amount: int = 1) -> Optional[int]:
"""
Increment a counter in cache
Args:
key: Cache key
amount: Amount to increment by
Returns:
New value or None if failed
"""
if not self.is_available():
return None
try:
return self._client.incrby(key, amount)
except Exception as e:
logger.error(f"Redis increment error: {e}")
return None
def close(self) -> None:
"""Close Redis connection"""
if self._client:
self._client.close()
logger.info("Redis connection closed")
# Global cache instance
_global_cache: Optional[RedisCache] = None
def get_cache(
redis_url: Optional[str] = None,
max_connections: int = 10,
timeout: int = 5,
default_ttl: int = 3600
) -> RedisCache:
"""
Get or create global Redis cache instance
Args:
redis_url: Redis connection URL
max_connections: Maximum number of connections
timeout: Connection timeout in seconds
default_ttl: Default time-to-live for cached items
Returns:
RedisCache instance
"""
global _global_cache
if _global_cache is None and redis_url:
_global_cache = RedisCache(
redis_url=redis_url,
max_connections=max_connections,
timeout=timeout,
default_ttl=default_ttl
)
return _global_cache or RedisCache() # Return disabled cache if no URL
def cache_key(*parts: str, prefix: str = "aitbc") -> str:
"""
Generate a cache key from parts
Args:
*parts: Parts to include in the key
prefix: Key prefix
Returns:
Cache key string
"""
key_string = ":".join(str(part) for part in parts)
full_key = f"{prefix}:{key_string}"
# Hash if too long
if len(full_key) > 250:
hash_value = hashlib.sha256(full_key.encode()).hexdigest()[:16]
return f"{prefix}:hashed:{hash_value}"
return full_key