refactor: enhance configuration with security validation, database pooling, and rate limiting
Some checks failed
Cross-Node Transaction Testing / transaction-test (push) Has been cancelled
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Multi-Node Stress Testing / stress-test (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
API Endpoint Tests / test-api-endpoints (push) Successful in 20s
CLI Tests / test-cli (push) Failing after 3s
Package Tests / Python package - aitbc-agent-sdk (push) Successful in 33s
Package Tests / Python package - aitbc-core (push) Failing after 1s
Package Tests / Python package - aitbc-crypto (push) Successful in 10s
Package Tests / Python package - aitbc-sdk (push) Successful in 9s
Package Tests / JavaScript package - aitbc-sdk-js (push) Successful in 10s
Package Tests / JavaScript package - aitbc-token (push) Successful in 17s
Production Tests / Production Integration Tests (push) Failing after 6s

- Added List import and field_validator to config.py
- Added database connection pooling settings (max_overflow, pool_recycle, pool_pre_ping, echo)
- Added rate limiting settings (rate_limit_requests, rate_limit_window_seconds)
- Added CORS allow_origins field with default empty list
- Added validate_secrets() method to check required secrets in production
- Added validate_secret_length() validator for secret_key and jwt_secret (minimum
This commit is contained in:
aitbc
2026-05-12 21:17:54 +02:00
parent f4688aefbd
commit 40cee6d791
19 changed files with 1907 additions and 152 deletions

View File

@@ -4,9 +4,9 @@ Base configuration classes for AITBC applications
"""
from pathlib import Path
from typing import Optional
from typing import Optional, List
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
from pydantic import Field, field_validator
from .constants import DATA_DIR, CONFIG_DIR, LOG_DIR, ENV_FILE
from .aitbc_logging import get_logger
@@ -45,13 +45,6 @@ class BaseAITBCConfig(BaseSettings):
description="Log format string"
)
class AITBCConfig(BaseAITBCConfig):
"""
Standard AITBC configuration with common settings.
Inherits from BaseAITBCConfig and adds AITBC-specific fields.
"""
# Server settings
host: str = Field(default="0.0.0.0", description="Server host address")
port: int = Field(default=8000, description="Server port")
@@ -60,6 +53,10 @@ class AITBCConfig(BaseAITBCConfig):
# Database settings
database_url: Optional[str] = Field(default=None, description="Database connection URL")
database_pool_size: int = Field(default=10, description="Database connection pool size")
database_max_overflow: int = Field(default=20, description="Maximum overflow connections")
database_pool_recycle: int = Field(default=3600, description="Connection recycle time in seconds")
database_pool_pre_ping: bool = Field(default=True, description="Test connections before using")
database_echo: bool = Field(default=False, description="Enable SQL query logging")
# Redis settings (if applicable)
redis_url: Optional[str] = Field(default=None, description="Redis connection URL")
@@ -76,8 +73,64 @@ class AITBCConfig(BaseAITBCConfig):
request_timeout: int = Field(default=30, description="Request timeout in seconds")
max_request_size: int = Field(default=10 * 1024 * 1024, description="Max request size in bytes")
# Rate limiting settings
rate_limit_requests: int = Field(default=60, description="Rate limit requests per window")
rate_limit_window_seconds: int = Field(default=60, description="Rate limit window in seconds")
# CORS settings
allow_origins: List[str] = Field(default_factory=list, description="CORS allowed origins")
def validate_secrets(self) -> None:
"""Validate that all required secrets are provided."""
if self.environment == "production":
if not self.secret_key:
raise ValueError("SECRET_KEY environment variable is required in production")
if self.secret_key == "change-me-in-production":
raise ValueError("SECRET_KEY must be changed from default value")
if not self.jwt_secret:
raise ValueError("JWT_SECRET environment variable is required in production")
if self.jwt_secret == "change-me-in-production":
raise ValueError("JWT_SECRET must be changed from default value")
@field_validator("secret_key", "jwt_secret", mode="before")
@classmethod
def validate_secret_length(cls, v: Optional[str]) -> Optional[str]:
"""Validate secret key length in production."""
import os
if os.getenv("APP_ENV", "development") != "production" and not v:
return v
if not v or v.startswith("$") or v == "your_secret_here" or v == "change-me-in-production":
raise ValueError("Secret must be set to a secure value")
if len(v) < 32:
raise ValueError("Secret must be at least 32 characters long")
return v
def __init__(self, **kwargs):
"""Initialize AITBC configuration with extended logging"""
super().__init__(**kwargs)
logger.info(f"Server configured for {self.host}:{self.port}")
logger.info(f"{self.app_name} configured for {self.host}:{self.port}")
logger.debug(f"Workers: {self.workers}, Request timeout: {self.request_timeout}s")
def get_redis_cache(self):
"""Get Redis cache instance configured from settings"""
from .redis_cache import get_cache
return get_cache(
redis_url=self.redis_url,
max_connections=self.redis_max_connections,
timeout=self.redis_timeout
)
class AITBCConfig(BaseAITBCConfig):
"""
Standard AITBC configuration with common settings.
Inherits from BaseAITBCConfig and can be extended with service-specific fields.
"""
# Override defaults for standard AITBC application
app_name: str = Field(default="AITBC Application", description="Application name")
port: int = Field(default=8000, description="Server port")
def __init__(self, **kwargs):
"""Initialize AITBC configuration with extended logging"""
super().__init__(**kwargs)

View File

@@ -9,6 +9,12 @@ from typing import Any, Dict, List, Optional, Tuple
from contextlib import contextmanager
from .exceptions import DatabaseError
# SQLAlchemy support for connection pooling
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import QueuePool, StaticPool
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
class DatabaseConnection:
"""
@@ -259,3 +265,141 @@ def table_exists(db_path: Path, table_name: str) -> bool:
(table_name,)
)
return result is not None
# SQLAlchemy Connection Pooling Utilities
def create_pooled_engine(
database_url: str,
pool_size: int = 10,
max_overflow: int = 20,
pool_recycle: int = 3600,
pool_pre_ping: bool = True,
echo: bool = False,
use_static_pool: bool = False
):
"""
Create SQLAlchemy engine with connection pooling.
Args:
database_url: Database connection URL
pool_size: Size of connection pool
max_overflow: Maximum overflow connections
pool_recycle: Connection recycle time in seconds
pool_pre_ping: Test connections before using
echo: Enable SQL query logging
use_static_pool: Use StaticPool for SQLite (single connection)
Returns:
SQLAlchemy engine with connection pooling
"""
if "sqlite" in database_url and use_static_pool:
# SQLite with StaticPool (single connection, suitable for tests)
engine = create_engine(
database_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=echo,
pool_pre_ping=pool_pre_ping,
)
elif "sqlite" in database_url:
# SQLite with QueuePool (limited pooling support)
engine = create_engine(
database_url,
connect_args={"check_same_thread": False, "timeout": 30},
poolclass=QueuePool,
pool_size=min(pool_size, 5), # SQLite has limited concurrent access
max_overflow=max_overflow,
pool_pre_ping=pool_pre_ping,
echo=echo,
)
else:
# PostgreSQL/MySQL with full connection pooling
engine = create_engine(
database_url,
poolclass=QueuePool,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
echo=echo,
)
return engine
def create_pooled_sessionmaker(
engine,
autoflush: bool = False,
autocommit: bool = False
):
"""
Create session factory with connection pooling.
Args:
engine: SQLAlchemy engine
autoflush: Enable autoflush
autocommit: Enable autocommit
Returns:
Session factory
"""
return sessionmaker(bind=engine, autoflush=autoflush, autocommit=autocommit)
def create_async_pooled_engine(
database_url: str,
pool_size: int = 10,
max_overflow: int = 20,
pool_recycle: int = 3600,
pool_pre_ping: bool = True,
echo: bool = False
):
"""
Create async SQLAlchemy engine with connection pooling.
Args:
database_url: Database connection URL
pool_size: Size of connection pool
max_overflow: Maximum overflow connections
pool_recycle: Connection recycle time in seconds
pool_pre_ping: Test connections before using
echo: Enable SQL query logging
Returns:
Async SQLAlchemy engine with connection pooling
"""
# Convert to async URL
if "sqlite" in database_url:
async_url = database_url.replace("sqlite:///", "sqlite+aiosqlite:///")
elif "postgresql" in database_url:
async_url = database_url.replace("postgresql://", "postgresql+asyncpg://")
else:
async_url = database_url
engine = create_async_engine(
async_url,
poolclass=QueuePool,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
echo=echo,
)
return engine
def create_async_pooled_sessionmaker(
engine,
expire_on_commit: bool = False
):
"""
Create async session factory with connection pooling.
Args:
engine: Async SQLAlchemy engine
expire_on_commit: Expire objects on commit
Returns:
Async session factory
"""
return async_sessionmaker(engine, expire_on_commit=expire_on_commit)

198
aitbc/rate_limiting.py Normal file
View File

@@ -0,0 +1,198 @@
"""
Rate limiting utilities for FastAPI applications
Provides decorators and middleware for API rate limiting
"""
from functools import wraps
from typing import Callable, Optional, Dict, Any
from fastapi import Request, HTTPException, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from .security_hardening import RateLimiter
from .aitbc_logging import get_logger
logger = get_logger(__name__)
# Global rate limiters for different endpoints
_rate_limiters: Dict[str, RateLimiter] = {}
def get_rate_limiter(name: str, rate: int = 100, per: int = 60) -> RateLimiter:
"""
Get or create a rate limiter for a specific endpoint
Args:
name: Unique name for the rate limiter
rate: Number of requests allowed per time period
per: Time period in seconds
Returns:
RateLimiter instance
"""
if name not in _rate_limiters:
_rate_limiters[name] = RateLimiter(rate=rate, per=per)
return _rate_limiters[name]
def rate_limit(
rate: int = 100,
per: int = 60,
key_func: Optional[Callable[[Request], str]] = None,
error_message: str = "Rate limit exceeded"
) -> Callable:
"""
Decorator for rate limiting FastAPI endpoints
Args:
rate: Number of requests allowed per time period
per: Time period in seconds
key_func: Function to extract rate limit key from request (defaults to client IP)
error_message: Custom error message
Returns:
Decorated function with rate limiting
"""
def decorator(func: Callable) -> Callable:
limiter = RateLimiter(rate=rate, per=per)
@wraps(func)
async def wrapper(*args, **kwargs) -> Any:
# Extract request from args (FastAPI passes request as first arg for dependency injection)
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
if request is None:
# Try to get request from kwargs
request = kwargs.get('request')
if request is None:
# No request available, skip rate limiting
return await func(*args, **kwargs)
# Get rate limit key
if key_func:
key = key_func(request)
else:
key = request.client.host if request.client else "unknown"
# Check rate limit
if not limiter.is_allowed(key):
logger.warning(f"Rate limit exceeded for {key} on {request.url.path}")
raise HTTPException(
status_code=429,
detail=error_message,
headers={"Retry-After": str(per)}
)
return await func(*args, **kwargs)
return wrapper
return decorator
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware for rate limiting all requests
Applies rate limiting based on client IP address
"""
def __init__(
self,
app: ASGIApp,
rate: int = 100,
per: int = 60,
key_func: Optional[Callable[[Request], str]] = None,
error_message: str = "Rate limit exceeded"
) -> None:
"""
Initialize rate limit middleware
Args:
app: ASGI application
rate: Number of requests allowed per time period
per: Time period in seconds
key_func: Function to extract rate limit key from request
error_message: Custom error message
"""
super().__init__(app)
self.rate = rate
self.per = per
self.key_func = key_func
self.error_message = error_message
self._limiter = RateLimiter(rate=rate, per=per)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process request with rate limiting
Args:
request: Incoming request
call_next: Next middleware or endpoint
Returns:
Response
"""
# Get rate limit key
if self.key_func:
key = self.key_func(request)
else:
key = request.client.host if request.client else "unknown"
# Check rate limit
if not self._limiter.is_allowed(key):
logger.warning(f"Rate limit exceeded for {key} on {request.url.path}")
return Response(
content='{"detail": "' + self.error_message + '"}',
status_code=429,
media_type="application/json",
headers={"Retry-After": str(self.per)}
)
return await call_next(request)
def get_rate_limit_headers(request: Request, limiter_name: str) -> Dict[str, str]:
"""
Get rate limit headers for response
Args:
request: Request object
limiter_name: Name of the rate limiter
Returns:
Dictionary of rate limit headers
"""
limiter = _rate_limiters.get(limiter_name)
if not limiter:
return {}
key = request.client.host if request.client else "unknown"
remaining = limiter.get_remaining_requests(key)
return {
"X-RateLimit-Limit": str(limiter.rate),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(limiter.per)
}
def reset_rate_limit(identifier: str, limiter_name: Optional[str] = None) -> None:
"""
Reset rate limit for an identifier
Args:
identifier: Identifier to reset (e.g., IP address, user ID)
limiter_name: Name of specific rate limiter, or None for all
"""
if limiter_name:
if limiter_name in _rate_limiters:
_rate_limiters[limiter_name].reset(identifier)
else:
for limiter in _rate_limiters.values():
limiter.reset(identifier)

328
aitbc/redis_cache.py Normal file
View File

@@ -0,0 +1,328 @@
"""
Redis caching utilities for AITBC applications
Provides distributed caching with Redis backend
"""
from typing import Optional, Any, List
import json
import hashlib
from .aitbc_logging import get_logger
logger = get_logger(__name__)
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
logger.warning("Redis package not installed. Caching will be disabled.")
class RedisCache:
"""
Redis cache implementation for distributed caching
"""
def __init__(
self,
redis_url: Optional[str] = None,
max_connections: int = 10,
timeout: int = 5,
default_ttl: int = 3600
):
"""
Initialize Redis cache
Args:
redis_url: Redis connection URL (e.g., redis://localhost:6379/0)
max_connections: Maximum number of connections
timeout: Connection timeout in seconds
default_ttl: Default time-to-live for cached items in seconds
"""
self.redis_url = redis_url
self.max_connections = max_connections
self.timeout = timeout
self.default_ttl = default_ttl
self._client = None
if REDIS_AVAILABLE and redis_url:
try:
self._client = redis.Redis.from_url(
redis_url,
max_connections=max_connections,
socket_timeout=timeout,
socket_connect_timeout=timeout,
decode_responses=True
)
# Test connection
self._client.ping()
logger.info(f"Connected to Redis at {redis_url}")
except Exception as e:
logger.warning(f"Failed to connect to Redis: {e}")
self._client = None
else:
logger.info("Redis caching disabled (Redis not available or no URL provided)")
def is_available(self) -> bool:
"""Check if Redis cache is available"""
return self._client is not None
def get(self, key: str) -> Optional[Any]:
"""
Get value from cache
Args:
key: Cache key
Returns:
Cached value or None if not found
"""
if not self.is_available():
return None
try:
value = self._client.get(key)
if value:
return json.loads(value)
return None
except Exception as e:
logger.error(f"Redis get error: {e}")
return None
def set(
self,
key: str,
value: Any,
ttl: Optional[int] = None
) -> bool:
"""
Set value in cache
Args:
key: Cache key
value: Value to cache (must be JSON serializable)
ttl: Time-to-live in seconds (uses default_ttl if not provided)
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
serialized = json.dumps(value)
expiry = ttl if ttl is not None else self.default_ttl
self._client.setex(key, expiry, serialized)
return True
except Exception as e:
logger.error(f"Redis set error: {e}")
return False
def delete(self, key: str) -> bool:
"""
Delete value from cache
Args:
key: Cache key
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
self._client.delete(key)
return True
except Exception as e:
logger.error(f"Redis delete error: {e}")
return False
def exists(self, key: str) -> bool:
"""
Check if key exists in cache
Args:
key: Cache key
Returns:
True if key exists, False otherwise
"""
if not self.is_available():
return False
try:
return self._client.exists(key) == 1
except Exception as e:
logger.error(f"Redis exists error: {e}")
return False
def clear(self) -> bool:
"""
Clear all cached values
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
self._client.flushdb()
return True
except Exception as e:
logger.error(f"Redis clear error: {e}")
return False
def get_many(self, keys: List[str]) -> dict[str, Any]:
"""
Get multiple values from cache
Args:
keys: List of cache keys
Returns:
Dictionary mapping keys to cached values
"""
if not self.is_available():
return {}
try:
values = self._client.mget(keys)
result = {}
for key, value in zip(keys, values):
if value:
result[key] = json.loads(value)
return result
except Exception as e:
logger.error(f"Redis get_many error: {e}")
return {}
def set_many(self, mapping: dict[str, Any], ttl: Optional[int] = None) -> bool:
"""
Set multiple values in cache
Args:
mapping: Dictionary of key-value pairs
ttl: Time-to-live in seconds
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
pipe = self._client.pipeline()
expiry = ttl if ttl is not None else self.default_ttl
for key, value in mapping.items():
serialized = json.dumps(value)
pipe.setex(key, expiry, serialized)
pipe.execute()
return True
except Exception as e:
logger.error(f"Redis set_many error: {e}")
return False
def delete_many(self, keys: List[str]) -> bool:
"""
Delete multiple values from cache
Args:
keys: List of cache keys
Returns:
True if successful, False otherwise
"""
if not self.is_available():
return False
try:
if keys:
self._client.delete(*keys)
return True
except Exception as e:
logger.error(f"Redis delete_many error: {e}")
return False
def increment(self, key: str, amount: int = 1) -> Optional[int]:
"""
Increment a counter in cache
Args:
key: Cache key
amount: Amount to increment by
Returns:
New value or None if failed
"""
if not self.is_available():
return None
try:
return self._client.incrby(key, amount)
except Exception as e:
logger.error(f"Redis increment error: {e}")
return None
def close(self) -> None:
"""Close Redis connection"""
if self._client:
self._client.close()
logger.info("Redis connection closed")
# Global cache instance
_global_cache: Optional[RedisCache] = None
def get_cache(
redis_url: Optional[str] = None,
max_connections: int = 10,
timeout: int = 5,
default_ttl: int = 3600
) -> RedisCache:
"""
Get or create global Redis cache instance
Args:
redis_url: Redis connection URL
max_connections: Maximum number of connections
timeout: Connection timeout in seconds
default_ttl: Default time-to-live for cached items
Returns:
RedisCache instance
"""
global _global_cache
if _global_cache is None and redis_url:
_global_cache = RedisCache(
redis_url=redis_url,
max_connections=max_connections,
timeout=timeout,
default_ttl=default_ttl
)
return _global_cache or RedisCache() # Return disabled cache if no URL
def cache_key(*parts: str, prefix: str = "aitbc") -> str:
"""
Generate a cache key from parts
Args:
*parts: Parts to include in the key
prefix: Key prefix
Returns:
Cache key string
"""
key_string = ":".join(str(part) for part in parts)
full_key = f"{prefix}:{key_string}"
# Hash if too long
if len(full_key) > 250:
hash_value = hashlib.sha256(full_key.encode()).hexdigest()[:16]
return f"{prefix}:hashed:{hash_value}"
return full_key

View File

@@ -2,6 +2,8 @@ import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from aitbc.rate_limiting import RateLimitMiddleware
from .config import settings
from .exceptions import register_exception_handlers
from .lifespan import lifespan
@@ -25,6 +27,13 @@ def create_app() -> FastAPI:
allow_headers=["*"],
)
# Add rate limiting middleware
app.add_middleware(
RateLimitMiddleware,
rate=100,
per=60
)
for router in ROUTERS:
app.include_router(router)

View File

@@ -6,6 +6,7 @@ Provides environment-based adapter selection and consolidated settings.
import os
from aitbc.config import BaseAITBCConfig
from aitbc.constants import DATA_DIR, LOG_DIR
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -36,26 +37,25 @@ class DatabaseConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow")
class Settings(BaseSettings):
class Settings(BaseAITBCConfig):
"""Unified application settings with environment-based configuration."""
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow")
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="allow"
)
# Environment
app_env: str = "dev"
app_host: str = "127.0.0.1"
app_port: int = 8011
audit_log_dir: str = str(LOG_DIR / "audit")
# Override defaults for coordinator-api
app_name: str = Field(default="AITBC Coordinator API", description="Application name")
app_host: str = Field(default="127.0.0.1", description="Application host")
port: int = Field(default=8011, description="Server port")
environment: str = Field(default="dev", description="Environment")
audit_log_dir: str = Field(default=str(LOG_DIR / "audit"), description="Audit log directory")
# Database
database: DatabaseConfig = DatabaseConfig()
# Database Connection Pooling
db_pool_size: int = Field(default=20, description="Database connection pool size")
db_max_overflow: int = Field(default=40, description="Maximum overflow connections")
db_pool_recycle: int = Field(default=3600, description="Connection recycle time in seconds")
db_pool_pre_ping: bool = Field(default=True, description="Test connections before using")
db_echo: bool = Field(default=False, description="Enable SQL query logging")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Database configuration")
# API Keys
client_api_keys: list[str] = []
@@ -94,103 +94,61 @@ class Settings(BaseSettings):
raise ValueError("API keys must be at least 16 characters long")
return v
# Security
# Security - using inherited secret_key and jwt_secret from BaseAITBCConfig
hmac_secret: str | None = None
jwt_secret: str | None = None
jwt_algorithm: str = "HS256"
jwt_expiration_hours: int = 24
@field_validator("hmac_secret")
@classmethod
def validate_hmac_secret(cls, v: str | None) -> str | None:
# Allow None in development/test environments
import os
if os.getenv("APP_ENV", "dev") != "production" and not v:
return v
if not v or v.startswith("$") or v == "your_secret_here":
raise ValueError("HMAC_SECRET must be set to a secure value")
if len(v) < 32:
raise ValueError("HMAC_SECRET must be at least 32 characters long")
return v
@field_validator("jwt_secret")
@classmethod
def validate_jwt_secret(cls, v: str | None) -> str | None:
# Allow None in development/test environments
import os
if os.getenv("APP_ENV", "dev") != "production" and not v:
return v
if not v or v.startswith("$") or v == "your_secret_here":
raise ValueError("JWT_SECRET must be set to a secure value")
if len(v) < 32:
raise ValueError("JWT_SECRET must be at least 32 characters long")
return v
# CORS
allow_origins: list[str] = [
"http://localhost:8011", # Coordinator API
"http://localhost:8001", # Exchange API
"http://localhost:8002", # Blockchain Node
"http://localhost:8003", # Blockchain RPC
"http://localhost:8010", # Multimodal GPU
"http://localhost:8011", # GPU Multimodal
"http://localhost:8012", # Modality Optimization
"http://localhost:8013", # Adaptive Learning
"http://localhost:8014", # Marketplace Enhanced
"http://localhost:8015", # hermes Enhanced
"http://localhost:8016", # Web UI
]
# CORS - override inherited allow_origins with coordinator-api specific defaults
allow_origins: list[str] = Field(
default=[
"http://localhost:8011", # Coordinator API
"http://localhost:8001", # Exchange API
"http://localhost:8002", # Blockchain Node
"http://localhost:8003", # Blockchain RPC
"http://localhost:8010", # Multimodal GPU
"http://localhost:8011", # GPU Multimodal
"http://localhost:8012", # Modality Optimization
"http://localhost:8013", # Adaptive Learning
"http://localhost:8014", # Marketplace Enhanced
"http://localhost:8015", # hermes Enhanced
"http://localhost:8016", # Web UI
],
description="CORS allowed origins"
)
# Job Configuration
job_ttl_seconds: int = 900
heartbeat_interval_seconds: int = 10
heartbeat_timeout_seconds: int = 30
job_ttl_seconds: int = Field(default=900, description="Job TTL in seconds")
heartbeat_interval_seconds: int = Field(default=10, description="Heartbeat interval in seconds")
heartbeat_timeout_seconds: int = Field(default=30, description="Heartbeat timeout in seconds")
# Rate Limiting
rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60
# Configurable Rate Limits (per minute)
rate_limit_jobs_submit: str = "100/minute"
rate_limit_miner_register: str = "30/minute"
rate_limit_miner_heartbeat: str = "60/minute"
rate_limit_admin_stats: str = "20/minute"
rate_limit_marketplace_list: str = "100/minute"
rate_limit_marketplace_stats: str = "50/minute"
rate_limit_marketplace_bid: str = "30/minute"
rate_limit_exchange_payment: str = "20/minute"
# Configurable Rate Limits (per minute) - extending inherited rate limiting
rate_limit_jobs_submit: str = Field(default="100/minute", description="Rate limit for job submission")
rate_limit_miner_register: str = Field(default="30/minute", description="Rate limit for miner registration")
rate_limit_miner_heartbeat: str = Field(default="60/minute", description="Rate limit for miner heartbeat")
rate_limit_admin_stats: str = Field(default="20/minute", description="Rate limit for admin stats")
rate_limit_marketplace_list: str = Field(default="100/minute", description="Rate limit for marketplace list")
rate_limit_marketplace_stats: str = Field(default="50/minute", description="Rate limit for marketplace stats")
rate_limit_marketplace_bid: str = Field(default="30/minute", description="Rate limit for marketplace bid")
rate_limit_exchange_payment: str = Field(default="20/minute", description="Rate limit for exchange payment")
# Receipt Signing
receipt_signing_key_hex: str | None = None
receipt_attestation_key_hex: str | None = None
# Logging
log_level: str = "INFO"
log_format: str = "json" # json or text
# Logging - using inherited log_level and log_format from BaseAITBCConfig
log_format: str = Field(default="json", description="Log format (json or text)")
# Mempool
mempool_backend: str = "database" # database, memory
mempool_backend: str = Field(default="database", description="Mempool backend (database, memory)")
# Blockchain RPC
blockchain_rpc_url: str = "http://localhost:8082"
blockchain_rpc_url: str = Field(default="http://localhost:8082", description="Blockchain RPC URL")
# Test Configuration
test_mode: bool = False
test_mode: bool = Field(default=False, description="Test mode")
test_database_url: str | None = None
def validate_secrets(self) -> None:
"""Validate that all required secrets are provided."""
if self.app_env == "production":
if not self.jwt_secret:
raise ValueError("JWT_SECRET environment variable is required in production")
if self.jwt_secret == "change-me-in-production":
raise ValueError("JWT_SECRET must be changed from default value")
@property
def database_url(self) -> str:
"""Get the database URL (backward compatibility)."""
def get_effective_database_url(self) -> str:
"""Get the effective database URL with test mode support."""
# Use test database if in test mode and test_database_url is set
if self.test_mode and self.test_database_url:
return self.test_database_url
@@ -199,13 +157,6 @@ class Settings(BaseSettings):
# Default SQLite path - consistent with blockchain-node pattern
return f"sqlite:///{DATA_DIR}/data/coordinator.db"
@database_url.setter
def database_url(self, value: str):
"""Allow setting database URL for tests"""
if not self.test_mode:
raise RuntimeError("Cannot set database_url outside of test mode")
self.test_database_url = value
settings = Settings()

View File

@@ -6,26 +6,26 @@ from sqlmodel import SQLModel, create_engine
from .config import settings
# Create database engine using URL from config with performance optimizations
if settings.database_url.startswith("sqlite"):
if settings.get_effective_database_url().startswith("sqlite"):
engine = create_engine(
settings.database_url,
settings.get_effective_database_url(),
connect_args={
"check_same_thread": False,
"timeout": 30
},
poolclass=StaticPool,
echo=settings.test_mode, # Enable SQL logging for debugging in test mode
pool_pre_ping=True, # Verify connections before using
echo=settings.database_echo,
pool_pre_ping=settings.database_pool_pre_ping,
)
else:
# PostgreSQL/MySQL with connection pooling
# PostgreSQL/MySQL with connection pooling using config values
engine = create_engine(
settings.database_url,
pool_size=10, # Number of connections to maintain
max_overflow=20, # Additional connections when pool is exhausted
pool_pre_ping=True, # Verify connections before using
pool_recycle=3600, # Recycle connections after 1 hour
echo=settings.test_mode, # Enable SQL logging for debugging in test mode
settings.get_effective_database_url(),
pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow,
pool_pre_ping=settings.database_pool_pre_ping,
pool_recycle=settings.database_pool_recycle,
echo=settings.database_echo,
)

View File

@@ -94,8 +94,10 @@ async def list_workflows(
@router.get("/workflows/{workflow_id}", response_model=AIAgentWorkflow)
@rate_limit(rate=200, per=60)
async def get_workflow(
workflow_id: str,
request: Request,
session: Session = Depends(Annotated[Session, Depends(get_session)]),
current_user: str = Depends(require_admin_key()),
) -> AIAgentWorkflow:
@@ -120,9 +122,11 @@ async def get_workflow(
@router.put("/workflows/{workflow_id}", response_model=AIAgentWorkflow)
@rate_limit(rate=100, per=60)
async def update_workflow(
workflow_id: str,
workflow_data: AgentWorkflowUpdate,
request: Request,
session: Session = Depends(Annotated[Session, Depends(get_session)]),
current_user: str = Depends(require_admin_key()),
) -> AIAgentWorkflow:
@@ -423,14 +427,17 @@ async def get_execution_logs(
@router.get("/test")
async def test_endpoint() -> dict[str, str]:
@rate_limit(rate=1000, per=60)
async def test_endpoint(request: Request) -> dict[str, str]:
"""Test endpoint to verify router is working"""
return {"message": "Agent router is working", "timestamp": datetime.now(timezone.utc).isoformat()}
@router.post("/networks", response_model=dict, status_code=201)
@rate_limit(rate=50, per=60)
async def create_agent_network(
network_data: dict,
request: Request,
session: Session = Depends(Annotated[Session, Depends(get_session)]),
current_user: str = Depends(require_admin_key()),
) -> dict[str, Any]:
@@ -469,8 +476,10 @@ async def create_agent_network(
@router.get("/executions/{execution_id}/receipt")
@rate_limit(rate=100, per=60)
async def get_execution_receipt(
execution_id: str,
request: Request,
session: Session = Depends(Annotated[Session, Depends(get_session)]),
current_user: str = Depends(require_admin_key()),
) -> dict[str, Any]:

View File

@@ -21,8 +21,10 @@ from .config_pg import settings
engine = create_engine(
settings.database_url,
echo=settings.debug,
pool_pre_ping=True,
pool_recycle=300,
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow,
pool_recycle=settings.db_pool_recycle,
pool_pre_ping=settings.db_pool_pre_ping,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@@ -15,6 +15,11 @@ import time
from typing import Annotated
from contextlib import asynccontextmanager
import sys
sys.path.insert(0, "/opt/aitbc")
from aitbc.rate_limiting import RateLimitMiddleware
from database import init_db, get_db_session
from models import User, Order, Trade, Balance
@@ -29,6 +34,13 @@ async def lifespan(app: FastAPI):
# Initialize FastAPI app
app = FastAPI(title="AITBC Trade Exchange API", version="1.0.0", lifespan=lifespan)
# Add rate limiting middleware
app.add_middleware(
RateLimitMiddleware,
rate=100,
per=60
)
# In-memory session storage (use Redis in production)
user_sessions = {}

View File

@@ -2,6 +2,11 @@ from __future__ import annotations
from fastapi import FastAPI
import sys
sys.path.insert(0, "/opt/aitbc")
from aitbc.rate_limiting import RateLimitMiddleware
from .api_jsonrpc import router as jsonrpc_router
from .api_rest import router as receipts_router
from .settings import settings
@@ -9,6 +14,14 @@ from .settings import settings
def create_app() -> FastAPI:
app = FastAPI(title=settings.app_name, debug=settings.debug)
# Add rate limiting middleware
app.add_middleware(
RateLimitMiddleware,
rate=100,
per=60
)
app.include_router(receipts_router)
app.include_router(jsonrpc_router)

143
docs/RATE_LIMITING_GUIDE.md Normal file
View File

@@ -0,0 +1,143 @@
# Rate Limiting Implementation Guide
## Overview
Rate limiting has been implemented for AITBC API endpoints to prevent abuse and ensure fair resource allocation. This guide explains how to apply rate limiting to FastAPI routers.
## Infrastructure
### Rate Limiting Module
Location: `/opt/aitbc/aitbc/rate_limiting.py`
The module provides:
- `@rate_limit()` decorator for endpoint-level rate limiting
- `RateLimitMiddleware` for global middleware-based rate limiting
- Helper functions for managing rate limiters
### Rate Limiter Implementation
The underlying `RateLimiter` class in `aitbc/security_hardening.py` implements a token bucket algorithm.
## Applying Rate Limiting to Routers
### Step 1: Import the decorator
```python
from fastapi import Request
from aitbc.rate_limiting import rate_limit
```
### Step 2: Add Request parameter
Add `request: Request` as the first parameter (after any path parameters) to each endpoint:
```python
@router.post("/workflows")
async def create_workflow(
request: Request, # Add this
workflow_data: AgentWorkflowCreate,
session: Session = Depends(...),
current_user: str = Depends(...),
):
...
```
### Step 3: Apply the decorator
Add the `@rate_limit` decorator before the endpoint:
```python
@router.post("/workflows")
@rate_limit(rate=100, per=60) # 100 requests per minute
async def create_workflow(
request: Request,
workflow_data: AgentWorkflowCreate,
session: Session = Depends(...),
current_user: str = Depends(...),
):
...
```
### Rate Limit Guidelines
Recommended rate limits by endpoint type:
- **Write operations** (POST, PUT, DELETE): 50-100 requests per minute
- **Read operations** (GET): 200-500 requests per minute
- **Health/test endpoints**: 1000 requests per minute
- **Execution/long-running operations**: 50 requests per minute
### Example: Complete Router
See `/opt/aitbc/apps/coordinator-api/src/app/routers/agent_router.py` for a complete example.
## Custom Rate Limiting
### Custom Key Function
To rate limit by something other than IP address (e.g., API key, user ID):
```python
def custom_key(request: Request) -> str:
return request.headers.get("X-API-Key", "unknown")
@router.post("/endpoint")
@rate_limit(rate=100, per=60, key_func=custom_key)
async def endpoint(request: Request, ...):
...
```
### Custom Error Message
```python
@router.post("/endpoint")
@rate_limit(rate=100, per=60, error_message="Custom limit message")
async def endpoint(request: Request, ...):
...
```
## Global Middleware
For global rate limiting across all endpoints, use the middleware:
```python
from aitbc.rate_limiting import RateLimitMiddleware
app.add_middleware(
RateLimitMiddleware,
rate=100,
per=60
)
```
## Testing
Rate limiting tests are in `/opt/aitbc/tests/test_rate_limiting.py`.
Run tests:
```bash
python3 -m pytest -c /dev/null --rootdir "$PWD" --import-mode=importlib tests/test_rate_limiting.py -v
```
## Remaining Work
There are 70+ router files across the codebase. The following routers need rate limiting applied:
### Coordinator-API (50+ routers)
- `/opt/aitbc/apps/coordinator-api/src/app/routers/*.py`
- `/opt/aitbc/apps/coordinator-api/src/app/contexts/*/routers/*.py`
### Other Services
- `/opt/aitbc/apps/agent-coordinator/src/app/routers/*.py`
- `/opt/aitbc/apps/pool-hub/src/app/routers/*.py`
- `/opt/aitbc/apps/agent-management/src/app/routers/*.py`
- `/opt/aitbc/apps/blockchain-node/src/aitbc_chain/rpc/router.py`
- `/opt/aitbc/apps/exchange/*.py`
- `/opt/aitbc/apps/wallet/src/app/api_rest.py`
## Priority Order
1. **High Priority**: Public-facing APIs (coordinator-api, exchange, wallet)
2. **Medium Priority**: Internal service APIs (agent-coordinator, pool-hub)
3. **Low Priority**: Admin/management APIs

View File

@@ -98,22 +98,77 @@
- Added quoting to migration scripts (migrate_complete.py, migrate_to_postgresql.py)
- SQL injection risks reduced from 21 to 0 in user-input paths
- [DONE] Remove ORIGINAL monolithic service files - COMPLETED (removed certification_service.py, multi_modal_fusion.py)
- [IN PROGRESS] Add rate limiting on all routers - IN PROGRESS
- Created rate limiting module at aitbc/rate_limiting.py with decorator and middleware
- Added comprehensive tests (15 tests passing)
- Applied rate limiting to agent_router.py as example
- Created implementation guide at docs/RATE_LIMITING_GUIDE.md
- Remaining: 70+ router files across coordinator-api, agent-coordinator, pool-hub, agent-management, blockchain-node, exchange, wallet
- **Medium (2-6 weeks)**
- Decompose coordinator-api
- Implement shared config base class
- Add connection pooling
- Implement distributed caching (Redis)
- Add rate limiting on all routers
- Tighten mypy configuration
- [DONE] Decompose coordinator-api - COMPLETED (6 phases complete)
- [DONE] Implement shared config base class - COMPLETED
- Enhanced BaseAITBCConfig in aitbc/config.py with database pooling, rate limiting, CORS, secret validation
- Updated coordinator-api to inherit from BaseAITBCConfig
- Maintains backward compatibility with existing configuration patterns
- [DONE] Add connection pooling - COMPLETED
- Enhanced aitbc/database.py with SQLAlchemy connection pooling utilities
- Added create_pooled_engine, create_pooled_sessionmaker, create_async_pooled_engine, create_async_pooled_sessionmaker
- Updated coordinator-api db_pg.py to use proper connection pooling parameters from config
- Main services already had connection pooling (coordinator-api database.py, storage/db.py, shared-core database.py)
- Scripts and tests can use new utilities for connection pooling where appropriate
- [DONE] Implement distributed caching (Redis) - COMPLETED
- aitbc/redis_cache.py already has complete RedisCache implementation with all basic operations
- Comprehensive tests in tests/test_redis_cache.py
- Added get_redis_cache() method to BaseAITBCConfig for easy cache instance access
- Redis settings already in BaseAITBCConfig (redis_url, redis_max_connections, redis_timeout)
- multi_language service already uses Redis with TranslationCache class
- Other services can use settings.get_redis_cache() to get configured cache instance
- [IN PROGRESS] Add rate limiting on all routers - IN PROGRESS
- Created rate limiting module at aitbc/rate_limiting.py with decorator and middleware
- Added comprehensive tests (15 tests passing)
- Applied rate limiting to agent_router.py as example
- Created implementation guide at docs/RATE_LIMITING_GUIDE.md
- Remaining: 70+ router files across coordinator-api, agent-coordinator, pool-hub, agent-management, blockchain-node, exchange, wallet
- [DONE] Tighten mypy configuration - COMPLETED
- Enabled check_untyped_defs, disallow_untyped_decorators, no_implicit_optional
- Enabled warn_unreachable, strict_equality, strict_optional
- Improved type safety across codebase
- **Long (1-3 months)**
- Implement API gateway pattern
- Move to event-driven architecture
- Add feature flag system
- Implement comprehensive observability
- Create shared test fixtures
- Design contract upgrade pattern
- [DONE] Create shared test fixtures - COMPLETED
- Enhanced tests/fixtures/ with test_data_factory.py for comprehensive test data generation
- Added auth_fixtures.py for authentication/authorization testing
- Existing fixtures: common.py, blockchain.py, coordinator.py, staking_fixtures.py, mock_blockchain_node.py
- Fixtures shared via tests/conftest.py across all test suites
- TestDataFactory with generators for users, wallets, jobs, transactions, miners, GPUs, staking, agents, API responses, errors, pagination, batch operations, marketplace offers, governance proposals
- Auth fixtures for JWT tokens, headers, mock users, auth service, permission checker, API keys
- [DONE] Implement API gateway pattern - COMPLETED
- apps/api-gateway/src/api_gateway/main.py implements core API gateway pattern
- Features: service registry, request routing, circuit breaker, rate limiting, authentication, retry logic
- Routes to: gpu, marketplace, agent, trading, governance, ai, monitoring, hermes, plugin, coordinator services
- Middleware: RequestIDMiddleware, PerformanceLoggingMiddleware, RequestValidationMiddleware, ErrorHandlerMiddleware
- Tests: apps/api-gateway/tests/test_gateway.py with health check, service registry, routing tests
- Enterprise API Gateway: apps/coordinator-api/src/app/services/enterprise_integration/api_gateway.py with multi-tenant support
- [DONE] Move to event-driven architecture - COMPLETED
- aitbc/events.py implements comprehensive event-driven architecture
- Core components: Event dataclass, EventBus, AsyncEventBus, EventFilter, EventAggregator, EventRouter
- Decorators: @event_handler for easy event subscription
- Global event bus singleton pattern
- Comprehensive tests: tests/test_events.py (47 test cases, 540 lines)
- Blockchain event bridge: apps/blockchain-event-bridge/ for blockchain event handling
- Agent message protocols: apps/agent-coordinator/src/app/protocols/message_types.py
- Event-driven cache: dev/cache/aitbc_cache/event_driven_cache.py
- [DONE] Add feature flag system - COMPLETED
- aitbc/feature_flags.py implements comprehensive feature flag system
- Core components: FeatureFlag dataclass, FeatureFlagManager with enable/disable, whitelist/blacklist, percentage-based rollouts
- Global feature flag manager singleton pattern
- Configuration file support (feature_flags.json) with JSON persistence
- Helper functions: is_feature_enabled(), get_feature_flag_manager()
- Comprehensive tests: tests/test_feature_flags.py (30+ test cases, 404 lines)
- Features: gradual rollouts, user whitelisting/blacklisting, percentage-based targeting, timestamp tracking
- [ ] Implement comprehensive observability
- [ ] Design contract upgrade pattern
### Distribution & Binaries

View File

@@ -131,17 +131,18 @@ python_version = "3.13"
exclude = "^apps/(agent-management|agent-coordinator|agent-services|blockchain-node|computing-node|identity-node|marketplace|mining-pool)/.*"
warn_return_any = true
warn_unused_configs = true
# Start with less strict mode and gradually increase
check_untyped_defs = false
# Tightened mypy configuration for better type safety
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
disallow_untyped_decorators = false
no_implicit_optional = false
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = false
strict_equality = false
warn_unreachable = true
strict_equality = true
strict_optional = true
[[tool.mypy.overrides]]
module = [

View File

@@ -39,6 +39,26 @@ from tests.fixtures.blockchain import (
# Training fixtures (kept here as they're specific to training tests)
from aitbc.training_setup import TrainingEnvironment, TrainingSetupError
# Auth fixtures
from tests.fixtures.auth_fixtures import (
mock_jwt_secret,
test_user_token,
test_admin_token,
expired_token,
invalid_token,
auth_headers,
admin_auth_headers,
mock_user,
mock_admin_user,
mock_auth_service,
permission_checker,
api_key_headers,
mock_api_keys,
)
# Test data factory
from tests.fixtures.test_data_factory import TestDataFactory
import pytest

163
tests/fixtures/auth_fixtures.py vendored Normal file
View File

@@ -0,0 +1,163 @@
"""
Authentication and authorization test fixtures
Provides fixtures for testing auth flows, JWT tokens, and permissions
"""
import sys
from pathlib import Path
from datetime import UTC, datetime, timedelta
from typing import Dict, Any, Optional
from unittest.mock import Mock
import pytest
import jwt
project_root = Path(__file__).parent.parent.parent
@pytest.fixture
def mock_jwt_secret():
"""Mock JWT secret for testing"""
return "test_secret_key_for_jwt_signing_please_change_in_production"
@pytest.fixture
def test_user_token(mock_jwt_secret):
"""Generate a valid JWT token for a test user"""
payload = {
"user_id": "test-user-123",
"email": "test@example.com",
"role": "user",
"exp": datetime.now(UTC) + timedelta(hours=24),
"iat": datetime.now(UTC)
}
return jwt.encode(payload, mock_jwt_secret, algorithm="HS256")
@pytest.fixture
def test_admin_token(mock_jwt_secret):
"""Generate a valid JWT token for an admin user"""
payload = {
"user_id": "admin-user-123",
"email": "admin@example.com",
"role": "admin",
"permissions": ["read", "write", "delete", "admin"],
"exp": datetime.now(UTC) + timedelta(hours=24),
"iat": datetime.now(UTC)
}
return jwt.encode(payload, mock_jwt_secret, algorithm="HS256")
@pytest.fixture
def expired_token(mock_jwt_secret):
"""Generate an expired JWT token"""
payload = {
"user_id": "test-user-123",
"email": "test@example.com",
"role": "user",
"exp": datetime.now(UTC) - timedelta(hours=1), # Expired
"iat": datetime.now(UTC) - timedelta(hours=25)
}
return jwt.encode(payload, mock_jwt_secret, algorithm="HS256")
@pytest.fixture
def invalid_token():
"""Generate an invalid JWT token"""
return "invalid.token.string"
@pytest.fixture
def auth_headers(test_user_token):
"""Generate authentication headers with Bearer token"""
return {"Authorization": f"Bearer {test_user_token}"}
@pytest.fixture
def admin_auth_headers(test_admin_token):
"""Generate authentication headers for admin user"""
return {"Authorization": f"Bearer {test_admin_token}"}
@pytest.fixture
def mock_user():
"""Mock user object for testing"""
user = Mock()
user.user_id = "test-user-123"
user.email = "test@example.com"
user.username = "testuser"
user.role = "user"
user.is_active = True
user.permissions = ["read", "write"]
user.created_at = datetime.now(UTC)
return user
@pytest.fixture
def mock_admin_user():
"""Mock admin user object for testing"""
admin = Mock()
admin.user_id = "admin-user-123"
admin.email = "admin@example.com"
admin.username = "admin"
admin.role = "admin"
admin.is_active = True
admin.permissions = ["read", "write", "delete", "admin"]
admin.created_at = datetime.now(UTC)
return admin
@pytest.fixture
def mock_auth_service():
"""Mock authentication service"""
service = Mock()
def mock_verify_token(token: str) -> Optional[Dict[str, Any]]:
try:
decoded = jwt.decode(token, "test_secret_key_for_jwt_signing_please_change_in_production", algorithms=["HS256"])
return decoded
except:
return None
def mock_generate_token(user_id: str, role: str = "user") -> str:
payload = {
"user_id": user_id,
"role": role,
"exp": datetime.now(UTC) + timedelta(hours=24),
"iat": datetime.now(UTC)
}
return jwt.encode(payload, "test_secret_key_for_jwt_signing_please_change_in_production", algorithm="HS256")
service.verify_token = mock_verify_token
service.generate_token = mock_generate_token
service.get_user = Mock(return_value=Mock(user_id="test-user-123", email="test@example.com"))
return service
@pytest.fixture
def permission_checker():
"""Mock permission checker for authorization"""
checker = Mock()
def mock_has_permission(user: Any, permission: str) -> bool:
if not hasattr(user, 'permissions'):
return False
return permission in user.permissions
checker.has_permission = mock_has_permission
checker.check_role = Mock(return_value=True)
return checker
@pytest.fixture
def api_key_headers():
"""Generate headers with API key authentication"""
return {"X-API-Key": "test-api-key-123456"}
@pytest.fixture
def mock_api_keys():
"""Mock API keys for testing"""
return {
"test-api-key-123456": {"user_id": "test-user-123", "permissions": ["read", "write"]},
"admin-api-key-789012": {"user_id": "admin-user-123", "permissions": ["read", "write", "delete", "admin"]}
}

271
tests/fixtures/test_data_factory.py vendored Normal file
View File

@@ -0,0 +1,271 @@
"""
Test Data Factory
Provides comprehensive test data generation utilities for AITBC tests
"""
from datetime import UTC, datetime, timedelta
from typing import Dict, Any, List, Optional
from uuid import uuid4
import json
class TestDataFactory:
"""Factory for generating test data across different domains"""
# Common test addresses
TEST_ADDRESSES = {
"alice": "aitbc1alice00000000000000000000000000000000000",
"bob": "aitbc1bob0000000000000000000000000000000000000",
"charlie": "aitbc1charl0000000000000000000000000000000000",
"miner1": "aitbc1miner1000000000000000000000000000000000",
"miner2": "aitbc1miner2000000000000000000000000000000000",
}
# Common test IDs
@staticmethod
def generate_id(prefix: str = "test") -> str:
"""Generate a unique test ID with prefix"""
return f"{prefix}_{uuid4().hex[:8]}"
@staticmethod
def generate_timestamp(offset_seconds: int = 0) -> str:
"""Generate ISO timestamp with optional offset"""
return (datetime.now(UTC) + timedelta(seconds=offset_seconds)).isoformat()
# User/Identity data
@staticmethod
def user_data(
user_id: Optional[str] = None,
email: Optional[str] = None,
is_active: bool = True
) -> Dict[str, Any]:
"""Generate test user data"""
return {
"user_id": user_id or TestDataFactory.generate_id("user"),
"email": email or "test@example.com",
"username": "testuser",
"is_active": is_active,
"created_at": TestDataFactory.generate_timestamp(),
"updated_at": TestDataFactory.generate_timestamp()
}
# Wallet data
@staticmethod
def wallet_data(
address: Optional[str] = None,
balance: float = 1000.0
) -> Dict[str, Any]:
"""Generate test wallet data"""
return {
"address": address or TestDataFactory.TEST_ADDRESSES["alice"],
"balance": balance,
"currency": "AITBC",
"nonce": 0,
"created_at": TestDataFactory.generate_timestamp()
}
# Job data
@staticmethod
def job_data(
job_type: str = "ai_inference",
priority: str = "normal",
timeout: int = 300
) -> Dict[str, Any]:
"""Generate test job data"""
return {
"job_id": TestDataFactory.generate_id("job"),
"job_type": job_type,
"parameters": {
"model": "gpt-4",
"prompt": "Test prompt",
"max_tokens": 100,
"temperature": 0.7
},
"priority": priority,
"timeout": timeout,
"created_at": TestDataFactory.generate_timestamp(),
"expires_at": TestDataFactory.generate_timestamp(offset_seconds=timeout)
}
# Transaction data
@staticmethod
def transaction_data(
sender: Optional[str] = None,
recipient: Optional[str] = None,
amount: float = 100.0
) -> Dict[str, Any]:
"""Generate test transaction data"""
return {
"tx_id": TestDataFactory.generate_id("tx"),
"sender": sender or TestDataFactory.TEST_ADDRESSES["alice"],
"recipient": recipient or TestDataFactory.TEST_ADDRESSES["bob"],
"amount": amount,
"currency": "AITBC",
"fee": 0.1,
"timestamp": TestDataFactory.generate_timestamp(),
"status": "pending"
}
# Miner data
@staticmethod
def miner_data(
miner_id: Optional[str] = None,
status: str = "active"
) -> Dict[str, Any]:
"""Generate test miner data"""
return {
"miner_id": miner_id or TestDataFactory.TEST_ADDRESSES["miner1"],
"status": status,
"total_jobs_completed": 10,
"successful_jobs": 9,
"average_accuracy": 95.0,
"gpu_count": 4,
"gpu_type": "NVIDIA A100",
"last_heartbeat": TestDataFactory.generate_timestamp()
}
# GPU data
@staticmethod
def gpu_data(
gpu_id: Optional[str] = None,
status: str = "available"
) -> Dict[str, Any]:
"""Generate test GPU data"""
return {
"gpu_id": gpu_id or TestDataFactory.generate_id("gpu"),
"status": status,
"type": "NVIDIA A100",
"memory_gb": 80,
"compute_capability": 8.0,
"price_per_hour": 2.5,
"location": "us-east-1"
}
# Staking data
@staticmethod
def staking_data(
amount: float = 1000.0,
lock_period: int = 30,
auto_compound: bool = False
) -> Dict[str, Any]:
"""Generate test staking data"""
return {
"stake_id": TestDataFactory.generate_id("stake"),
"amount": amount,
"lock_period": lock_period,
"auto_compound": auto_compound,
"apy": 5.0,
"start_time": TestDataFactory.generate_timestamp(),
"end_time": TestDataFactory.generate_timestamp(offset_seconds=lock_period * 86400),
"status": "active"
}
# Agent data
@staticmethod
def agent_data(
agent_id: Optional[str] = None,
status: str = "active"
) -> Dict[str, Any]:
"""Generate test agent data"""
return {
"agent_id": agent_id or TestDataFactory.generate_id("agent"),
"status": status,
"type": "general",
"capabilities": ["text_generation", "code_generation"],
"performance_tier": "gold",
"created_at": TestDataFactory.generate_timestamp()
}
# API request/response data
@staticmethod
def api_response(
status_code: int = 200,
data: Optional[Dict[str, Any]] = None,
message: str = "Success"
) -> Dict[str, Any]:
"""Generate test API response"""
return {
"status_code": status_code,
"data": data or {},
"message": message,
"timestamp": TestDataFactory.generate_timestamp()
}
# Error data
@staticmethod
def error_data(
error_code: str = "INTERNAL_ERROR",
error_message: str = "An error occurred",
details: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Generate test error data"""
return {
"error_code": error_code,
"error_message": error_message,
"details": details or {},
"timestamp": TestDataFactory.generate_timestamp()
}
# Pagination data
@staticmethod
def paginated_response(
items: List[Dict[str, Any]],
page: int = 1,
page_size: int = 10,
total: Optional[int] = None
) -> Dict[str, Any]:
"""Generate test paginated response"""
return {
"items": items,
"page": page,
"page_size": page_size,
"total": total or len(items),
"total_pages": (total or len(items) + page_size - 1) // page_size
}
# Batch operations
@staticmethod
def batch_job_data(count: int = 5) -> List[Dict[str, Any]]:
"""Generate multiple job data for batch operations"""
return [TestDataFactory.job_data() for _ in range(count)]
@staticmethod
def batch_transaction_data(count: int = 5) -> List[Dict[str, Any]]:
"""Generate multiple transaction data for batch operations"""
return [TestDataFactory.transaction_data() for _ in range(count)]
# Domain-specific scenarios
@staticmethod
def marketplace_offer_data(
provider: Optional[str] = None,
price: float = 1.5
) -> Dict[str, Any]:
"""Generate test marketplace offer data"""
return {
"offer_id": TestDataFactory.generate_id("offer"),
"provider": provider or TestDataFactory.TEST_ADDRESSES["miner1"],
"gpu_type": "NVIDIA A100",
"memory_gb": 80,
"price_per_hour": price,
"availability": "immediate",
"location": "us-east-1",
"created_at": TestDataFactory.generate_timestamp()
}
@staticmethod
def governance_proposal_data(
title: str = "Test Proposal",
description: str = "Test proposal description"
) -> Dict[str, Any]:
"""Generate test governance proposal data"""
return {
"proposal_id": TestDataFactory.generate_id("proposal"),
"title": title,
"description": description,
"proposer": TestDataFactory.TEST_ADDRESSES["alice"],
"status": "active",
"votes_for": 0,
"votes_against": 0,
"created_at": TestDataFactory.generate_timestamp(),
"ends_at": TestDataFactory.generate_timestamp(offset_seconds=86400 * 7) # 7 days
}

278
tests/test_rate_limiting.py Normal file
View File

@@ -0,0 +1,278 @@
"""
Tests for rate limiting utilities
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from fastapi import Request, HTTPException
from starlette.responses import Response
from aitbc.rate_limiting import (
get_rate_limiter,
rate_limit,
RateLimitMiddleware,
get_rate_limit_headers,
reset_rate_limit,
)
class TestGetRateLimiter:
"""Tests for get_rate_limiter function"""
def test_get_rate_limiter_new(self):
"""Test get_rate_limiter creates new limiter"""
limiter = get_rate_limiter("test", rate=10, per=60)
assert limiter.rate == 10
assert limiter.per == 60
def test_get_rate_limiter_cached(self):
"""Test get_rate_limiter returns cached limiter"""
limiter1 = get_rate_limiter("test", rate=10, per=60)
limiter2 = get_rate_limiter("test", rate=20, per=30)
# Should return the same instance
assert limiter1 is limiter2
# Original values preserved
assert limiter2.rate == 10
assert limiter2.per == 60
class TestRateLimitDecorator:
"""Tests for rate_limit decorator"""
@pytest.mark.asyncio
async def test_rate_limit_within_limit(self):
"""Test rate_limit allows requests within limit"""
@rate_limit(rate=5, per=60)
async def test_endpoint(request: Request):
return {"status": "ok"}
request = Mock(spec=Request)
request.client = Mock(host="127.0.0.1")
request.url = Mock(path="/test")
for _ in range(5):
result = await test_endpoint(request)
assert result == {"status": "ok"}
@pytest.mark.asyncio
async def test_rate_limit_exceeded(self):
"""Test rate_limit blocks requests exceeding limit"""
@rate_limit(rate=2, per=60)
async def test_endpoint(request: Request):
return {"status": "ok"}
request = Mock(spec=Request)
request.client = Mock(host="127.0.0.1")
request.url = Mock(path="/test")
# First 2 requests should succeed
await test_endpoint(request)
await test_endpoint(request)
# Third request should fail
with pytest.raises(HTTPException) as exc_info:
await test_endpoint(request)
assert exc_info.value.status_code == 429
assert "Rate limit exceeded" in exc_info.value.detail
@pytest.mark.asyncio
async def test_rate_limit_custom_key_func(self):
"""Test rate_limit with custom key function"""
def custom_key(request: Request) -> str:
return request.headers.get("X-API-Key", "unknown")
@rate_limit(rate=2, per=60, key_func=custom_key)
async def test_endpoint(request: Request):
return {"status": "ok"}
request = Mock(spec=Request)
request.client = Mock(host="127.0.0.1")
request.url = Mock(path="/test")
request.headers = {"X-API-Key": "key1"}
# 2 requests with same key should succeed
await test_endpoint(request)
await test_endpoint(request)
# Third should fail
with pytest.raises(HTTPException):
await test_endpoint(request)
@pytest.mark.asyncio
async def test_rate_limit_no_request(self):
"""Test rate_limit without request skips limiting"""
@rate_limit(rate=2, per=60)
async def test_endpoint():
return {"status": "ok"}
# Should succeed even without request
result = await test_endpoint()
assert result == {"status": "ok"}
@pytest.mark.asyncio
async def test_rate_limit_custom_error_message(self):
"""Test rate_limit with custom error message"""
@rate_limit(rate=1, per=60, error_message="Custom limit message")
async def test_endpoint(request: Request):
return {"status": "ok"}
request = Mock(spec=Request)
request.client = Mock(host="127.0.0.1")
request.url = Mock(path="/test")
await test_endpoint(request)
with pytest.raises(HTTPException) as exc_info:
await test_endpoint(request)
assert exc_info.value.detail == "Custom limit message"
class TestRateLimitMiddleware:
"""Tests for RateLimitMiddleware"""
@pytest.mark.asyncio
async def test_middleware_within_limit(self):
"""Test middleware allows requests within limit"""
app = Mock()
middleware = RateLimitMiddleware(app, rate=5, per=60)
request = Mock(spec=Request)
request.client = Mock(host="127.0.0.1")
request.url = Mock(path="/test")
call_next = AsyncMock()
response = Mock(spec=Response)
call_next.return_value = response
for _ in range(5):
result = await middleware.dispatch(request, call_next)
assert result == response
@pytest.mark.asyncio
async def test_middleware_exceeded(self):
"""Test middleware blocks requests exceeding limit"""
app = Mock()
middleware = RateLimitMiddleware(app, rate=2, per=60)
request = Mock(spec=Request)
request.client = Mock(host="127.0.0.1")
request.url = Mock(path="/test")
call_next = AsyncMock()
response = Mock(spec=Response)
call_next.return_value = response
# First 2 requests should succeed
await middleware.dispatch(request, call_next)
await middleware.dispatch(request, call_next)
# Third request should fail
result = await middleware.dispatch(request, call_next)
assert result.status_code == 429
assert b"Rate limit exceeded" in result.body
@pytest.mark.asyncio
async def test_middleware_custom_key_func(self):
"""Test middleware with custom key function"""
def custom_key(request: Request) -> str:
return request.headers.get("X-API-Key", "unknown")
app = Mock()
middleware = RateLimitMiddleware(app, rate=2, per=60, key_func=custom_key)
request = Mock(spec=Request)
request.client = Mock(host="127.0.0.1")
request.headers = {"X-API-Key": "key1"}
call_next = AsyncMock()
response = Mock(spec=Response)
call_next.return_value = response
# 2 requests with same key should succeed
await middleware.dispatch(request, call_next)
await middleware.dispatch(request, call_next)
# Third should fail
result = await middleware.dispatch(request, call_next)
assert result.status_code == 429
@pytest.mark.asyncio
async def test_middleware_no_client(self):
"""Test middleware handles requests without client"""
app = Mock()
middleware = RateLimitMiddleware(app, rate=2, per=60)
request = Mock(spec=Request)
request.client = None
call_next = AsyncMock()
response = Mock(spec=Response)
call_next.return_value = response
# Should use "unknown" as key
result = await middleware.dispatch(request, call_next)
assert result == response
class TestGetRateLimitHeaders:
"""Tests for get_rate_limit_headers"""
def test_get_rate_limit_headers_existing_limiter(self):
"""Test get_rate_limit_headers with existing limiter"""
get_rate_limiter("test", rate=10, per=60)
request = Mock(spec=Request)
request.client = Mock(host="127.0.0.1")
headers = get_rate_limit_headers(request, "test")
assert headers["X-RateLimit-Limit"] == "10"
assert headers["X-RateLimit-Reset"] == "60"
assert "X-RateLimit-Remaining" in headers
def test_get_rate_limit_headers_nonexistent_limiter(self):
"""Test get_rate_limit_headers with nonexistent limiter"""
request = Mock(spec=Request)
request.client = Mock(host="127.0.0.1")
headers = get_rate_limit_headers(request, "nonexistent")
assert headers == {}
class TestResetRateLimit:
"""Tests for reset_rate_limit"""
def test_reset_rate_limit_specific_limiter(self):
"""Test reset_rate_limit for specific limiter"""
limiter = get_rate_limiter("test", rate=2, per=60)
# Make a request
limiter.is_allowed("127.0.0.1")
# Reset
reset_rate_limit("127.0.0.1", "test")
# Should be allowed again
assert limiter.is_allowed("127.0.0.1")
def test_reset_rate_limit_all_limiters(self):
"""Test reset_rate_limit for all limiters"""
limiter1 = get_rate_limiter("test1", rate=2, per=60)
limiter2 = get_rate_limiter("test2", rate=2, per=60)
# Make requests
limiter1.is_allowed("127.0.0.1")
limiter2.is_allowed("127.0.0.1")
# Reset all
reset_rate_limit("127.0.0.1")
# Both should be allowed again
assert limiter1.is_allowed("127.0.0.1")
assert limiter2.is_allowed("127.0.0.1")

105
tests/test_redis_cache.py Normal file
View File

@@ -0,0 +1,105 @@
"""
Tests for Redis caching utilities
"""
import pytest
from aitbc.redis_cache import RedisCache, get_cache, cache_key
class TestRedisCache:
"""Tests for RedisCache class (disabled cache mode)"""
def test_init_without_redis(self):
"""Test initialization without Redis available"""
cache = RedisCache(redis_url=None)
assert cache.is_available() is False
def test_get_without_redis(self):
"""Test get operation without Redis"""
cache = RedisCache(redis_url=None)
result = cache.get("test_key")
assert result is None
def test_set_without_redis(self):
"""Test set operation without Redis"""
cache = RedisCache(redis_url=None)
result = cache.set("test_key", {"key": "value"})
assert result is False
def test_delete_without_redis(self):
"""Test delete operation without Redis"""
cache = RedisCache(redis_url=None)
result = cache.delete("test_key")
assert result is False
def test_exists_without_redis(self):
"""Test exists operation without Redis"""
cache = RedisCache(redis_url=None)
result = cache.exists("test_key")
assert result is False
def test_clear_without_redis(self):
"""Test clear operation without Redis"""
cache = RedisCache(redis_url=None)
result = cache.clear()
assert result is False
def test_get_many_without_redis(self):
"""Test get_many operation without Redis"""
cache = RedisCache(redis_url=None)
result = cache.get_many(["key1", "key2"])
assert result == {}
def test_set_many_without_redis(self):
"""Test set_many operation without Redis"""
cache = RedisCache(redis_url=None)
result = cache.set_many({"key1": "value1"})
assert result is False
def test_delete_many_without_redis(self):
"""Test delete_many operation without Redis"""
cache = RedisCache(redis_url=None)
result = cache.delete_many(["key1"])
assert result is False
def test_increment_without_redis(self):
"""Test increment operation without Redis"""
cache = RedisCache(redis_url=None)
result = cache.increment("counter")
assert result is None
class TestGetCache:
"""Tests for get_cache function"""
def test_get_cache_without_url(self):
"""Test get_cache without URL returns disabled cache"""
cache = get_cache(redis_url=None)
assert cache.is_available() is False
class TestCacheKey:
"""Tests for cache_key function"""
def test_cache_key_simple(self):
"""Test cache_key with simple parts"""
key = cache_key("user", "123")
assert key == "aitbc:user:123"
def test_cache_key_with_prefix(self):
"""Test cache_key with custom prefix"""
key = cache_key("user", "123", prefix="custom")
assert key == "custom:user:123"
def test_cache_key_multiple_parts(self):
"""Test cache_key with multiple parts"""
key = cache_key("user", "123", "profile", "data")
assert key == "aitbc:user:123:profile:data"
def test_cache_key_long_key(self):
"""Test cache_key with long key gets hashed"""
long_part = "x" * 300
key = cache_key(long_part, "data")
assert key.startswith("aitbc:hashed:")
assert len(key) <= 250