feat(coordinator-api): add database connection pooling configuration and comprehensive payment validation

- Add configurable database connection pool settings (pool_size, max_overflow, pool_recycle, pool_pre_ping, echo)
- Replace hardcoded pool settings with environment-configurable values from settings
- Add QueuePool explicitly to both sync and async database engines
- Implement comprehensive Pydantic validators for payment schemas (JobPaymentCreate, PaymentRequest, ExchangePaymentRequest)
- Add regex
This commit is contained in:
oib
2026-02-28 21:35:58 +01:00
parent 57b12a882a
commit 93ffaf53de
6 changed files with 737 additions and 45 deletions

View File

@@ -55,6 +55,13 @@ class Settings(BaseSettings):
# Database
database: DatabaseConfig = DatabaseConfig()
# Database Connection Pooling
db_pool_size: int = Field(default=20, description="Database connection pool size")
db_max_overflow: int = Field(default=40, description="Maximum overflow connections")
db_pool_recycle: int = Field(default=3600, description="Connection recycle time in seconds")
db_pool_pre_ping: bool = Field(default=True, description="Test connections before using")
db_echo: bool = Field(default=False, description="Enable SQL query logging")
# API Keys
client_api_keys: List[str] = []

View File

@@ -4,8 +4,9 @@ from datetime import datetime
from typing import Any, Dict, Optional, List
from base64 import b64encode, b64decode
from enum import Enum
import re
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
from .types import JobState, Constraints
@@ -13,11 +14,36 @@ from .types import JobState, Constraints
# Payment schemas
class JobPaymentCreate(BaseModel):
"""Request to create a payment for a job"""
job_id: str
amount: float
currency: str = "AITBC" # Jobs paid with AITBC tokens
payment_method: str = "aitbc_token" # Primary method for job payments
escrow_timeout_seconds: int = 3600 # 1 hour default
job_id: str = Field(..., min_length=1, max_length=128, description="Job identifier")
amount: float = Field(..., gt=0, le=1_000_000, description="Payment amount in AITBC")
currency: str = Field(default="AITBC", description="Payment currency")
payment_method: str = Field(default="aitbc_token", description="Payment method")
escrow_timeout_seconds: int = Field(default=3600, ge=300, le=86400, description="Escrow timeout in seconds")
@field_validator('job_id')
@classmethod
def validate_job_id(cls, v: str) -> str:
"""Validate job ID format to prevent injection attacks"""
if not re.match(r'^[a-zA-Z0-9\-_]+$', v):
raise ValueError('Job ID contains invalid characters')
return v
@field_validator('amount')
@classmethod
def validate_amount(cls, v: float) -> float:
"""Validate and round payment amount"""
if v < 0.01:
raise ValueError('Minimum payment amount is 0.01 AITBC')
return round(v, 8) # Prevent floating point precision issues
@field_validator('currency')
@classmethod
def validate_currency(cls, v: str) -> str:
"""Validate currency code"""
allowed_currencies = ['AITBC', 'BTC', 'ETH', 'USDT']
if v.upper() not in allowed_currencies:
raise ValueError(f'Currency must be one of: {allowed_currencies}')
return v.upper()
class JobPaymentView(BaseModel):
@@ -40,10 +66,37 @@ class JobPaymentView(BaseModel):
class PaymentRequest(BaseModel):
"""Request to pay for a job"""
job_id: str
amount: float
currency: str = "BTC"
refund_address: Optional[str] = None
job_id: str = Field(..., min_length=1, max_length=128, description="Job identifier")
amount: float = Field(..., gt=0, le=1_000_000, description="Payment amount")
currency: str = Field(default="BTC", description="Payment currency")
refund_address: Optional[str] = Field(None, min_length=1, max_length=255, description="Refund address")
@field_validator('job_id')
@classmethod
def validate_job_id(cls, v: str) -> str:
"""Validate job ID format"""
if not re.match(r'^[a-zA-Z0-9\-_]+$', v):
raise ValueError('Job ID contains invalid characters')
return v
@field_validator('amount')
@classmethod
def validate_amount(cls, v: float) -> float:
"""Validate payment amount"""
if v < 0.0001: # Minimum BTC amount
raise ValueError('Minimum payment amount is 0.0001')
return round(v, 8)
@field_validator('refund_address')
@classmethod
def validate_refund_address(cls, v: Optional[str]) -> Optional[str]:
"""Validate refund address format"""
if v is None:
return v
# Basic Bitcoin address validation
if not re.match(r'^[13][a-km-zA-HJ-NP-Z1-9]{25,34}$|^bc1[a-z0-9]{8,87}$', v):
raise ValueError('Invalid Bitcoin address format')
return v
class PaymentReceipt(BaseModel):
@@ -111,9 +164,44 @@ class TransactionHistory(BaseModel):
total: int
class ExchangePaymentRequest(BaseModel):
user_id: str
aitbc_amount: float
btc_amount: float
"""Request for Bitcoin exchange payment"""
user_id: str = Field(..., min_length=1, max_length=128, description="User identifier")
aitbc_amount: float = Field(..., gt=0, le=1_000_000, description="AITBC amount to exchange")
btc_amount: float = Field(..., gt=0, le=100, description="BTC amount to receive")
@field_validator('user_id')
@classmethod
def validate_user_id(cls, v: str) -> str:
"""Validate user ID format"""
if not re.match(r'^[a-zA-Z0-9\-_]+$', v):
raise ValueError('User ID contains invalid characters')
return v
@field_validator('aitbc_amount')
@classmethod
def validate_aitbc_amount(cls, v: float) -> float:
"""Validate AITBC amount"""
if v < 0.01:
raise ValueError('Minimum AITBC amount is 0.01')
return round(v, 8)
@field_validator('btc_amount')
@classmethod
def validate_btc_amount(cls, v: float) -> float:
"""Validate BTC amount"""
if v < 0.0001:
raise ValueError('Minimum BTC amount is 0.0001')
return round(v, 8)
@model_validator(mode='after')
def validate_exchange_ratio(self) -> 'ExchangePaymentRequest':
"""Validate that the exchange ratio is reasonable"""
if self.aitbc_amount > 0 and self.btc_amount > 0:
ratio = self.aitbc_amount / self.btc_amount
# AITBC/BTC ratio should be reasonable (e.g., 100,000 AITBC = 1 BTC)
if ratio < 1000 or ratio > 1000000:
raise ValueError('Exchange ratio is outside reasonable bounds')
return self
class ExchangePaymentResponse(BaseModel):
payment_id: str

View File

@@ -26,29 +26,40 @@ class PaymentService:
self.exchange_base_url = "http://127.0.0.1:23000" # Exchange API URL
async def create_payment(self, job_id: str, payment_data: JobPaymentCreate) -> JobPayment:
"""Create a new payment for a job"""
# Create payment record
payment = JobPayment(
job_id=job_id,
amount=payment_data.amount,
currency=payment_data.currency,
payment_method=payment_data.payment_method,
expires_at=datetime.utcnow() + timedelta(seconds=payment_data.escrow_timeout_seconds)
)
self.session.add(payment)
self.session.commit()
self.session.refresh(payment)
# For AITBC token payments, use token escrow
if payment_data.payment_method == "aitbc_token":
await self._create_token_escrow(payment)
# Bitcoin payments only for exchange purchases
elif payment_data.payment_method == "bitcoin":
await self._create_bitcoin_escrow(payment)
return payment
"""Create a new payment for a job with ACID compliance"""
try:
# Create payment record
payment = JobPayment(
job_id=job_id,
amount=payment_data.amount,
currency=payment_data.currency,
payment_method=payment_data.payment_method,
expires_at=datetime.utcnow() + timedelta(seconds=payment_data.escrow_timeout_seconds)
)
self.session.add(payment)
# For AITBC token payments, use token escrow
if payment_data.payment_method == "aitbc_token":
escrow = await self._create_token_escrow(payment)
self.session.add(escrow)
# Bitcoin payments only for exchange purchases
elif payment_data.payment_method == "bitcoin":
escrow = await self._create_bitcoin_escrow(payment)
self.session.add(escrow)
# Single atomic commit - all or nothing
self.session.commit()
self.session.refresh(payment)
logger.info(f"Payment created successfully: {payment.id}")
return payment
except Exception as e:
# Rollback all changes on any error
self.session.rollback()
logger.error(f"Failed to create payment: {e}")
raise
async def _create_token_escrow(self, payment: JobPayment) -> None:
"""Create an escrow for AITBC token payments"""

View File

@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
from typing import Generator, AsyncGenerator
from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import Session, sessionmaker
@@ -37,16 +38,22 @@ def get_engine() -> Engine:
if "sqlite" in effective_url:
_engine = create_engine(
effective_url,
echo=False,
echo=settings.db_echo,
connect_args={"check_same_thread": False},
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_pre_ping=settings.db_pool_pre_ping,
)
else:
_engine = create_engine(
effective_url,
echo=False,
pool_size=db_config.pool_size,
max_overflow=db_config.max_overflow,
pool_pre_ping=db_config.pool_pre_ping,
echo=settings.db_echo,
poolclass=QueuePool,
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow,
pool_recycle=settings.db_pool_recycle,
pool_pre_ping=settings.db_pool_pre_ping,
)
return _engine
@@ -106,10 +113,12 @@ async def get_async_engine() -> AsyncEngine:
_async_engine = create_async_engine(
async_url,
echo=False,
pool_size=db_config.pool_size,
max_overflow=db_config.max_overflow,
pool_pre_ping=db_config.pool_pre_ping,
echo=settings.db_echo,
poolclass=QueuePool,
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow,
pool_recycle=settings.db_pool_recycle,
pool_pre_ping=settings.db_pool_pre_ping,
)
return _async_engine

View File

@@ -0,0 +1,253 @@
"""
Caching strategy for expensive queries
"""
from datetime import datetime, timedelta
from typing import Any, Optional, Dict
from functools import wraps
import hashlib
import json
from aitbc.logging import get_logger
logger = get_logger(__name__)
class CacheManager:
"""Simple in-memory cache with TTL support"""
def __init__(self):
self._cache: Dict[str, Dict[str, Any]] = {}
self._stats = {
"hits": 0,
"misses": 0,
"sets": 0,
"evictions": 0
}
def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
if key not in self._cache:
self._stats["misses"] += 1
return None
cache_entry = self._cache[key]
# Check if expired
if datetime.now() > cache_entry["expires_at"]:
del self._cache[key]
self._stats["evictions"] += 1
self._stats["misses"] += 1
return None
self._stats["hits"] += 1
logger.debug(f"Cache hit for key: {key}")
return cache_entry["value"]
def set(self, key: str, value: Any, ttl_seconds: int = 300) -> None:
"""Set value in cache with TTL"""
expires_at = datetime.now() + timedelta(seconds=ttl_seconds)
self._cache[key] = {
"value": value,
"expires_at": expires_at,
"created_at": datetime.now(),
"ttl": ttl_seconds
}
self._stats["sets"] += 1
logger.debug(f"Cache set for key: {key}, TTL: {ttl_seconds}s")
def delete(self, key: str) -> bool:
"""Delete key from cache"""
if key in self._cache:
del self._cache[key]
return True
return False
def clear(self) -> None:
"""Clear all cache entries"""
self._cache.clear()
logger.info("Cache cleared")
def cleanup_expired(self) -> int:
"""Remove expired entries and return count removed"""
now = datetime.now()
expired_keys = [
key for key, entry in self._cache.items()
if now > entry["expires_at"]
]
for key in expired_keys:
del self._cache[key]
self._stats["evictions"] += len(expired_keys)
if expired_keys:
logger.info(f"Cleaned up {len(expired_keys)} expired cache entries")
return len(expired_keys)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
total_requests = self._stats["hits"] + self._stats["misses"]
hit_rate = (self._stats["hits"] / total_requests * 100) if total_requests > 0 else 0
return {
**self._stats,
"total_entries": len(self._cache),
"hit_rate_percent": round(hit_rate, 2),
"total_requests": total_requests
}
# Global cache manager instance
cache_manager = CacheManager()
def cache_key_generator(*args, **kwargs) -> str:
"""Generate a cache key from function arguments"""
# Create a deterministic string representation
key_parts = []
# Add function args
for arg in args:
if hasattr(arg, '__dict__'):
# For objects, use their dict representation
key_parts.append(str(sorted(arg.__dict__.items())))
else:
key_parts.append(str(arg))
# Add function kwargs
if kwargs:
key_parts.append(str(sorted(kwargs.items())))
# Create hash for consistent key length
key_string = "|".join(key_parts)
return hashlib.md5(key_string.encode()).hexdigest()
def cached(ttl_seconds: int = 300, key_prefix: str = ""):
"""Decorator for caching function results"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
# Generate cache key
cache_key = f"{key_prefix}{func.__name__}_{cache_key_generator(*args, **kwargs)}"
# Try to get from cache
cached_result = cache_manager.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function and cache result
result = await func(*args, **kwargs)
cache_manager.set(cache_key, result, ttl_seconds)
return result
@wraps(func)
def sync_wrapper(*args, **kwargs):
# Generate cache key
cache_key = f"{key_prefix}{func.__name__}_{cache_key_generator(*args, **kwargs)}"
# Try to get from cache
cached_result = cache_manager.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function and cache result
result = func(*args, **kwargs)
cache_manager.set(cache_key, result, ttl_seconds)
return result
# Return appropriate wrapper based on whether function is async
import asyncio
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
# Cache configurations for different query types
CACHE_CONFIGS = {
"marketplace_stats": {"ttl": 300, "prefix": "marketplace_"}, # 5 minutes
"job_list": {"ttl": 60, "prefix": "jobs_"}, # 1 minute
"miner_list": {"ttl": 120, "prefix": "miners_"}, # 2 minutes
"user_balance": {"ttl": 30, "prefix": "balance_"}, # 30 seconds
"exchange_rates": {"ttl": 600, "prefix": "rates_"}, # 10 minutes
}
def get_cache_config(cache_type: str) -> Dict[str, Any]:
"""Get cache configuration for a specific type"""
return CACHE_CONFIGS.get(cache_type, {"ttl": 300, "prefix": ""})
# Periodic cleanup task
async def cleanup_expired_cache():
"""Background task to clean up expired cache entries"""
while True:
try:
removed_count = cache_manager.cleanup_expired()
if removed_count > 0:
logger.info(f"Background cleanup removed {removed_count} expired entries")
# Run cleanup every 5 minutes
await asyncio.sleep(300)
except Exception as e:
logger.error(f"Cache cleanup error: {e}")
await asyncio.sleep(60) # Retry after 1 minute on error
# Cache warming utilities
class CacheWarmer:
"""Utility class for warming up cache with common queries"""
def __init__(self, session):
self.session = session
async def warm_marketplace_stats(self):
"""Warm up marketplace statistics cache"""
try:
from ..services.marketplace import MarketplaceService
service = MarketplaceService(self.session)
# Cache common stats queries
stats = await service.get_stats()
cache_manager.set("marketplace_stats_overview", stats, ttl_seconds=300)
logger.info("Marketplace stats cache warmed up")
except Exception as e:
logger.error(f"Failed to warm marketplace stats cache: {e}")
async def warm_exchange_rates(self):
"""Warm up exchange rates cache"""
try:
# This would call an exchange rate API
# For now, just set a placeholder
rates = {"AITBC_BTC": 0.00001, "AITBC_USD": 0.10}
cache_manager.set("exchange_rates_current", rates, ttl_seconds=600)
logger.info("Exchange rates cache warmed up")
except Exception as e:
logger.error(f"Failed to warm exchange rates cache: {e}")
# Cache middleware for FastAPI
async def cache_middleware(request, call_next):
"""FastAPI middleware to add cache headers and track cache performance"""
response = await call_next(request)
# Add cache statistics to response headers (for debugging)
stats = cache_manager.get_stats()
response.headers["X-Cache-Hits"] = str(stats["hits"])
response.headers["X-Cache-Misses"] = str(stats["misses"])
response.headers["X-Cache-Hit-Rate"] = f"{stats['hit_rate_percent']}%"
return response

View File

@@ -0,0 +1,324 @@
"""
Circuit breaker pattern for external services
"""
from enum import Enum
from datetime import datetime, timedelta
from typing import Any, Callable, Optional, Dict
from functools import wraps
import asyncio
from aitbc.logging import get_logger
logger = get_logger(__name__)
class CircuitState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, reject requests
HALF_OPEN = "half_open" # Testing recovery
class CircuitBreakerError(Exception):
"""Custom exception for circuit breaker failures"""
pass
class CircuitBreaker:
"""Circuit breaker implementation for external service calls"""
def __init__(
self,
failure_threshold: int = 5,
timeout_seconds: int = 60,
expected_exception: type = Exception,
name: str = "circuit_breaker"
):
self.failure_threshold = failure_threshold
self.timeout_seconds = timeout_seconds
self.expected_exception = expected_exception
self.name = name
self.failures = 0
self.state = CircuitState.CLOSED
self.last_failure_time: Optional[datetime] = None
self.success_count = 0
# Statistics
self.stats = {
"total_calls": 0,
"successful_calls": 0,
"failed_calls": 0,
"circuit_opens": 0,
"circuit_closes": 0
}
async def call(self, func: Callable, *args, **kwargs) -> Any:
"""Execute function with circuit breaker protection"""
self.stats["total_calls"] += 1
# Check if circuit is open
if self.state == CircuitState.OPEN:
if self._should_attempt_reset():
self.state = CircuitState.HALF_OPEN
logger.info(f"Circuit breaker '{self.name}' entering HALF_OPEN state")
else:
self.stats["failed_calls"] += 1
raise CircuitBreakerError(f"Circuit breaker '{self.name}' is OPEN")
try:
# Execute the protected function
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
# Success - reset circuit if needed
self._on_success()
self.stats["successful_calls"] += 1
return result
except self.expected_exception as e:
# Expected failure - update circuit state
self._on_failure()
self.stats["failed_calls"] += 1
logger.warning(f"Circuit breaker '{self.name}' failure: {e}")
raise
def _should_attempt_reset(self) -> bool:
"""Check if enough time has passed to attempt circuit reset"""
if self.last_failure_time is None:
return True
return datetime.now() - self.last_failure_time > timedelta(seconds=self.timeout_seconds)
def _on_success(self):
"""Handle successful call"""
if self.state == CircuitState.HALF_OPEN:
# Successful call in half-open state - close circuit
self.state = CircuitState.CLOSED
self.failures = 0
self.success_count = 0
self.stats["circuit_closes"] += 1
logger.info(f"Circuit breaker '{self.name}' CLOSED (recovered)")
elif self.state == CircuitState.CLOSED:
# Reset failure count on success in closed state
self.failures = 0
def _on_failure(self):
"""Handle failed call"""
self.failures += 1
self.last_failure_time = datetime.now()
if self.state == CircuitState.HALF_OPEN:
# Failure in half-open - reopen circuit
self.state = CircuitState.OPEN
logger.error(f"Circuit breaker '{self.name}' OPEN (half-open test failed)")
elif self.failures >= self.failure_threshold:
# Too many failures - open circuit
self.state = CircuitState.OPEN
self.stats["circuit_opens"] += 1
logger.error(f"Circuit breaker '{self.name}' OPEN after {self.failures} failures")
def get_state(self) -> Dict[str, Any]:
"""Get current circuit breaker state and statistics"""
return {
"name": self.name,
"state": self.state.value,
"failures": self.failures,
"failure_threshold": self.failure_threshold,
"timeout_seconds": self.timeout_seconds,
"last_failure_time": self.last_failure_time.isoformat() if self.last_failure_time else None,
"stats": self.stats.copy(),
"success_rate": (
(self.stats["successful_calls"] / self.stats["total_calls"] * 100)
if self.stats["total_calls"] > 0 else 0
)
}
def reset(self):
"""Manually reset circuit breaker to closed state"""
self.state = CircuitState.CLOSED
self.failures = 0
self.last_failure_time = None
self.success_count = 0
logger.info(f"Circuit breaker '{self.name}' manually reset to CLOSED")
def circuit_breaker(
failure_threshold: int = 5,
timeout_seconds: int = 60,
expected_exception: type = Exception,
name: str = None
):
"""Decorator for adding circuit breaker protection to functions"""
def decorator(func):
breaker_name = name or f"{func.__module__}.{func.__name__}"
breaker = CircuitBreaker(
failure_threshold=failure_threshold,
timeout_seconds=timeout_seconds,
expected_exception=expected_exception,
name=breaker_name
)
# Store breaker on function for access to stats
func._circuit_breaker = breaker
@wraps(func)
async def async_wrapper(*args, **kwargs):
return await breaker.call(func, *args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
return asyncio.run(breaker.call(func, *args, **kwargs))
# Return appropriate wrapper
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
# Pre-configured circuit breakers for common external services
class CircuitBreakers:
"""Collection of pre-configured circuit breakers"""
def __init__(self):
# Blockchain RPC circuit breaker
self.blockchain_rpc = CircuitBreaker(
failure_threshold=3,
timeout_seconds=30,
expected_exception=ConnectionError,
name="blockchain_rpc"
)
# Exchange API circuit breaker
self.exchange_api = CircuitBreaker(
failure_threshold=5,
timeout_seconds=60,
expected_exception=Exception,
name="exchange_api"
)
# Wallet daemon circuit breaker
self.wallet_daemon = CircuitBreaker(
failure_threshold=3,
timeout_seconds=45,
expected_exception=ConnectionError,
name="wallet_daemon"
)
# External payment processor circuit breaker
self.payment_processor = CircuitBreaker(
failure_threshold=2,
timeout_seconds=120,
expected_exception=Exception,
name="payment_processor"
)
def get_all_states(self) -> Dict[str, Dict[str, Any]]:
"""Get state of all circuit breakers"""
return {
"blockchain_rpc": self.blockchain_rpc.get_state(),
"exchange_api": self.exchange_api.get_state(),
"wallet_daemon": self.wallet_daemon.get_state(),
"payment_processor": self.payment_processor.get_state()
}
def reset_all(self):
"""Reset all circuit breakers"""
self.blockchain_rpc.reset()
self.exchange_api.reset()
self.wallet_daemon.reset()
self.payment_processor.reset()
logger.info("All circuit breakers reset")
# Global circuit breakers instance
circuit_breakers = CircuitBreakers()
# Usage examples and utilities
class ProtectedServiceClient:
"""Example of a service client with circuit breaker protection"""
def __init__(self, base_url: str):
self.base_url = base_url
self.circuit_breaker = CircuitBreaker(
failure_threshold=3,
timeout_seconds=60,
name=f"service_client_{base_url}"
)
@circuit_breaker(failure_threshold=3, timeout_seconds=60)
async def call_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""Protected API call"""
import httpx
async with httpx.AsyncClient() as client:
response = await client.post(f"{self.base_url}{endpoint}", json=data)
response.raise_for_status()
return response.json()
def get_health_status(self) -> Dict[str, Any]:
"""Get health status including circuit breaker state"""
return {
"service_url": self.base_url,
"circuit_breaker": self.circuit_breaker.get_state()
}
# FastAPI endpoint for circuit breaker monitoring
async def get_circuit_breaker_status():
"""Get status of all circuit breakers (for monitoring)"""
return circuit_breakers.get_all_states()
async def reset_circuit_breaker(breaker_name: str):
"""Reset a specific circuit breaker (for admin operations)"""
breaker_map = {
"blockchain_rpc": circuit_breakers.blockchain_rpc,
"exchange_api": circuit_breakers.exchange_api,
"wallet_daemon": circuit_breakers.wallet_daemon,
"payment_processor": circuit_breakers.payment_processor
}
if breaker_name not in breaker_map:
raise ValueError(f"Unknown circuit breaker: {breaker_name}")
breaker_map[breaker_name].reset()
logger.info(f"Circuit breaker '{breaker_name}' reset via admin API")
return {"status": "reset", "breaker": breaker_name}
# Background task for circuit breaker health monitoring
async def monitor_circuit_breakers():
"""Background task to monitor circuit breaker health"""
while True:
try:
states = circuit_breakers.get_all_states()
# Log any open circuits
for name, state in states.items():
if state["state"] == "open":
logger.warning(f"Circuit breaker '{name}' is OPEN - check service health")
elif state["state"] == "half_open":
logger.info(f"Circuit breaker '{name}' is HALF_OPEN - testing recovery")
# Check for circuits with high failure rates
for name, state in states.items():
if state["stats"]["total_calls"] > 10: # Only check if enough calls
success_rate = state["success_rate"]
if success_rate < 80: # Less than 80% success rate
logger.warning(f"Circuit breaker '{name}' has low success rate: {success_rate:.1f}%")
# Run monitoring every 30 seconds
await asyncio.sleep(30)
except Exception as e:
logger.error(f"Circuit breaker monitoring error: {e}")
await asyncio.sleep(60) # Retry after 1 minute on error