refactor: enhance configuration with security validation, database pooling, and rate limiting
Some checks failed
Cross-Node Transaction Testing / transaction-test (push) Has been cancelled
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Multi-Node Stress Testing / stress-test (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
API Endpoint Tests / test-api-endpoints (push) Successful in 20s
CLI Tests / test-cli (push) Failing after 3s
Package Tests / Python package - aitbc-agent-sdk (push) Successful in 33s
Package Tests / Python package - aitbc-core (push) Failing after 1s
Package Tests / Python package - aitbc-crypto (push) Successful in 10s
Package Tests / Python package - aitbc-sdk (push) Successful in 9s
Package Tests / JavaScript package - aitbc-sdk-js (push) Successful in 10s
Package Tests / JavaScript package - aitbc-token (push) Successful in 17s
Production Tests / Production Integration Tests (push) Failing after 6s
Some checks failed
Cross-Node Transaction Testing / transaction-test (push) Has been cancelled
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Multi-Node Stress Testing / stress-test (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
API Endpoint Tests / test-api-endpoints (push) Successful in 20s
CLI Tests / test-cli (push) Failing after 3s
Package Tests / Python package - aitbc-agent-sdk (push) Successful in 33s
Package Tests / Python package - aitbc-core (push) Failing after 1s
Package Tests / Python package - aitbc-crypto (push) Successful in 10s
Package Tests / Python package - aitbc-sdk (push) Successful in 9s
Package Tests / JavaScript package - aitbc-sdk-js (push) Successful in 10s
Package Tests / JavaScript package - aitbc-token (push) Successful in 17s
Production Tests / Production Integration Tests (push) Failing after 6s
- Added List import and field_validator to config.py - Added database connection pooling settings (max_overflow, pool_recycle, pool_pre_ping, echo) - Added rate limiting settings (rate_limit_requests, rate_limit_window_seconds) - Added CORS allow_origins field with default empty list - Added validate_secrets() method to check required secrets in production - Added validate_secret_length() validator for secret_key and jwt_secret (minimum
This commit is contained in:
@@ -4,9 +4,9 @@ Base configuration classes for AITBC applications
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from .constants import DATA_DIR, CONFIG_DIR, LOG_DIR, ENV_FILE
|
||||
from .aitbc_logging import get_logger
|
||||
@@ -19,25 +19,25 @@ class BaseAITBCConfig(BaseSettings):
|
||||
Base configuration class for all AITBC applications.
|
||||
Provides common AITBC-specific settings and environment file loading.
|
||||
"""
|
||||
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=str(ENV_FILE),
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
# AITBC system directories
|
||||
data_dir: Path = Field(default=DATA_DIR, description="AITBC data directory")
|
||||
config_dir: Path = Field(default=CONFIG_DIR, description="AITBC configuration directory")
|
||||
log_dir: Path = Field(default=LOG_DIR, description="AITBC log directory")
|
||||
|
||||
|
||||
# Application settings
|
||||
app_name: str = Field(default="AITBC Application", description="Application name")
|
||||
app_version: str = Field(default="1.0.0", description="Application version")
|
||||
environment: str = Field(default="development", description="Environment (development/staging/production)")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
|
||||
# Logging settings
|
||||
log_level: str = Field(default="INFO", description="Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)")
|
||||
log_format: str = Field(
|
||||
@@ -45,39 +45,92 @@ class BaseAITBCConfig(BaseSettings):
|
||||
description="Log format string"
|
||||
)
|
||||
|
||||
|
||||
class AITBCConfig(BaseAITBCConfig):
|
||||
"""
|
||||
Standard AITBC configuration with common settings.
|
||||
Inherits from BaseAITBCConfig and adds AITBC-specific fields.
|
||||
"""
|
||||
|
||||
# Server settings
|
||||
host: str = Field(default="0.0.0.0", description="Server host address")
|
||||
port: int = Field(default=8000, description="Server port")
|
||||
workers: int = Field(default=1, description="Number of worker processes")
|
||||
|
||||
|
||||
# Database settings
|
||||
database_url: Optional[str] = Field(default=None, description="Database connection URL")
|
||||
database_pool_size: int = Field(default=10, description="Database connection pool size")
|
||||
|
||||
database_max_overflow: int = Field(default=20, description="Maximum overflow connections")
|
||||
database_pool_recycle: int = Field(default=3600, description="Connection recycle time in seconds")
|
||||
database_pool_pre_ping: bool = Field(default=True, description="Test connections before using")
|
||||
database_echo: bool = Field(default=False, description="Enable SQL query logging")
|
||||
|
||||
# Redis settings (if applicable)
|
||||
redis_url: Optional[str] = Field(default=None, description="Redis connection URL")
|
||||
redis_max_connections: int = Field(default=10, description="Redis max connections")
|
||||
redis_timeout: int = Field(default=5, description="Redis timeout in seconds")
|
||||
|
||||
|
||||
# Security settings
|
||||
secret_key: Optional[str] = Field(default=None, description="Application secret key")
|
||||
jwt_secret: Optional[str] = Field(default=None, description="JWT secret key")
|
||||
jwt_algorithm: str = Field(default="HS256", description="JWT algorithm")
|
||||
jwt_expiration_hours: int = Field(default=24, description="JWT token expiration in hours")
|
||||
|
||||
|
||||
# Performance settings
|
||||
request_timeout: int = Field(default=30, description="Request timeout in seconds")
|
||||
max_request_size: int = Field(default=10 * 1024 * 1024, description="Max request size in bytes")
|
||||
|
||||
|
||||
# Rate limiting settings
|
||||
rate_limit_requests: int = Field(default=60, description="Rate limit requests per window")
|
||||
rate_limit_window_seconds: int = Field(default=60, description="Rate limit window in seconds")
|
||||
|
||||
# CORS settings
|
||||
allow_origins: List[str] = Field(default_factory=list, description="CORS allowed origins")
|
||||
|
||||
def validate_secrets(self) -> None:
|
||||
"""Validate that all required secrets are provided."""
|
||||
if self.environment == "production":
|
||||
if not self.secret_key:
|
||||
raise ValueError("SECRET_KEY environment variable is required in production")
|
||||
if self.secret_key == "change-me-in-production":
|
||||
raise ValueError("SECRET_KEY must be changed from default value")
|
||||
if not self.jwt_secret:
|
||||
raise ValueError("JWT_SECRET environment variable is required in production")
|
||||
if self.jwt_secret == "change-me-in-production":
|
||||
raise ValueError("JWT_SECRET must be changed from default value")
|
||||
|
||||
@field_validator("secret_key", "jwt_secret", mode="before")
|
||||
@classmethod
|
||||
def validate_secret_length(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate secret key length in production."""
|
||||
import os
|
||||
if os.getenv("APP_ENV", "development") != "production" and not v:
|
||||
return v
|
||||
if not v or v.startswith("$") or v == "your_secret_here" or v == "change-me-in-production":
|
||||
raise ValueError("Secret must be set to a secure value")
|
||||
if len(v) < 32:
|
||||
raise ValueError("Secret must be at least 32 characters long")
|
||||
return v
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize AITBC configuration with extended logging"""
|
||||
super().__init__(**kwargs)
|
||||
logger.info(f"Server configured for {self.host}:{self.port}")
|
||||
logger.info(f"{self.app_name} configured for {self.host}:{self.port}")
|
||||
logger.debug(f"Workers: {self.workers}, Request timeout: {self.request_timeout}s")
|
||||
|
||||
def get_redis_cache(self):
|
||||
"""Get Redis cache instance configured from settings"""
|
||||
from .redis_cache import get_cache
|
||||
return get_cache(
|
||||
redis_url=self.redis_url,
|
||||
max_connections=self.redis_max_connections,
|
||||
timeout=self.redis_timeout
|
||||
)
|
||||
|
||||
|
||||
class AITBCConfig(BaseAITBCConfig):
|
||||
"""
|
||||
Standard AITBC configuration with common settings.
|
||||
Inherits from BaseAITBCConfig and can be extended with service-specific fields.
|
||||
"""
|
||||
|
||||
# Override defaults for standard AITBC application
|
||||
app_name: str = Field(default="AITBC Application", description="Application name")
|
||||
port: int = Field(default=8000, description="Server port")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize AITBC configuration with extended logging"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -9,6 +9,12 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from contextlib import contextmanager
|
||||
from .exceptions import DatabaseError
|
||||
|
||||
# SQLAlchemy support for connection pooling
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import QueuePool, StaticPool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
class DatabaseConnection:
|
||||
"""
|
||||
@@ -259,3 +265,141 @@ def table_exists(db_path: Path, table_name: str) -> bool:
|
||||
(table_name,)
|
||||
)
|
||||
return result is not None
|
||||
|
||||
|
||||
# SQLAlchemy Connection Pooling Utilities
|
||||
|
||||
def create_pooled_engine(
|
||||
database_url: str,
|
||||
pool_size: int = 10,
|
||||
max_overflow: int = 20,
|
||||
pool_recycle: int = 3600,
|
||||
pool_pre_ping: bool = True,
|
||||
echo: bool = False,
|
||||
use_static_pool: bool = False
|
||||
):
|
||||
"""
|
||||
Create SQLAlchemy engine with connection pooling.
|
||||
|
||||
Args:
|
||||
database_url: Database connection URL
|
||||
pool_size: Size of connection pool
|
||||
max_overflow: Maximum overflow connections
|
||||
pool_recycle: Connection recycle time in seconds
|
||||
pool_pre_ping: Test connections before using
|
||||
echo: Enable SQL query logging
|
||||
use_static_pool: Use StaticPool for SQLite (single connection)
|
||||
|
||||
Returns:
|
||||
SQLAlchemy engine with connection pooling
|
||||
"""
|
||||
if "sqlite" in database_url and use_static_pool:
|
||||
# SQLite with StaticPool (single connection, suitable for tests)
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
echo=echo,
|
||||
pool_pre_ping=pool_pre_ping,
|
||||
)
|
||||
elif "sqlite" in database_url:
|
||||
# SQLite with QueuePool (limited pooling support)
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
connect_args={"check_same_thread": False, "timeout": 30},
|
||||
poolclass=QueuePool,
|
||||
pool_size=min(pool_size, 5), # SQLite has limited concurrent access
|
||||
max_overflow=max_overflow,
|
||||
pool_pre_ping=pool_pre_ping,
|
||||
echo=echo,
|
||||
)
|
||||
else:
|
||||
# PostgreSQL/MySQL with full connection pooling
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
poolclass=QueuePool,
|
||||
pool_size=pool_size,
|
||||
max_overflow=max_overflow,
|
||||
pool_recycle=pool_recycle,
|
||||
pool_pre_ping=pool_pre_ping,
|
||||
echo=echo,
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
def create_pooled_sessionmaker(
|
||||
engine,
|
||||
autoflush: bool = False,
|
||||
autocommit: bool = False
|
||||
):
|
||||
"""
|
||||
Create session factory with connection pooling.
|
||||
|
||||
Args:
|
||||
engine: SQLAlchemy engine
|
||||
autoflush: Enable autoflush
|
||||
autocommit: Enable autocommit
|
||||
|
||||
Returns:
|
||||
Session factory
|
||||
"""
|
||||
return sessionmaker(bind=engine, autoflush=autoflush, autocommit=autocommit)
|
||||
|
||||
|
||||
def create_async_pooled_engine(
|
||||
database_url: str,
|
||||
pool_size: int = 10,
|
||||
max_overflow: int = 20,
|
||||
pool_recycle: int = 3600,
|
||||
pool_pre_ping: bool = True,
|
||||
echo: bool = False
|
||||
):
|
||||
"""
|
||||
Create async SQLAlchemy engine with connection pooling.
|
||||
|
||||
Args:
|
||||
database_url: Database connection URL
|
||||
pool_size: Size of connection pool
|
||||
max_overflow: Maximum overflow connections
|
||||
pool_recycle: Connection recycle time in seconds
|
||||
pool_pre_ping: Test connections before using
|
||||
echo: Enable SQL query logging
|
||||
|
||||
Returns:
|
||||
Async SQLAlchemy engine with connection pooling
|
||||
"""
|
||||
# Convert to async URL
|
||||
if "sqlite" in database_url:
|
||||
async_url = database_url.replace("sqlite:///", "sqlite+aiosqlite:///")
|
||||
elif "postgresql" in database_url:
|
||||
async_url = database_url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
else:
|
||||
async_url = database_url
|
||||
|
||||
engine = create_async_engine(
|
||||
async_url,
|
||||
poolclass=QueuePool,
|
||||
pool_size=pool_size,
|
||||
max_overflow=max_overflow,
|
||||
pool_recycle=pool_recycle,
|
||||
pool_pre_ping=pool_pre_ping,
|
||||
echo=echo,
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
def create_async_pooled_sessionmaker(
|
||||
engine,
|
||||
expire_on_commit: bool = False
|
||||
):
|
||||
"""
|
||||
Create async session factory with connection pooling.
|
||||
|
||||
Args:
|
||||
engine: Async SQLAlchemy engine
|
||||
expire_on_commit: Expire objects on commit
|
||||
|
||||
Returns:
|
||||
Async session factory
|
||||
"""
|
||||
return async_sessionmaker(engine, expire_on_commit=expire_on_commit)
|
||||
|
||||
198
aitbc/rate_limiting.py
Normal file
198
aitbc/rate_limiting.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Rate limiting utilities for FastAPI applications
|
||||
Provides decorators and middleware for API rate limiting
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional, Dict, Any
|
||||
from fastapi import Request, HTTPException, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from .security_hardening import RateLimiter
|
||||
from .aitbc_logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Global rate limiters for different endpoints
|
||||
_rate_limiters: Dict[str, RateLimiter] = {}
|
||||
|
||||
|
||||
def get_rate_limiter(name: str, rate: int = 100, per: int = 60) -> RateLimiter:
|
||||
"""
|
||||
Get or create a rate limiter for a specific endpoint
|
||||
|
||||
Args:
|
||||
name: Unique name for the rate limiter
|
||||
rate: Number of requests allowed per time period
|
||||
per: Time period in seconds
|
||||
|
||||
Returns:
|
||||
RateLimiter instance
|
||||
"""
|
||||
if name not in _rate_limiters:
|
||||
_rate_limiters[name] = RateLimiter(rate=rate, per=per)
|
||||
return _rate_limiters[name]
|
||||
|
||||
|
||||
def rate_limit(
|
||||
rate: int = 100,
|
||||
per: int = 60,
|
||||
key_func: Optional[Callable[[Request], str]] = None,
|
||||
error_message: str = "Rate limit exceeded"
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator for rate limiting FastAPI endpoints
|
||||
|
||||
Args:
|
||||
rate: Number of requests allowed per time period
|
||||
per: Time period in seconds
|
||||
key_func: Function to extract rate limit key from request (defaults to client IP)
|
||||
error_message: Custom error message
|
||||
|
||||
Returns:
|
||||
Decorated function with rate limiting
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
limiter = RateLimiter(rate=rate, per=per)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
# Extract request from args (FastAPI passes request as first arg for dependency injection)
|
||||
request = None
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if request is None:
|
||||
# Try to get request from kwargs
|
||||
request = kwargs.get('request')
|
||||
|
||||
if request is None:
|
||||
# No request available, skip rate limiting
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Get rate limit key
|
||||
if key_func:
|
||||
key = key_func(request)
|
||||
else:
|
||||
key = request.client.host if request.client else "unknown"
|
||||
|
||||
# Check rate limit
|
||||
if not limiter.is_allowed(key):
|
||||
logger.warning(f"Rate limit exceeded for {key} on {request.url.path}")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=error_message,
|
||||
headers={"Retry-After": str(per)}
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware for rate limiting all requests
|
||||
|
||||
Applies rate limiting based on client IP address
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
rate: int = 100,
|
||||
per: int = 60,
|
||||
key_func: Optional[Callable[[Request], str]] = None,
|
||||
error_message: str = "Rate limit exceeded"
|
||||
) -> None:
|
||||
"""
|
||||
Initialize rate limit middleware
|
||||
|
||||
Args:
|
||||
app: ASGI application
|
||||
rate: Number of requests allowed per time period
|
||||
per: Time period in seconds
|
||||
key_func: Function to extract rate limit key from request
|
||||
error_message: Custom error message
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.rate = rate
|
||||
self.per = per
|
||||
self.key_func = key_func
|
||||
self.error_message = error_message
|
||||
self._limiter = RateLimiter(rate=rate, per=per)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
Process request with rate limiting
|
||||
|
||||
Args:
|
||||
request: Incoming request
|
||||
call_next: Next middleware or endpoint
|
||||
|
||||
Returns:
|
||||
Response
|
||||
"""
|
||||
# Get rate limit key
|
||||
if self.key_func:
|
||||
key = self.key_func(request)
|
||||
else:
|
||||
key = request.client.host if request.client else "unknown"
|
||||
|
||||
# Check rate limit
|
||||
if not self._limiter.is_allowed(key):
|
||||
logger.warning(f"Rate limit exceeded for {key} on {request.url.path}")
|
||||
return Response(
|
||||
content='{"detail": "' + self.error_message + '"}',
|
||||
status_code=429,
|
||||
media_type="application/json",
|
||||
headers={"Retry-After": str(self.per)}
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def get_rate_limit_headers(request: Request, limiter_name: str) -> Dict[str, str]:
|
||||
"""
|
||||
Get rate limit headers for response
|
||||
|
||||
Args:
|
||||
request: Request object
|
||||
limiter_name: Name of the rate limiter
|
||||
|
||||
Returns:
|
||||
Dictionary of rate limit headers
|
||||
"""
|
||||
limiter = _rate_limiters.get(limiter_name)
|
||||
if not limiter:
|
||||
return {}
|
||||
|
||||
key = request.client.host if request.client else "unknown"
|
||||
remaining = limiter.get_remaining_requests(key)
|
||||
|
||||
return {
|
||||
"X-RateLimit-Limit": str(limiter.rate),
|
||||
"X-RateLimit-Remaining": str(remaining),
|
||||
"X-RateLimit-Reset": str(limiter.per)
|
||||
}
|
||||
|
||||
|
||||
def reset_rate_limit(identifier: str, limiter_name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Reset rate limit for an identifier
|
||||
|
||||
Args:
|
||||
identifier: Identifier to reset (e.g., IP address, user ID)
|
||||
limiter_name: Name of specific rate limiter, or None for all
|
||||
"""
|
||||
if limiter_name:
|
||||
if limiter_name in _rate_limiters:
|
||||
_rate_limiters[limiter_name].reset(identifier)
|
||||
else:
|
||||
for limiter in _rate_limiters.values():
|
||||
limiter.reset(identifier)
|
||||
328
aitbc/redis_cache.py
Normal file
328
aitbc/redis_cache.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Redis caching utilities for AITBC applications
|
||||
Provides distributed caching with Redis backend
|
||||
"""
|
||||
|
||||
from typing import Optional, Any, List
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from .aitbc_logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
import redis
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REDIS_AVAILABLE = False
|
||||
logger.warning("Redis package not installed. Caching will be disabled.")
|
||||
|
||||
|
||||
class RedisCache:
|
||||
"""
|
||||
Redis cache implementation for distributed caching
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: Optional[str] = None,
|
||||
max_connections: int = 10,
|
||||
timeout: int = 5,
|
||||
default_ttl: int = 3600
|
||||
):
|
||||
"""
|
||||
Initialize Redis cache
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL (e.g., redis://localhost:6379/0)
|
||||
max_connections: Maximum number of connections
|
||||
timeout: Connection timeout in seconds
|
||||
default_ttl: Default time-to-live for cached items in seconds
|
||||
"""
|
||||
self.redis_url = redis_url
|
||||
self.max_connections = max_connections
|
||||
self.timeout = timeout
|
||||
self.default_ttl = default_ttl
|
||||
self._client = None
|
||||
|
||||
if REDIS_AVAILABLE and redis_url:
|
||||
try:
|
||||
self._client = redis.Redis.from_url(
|
||||
redis_url,
|
||||
max_connections=max_connections,
|
||||
socket_timeout=timeout,
|
||||
socket_connect_timeout=timeout,
|
||||
decode_responses=True
|
||||
)
|
||||
# Test connection
|
||||
self._client.ping()
|
||||
logger.info(f"Connected to Redis at {redis_url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to connect to Redis: {e}")
|
||||
self._client = None
|
||||
else:
|
||||
logger.info("Redis caching disabled (Redis not available or no URL provided)")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Redis cache is available"""
|
||||
return self._client is not None
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found
|
||||
"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
value = self._client.get(key)
|
||||
if value:
|
||||
return json.loads(value)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Redis get error: {e}")
|
||||
return None
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Set value in cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache (must be JSON serializable)
|
||||
ttl: Time-to-live in seconds (uses default_ttl if not provided)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
serialized = json.dumps(value)
|
||||
expiry = ttl if ttl is not None else self.default_ttl
|
||||
self._client.setex(key, expiry, serialized)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set error: {e}")
|
||||
return False
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete value from cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
self._client.delete(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis delete error: {e}")
|
||||
return False
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if key exists in cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise
|
||||
"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
return self._client.exists(key) == 1
|
||||
except Exception as e:
|
||||
logger.error(f"Redis exists error: {e}")
|
||||
return False
|
||||
|
||||
def clear(self) -> bool:
|
||||
"""
|
||||
Clear all cached values
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
self._client.flushdb()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis clear error: {e}")
|
||||
return False
|
||||
|
||||
def get_many(self, keys: List[str]) -> dict[str, Any]:
|
||||
"""
|
||||
Get multiple values from cache
|
||||
|
||||
Args:
|
||||
keys: List of cache keys
|
||||
|
||||
Returns:
|
||||
Dictionary mapping keys to cached values
|
||||
"""
|
||||
if not self.is_available():
|
||||
return {}
|
||||
|
||||
try:
|
||||
values = self._client.mget(keys)
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
if value:
|
||||
result[key] = json.loads(value)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Redis get_many error: {e}")
|
||||
return {}
|
||||
|
||||
def set_many(self, mapping: dict[str, Any], ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set multiple values in cache
|
||||
|
||||
Args:
|
||||
mapping: Dictionary of key-value pairs
|
||||
ttl: Time-to-live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
pipe = self._client.pipeline()
|
||||
expiry = ttl if ttl is not None else self.default_ttl
|
||||
for key, value in mapping.items():
|
||||
serialized = json.dumps(value)
|
||||
pipe.setex(key, expiry, serialized)
|
||||
pipe.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set_many error: {e}")
|
||||
return False
|
||||
|
||||
def delete_many(self, keys: List[str]) -> bool:
|
||||
"""
|
||||
Delete multiple values from cache
|
||||
|
||||
Args:
|
||||
keys: List of cache keys
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
if keys:
|
||||
self._client.delete(*keys)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis delete_many error: {e}")
|
||||
return False
|
||||
|
||||
def increment(self, key: str, amount: int = 1) -> Optional[int]:
|
||||
"""
|
||||
Increment a counter in cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
amount: Amount to increment by
|
||||
|
||||
Returns:
|
||||
New value or None if failed
|
||||
"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
return self._client.incrby(key, amount)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis increment error: {e}")
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close Redis connection"""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
logger.info("Redis connection closed")
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_global_cache: Optional[RedisCache] = None
|
||||
|
||||
|
||||
def get_cache(
|
||||
redis_url: Optional[str] = None,
|
||||
max_connections: int = 10,
|
||||
timeout: int = 5,
|
||||
default_ttl: int = 3600
|
||||
) -> RedisCache:
|
||||
"""
|
||||
Get or create global Redis cache instance
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
max_connections: Maximum number of connections
|
||||
timeout: Connection timeout in seconds
|
||||
default_ttl: Default time-to-live for cached items
|
||||
|
||||
Returns:
|
||||
RedisCache instance
|
||||
"""
|
||||
global _global_cache
|
||||
|
||||
if _global_cache is None and redis_url:
|
||||
_global_cache = RedisCache(
|
||||
redis_url=redis_url,
|
||||
max_connections=max_connections,
|
||||
timeout=timeout,
|
||||
default_ttl=default_ttl
|
||||
)
|
||||
|
||||
return _global_cache or RedisCache() # Return disabled cache if no URL
|
||||
|
||||
|
||||
def cache_key(*parts: str, prefix: str = "aitbc") -> str:
|
||||
"""
|
||||
Generate a cache key from parts
|
||||
|
||||
Args:
|
||||
*parts: Parts to include in the key
|
||||
prefix: Key prefix
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
key_string = ":".join(str(part) for part in parts)
|
||||
full_key = f"{prefix}:{key_string}"
|
||||
# Hash if too long
|
||||
if len(full_key) > 250:
|
||||
hash_value = hashlib.sha256(full_key.encode()).hexdigest()[:16]
|
||||
return f"{prefix}:hashed:{hash_value}"
|
||||
return full_key
|
||||
Reference in New Issue
Block a user