diff --git a/apps/coordinator-api/src/app/config.py b/apps/coordinator-api/src/app/config.py index f0233b8d..5ddf3c7d 100644 --- a/apps/coordinator-api/src/app/config.py +++ b/apps/coordinator-api/src/app/config.py @@ -55,6 +55,13 @@ class Settings(BaseSettings): # Database database: DatabaseConfig = DatabaseConfig() + + # Database Connection Pooling + db_pool_size: int = Field(default=20, description="Database connection pool size") + db_max_overflow: int = Field(default=40, description="Maximum overflow connections") + db_pool_recycle: int = Field(default=3600, description="Connection recycle time in seconds") + db_pool_pre_ping: bool = Field(default=True, description="Test connections before using") + db_echo: bool = Field(default=False, description="Enable SQL query logging") # API Keys client_api_keys: List[str] = [] diff --git a/apps/coordinator-api/src/app/schemas.py b/apps/coordinator-api/src/app/schemas.py index 5eb6ef4e..8203e44e 100644 --- a/apps/coordinator-api/src/app/schemas.py +++ b/apps/coordinator-api/src/app/schemas.py @@ -4,8 +4,9 @@ from datetime import datetime from typing import Any, Dict, Optional, List from base64 import b64encode, b64decode from enum import Enum +import re -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator from .types import JobState, Constraints @@ -13,11 +14,36 @@ from .types import JobState, Constraints # Payment schemas class JobPaymentCreate(BaseModel): """Request to create a payment for a job""" - job_id: str - amount: float - currency: str = "AITBC" # Jobs paid with AITBC tokens - payment_method: str = "aitbc_token" # Primary method for job payments - escrow_timeout_seconds: int = 3600 # 1 hour default + job_id: str = Field(..., min_length=1, max_length=128, description="Job identifier") + amount: float = Field(..., gt=0, le=1_000_000, description="Payment amount in AITBC") + currency: str = Field(default="AITBC", description="Payment currency") + payment_method: str = Field(default="aitbc_token", description="Payment method") + escrow_timeout_seconds: int = Field(default=3600, ge=300, le=86400, description="Escrow timeout in seconds") + + @field_validator('job_id') + @classmethod + def validate_job_id(cls, v: str) -> str: + """Validate job ID format to prevent injection attacks""" + if not re.match(r'^[a-zA-Z0-9\-_]+$', v): + raise ValueError('Job ID contains invalid characters') + return v + + @field_validator('amount') + @classmethod + def validate_amount(cls, v: float) -> float: + """Validate and round payment amount""" + if v < 0.01: + raise ValueError('Minimum payment amount is 0.01 AITBC') + return round(v, 8) # Prevent floating point precision issues + + @field_validator('currency') + @classmethod + def validate_currency(cls, v: str) -> str: + """Validate currency code""" + allowed_currencies = ['AITBC', 'BTC', 'ETH', 'USDT'] + if v.upper() not in allowed_currencies: + raise ValueError(f'Currency must be one of: {allowed_currencies}') + return v.upper() class JobPaymentView(BaseModel): @@ -40,10 +66,37 @@ class JobPaymentView(BaseModel): class PaymentRequest(BaseModel): """Request to pay for a job""" - job_id: str - amount: float - currency: str = "BTC" - refund_address: Optional[str] = None + job_id: str = Field(..., min_length=1, max_length=128, description="Job identifier") + amount: float = Field(..., gt=0, le=1_000_000, description="Payment amount") + currency: str = Field(default="BTC", description="Payment currency") + refund_address: Optional[str] = Field(None, min_length=1, max_length=255, description="Refund address") + + @field_validator('job_id') + @classmethod + def validate_job_id(cls, v: str) -> str: + """Validate job ID format""" + if not re.match(r'^[a-zA-Z0-9\-_]+$', v): + raise ValueError('Job ID contains invalid characters') + return v + + @field_validator('amount') + @classmethod + def validate_amount(cls, v: float) -> float: + """Validate payment amount""" + if v < 0.0001: # Minimum BTC amount + raise ValueError('Minimum payment amount is 0.0001') + return round(v, 8) + + @field_validator('refund_address') + @classmethod + def validate_refund_address(cls, v: Optional[str]) -> Optional[str]: + """Validate refund address format""" + if v is None: + return v + # Basic Bitcoin address validation + if not re.match(r'^[13][a-km-zA-HJ-NP-Z1-9]{25,34}$|^bc1[a-z0-9]{8,87}$', v): + raise ValueError('Invalid Bitcoin address format') + return v class PaymentReceipt(BaseModel): @@ -111,9 +164,44 @@ class TransactionHistory(BaseModel): total: int class ExchangePaymentRequest(BaseModel): - user_id: str - aitbc_amount: float - btc_amount: float + """Request for Bitcoin exchange payment""" + user_id: str = Field(..., min_length=1, max_length=128, description="User identifier") + aitbc_amount: float = Field(..., gt=0, le=1_000_000, description="AITBC amount to exchange") + btc_amount: float = Field(..., gt=0, le=100, description="BTC amount to receive") + + @field_validator('user_id') + @classmethod + def validate_user_id(cls, v: str) -> str: + """Validate user ID format""" + if not re.match(r'^[a-zA-Z0-9\-_]+$', v): + raise ValueError('User ID contains invalid characters') + return v + + @field_validator('aitbc_amount') + @classmethod + def validate_aitbc_amount(cls, v: float) -> float: + """Validate AITBC amount""" + if v < 0.01: + raise ValueError('Minimum AITBC amount is 0.01') + return round(v, 8) + + @field_validator('btc_amount') + @classmethod + def validate_btc_amount(cls, v: float) -> float: + """Validate BTC amount""" + if v < 0.0001: + raise ValueError('Minimum BTC amount is 0.0001') + return round(v, 8) + + @model_validator(mode='after') + def validate_exchange_ratio(self) -> 'ExchangePaymentRequest': + """Validate that the exchange ratio is reasonable""" + if self.aitbc_amount > 0 and self.btc_amount > 0: + ratio = self.aitbc_amount / self.btc_amount + # AITBC/BTC ratio should be reasonable (e.g., 100,000 AITBC = 1 BTC) + if ratio < 1000 or ratio > 1000000: + raise ValueError('Exchange ratio is outside reasonable bounds') + return self class ExchangePaymentResponse(BaseModel): payment_id: str diff --git a/apps/coordinator-api/src/app/services/payments.py b/apps/coordinator-api/src/app/services/payments.py index 589eb2c2..c86b6209 100644 --- a/apps/coordinator-api/src/app/services/payments.py +++ b/apps/coordinator-api/src/app/services/payments.py @@ -26,29 +26,40 @@ class PaymentService: self.exchange_base_url = "http://127.0.0.1:23000" # Exchange API URL async def create_payment(self, job_id: str, payment_data: JobPaymentCreate) -> JobPayment: - """Create a new payment for a job""" - - # Create payment record - payment = JobPayment( - job_id=job_id, - amount=payment_data.amount, - currency=payment_data.currency, - payment_method=payment_data.payment_method, - expires_at=datetime.utcnow() + timedelta(seconds=payment_data.escrow_timeout_seconds) - ) - - self.session.add(payment) - self.session.commit() - self.session.refresh(payment) - - # For AITBC token payments, use token escrow - if payment_data.payment_method == "aitbc_token": - await self._create_token_escrow(payment) - # Bitcoin payments only for exchange purchases - elif payment_data.payment_method == "bitcoin": - await self._create_bitcoin_escrow(payment) - - return payment + """Create a new payment for a job with ACID compliance""" + try: + # Create payment record + payment = JobPayment( + job_id=job_id, + amount=payment_data.amount, + currency=payment_data.currency, + payment_method=payment_data.payment_method, + expires_at=datetime.utcnow() + timedelta(seconds=payment_data.escrow_timeout_seconds) + ) + + self.session.add(payment) + + # For AITBC token payments, use token escrow + if payment_data.payment_method == "aitbc_token": + escrow = await self._create_token_escrow(payment) + self.session.add(escrow) + # Bitcoin payments only for exchange purchases + elif payment_data.payment_method == "bitcoin": + escrow = await self._create_bitcoin_escrow(payment) + self.session.add(escrow) + + # Single atomic commit - all or nothing + self.session.commit() + self.session.refresh(payment) + + logger.info(f"Payment created successfully: {payment.id}") + return payment + + except Exception as e: + # Rollback all changes on any error + self.session.rollback() + logger.error(f"Failed to create payment: {e}") + raise async def _create_token_escrow(self, payment: JobPayment) -> None: """Create an escrow for AITBC token payments""" diff --git a/apps/coordinator-api/src/app/storage/db.py b/apps/coordinator-api/src/app/storage/db.py index 43e3b0a9..5b077680 100644 --- a/apps/coordinator-api/src/app/storage/db.py +++ b/apps/coordinator-api/src/app/storage/db.py @@ -12,6 +12,7 @@ from contextlib import asynccontextmanager from typing import Generator, AsyncGenerator from sqlalchemy import create_engine +from sqlalchemy.pool import QueuePool from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import Session, sessionmaker @@ -37,16 +38,22 @@ def get_engine() -> Engine: if "sqlite" in effective_url: _engine = create_engine( effective_url, - echo=False, + echo=settings.db_echo, connect_args={"check_same_thread": False}, + poolclass=QueuePool, + pool_size=5, + max_overflow=10, + pool_pre_ping=settings.db_pool_pre_ping, ) else: _engine = create_engine( effective_url, - echo=False, - pool_size=db_config.pool_size, - max_overflow=db_config.max_overflow, - pool_pre_ping=db_config.pool_pre_ping, + echo=settings.db_echo, + poolclass=QueuePool, + pool_size=settings.db_pool_size, + max_overflow=settings.db_max_overflow, + pool_recycle=settings.db_pool_recycle, + pool_pre_ping=settings.db_pool_pre_ping, ) return _engine @@ -106,10 +113,12 @@ async def get_async_engine() -> AsyncEngine: _async_engine = create_async_engine( async_url, - echo=False, - pool_size=db_config.pool_size, - max_overflow=db_config.max_overflow, - pool_pre_ping=db_config.pool_pre_ping, + echo=settings.db_echo, + poolclass=QueuePool, + pool_size=settings.db_pool_size, + max_overflow=settings.db_max_overflow, + pool_recycle=settings.db_pool_recycle, + pool_pre_ping=settings.db_pool_pre_ping, ) return _async_engine diff --git a/apps/coordinator-api/src/app/utils/cache.py b/apps/coordinator-api/src/app/utils/cache.py new file mode 100644 index 00000000..7157c3e3 --- /dev/null +++ b/apps/coordinator-api/src/app/utils/cache.py @@ -0,0 +1,253 @@ +""" +Caching strategy for expensive queries +""" + +from datetime import datetime, timedelta +from typing import Any, Optional, Dict +from functools import wraps +import hashlib +import json +from aitbc.logging import get_logger + +logger = get_logger(__name__) + + +class CacheManager: + """Simple in-memory cache with TTL support""" + + def __init__(self): + self._cache: Dict[str, Dict[str, Any]] = {} + self._stats = { + "hits": 0, + "misses": 0, + "sets": 0, + "evictions": 0 + } + + def get(self, key: str) -> Optional[Any]: + """Get value from cache""" + if key not in self._cache: + self._stats["misses"] += 1 + return None + + cache_entry = self._cache[key] + + # Check if expired + if datetime.now() > cache_entry["expires_at"]: + del self._cache[key] + self._stats["evictions"] += 1 + self._stats["misses"] += 1 + return None + + self._stats["hits"] += 1 + logger.debug(f"Cache hit for key: {key}") + return cache_entry["value"] + + def set(self, key: str, value: Any, ttl_seconds: int = 300) -> None: + """Set value in cache with TTL""" + expires_at = datetime.now() + timedelta(seconds=ttl_seconds) + + self._cache[key] = { + "value": value, + "expires_at": expires_at, + "created_at": datetime.now(), + "ttl": ttl_seconds + } + + self._stats["sets"] += 1 + logger.debug(f"Cache set for key: {key}, TTL: {ttl_seconds}s") + + def delete(self, key: str) -> bool: + """Delete key from cache""" + if key in self._cache: + del self._cache[key] + return True + return False + + def clear(self) -> None: + """Clear all cache entries""" + self._cache.clear() + logger.info("Cache cleared") + + def cleanup_expired(self) -> int: + """Remove expired entries and return count removed""" + now = datetime.now() + expired_keys = [ + key for key, entry in self._cache.items() + if now > entry["expires_at"] + ] + + for key in expired_keys: + del self._cache[key] + + self._stats["evictions"] += len(expired_keys) + + if expired_keys: + logger.info(f"Cleaned up {len(expired_keys)} expired cache entries") + + return len(expired_keys) + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics""" + total_requests = self._stats["hits"] + self._stats["misses"] + hit_rate = (self._stats["hits"] / total_requests * 100) if total_requests > 0 else 0 + + return { + **self._stats, + "total_entries": len(self._cache), + "hit_rate_percent": round(hit_rate, 2), + "total_requests": total_requests + } + + +# Global cache manager instance +cache_manager = CacheManager() + + +def cache_key_generator(*args, **kwargs) -> str: + """Generate a cache key from function arguments""" + # Create a deterministic string representation + key_parts = [] + + # Add function args + for arg in args: + if hasattr(arg, '__dict__'): + # For objects, use their dict representation + key_parts.append(str(sorted(arg.__dict__.items()))) + else: + key_parts.append(str(arg)) + + # Add function kwargs + if kwargs: + key_parts.append(str(sorted(kwargs.items()))) + + # Create hash for consistent key length + key_string = "|".join(key_parts) + return hashlib.md5(key_string.encode()).hexdigest() + + +def cached(ttl_seconds: int = 300, key_prefix: str = ""): + """Decorator for caching function results""" + def decorator(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + # Generate cache key + cache_key = f"{key_prefix}{func.__name__}_{cache_key_generator(*args, **kwargs)}" + + # Try to get from cache + cached_result = cache_manager.get(cache_key) + if cached_result is not None: + return cached_result + + # Execute function and cache result + result = await func(*args, **kwargs) + cache_manager.set(cache_key, result, ttl_seconds) + + return result + + @wraps(func) + def sync_wrapper(*args, **kwargs): + # Generate cache key + cache_key = f"{key_prefix}{func.__name__}_{cache_key_generator(*args, **kwargs)}" + + # Try to get from cache + cached_result = cache_manager.get(cache_key) + if cached_result is not None: + return cached_result + + # Execute function and cache result + result = func(*args, **kwargs) + cache_manager.set(cache_key, result, ttl_seconds) + + return result + + # Return appropriate wrapper based on whether function is async + import asyncio + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + +# Cache configurations for different query types +CACHE_CONFIGS = { + "marketplace_stats": {"ttl": 300, "prefix": "marketplace_"}, # 5 minutes + "job_list": {"ttl": 60, "prefix": "jobs_"}, # 1 minute + "miner_list": {"ttl": 120, "prefix": "miners_"}, # 2 minutes + "user_balance": {"ttl": 30, "prefix": "balance_"}, # 30 seconds + "exchange_rates": {"ttl": 600, "prefix": "rates_"}, # 10 minutes +} + + +def get_cache_config(cache_type: str) -> Dict[str, Any]: + """Get cache configuration for a specific type""" + return CACHE_CONFIGS.get(cache_type, {"ttl": 300, "prefix": ""}) + + +# Periodic cleanup task +async def cleanup_expired_cache(): + """Background task to clean up expired cache entries""" + while True: + try: + removed_count = cache_manager.cleanup_expired() + if removed_count > 0: + logger.info(f"Background cleanup removed {removed_count} expired entries") + + # Run cleanup every 5 minutes + await asyncio.sleep(300) + + except Exception as e: + logger.error(f"Cache cleanup error: {e}") + await asyncio.sleep(60) # Retry after 1 minute on error + + +# Cache warming utilities +class CacheWarmer: + """Utility class for warming up cache with common queries""" + + def __init__(self, session): + self.session = session + + async def warm_marketplace_stats(self): + """Warm up marketplace statistics cache""" + try: + from ..services.marketplace import MarketplaceService + service = MarketplaceService(self.session) + + # Cache common stats queries + stats = await service.get_stats() + cache_manager.set("marketplace_stats_overview", stats, ttl_seconds=300) + + logger.info("Marketplace stats cache warmed up") + + except Exception as e: + logger.error(f"Failed to warm marketplace stats cache: {e}") + + async def warm_exchange_rates(self): + """Warm up exchange rates cache""" + try: + # This would call an exchange rate API + # For now, just set a placeholder + rates = {"AITBC_BTC": 0.00001, "AITBC_USD": 0.10} + cache_manager.set("exchange_rates_current", rates, ttl_seconds=600) + + logger.info("Exchange rates cache warmed up") + + except Exception as e: + logger.error(f"Failed to warm exchange rates cache: {e}") + + +# Cache middleware for FastAPI +async def cache_middleware(request, call_next): + """FastAPI middleware to add cache headers and track cache performance""" + response = await call_next(request) + + # Add cache statistics to response headers (for debugging) + stats = cache_manager.get_stats() + response.headers["X-Cache-Hits"] = str(stats["hits"]) + response.headers["X-Cache-Misses"] = str(stats["misses"]) + response.headers["X-Cache-Hit-Rate"] = f"{stats['hit_rate_percent']}%" + + return response diff --git a/apps/coordinator-api/src/app/utils/circuit_breaker.py b/apps/coordinator-api/src/app/utils/circuit_breaker.py new file mode 100644 index 00000000..d5ec7ea5 --- /dev/null +++ b/apps/coordinator-api/src/app/utils/circuit_breaker.py @@ -0,0 +1,324 @@ +""" +Circuit breaker pattern for external services +""" + +from enum import Enum +from datetime import datetime, timedelta +from typing import Any, Callable, Optional, Dict +from functools import wraps +import asyncio +from aitbc.logging import get_logger + +logger = get_logger(__name__) + + +class CircuitState(Enum): + """Circuit breaker states""" + CLOSED = "closed" # Normal operation + OPEN = "open" # Failing, reject requests + HALF_OPEN = "half_open" # Testing recovery + + +class CircuitBreakerError(Exception): + """Custom exception for circuit breaker failures""" + pass + + +class CircuitBreaker: + """Circuit breaker implementation for external service calls""" + + def __init__( + self, + failure_threshold: int = 5, + timeout_seconds: int = 60, + expected_exception: type = Exception, + name: str = "circuit_breaker" + ): + self.failure_threshold = failure_threshold + self.timeout_seconds = timeout_seconds + self.expected_exception = expected_exception + self.name = name + + self.failures = 0 + self.state = CircuitState.CLOSED + self.last_failure_time: Optional[datetime] = None + self.success_count = 0 + + # Statistics + self.stats = { + "total_calls": 0, + "successful_calls": 0, + "failed_calls": 0, + "circuit_opens": 0, + "circuit_closes": 0 + } + + async def call(self, func: Callable, *args, **kwargs) -> Any: + """Execute function with circuit breaker protection""" + self.stats["total_calls"] += 1 + + # Check if circuit is open + if self.state == CircuitState.OPEN: + if self._should_attempt_reset(): + self.state = CircuitState.HALF_OPEN + logger.info(f"Circuit breaker '{self.name}' entering HALF_OPEN state") + else: + self.stats["failed_calls"] += 1 + raise CircuitBreakerError(f"Circuit breaker '{self.name}' is OPEN") + + try: + # Execute the protected function + if asyncio.iscoroutinefunction(func): + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + + # Success - reset circuit if needed + self._on_success() + self.stats["successful_calls"] += 1 + + return result + + except self.expected_exception as e: + # Expected failure - update circuit state + self._on_failure() + self.stats["failed_calls"] += 1 + logger.warning(f"Circuit breaker '{self.name}' failure: {e}") + raise + + def _should_attempt_reset(self) -> bool: + """Check if enough time has passed to attempt circuit reset""" + if self.last_failure_time is None: + return True + + return datetime.now() - self.last_failure_time > timedelta(seconds=self.timeout_seconds) + + def _on_success(self): + """Handle successful call""" + if self.state == CircuitState.HALF_OPEN: + # Successful call in half-open state - close circuit + self.state = CircuitState.CLOSED + self.failures = 0 + self.success_count = 0 + self.stats["circuit_closes"] += 1 + logger.info(f"Circuit breaker '{self.name}' CLOSED (recovered)") + elif self.state == CircuitState.CLOSED: + # Reset failure count on success in closed state + self.failures = 0 + + def _on_failure(self): + """Handle failed call""" + self.failures += 1 + self.last_failure_time = datetime.now() + + if self.state == CircuitState.HALF_OPEN: + # Failure in half-open - reopen circuit + self.state = CircuitState.OPEN + logger.error(f"Circuit breaker '{self.name}' OPEN (half-open test failed)") + elif self.failures >= self.failure_threshold: + # Too many failures - open circuit + self.state = CircuitState.OPEN + self.stats["circuit_opens"] += 1 + logger.error(f"Circuit breaker '{self.name}' OPEN after {self.failures} failures") + + def get_state(self) -> Dict[str, Any]: + """Get current circuit breaker state and statistics""" + return { + "name": self.name, + "state": self.state.value, + "failures": self.failures, + "failure_threshold": self.failure_threshold, + "timeout_seconds": self.timeout_seconds, + "last_failure_time": self.last_failure_time.isoformat() if self.last_failure_time else None, + "stats": self.stats.copy(), + "success_rate": ( + (self.stats["successful_calls"] / self.stats["total_calls"] * 100) + if self.stats["total_calls"] > 0 else 0 + ) + } + + def reset(self): + """Manually reset circuit breaker to closed state""" + self.state = CircuitState.CLOSED + self.failures = 0 + self.last_failure_time = None + self.success_count = 0 + logger.info(f"Circuit breaker '{self.name}' manually reset to CLOSED") + + +def circuit_breaker( + failure_threshold: int = 5, + timeout_seconds: int = 60, + expected_exception: type = Exception, + name: str = None +): + """Decorator for adding circuit breaker protection to functions""" + def decorator(func): + breaker_name = name or f"{func.__module__}.{func.__name__}" + breaker = CircuitBreaker( + failure_threshold=failure_threshold, + timeout_seconds=timeout_seconds, + expected_exception=expected_exception, + name=breaker_name + ) + + # Store breaker on function for access to stats + func._circuit_breaker = breaker + + @wraps(func) + async def async_wrapper(*args, **kwargs): + return await breaker.call(func, *args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + return asyncio.run(breaker.call(func, *args, **kwargs)) + + # Return appropriate wrapper + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + +# Pre-configured circuit breakers for common external services +class CircuitBreakers: + """Collection of pre-configured circuit breakers""" + + def __init__(self): + # Blockchain RPC circuit breaker + self.blockchain_rpc = CircuitBreaker( + failure_threshold=3, + timeout_seconds=30, + expected_exception=ConnectionError, + name="blockchain_rpc" + ) + + # Exchange API circuit breaker + self.exchange_api = CircuitBreaker( + failure_threshold=5, + timeout_seconds=60, + expected_exception=Exception, + name="exchange_api" + ) + + # Wallet daemon circuit breaker + self.wallet_daemon = CircuitBreaker( + failure_threshold=3, + timeout_seconds=45, + expected_exception=ConnectionError, + name="wallet_daemon" + ) + + # External payment processor circuit breaker + self.payment_processor = CircuitBreaker( + failure_threshold=2, + timeout_seconds=120, + expected_exception=Exception, + name="payment_processor" + ) + + def get_all_states(self) -> Dict[str, Dict[str, Any]]: + """Get state of all circuit breakers""" + return { + "blockchain_rpc": self.blockchain_rpc.get_state(), + "exchange_api": self.exchange_api.get_state(), + "wallet_daemon": self.wallet_daemon.get_state(), + "payment_processor": self.payment_processor.get_state() + } + + def reset_all(self): + """Reset all circuit breakers""" + self.blockchain_rpc.reset() + self.exchange_api.reset() + self.wallet_daemon.reset() + self.payment_processor.reset() + logger.info("All circuit breakers reset") + + +# Global circuit breakers instance +circuit_breakers = CircuitBreakers() + + +# Usage examples and utilities +class ProtectedServiceClient: + """Example of a service client with circuit breaker protection""" + + def __init__(self, base_url: str): + self.base_url = base_url + self.circuit_breaker = CircuitBreaker( + failure_threshold=3, + timeout_seconds=60, + name=f"service_client_{base_url}" + ) + + @circuit_breaker(failure_threshold=3, timeout_seconds=60) + async def call_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: + """Protected API call""" + import httpx + + async with httpx.AsyncClient() as client: + response = await client.post(f"{self.base_url}{endpoint}", json=data) + response.raise_for_status() + return response.json() + + def get_health_status(self) -> Dict[str, Any]: + """Get health status including circuit breaker state""" + return { + "service_url": self.base_url, + "circuit_breaker": self.circuit_breaker.get_state() + } + + +# FastAPI endpoint for circuit breaker monitoring +async def get_circuit_breaker_status(): + """Get status of all circuit breakers (for monitoring)""" + return circuit_breakers.get_all_states() + + +async def reset_circuit_breaker(breaker_name: str): + """Reset a specific circuit breaker (for admin operations)""" + breaker_map = { + "blockchain_rpc": circuit_breakers.blockchain_rpc, + "exchange_api": circuit_breakers.exchange_api, + "wallet_daemon": circuit_breakers.wallet_daemon, + "payment_processor": circuit_breakers.payment_processor + } + + if breaker_name not in breaker_map: + raise ValueError(f"Unknown circuit breaker: {breaker_name}") + + breaker_map[breaker_name].reset() + logger.info(f"Circuit breaker '{breaker_name}' reset via admin API") + + return {"status": "reset", "breaker": breaker_name} + + +# Background task for circuit breaker health monitoring +async def monitor_circuit_breakers(): + """Background task to monitor circuit breaker health""" + while True: + try: + states = circuit_breakers.get_all_states() + + # Log any open circuits + for name, state in states.items(): + if state["state"] == "open": + logger.warning(f"Circuit breaker '{name}' is OPEN - check service health") + elif state["state"] == "half_open": + logger.info(f"Circuit breaker '{name}' is HALF_OPEN - testing recovery") + + # Check for circuits with high failure rates + for name, state in states.items(): + if state["stats"]["total_calls"] > 10: # Only check if enough calls + success_rate = state["success_rate"] + if success_rate < 80: # Less than 80% success rate + logger.warning(f"Circuit breaker '{name}' has low success rate: {success_rate:.1f}%") + + # Run monitoring every 30 seconds + await asyncio.sleep(30) + + except Exception as e: + logger.error(f"Circuit breaker monitoring error: {e}") + await asyncio.sleep(60) # Retry after 1 minute on error