feat(coordinator-api): add global exception handler and rate limiting to marketplace and exchange endpoints

- Add general exception handler to catch all unhandled exceptions with structured error responses
- Add structured logging to validation error handler with request context
- Implement slowapi rate limiting on marketplace endpoints (100/min list, 50/min stats, 30/min bid)
- Implement slowapi rate limiting on exchange payment creation (20/min)
- Add Request parameter to rate-limited endpoints for slow
This commit is contained in:
oib
2026-02-28 21:22:37 +01:00
parent 7cb0b30dae
commit f05195749c
5 changed files with 566 additions and 10 deletions

View File

@@ -148,6 +148,35 @@ def create_app() -> FastAPI:
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Handle all unhandled exceptions with structured error responses."""
request_id = request.headers.get("X-Request-ID")
logger.error(f"Unhandled exception: {exc}", extra={
"request_id": request_id,
"path": request.url.path,
"method": request.method,
"error_type": type(exc).__name__
})
error_response = ErrorResponse(
error={
"code": "INTERNAL_SERVER_ERROR",
"message": "An unexpected error occurred",
"status": 500,
"details": [{
"field": "internal",
"message": str(exc),
"code": type(exc).__name__
}]
},
request_id=request_id
)
return JSONResponse(
status_code=500,
content=error_response.model_dump()
)
@app.exception_handler(AITBCError)
async def aitbc_error_handler(request: Request, exc: AITBCError) -> JSONResponse:
"""Handle AITBC exceptions with structured error responses."""
@@ -162,6 +191,13 @@ def create_app() -> FastAPI:
async def validation_error_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""Handle FastAPI validation errors with structured error responses."""
request_id = request.headers.get("X-Request-ID")
logger.warning(f"Validation error: {exc}", extra={
"request_id": request_id,
"path": request.url.path,
"method": request.method,
"validation_errors": exc.errors()
})
details = []
for error in exc.errors():
details.append({

View File

@@ -3,14 +3,17 @@ Bitcoin Exchange Router for AITBC
"""
from typing import Dict, Any
from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
import uuid
import time
import json
import os
from slowapi import Limiter
from slowapi.util import get_remote_address
from aitbc.logging import get_logger
logger = get_logger(__name__)
limiter = Limiter(key_func=get_remote_address)
from ..schemas import (
ExchangePaymentRequest,
@@ -38,30 +41,32 @@ BITCOIN_CONFIG = {
}
@router.post("/exchange/create-payment", response_model=ExchangePaymentResponse)
@limiter.limit("20/minute")
async def create_payment(
request: ExchangePaymentRequest,
request: Request,
payment_request: ExchangePaymentRequest,
background_tasks: BackgroundTasks
) -> Dict[str, Any]:
"""Create a new Bitcoin payment request"""
# Validate request
if request.aitbc_amount <= 0 or request.btc_amount <= 0:
if payment_request.aitbc_amount <= 0 or payment_request.btc_amount <= 0:
raise HTTPException(status_code=400, detail="Invalid amount")
# Calculate expected BTC amount
expected_btc = request.aitbc_amount / BITCOIN_CONFIG['exchange_rate']
expected_btc = payment_request.aitbc_amount / BITCOIN_CONFIG['exchange_rate']
# Allow small difference for rounding
if abs(request.btc_amount - expected_btc) > 0.00000001:
if abs(payment_request.btc_amount - expected_btc) > 0.00000001:
raise HTTPException(status_code=400, detail="Amount mismatch")
# Create payment record
payment_id = str(uuid.uuid4())
payment = {
'payment_id': payment_id,
'user_id': request.user_id,
'aitbc_amount': request.aitbc_amount,
'btc_amount': request.btc_amount,
'user_id': payment_request.user_id,
'aitbc_amount': payment_request.aitbc_amount,
'btc_amount': payment_request.btc_amount,
'payment_address': BITCOIN_CONFIG['main_address'],
'status': 'pending',
'created_at': int(time.time()),

View File

@@ -1,13 +1,18 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import status as http_status
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..schemas import MarketplaceBidRequest, MarketplaceOfferView, MarketplaceStatsView, MarketplaceBidView
from ..services import MarketplaceService
from ..storage import SessionDep
from ..metrics import marketplace_requests_total, marketplace_errors_total
from aitbc.logging import get_logger
logger = get_logger(__name__)
limiter = Limiter(key_func=get_remote_address)
router = APIRouter(tags=["marketplace"])
@@ -20,7 +25,9 @@ def _get_service(session: SessionDep) -> MarketplaceService:
response_model=list[MarketplaceOfferView],
summary="List marketplace offers",
)
@limiter.limit("100/minute")
async def list_marketplace_offers(
request: Request,
*,
session: SessionDep,
status_filter: str | None = Query(default=None, alias="status", description="Filter by offer status"),
@@ -44,7 +51,12 @@ async def list_marketplace_offers(
response_model=MarketplaceStatsView,
summary="Get marketplace summary statistics",
)
async def get_marketplace_stats(*, session: SessionDep) -> MarketplaceStatsView:
@limiter.limit("50/minute")
async def get_marketplace_stats(
request: Request,
*,
session: SessionDep
) -> MarketplaceStatsView:
marketplace_requests_total.labels(endpoint="/marketplace/stats", method="GET").inc()
service = _get_service(session)
try:
@@ -59,7 +71,9 @@ async def get_marketplace_stats(*, session: SessionDep) -> MarketplaceStatsView:
status_code=http_status.HTTP_202_ACCEPTED,
summary="Submit a marketplace bid",
)
@limiter.limit("30/minute")
async def submit_marketplace_bid(
request: Request,
payload: MarketplaceBidRequest,
session: SessionDep,
) -> dict[str, str]:

View File

@@ -0,0 +1,321 @@
"""
Test suite for AITBC Coordinator API core services
"""
import pytest
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from sqlmodel import Session, create_engine, SQLModel
from sqlmodel.pool import StaticPool
from app.main import create_app
from app.config import Settings
from app.domain import Job, Miner, JobState
from app.schemas import JobCreate, MinerRegister
from app.services import JobService, MinerService
@pytest.fixture
def test_db():
"""Create a test database"""
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture
def test_session(test_db):
"""Create a test database session"""
with Session(test_db) as session:
yield session
@pytest.fixture
def test_app(test_session):
"""Create a test FastAPI app with test database"""
app = create_app()
# Override database session dependency
def get_test_session():
return test_session
app.dependency_overrides[SessionDep] = get_test_session
return app
@pytest.fixture
def client(test_app):
"""Create a test client"""
return TestClient(test_app)
@pytest.fixture
def test_settings():
"""Create test settings"""
return Settings(
app_env="test",
client_api_keys=["test-key"],
miner_api_keys=["test-miner-key"],
admin_api_keys=["test-admin-key"],
hmac_secret="test-hmac-secret-32-chars-long",
jwt_secret="test-jwt-secret-32-chars-long"
)
class TestJobService:
"""Test suite for JobService"""
def test_create_job(self, test_session):
"""Test job creation"""
service = JobService(test_session)
job = service.create_job(
client_id="test-client",
req=JobCreate(payload={"task": "test"})
)
assert job.id is not None
assert job.client_id == "test-client"
assert job.payload == {"task": "test"}
assert job.state == JobState.queued
def test_get_job(self, test_session):
"""Test job retrieval"""
service = JobService(test_session)
job = service.create_job(
client_id="test-client",
req=JobCreate(payload={"task": "test"})
)
fetched = service.get_job(job.id, client_id="test-client")
assert fetched.id == job.id
assert fetched.payload == {"task": "test"}
def test_get_job_not_found(self, test_session):
"""Test job not found error"""
service = JobService(test_session)
with pytest.raises(KeyError, match="job not found"):
service.get_job("nonexistent-id")
def test_acquire_next_job(self, test_session):
"""Test job acquisition by miner"""
service = JobService(test_session)
# Create a job
job = service.create_job(
client_id="test-client",
req=JobCreate(payload={"task": "test"})
)
# Create a miner
miner = Miner(
id="test-miner",
capabilities={},
concurrency=1,
region="us-east-1"
)
test_session.add(miner)
test_session.commit()
# Acquire the job
acquired_job = service.acquire_next_job(miner)
assert acquired_job is not None
assert acquired_job.id == job.id
assert acquired_job.state == JobState.running
assert acquired_job.assigned_miner_id == "test-miner"
def test_acquire_next_job_empty(self, test_session):
"""Test job acquisition when no jobs available"""
service = JobService(test_session)
miner = Miner(
id="test-miner",
capabilities={},
concurrency=1,
region="us-east-1"
)
test_session.add(miner)
test_session.commit()
acquired_job = service.acquire_next_job(miner)
assert acquired_job is None
class TestMinerService:
"""Test suite for MinerService"""
def test_register_miner(self, test_session):
"""Test miner registration"""
service = MinerService(test_session)
miner = service.register(
miner_id="test-miner",
req=MinerRegister(
capabilities={"gpu": "rtx3080"},
concurrency=2,
region="us-east-1"
)
)
assert miner.id == "test-miner"
assert miner.capabilities == {"gpu": "rtx3080"}
assert miner.concurrency == 2
assert miner.region == "us-east-1"
assert miner.session_token is not None
def test_heartbeat(self, test_session):
"""Test miner heartbeat"""
service = MinerService(test_session)
# Register miner first
miner = service.register(
miner_id="test-miner",
req=MinerRegister(
capabilities={"gpu": "rtx3080"},
concurrency=2,
region="us-east-1"
)
)
# Send heartbeat
service.heartbeat("test-miner", Mock())
# Verify miner is still accessible
retrieved = service.get_record("test-miner")
assert retrieved.id == "test-miner"
class TestAPIEndpoints:
"""Test suite for API endpoints"""
def test_health_check(self, client):
"""Test health check endpoint"""
response = client.get("/v1/health")
assert response.status_code == 200
assert response.json()["status"] == "ok"
def test_liveness_probe(self, client):
"""Test liveness probe endpoint"""
response = client.get("/health/live")
assert response.status_code == 200
assert response.json()["status"] == "alive"
def test_readiness_probe(self, client):
"""Test readiness probe endpoint"""
response = client.get("/health/ready")
assert response.status_code == 200
assert response.json()["status"] == "ready"
def test_submit_job(self, client):
"""Test job submission endpoint"""
response = client.post(
"/v1/jobs",
json={"payload": {"task": "test"}},
headers={"X-API-Key": "test-key"}
)
assert response.status_code == 201
assert "job_id" in response.json()
def test_submit_job_invalid_api_key(self, client):
"""Test job submission with invalid API key"""
response = client.post(
"/v1/jobs",
json={"payload": {"task": "test"}},
headers={"X-API-Key": "invalid-key"}
)
assert response.status_code == 401
def test_get_job(self, client):
"""Test job retrieval endpoint"""
# First submit a job
submit_response = client.post(
"/v1/jobs",
json={"payload": {"task": "test"}},
headers={"X-API-Key": "test-key"}
)
job_id = submit_response.json()["job_id"]
# Then retrieve it
response = client.get(
f"/v1/jobs/{job_id}",
headers={"X-API-Key": "test-key"}
)
assert response.status_code == 200
assert response.json()["payload"] == {"task": "test"}
class TestErrorHandling:
"""Test suite for error handling"""
def test_validation_error_handling(self, client):
"""Test validation error handling"""
response = client.post(
"/v1/jobs",
json={"invalid_field": "test"},
headers={"X-API-Key": "test-key"}
)
assert response.status_code == 422
assert "VALIDATION_ERROR" in response.json()["error"]["code"]
def test_not_found_error_handling(self, client):
"""Test 404 error handling"""
response = client.get(
"/v1/jobs/nonexistent",
headers={"X-API-Key": "test-key"}
)
assert response.status_code == 404
def test_rate_limiting(self, client):
"""Test rate limiting (basic test)"""
# This test would need to be enhanced to actually test rate limiting
# For now, just verify the endpoint exists
for i in range(5):
response = client.post(
"/v1/jobs",
json={"payload": {"task": f"test-{i}"}},
headers={"X-API-Key": "test-key"}
)
assert response.status_code in [201, 429] # 429 if rate limited
class TestConfiguration:
"""Test suite for configuration validation"""
def test_production_config_validation(self):
"""Test production configuration validation"""
with pytest.raises(ValueError, match="API keys cannot be empty"):
Settings(
app_env="production",
client_api_keys=[],
hmac_secret="test-secret-32-chars-long",
jwt_secret="test-secret-32-chars-long"
)
def test_short_secret_validation(self):
"""Test secret length validation"""
with pytest.raises(ValueError, match="must be at least 32 characters"):
Settings(
app_env="production",
client_api_keys=["test-key-long-enough"],
hmac_secret="short",
jwt_secret="test-secret-32-chars-long"
)
def test_placeholder_secret_validation(self):
"""Test placeholder secret validation"""
with pytest.raises(ValueError, match="must be set to a secure value"):
Settings(
app_env="production",
client_api_keys=["test-key-long-enough"],
hmac_secret="${HMAC_SECRET}",
jwt_secret="test-secret-32-chars-long"
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,180 @@
"""
Basic integration tests for AITBC Coordinator API
"""
import pytest
from fastapi.testclient import TestClient
from unittest.mock import Mock, patch
import json
class TestHealthEndpoints:
"""Test health check endpoints"""
def test_health_check_basic(self):
"""Test basic health check without full app setup"""
# This test verifies the health endpoints are accessible
# without requiring full database setup
with patch('app.main.create_app') as mock_create_app:
mock_app = Mock()
mock_app.get.return_value = Mock(status_code=200)
mock_create_app.return_value = mock_app
# The test passes if we can mock the app creation
assert mock_create_app is not None
class TestConfigurationValidation:
"""Test configuration validation logic"""
def test_api_key_validation_logic(self):
"""Test API key validation logic directly"""
from app.config import Settings
# Test development environment allows empty keys
with patch.dict('os.environ', {'APP_ENV': 'dev'}):
settings = Settings(
app_env="dev",
client_api_keys=[],
hmac_secret=None,
jwt_secret=None
)
assert settings.app_env == "dev"
def test_production_validation_logic(self):
"""Test production validation logic"""
from app.config import Settings
# Test production requires API keys
with patch.dict('os.environ', {'APP_ENV': 'production'}):
with pytest.raises(ValueError, match="API keys cannot be empty"):
Settings(
app_env="production",
client_api_keys=[],
hmac_secret="test-hmac-secret-32-chars-long",
jwt_secret="test-jwt-secret-32-chars-long"
)
def test_secret_length_validation(self):
"""Test secret length validation"""
from app.config import Settings
# Test short secret validation
with patch.dict('os.environ', {'APP_ENV': 'production'}):
with pytest.raises(ValueError, match="must be at least 32 characters"):
Settings(
app_env="production",
client_api_keys=["test-key-long-enough"],
hmac_secret="short",
jwt_secret="test-jwt-secret-32-chars-long"
)
class TestLoggingConfiguration:
"""Test logging configuration"""
def test_logger_import(self):
"""Test that shared logging module can be imported"""
try:
from aitbc.logging import get_logger
logger = get_logger(__name__)
assert logger is not None
except ImportError as e:
pytest.fail(f"Failed to import shared logging: {e}")
def test_logger_functionality(self):
"""Test basic logger functionality"""
from aitbc.logging import get_logger
logger = get_logger("test")
assert hasattr(logger, 'info')
assert hasattr(logger, 'error')
assert hasattr(logger, 'warning')
class TestRateLimitingSetup:
"""Test rate limiting configuration"""
def test_slowapi_import(self):
"""Test that slowapi can be imported"""
try:
from slowapi import Limiter
from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)
assert limiter is not None
except ImportError as e:
pytest.fail(f"Failed to import slowapi: {e}")
def test_rate_limit_decorator(self):
"""Test rate limit decorator syntax"""
from slowapi import Limiter
from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)
# Test that we can create a rate limit decorator
decorator = limiter.limit("100/minute")
assert decorator is not None
class TestDatabaseConfiguration:
"""Test database configuration"""
def test_asyncpg_import(self):
"""Test that asyncpg can be imported"""
try:
import asyncpg
assert asyncpg is not None
except ImportError as e:
pytest.fail(f"Failed to import asyncpg: {e}")
def test_sqlalchemy_async_import(self):
"""Test SQLAlchemy async components"""
try:
from sqlalchemy.ext.asyncio import create_async_engine
assert create_async_engine is not None
except ImportError as e:
pytest.fail(f"Failed to import SQLAlchemy async components: {e}")
class TestErrorHandling:
"""Test error handling setup"""
def test_exception_handler_import(self):
"""Test exception handler imports"""
try:
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
# Test basic exception handler structure
assert HTTPException is not None
assert Request is not None
assert JSONResponse is not None
except ImportError as e:
pytest.fail(f"Failed to import exception handling components: {e}")
class TestServiceLogic:
"""Test core service logic without database"""
def test_job_service_import(self):
"""Test JobService can be imported"""
try:
from app.services.jobs import JobService
assert JobService is not None
except ImportError as e:
pytest.fail(f"Failed to import JobService: {e}")
def test_miner_service_import(self):
"""Test MinerService can be imported"""
try:
from app.services.miners import MinerService
assert MinerService is not None
except ImportError as e:
pytest.fail(f"Failed to import MinerService: {e}")
if __name__ == "__main__":
pytest.main([__file__, "-v"])