feat(coordinator-api): add database connection pooling configuration and comprehensive payment validation
- Add configurable database connection pool settings (pool_size, max_overflow, pool_recycle, pool_pre_ping, echo) - Replace hardcoded pool settings with environment-configurable values from settings - Add QueuePool explicitly to both sync and async database engines - Implement comprehensive Pydantic validators for payment schemas (JobPaymentCreate, PaymentRequest, ExchangePaymentRequest) - Add regex
This commit is contained in:
@@ -55,6 +55,13 @@ class Settings(BaseSettings):
|
||||
|
||||
# Database
|
||||
database: DatabaseConfig = DatabaseConfig()
|
||||
|
||||
# Database Connection Pooling
|
||||
db_pool_size: int = Field(default=20, description="Database connection pool size")
|
||||
db_max_overflow: int = Field(default=40, description="Maximum overflow connections")
|
||||
db_pool_recycle: int = Field(default=3600, description="Connection recycle time in seconds")
|
||||
db_pool_pre_ping: bool = Field(default=True, description="Test connections before using")
|
||||
db_echo: bool = Field(default=False, description="Enable SQL query logging")
|
||||
|
||||
# API Keys
|
||||
client_api_keys: List[str] = []
|
||||
|
||||
@@ -4,8 +4,9 @@ from datetime import datetime
|
||||
from typing import Any, Dict, Optional, List
|
||||
from base64 import b64encode, b64decode
|
||||
from enum import Enum
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
||||
|
||||
from .types import JobState, Constraints
|
||||
|
||||
@@ -13,11 +14,36 @@ from .types import JobState, Constraints
|
||||
# Payment schemas
|
||||
class JobPaymentCreate(BaseModel):
|
||||
"""Request to create a payment for a job"""
|
||||
job_id: str
|
||||
amount: float
|
||||
currency: str = "AITBC" # Jobs paid with AITBC tokens
|
||||
payment_method: str = "aitbc_token" # Primary method for job payments
|
||||
escrow_timeout_seconds: int = 3600 # 1 hour default
|
||||
job_id: str = Field(..., min_length=1, max_length=128, description="Job identifier")
|
||||
amount: float = Field(..., gt=0, le=1_000_000, description="Payment amount in AITBC")
|
||||
currency: str = Field(default="AITBC", description="Payment currency")
|
||||
payment_method: str = Field(default="aitbc_token", description="Payment method")
|
||||
escrow_timeout_seconds: int = Field(default=3600, ge=300, le=86400, description="Escrow timeout in seconds")
|
||||
|
||||
@field_validator('job_id')
|
||||
@classmethod
|
||||
def validate_job_id(cls, v: str) -> str:
|
||||
"""Validate job ID format to prevent injection attacks"""
|
||||
if not re.match(r'^[a-zA-Z0-9\-_]+$', v):
|
||||
raise ValueError('Job ID contains invalid characters')
|
||||
return v
|
||||
|
||||
@field_validator('amount')
|
||||
@classmethod
|
||||
def validate_amount(cls, v: float) -> float:
|
||||
"""Validate and round payment amount"""
|
||||
if v < 0.01:
|
||||
raise ValueError('Minimum payment amount is 0.01 AITBC')
|
||||
return round(v, 8) # Prevent floating point precision issues
|
||||
|
||||
@field_validator('currency')
|
||||
@classmethod
|
||||
def validate_currency(cls, v: str) -> str:
|
||||
"""Validate currency code"""
|
||||
allowed_currencies = ['AITBC', 'BTC', 'ETH', 'USDT']
|
||||
if v.upper() not in allowed_currencies:
|
||||
raise ValueError(f'Currency must be one of: {allowed_currencies}')
|
||||
return v.upper()
|
||||
|
||||
|
||||
class JobPaymentView(BaseModel):
|
||||
@@ -40,10 +66,37 @@ class JobPaymentView(BaseModel):
|
||||
|
||||
class PaymentRequest(BaseModel):
|
||||
"""Request to pay for a job"""
|
||||
job_id: str
|
||||
amount: float
|
||||
currency: str = "BTC"
|
||||
refund_address: Optional[str] = None
|
||||
job_id: str = Field(..., min_length=1, max_length=128, description="Job identifier")
|
||||
amount: float = Field(..., gt=0, le=1_000_000, description="Payment amount")
|
||||
currency: str = Field(default="BTC", description="Payment currency")
|
||||
refund_address: Optional[str] = Field(None, min_length=1, max_length=255, description="Refund address")
|
||||
|
||||
@field_validator('job_id')
|
||||
@classmethod
|
||||
def validate_job_id(cls, v: str) -> str:
|
||||
"""Validate job ID format"""
|
||||
if not re.match(r'^[a-zA-Z0-9\-_]+$', v):
|
||||
raise ValueError('Job ID contains invalid characters')
|
||||
return v
|
||||
|
||||
@field_validator('amount')
|
||||
@classmethod
|
||||
def validate_amount(cls, v: float) -> float:
|
||||
"""Validate payment amount"""
|
||||
if v < 0.0001: # Minimum BTC amount
|
||||
raise ValueError('Minimum payment amount is 0.0001')
|
||||
return round(v, 8)
|
||||
|
||||
@field_validator('refund_address')
|
||||
@classmethod
|
||||
def validate_refund_address(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate refund address format"""
|
||||
if v is None:
|
||||
return v
|
||||
# Basic Bitcoin address validation
|
||||
if not re.match(r'^[13][a-km-zA-HJ-NP-Z1-9]{25,34}$|^bc1[a-z0-9]{8,87}$', v):
|
||||
raise ValueError('Invalid Bitcoin address format')
|
||||
return v
|
||||
|
||||
|
||||
class PaymentReceipt(BaseModel):
|
||||
@@ -111,9 +164,44 @@ class TransactionHistory(BaseModel):
|
||||
total: int
|
||||
|
||||
class ExchangePaymentRequest(BaseModel):
|
||||
user_id: str
|
||||
aitbc_amount: float
|
||||
btc_amount: float
|
||||
"""Request for Bitcoin exchange payment"""
|
||||
user_id: str = Field(..., min_length=1, max_length=128, description="User identifier")
|
||||
aitbc_amount: float = Field(..., gt=0, le=1_000_000, description="AITBC amount to exchange")
|
||||
btc_amount: float = Field(..., gt=0, le=100, description="BTC amount to receive")
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def validate_user_id(cls, v: str) -> str:
|
||||
"""Validate user ID format"""
|
||||
if not re.match(r'^[a-zA-Z0-9\-_]+$', v):
|
||||
raise ValueError('User ID contains invalid characters')
|
||||
return v
|
||||
|
||||
@field_validator('aitbc_amount')
|
||||
@classmethod
|
||||
def validate_aitbc_amount(cls, v: float) -> float:
|
||||
"""Validate AITBC amount"""
|
||||
if v < 0.01:
|
||||
raise ValueError('Minimum AITBC amount is 0.01')
|
||||
return round(v, 8)
|
||||
|
||||
@field_validator('btc_amount')
|
||||
@classmethod
|
||||
def validate_btc_amount(cls, v: float) -> float:
|
||||
"""Validate BTC amount"""
|
||||
if v < 0.0001:
|
||||
raise ValueError('Minimum BTC amount is 0.0001')
|
||||
return round(v, 8)
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_exchange_ratio(self) -> 'ExchangePaymentRequest':
|
||||
"""Validate that the exchange ratio is reasonable"""
|
||||
if self.aitbc_amount > 0 and self.btc_amount > 0:
|
||||
ratio = self.aitbc_amount / self.btc_amount
|
||||
# AITBC/BTC ratio should be reasonable (e.g., 100,000 AITBC = 1 BTC)
|
||||
if ratio < 1000 or ratio > 1000000:
|
||||
raise ValueError('Exchange ratio is outside reasonable bounds')
|
||||
return self
|
||||
|
||||
class ExchangePaymentResponse(BaseModel):
|
||||
payment_id: str
|
||||
|
||||
@@ -26,29 +26,40 @@ class PaymentService:
|
||||
self.exchange_base_url = "http://127.0.0.1:23000" # Exchange API URL
|
||||
|
||||
async def create_payment(self, job_id: str, payment_data: JobPaymentCreate) -> JobPayment:
|
||||
"""Create a new payment for a job"""
|
||||
|
||||
# Create payment record
|
||||
payment = JobPayment(
|
||||
job_id=job_id,
|
||||
amount=payment_data.amount,
|
||||
currency=payment_data.currency,
|
||||
payment_method=payment_data.payment_method,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=payment_data.escrow_timeout_seconds)
|
||||
)
|
||||
|
||||
self.session.add(payment)
|
||||
self.session.commit()
|
||||
self.session.refresh(payment)
|
||||
|
||||
# For AITBC token payments, use token escrow
|
||||
if payment_data.payment_method == "aitbc_token":
|
||||
await self._create_token_escrow(payment)
|
||||
# Bitcoin payments only for exchange purchases
|
||||
elif payment_data.payment_method == "bitcoin":
|
||||
await self._create_bitcoin_escrow(payment)
|
||||
|
||||
return payment
|
||||
"""Create a new payment for a job with ACID compliance"""
|
||||
try:
|
||||
# Create payment record
|
||||
payment = JobPayment(
|
||||
job_id=job_id,
|
||||
amount=payment_data.amount,
|
||||
currency=payment_data.currency,
|
||||
payment_method=payment_data.payment_method,
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=payment_data.escrow_timeout_seconds)
|
||||
)
|
||||
|
||||
self.session.add(payment)
|
||||
|
||||
# For AITBC token payments, use token escrow
|
||||
if payment_data.payment_method == "aitbc_token":
|
||||
escrow = await self._create_token_escrow(payment)
|
||||
self.session.add(escrow)
|
||||
# Bitcoin payments only for exchange purchases
|
||||
elif payment_data.payment_method == "bitcoin":
|
||||
escrow = await self._create_bitcoin_escrow(payment)
|
||||
self.session.add(escrow)
|
||||
|
||||
# Single atomic commit - all or nothing
|
||||
self.session.commit()
|
||||
self.session.refresh(payment)
|
||||
|
||||
logger.info(f"Payment created successfully: {payment.id}")
|
||||
return payment
|
||||
|
||||
except Exception as e:
|
||||
# Rollback all changes on any error
|
||||
self.session.rollback()
|
||||
logger.error(f"Failed to create payment: {e}")
|
||||
raise
|
||||
|
||||
async def _create_token_escrow(self, payment: JobPayment) -> None:
|
||||
"""Create an escrow for AITBC token payments"""
|
||||
|
||||
@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
|
||||
from typing import Generator, AsyncGenerator
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@@ -37,16 +38,22 @@ def get_engine() -> Engine:
|
||||
if "sqlite" in effective_url:
|
||||
_engine = create_engine(
|
||||
effective_url,
|
||||
echo=False,
|
||||
echo=settings.db_echo,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=QueuePool,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=settings.db_pool_pre_ping,
|
||||
)
|
||||
else:
|
||||
_engine = create_engine(
|
||||
effective_url,
|
||||
echo=False,
|
||||
pool_size=db_config.pool_size,
|
||||
max_overflow=db_config.max_overflow,
|
||||
pool_pre_ping=db_config.pool_pre_ping,
|
||||
echo=settings.db_echo,
|
||||
poolclass=QueuePool,
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow,
|
||||
pool_recycle=settings.db_pool_recycle,
|
||||
pool_pre_ping=settings.db_pool_pre_ping,
|
||||
)
|
||||
return _engine
|
||||
|
||||
@@ -106,10 +113,12 @@ async def get_async_engine() -> AsyncEngine:
|
||||
|
||||
_async_engine = create_async_engine(
|
||||
async_url,
|
||||
echo=False,
|
||||
pool_size=db_config.pool_size,
|
||||
max_overflow=db_config.max_overflow,
|
||||
pool_pre_ping=db_config.pool_pre_ping,
|
||||
echo=settings.db_echo,
|
||||
poolclass=QueuePool,
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow,
|
||||
pool_recycle=settings.db_pool_recycle,
|
||||
pool_pre_ping=settings.db_pool_pre_ping,
|
||||
)
|
||||
return _async_engine
|
||||
|
||||
|
||||
253
apps/coordinator-api/src/app/utils/cache.py
Normal file
253
apps/coordinator-api/src/app/utils/cache.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Caching strategy for expensive queries
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Dict
|
||||
from functools import wraps
|
||||
import hashlib
|
||||
import json
|
||||
from aitbc.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Simple in-memory cache with TTL support"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: Dict[str, Dict[str, Any]] = {}
|
||||
self._stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"sets": 0,
|
||||
"evictions": 0
|
||||
}
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache"""
|
||||
if key not in self._cache:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
cache_entry = self._cache[key]
|
||||
|
||||
# Check if expired
|
||||
if datetime.now() > cache_entry["expires_at"]:
|
||||
del self._cache[key]
|
||||
self._stats["evictions"] += 1
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
self._stats["hits"] += 1
|
||||
logger.debug(f"Cache hit for key: {key}")
|
||||
return cache_entry["value"]
|
||||
|
||||
def set(self, key: str, value: Any, ttl_seconds: int = 300) -> None:
|
||||
"""Set value in cache with TTL"""
|
||||
expires_at = datetime.now() + timedelta(seconds=ttl_seconds)
|
||||
|
||||
self._cache[key] = {
|
||||
"value": value,
|
||||
"expires_at": expires_at,
|
||||
"created_at": datetime.now(),
|
||||
"ttl": ttl_seconds
|
||||
}
|
||||
|
||||
self._stats["sets"] += 1
|
||||
logger.debug(f"Cache set for key: {key}, TTL: {ttl_seconds}s")
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete key from cache"""
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache entries"""
|
||||
self._cache.clear()
|
||||
logger.info("Cache cleared")
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove expired entries and return count removed"""
|
||||
now = datetime.now()
|
||||
expired_keys = [
|
||||
key for key, entry in self._cache.items()
|
||||
if now > entry["expires_at"]
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats["evictions"] += len(expired_keys)
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"Cleaned up {len(expired_keys)} expired cache entries")
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics"""
|
||||
total_requests = self._stats["hits"] + self._stats["misses"]
|
||||
hit_rate = (self._stats["hits"] / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
**self._stats,
|
||||
"total_entries": len(self._cache),
|
||||
"hit_rate_percent": round(hit_rate, 2),
|
||||
"total_requests": total_requests
|
||||
}
|
||||
|
||||
|
||||
# Global cache manager instance
|
||||
cache_manager = CacheManager()
|
||||
|
||||
|
||||
def cache_key_generator(*args, **kwargs) -> str:
|
||||
"""Generate a cache key from function arguments"""
|
||||
# Create a deterministic string representation
|
||||
key_parts = []
|
||||
|
||||
# Add function args
|
||||
for arg in args:
|
||||
if hasattr(arg, '__dict__'):
|
||||
# For objects, use their dict representation
|
||||
key_parts.append(str(sorted(arg.__dict__.items())))
|
||||
else:
|
||||
key_parts.append(str(arg))
|
||||
|
||||
# Add function kwargs
|
||||
if kwargs:
|
||||
key_parts.append(str(sorted(kwargs.items())))
|
||||
|
||||
# Create hash for consistent key length
|
||||
key_string = "|".join(key_parts)
|
||||
return hashlib.md5(key_string.encode()).hexdigest()
|
||||
|
||||
|
||||
def cached(ttl_seconds: int = 300, key_prefix: str = ""):
|
||||
"""Decorator for caching function results"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# Generate cache key
|
||||
cache_key = f"{key_prefix}{func.__name__}_{cache_key_generator(*args, **kwargs)}"
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = cache_manager.get(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Execute function and cache result
|
||||
result = await func(*args, **kwargs)
|
||||
cache_manager.set(cache_key, result, ttl_seconds)
|
||||
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# Generate cache key
|
||||
cache_key = f"{key_prefix}{func.__name__}_{cache_key_generator(*args, **kwargs)}"
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = cache_manager.get(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Execute function and cache result
|
||||
result = func(*args, **kwargs)
|
||||
cache_manager.set(cache_key, result, ttl_seconds)
|
||||
|
||||
return result
|
||||
|
||||
# Return appropriate wrapper based on whether function is async
|
||||
import asyncio
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Cache configurations for different query types
|
||||
CACHE_CONFIGS = {
|
||||
"marketplace_stats": {"ttl": 300, "prefix": "marketplace_"}, # 5 minutes
|
||||
"job_list": {"ttl": 60, "prefix": "jobs_"}, # 1 minute
|
||||
"miner_list": {"ttl": 120, "prefix": "miners_"}, # 2 minutes
|
||||
"user_balance": {"ttl": 30, "prefix": "balance_"}, # 30 seconds
|
||||
"exchange_rates": {"ttl": 600, "prefix": "rates_"}, # 10 minutes
|
||||
}
|
||||
|
||||
|
||||
def get_cache_config(cache_type: str) -> Dict[str, Any]:
|
||||
"""Get cache configuration for a specific type"""
|
||||
return CACHE_CONFIGS.get(cache_type, {"ttl": 300, "prefix": ""})
|
||||
|
||||
|
||||
# Periodic cleanup task
|
||||
async def cleanup_expired_cache():
|
||||
"""Background task to clean up expired cache entries"""
|
||||
while True:
|
||||
try:
|
||||
removed_count = cache_manager.cleanup_expired()
|
||||
if removed_count > 0:
|
||||
logger.info(f"Background cleanup removed {removed_count} expired entries")
|
||||
|
||||
# Run cleanup every 5 minutes
|
||||
await asyncio.sleep(300)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache cleanup error: {e}")
|
||||
await asyncio.sleep(60) # Retry after 1 minute on error
|
||||
|
||||
|
||||
# Cache warming utilities
|
||||
class CacheWarmer:
|
||||
"""Utility class for warming up cache with common queries"""
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
async def warm_marketplace_stats(self):
|
||||
"""Warm up marketplace statistics cache"""
|
||||
try:
|
||||
from ..services.marketplace import MarketplaceService
|
||||
service = MarketplaceService(self.session)
|
||||
|
||||
# Cache common stats queries
|
||||
stats = await service.get_stats()
|
||||
cache_manager.set("marketplace_stats_overview", stats, ttl_seconds=300)
|
||||
|
||||
logger.info("Marketplace stats cache warmed up")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to warm marketplace stats cache: {e}")
|
||||
|
||||
async def warm_exchange_rates(self):
|
||||
"""Warm up exchange rates cache"""
|
||||
try:
|
||||
# This would call an exchange rate API
|
||||
# For now, just set a placeholder
|
||||
rates = {"AITBC_BTC": 0.00001, "AITBC_USD": 0.10}
|
||||
cache_manager.set("exchange_rates_current", rates, ttl_seconds=600)
|
||||
|
||||
logger.info("Exchange rates cache warmed up")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to warm exchange rates cache: {e}")
|
||||
|
||||
|
||||
# Cache middleware for FastAPI
|
||||
async def cache_middleware(request, call_next):
|
||||
"""FastAPI middleware to add cache headers and track cache performance"""
|
||||
response = await call_next(request)
|
||||
|
||||
# Add cache statistics to response headers (for debugging)
|
||||
stats = cache_manager.get_stats()
|
||||
response.headers["X-Cache-Hits"] = str(stats["hits"])
|
||||
response.headers["X-Cache-Misses"] = str(stats["misses"])
|
||||
response.headers["X-Cache-Hit-Rate"] = f"{stats['hit_rate_percent']}%"
|
||||
|
||||
return response
|
||||
324
apps/coordinator-api/src/app/utils/circuit_breaker.py
Normal file
324
apps/coordinator-api/src/app/utils/circuit_breaker.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
Circuit breaker pattern for external services
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Optional, Dict
|
||||
from functools import wraps
|
||||
import asyncio
|
||||
from aitbc.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states"""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, reject requests
|
||||
HALF_OPEN = "half_open" # Testing recovery
|
||||
|
||||
|
||||
class CircuitBreakerError(Exception):
|
||||
"""Custom exception for circuit breaker failures"""
|
||||
pass
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Circuit breaker implementation for external service calls"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
timeout_seconds: int = 60,
|
||||
expected_exception: type = Exception,
|
||||
name: str = "circuit_breaker"
|
||||
):
|
||||
self.failure_threshold = failure_threshold
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.expected_exception = expected_exception
|
||||
self.name = name
|
||||
|
||||
self.failures = 0
|
||||
self.state = CircuitState.CLOSED
|
||||
self.last_failure_time: Optional[datetime] = None
|
||||
self.success_count = 0
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
"total_calls": 0,
|
||||
"successful_calls": 0,
|
||||
"failed_calls": 0,
|
||||
"circuit_opens": 0,
|
||||
"circuit_closes": 0
|
||||
}
|
||||
|
||||
async def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""Execute function with circuit breaker protection"""
|
||||
self.stats["total_calls"] += 1
|
||||
|
||||
# Check if circuit is open
|
||||
if self.state == CircuitState.OPEN:
|
||||
if self._should_attempt_reset():
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
logger.info(f"Circuit breaker '{self.name}' entering HALF_OPEN state")
|
||||
else:
|
||||
self.stats["failed_calls"] += 1
|
||||
raise CircuitBreakerError(f"Circuit breaker '{self.name}' is OPEN")
|
||||
|
||||
try:
|
||||
# Execute the protected function
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
result = await func(*args, **kwargs)
|
||||
else:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Success - reset circuit if needed
|
||||
self._on_success()
|
||||
self.stats["successful_calls"] += 1
|
||||
|
||||
return result
|
||||
|
||||
except self.expected_exception as e:
|
||||
# Expected failure - update circuit state
|
||||
self._on_failure()
|
||||
self.stats["failed_calls"] += 1
|
||||
logger.warning(f"Circuit breaker '{self.name}' failure: {e}")
|
||||
raise
|
||||
|
||||
def _should_attempt_reset(self) -> bool:
|
||||
"""Check if enough time has passed to attempt circuit reset"""
|
||||
if self.last_failure_time is None:
|
||||
return True
|
||||
|
||||
return datetime.now() - self.last_failure_time > timedelta(seconds=self.timeout_seconds)
|
||||
|
||||
def _on_success(self):
|
||||
"""Handle successful call"""
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
# Successful call in half-open state - close circuit
|
||||
self.state = CircuitState.CLOSED
|
||||
self.failures = 0
|
||||
self.success_count = 0
|
||||
self.stats["circuit_closes"] += 1
|
||||
logger.info(f"Circuit breaker '{self.name}' CLOSED (recovered)")
|
||||
elif self.state == CircuitState.CLOSED:
|
||||
# Reset failure count on success in closed state
|
||||
self.failures = 0
|
||||
|
||||
def _on_failure(self):
|
||||
"""Handle failed call"""
|
||||
self.failures += 1
|
||||
self.last_failure_time = datetime.now()
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
# Failure in half-open - reopen circuit
|
||||
self.state = CircuitState.OPEN
|
||||
logger.error(f"Circuit breaker '{self.name}' OPEN (half-open test failed)")
|
||||
elif self.failures >= self.failure_threshold:
|
||||
# Too many failures - open circuit
|
||||
self.state = CircuitState.OPEN
|
||||
self.stats["circuit_opens"] += 1
|
||||
logger.error(f"Circuit breaker '{self.name}' OPEN after {self.failures} failures")
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get current circuit breaker state and statistics"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self.state.value,
|
||||
"failures": self.failures,
|
||||
"failure_threshold": self.failure_threshold,
|
||||
"timeout_seconds": self.timeout_seconds,
|
||||
"last_failure_time": self.last_failure_time.isoformat() if self.last_failure_time else None,
|
||||
"stats": self.stats.copy(),
|
||||
"success_rate": (
|
||||
(self.stats["successful_calls"] / self.stats["total_calls"] * 100)
|
||||
if self.stats["total_calls"] > 0 else 0
|
||||
)
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Manually reset circuit breaker to closed state"""
|
||||
self.state = CircuitState.CLOSED
|
||||
self.failures = 0
|
||||
self.last_failure_time = None
|
||||
self.success_count = 0
|
||||
logger.info(f"Circuit breaker '{self.name}' manually reset to CLOSED")
|
||||
|
||||
|
||||
def circuit_breaker(
|
||||
failure_threshold: int = 5,
|
||||
timeout_seconds: int = 60,
|
||||
expected_exception: type = Exception,
|
||||
name: str = None
|
||||
):
|
||||
"""Decorator for adding circuit breaker protection to functions"""
|
||||
def decorator(func):
|
||||
breaker_name = name or f"{func.__module__}.{func.__name__}"
|
||||
breaker = CircuitBreaker(
|
||||
failure_threshold=failure_threshold,
|
||||
timeout_seconds=timeout_seconds,
|
||||
expected_exception=expected_exception,
|
||||
name=breaker_name
|
||||
)
|
||||
|
||||
# Store breaker on function for access to stats
|
||||
func._circuit_breaker = breaker
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
return await breaker.call(func, *args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
return asyncio.run(breaker.call(func, *args, **kwargs))
|
||||
|
||||
# Return appropriate wrapper
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Pre-configured circuit breakers for common external services
|
||||
class CircuitBreakers:
|
||||
"""Collection of pre-configured circuit breakers"""
|
||||
|
||||
def __init__(self):
|
||||
# Blockchain RPC circuit breaker
|
||||
self.blockchain_rpc = CircuitBreaker(
|
||||
failure_threshold=3,
|
||||
timeout_seconds=30,
|
||||
expected_exception=ConnectionError,
|
||||
name="blockchain_rpc"
|
||||
)
|
||||
|
||||
# Exchange API circuit breaker
|
||||
self.exchange_api = CircuitBreaker(
|
||||
failure_threshold=5,
|
||||
timeout_seconds=60,
|
||||
expected_exception=Exception,
|
||||
name="exchange_api"
|
||||
)
|
||||
|
||||
# Wallet daemon circuit breaker
|
||||
self.wallet_daemon = CircuitBreaker(
|
||||
failure_threshold=3,
|
||||
timeout_seconds=45,
|
||||
expected_exception=ConnectionError,
|
||||
name="wallet_daemon"
|
||||
)
|
||||
|
||||
# External payment processor circuit breaker
|
||||
self.payment_processor = CircuitBreaker(
|
||||
failure_threshold=2,
|
||||
timeout_seconds=120,
|
||||
expected_exception=Exception,
|
||||
name="payment_processor"
|
||||
)
|
||||
|
||||
def get_all_states(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get state of all circuit breakers"""
|
||||
return {
|
||||
"blockchain_rpc": self.blockchain_rpc.get_state(),
|
||||
"exchange_api": self.exchange_api.get_state(),
|
||||
"wallet_daemon": self.wallet_daemon.get_state(),
|
||||
"payment_processor": self.payment_processor.get_state()
|
||||
}
|
||||
|
||||
def reset_all(self):
|
||||
"""Reset all circuit breakers"""
|
||||
self.blockchain_rpc.reset()
|
||||
self.exchange_api.reset()
|
||||
self.wallet_daemon.reset()
|
||||
self.payment_processor.reset()
|
||||
logger.info("All circuit breakers reset")
|
||||
|
||||
|
||||
# Global circuit breakers instance
|
||||
circuit_breakers = CircuitBreakers()
|
||||
|
||||
|
||||
# Usage examples and utilities
|
||||
class ProtectedServiceClient:
|
||||
"""Example of a service client with circuit breaker protection"""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
self.circuit_breaker = CircuitBreaker(
|
||||
failure_threshold=3,
|
||||
timeout_seconds=60,
|
||||
name=f"service_client_{base_url}"
|
||||
)
|
||||
|
||||
@circuit_breaker(failure_threshold=3, timeout_seconds=60)
|
||||
async def call_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Protected API call"""
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(f"{self.base_url}{endpoint}", json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
"""Get health status including circuit breaker state"""
|
||||
return {
|
||||
"service_url": self.base_url,
|
||||
"circuit_breaker": self.circuit_breaker.get_state()
|
||||
}
|
||||
|
||||
|
||||
# FastAPI endpoint for circuit breaker monitoring
|
||||
async def get_circuit_breaker_status():
|
||||
"""Get status of all circuit breakers (for monitoring)"""
|
||||
return circuit_breakers.get_all_states()
|
||||
|
||||
|
||||
async def reset_circuit_breaker(breaker_name: str):
|
||||
"""Reset a specific circuit breaker (for admin operations)"""
|
||||
breaker_map = {
|
||||
"blockchain_rpc": circuit_breakers.blockchain_rpc,
|
||||
"exchange_api": circuit_breakers.exchange_api,
|
||||
"wallet_daemon": circuit_breakers.wallet_daemon,
|
||||
"payment_processor": circuit_breakers.payment_processor
|
||||
}
|
||||
|
||||
if breaker_name not in breaker_map:
|
||||
raise ValueError(f"Unknown circuit breaker: {breaker_name}")
|
||||
|
||||
breaker_map[breaker_name].reset()
|
||||
logger.info(f"Circuit breaker '{breaker_name}' reset via admin API")
|
||||
|
||||
return {"status": "reset", "breaker": breaker_name}
|
||||
|
||||
|
||||
# Background task for circuit breaker health monitoring
|
||||
async def monitor_circuit_breakers():
|
||||
"""Background task to monitor circuit breaker health"""
|
||||
while True:
|
||||
try:
|
||||
states = circuit_breakers.get_all_states()
|
||||
|
||||
# Log any open circuits
|
||||
for name, state in states.items():
|
||||
if state["state"] == "open":
|
||||
logger.warning(f"Circuit breaker '{name}' is OPEN - check service health")
|
||||
elif state["state"] == "half_open":
|
||||
logger.info(f"Circuit breaker '{name}' is HALF_OPEN - testing recovery")
|
||||
|
||||
# Check for circuits with high failure rates
|
||||
for name, state in states.items():
|
||||
if state["stats"]["total_calls"] > 10: # Only check if enough calls
|
||||
success_rate = state["success_rate"]
|
||||
if success_rate < 80: # Less than 80% success rate
|
||||
logger.warning(f"Circuit breaker '{name}' has low success rate: {success_rate:.1f}%")
|
||||
|
||||
# Run monitoring every 30 seconds
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Circuit breaker monitoring error: {e}")
|
||||
await asyncio.sleep(60) # Retry after 1 minute on error
|
||||
Reference in New Issue
Block a user