chore: enhance security configuration across applications
- Add root-level *.json to .gitignore to prevent wallet backup leaks - Replace wildcard CORS origins with explicit localhost URLs across all apps - Add OPTIONS method to CORS allowed methods for preflight requests - Update coordinator database to use absolute path in data/ directory to prevent duplicates - Add JWT secret validation in coordinator config (must be set via environment) - Replace deprecated get_session dependency with Session
This commit is contained in:
@@ -111,8 +111,13 @@ def create_app() -> FastAPI:
|
||||
app.add_middleware(RateLimitMiddleware, max_requests=200, window_seconds=60)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_origins=[
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8080",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8011"
|
||||
],
|
||||
allow_methods=["GET", "POST", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@@ -70,7 +70,16 @@ def create_app() -> Starlette:
|
||||
]
|
||||
|
||||
middleware = [
|
||||
Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
|
||||
Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8080",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8011"
|
||||
],
|
||||
allow_methods=["POST", "GET", "OPTIONS"]
|
||||
)
|
||||
]
|
||||
|
||||
return Starlette(routes=routes, middleware=middleware)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -9,14 +11,35 @@ class Settings(BaseSettings):
|
||||
app_host: str = "127.0.0.1"
|
||||
app_port: int = 8011
|
||||
|
||||
database_url: str = "sqlite:///./coordinator.db"
|
||||
# Use absolute path to avoid database duplicates in different working directories
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
# Find project root by looking for .git directory
|
||||
current = Path(__file__).resolve()
|
||||
while current.parent != current:
|
||||
if (current / ".git").exists():
|
||||
project_root = current
|
||||
break
|
||||
current = current.parent
|
||||
else:
|
||||
# Fallback to relative path if .git not found
|
||||
project_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
db_path = project_root / "data" / "coordinator.db"
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return f"sqlite:///{db_path}"
|
||||
|
||||
client_api_keys: List[str] = []
|
||||
miner_api_keys: List[str] = []
|
||||
admin_api_keys: List[str] = []
|
||||
|
||||
hmac_secret: Optional[str] = None
|
||||
allow_origins: List[str] = ["*"]
|
||||
allow_origins: List[str] = [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8080",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8011"
|
||||
]
|
||||
|
||||
job_ttl_seconds: int = 900
|
||||
heartbeat_interval_seconds: int = 10
|
||||
|
||||
@@ -17,7 +17,7 @@ class Settings(BaseSettings):
|
||||
database_url: str = "postgresql://localhost:5432/aitbc_coordinator"
|
||||
|
||||
# JWT Configuration
|
||||
jwt_secret: str = "change-me-in-production"
|
||||
jwt_secret: str = "" # Must be provided via environment
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_expiration_hours: int = 24
|
||||
|
||||
@@ -51,7 +51,17 @@ class Settings(BaseSettings):
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
def validate_secrets(self) -> None:
|
||||
"""Validate that all required secrets are provided"""
|
||||
if not self.jwt_secret:
|
||||
raise ValueError("JWT_SECRET environment variable is required")
|
||||
if self.jwt_secret == "change-me-in-production":
|
||||
raise ValueError("JWT_SECRET must be changed from default value")
|
||||
|
||||
|
||||
# Create global settings instance
|
||||
settings = Settings()
|
||||
|
||||
# Validate secrets on import
|
||||
settings.validate_secrets()
|
||||
|
||||
@@ -1,21 +1,9 @@
|
||||
from typing import Callable, Generator, Annotated
|
||||
from typing import Callable, Annotated
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
from sqlmodel import Session
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Get database session"""
|
||||
from .database import engine
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Type alias for session dependency
|
||||
SessionDep = Annotated[Session, Depends(get_session)]
|
||||
|
||||
|
||||
class APIKeyValidator:
|
||||
def __init__(self, allowed_keys: list[str]):
|
||||
self.allowed_keys = {key.strip() for key in allowed_keys if key}
|
||||
|
||||
@@ -3,7 +3,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from prometheus_client import make_asgi_app
|
||||
|
||||
from .config import settings
|
||||
from .database import create_db_and_tables
|
||||
from .storage import init_db
|
||||
from .routers import (
|
||||
client,
|
||||
@@ -38,8 +37,8 @@ def create_app() -> FastAPI:
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"]
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"] # Allow all headers for API keys and content types
|
||||
)
|
||||
|
||||
app.include_router(client, prefix="/v1")
|
||||
|
||||
@@ -10,7 +10,7 @@ import time
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from ..deps import get_session
|
||||
from ..storage import SessionDep
|
||||
from ..domain import User, Wallet
|
||||
from ..schemas import UserCreate, UserLogin, UserProfile, UserBalance
|
||||
|
||||
@@ -50,7 +50,7 @@ def verify_session_token(token: str) -> Optional[str]:
|
||||
@router.post("/register", response_model=UserProfile)
|
||||
async def register_user(
|
||||
user_data: UserCreate,
|
||||
session: Session = Depends(get_session)
|
||||
session: SessionDep
|
||||
) -> Dict[str, Any]:
|
||||
"""Register a new user"""
|
||||
|
||||
@@ -103,7 +103,7 @@ async def register_user(
|
||||
@router.post("/login", response_model=UserProfile)
|
||||
async def login_user(
|
||||
login_data: UserLogin,
|
||||
session: Session = Depends(get_session)
|
||||
session: SessionDep
|
||||
) -> Dict[str, Any]:
|
||||
"""Login user with wallet address"""
|
||||
|
||||
@@ -161,7 +161,7 @@ async def login_user(
|
||||
@router.get("/users/me", response_model=UserProfile)
|
||||
async def get_current_user(
|
||||
token: str,
|
||||
session: Session = Depends(get_session)
|
||||
session: SessionDep
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current user profile"""
|
||||
|
||||
@@ -190,7 +190,7 @@ async def get_current_user(
|
||||
@router.get("/users/{user_id}/balance", response_model=UserBalance)
|
||||
async def get_user_balance(
|
||||
user_id: str,
|
||||
session: Session = Depends(get_session)
|
||||
session: SessionDep
|
||||
) -> Dict[str, Any]:
|
||||
"""Get user's AITBC balance"""
|
||||
|
||||
@@ -223,7 +223,7 @@ async def logout_user(token: str) -> Dict[str, str]:
|
||||
@router.get("/users/{user_id}/transactions")
|
||||
async def get_user_transactions(
|
||||
user_id: str,
|
||||
session: Session = Depends(get_session)
|
||||
session: SessionDep
|
||||
) -> Dict[str, Any]:
|
||||
"""Get user's transaction history"""
|
||||
|
||||
|
||||
@@ -30,12 +30,16 @@ Base = declarative_base()
|
||||
# Direct PostgreSQL connection for performance
|
||||
def get_pg_connection():
|
||||
"""Get direct PostgreSQL connection"""
|
||||
# Parse database URL from settings
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(settings.database_url)
|
||||
|
||||
return psycopg2.connect(
|
||||
host="localhost",
|
||||
database="aitbc_coordinator",
|
||||
user="aitbc_user",
|
||||
password="aitbc_password",
|
||||
port=5432,
|
||||
host=parsed.hostname or "localhost",
|
||||
database=parsed.path[1:] if parsed.path else "aitbc_coordinator",
|
||||
user=parsed.username or "aitbc_user",
|
||||
password=parsed.password or "aitbc_password",
|
||||
port=parsed.port or 5432,
|
||||
cursor_factory=RealDictCursor
|
||||
)
|
||||
|
||||
@@ -194,8 +198,16 @@ class PostgreSQLAdapter:
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
|
||||
# Global adapter instance
|
||||
db_adapter = PostgreSQLAdapter()
|
||||
# Global adapter instance (lazy initialization)
|
||||
db_adapter: Optional[PostgreSQLAdapter] = None
|
||||
|
||||
|
||||
def get_db_adapter() -> PostgreSQLAdapter:
|
||||
"""Get or create database adapter instance"""
|
||||
global db_adapter
|
||||
if db_adapter is None:
|
||||
db_adapter = PostgreSQLAdapter()
|
||||
return db_adapter
|
||||
|
||||
# Database initialization
|
||||
def init_db():
|
||||
@@ -212,7 +224,8 @@ def init_db():
|
||||
def check_db_health() -> Dict[str, Any]:
|
||||
"""Check database health"""
|
||||
try:
|
||||
result = db_adapter.execute_query("SELECT 1 as health_check")
|
||||
adapter = get_db_adapter()
|
||||
result = adapter.execute_query("SELECT 1 as health_check")
|
||||
return {
|
||||
"status": "healthy",
|
||||
"database": "postgresql",
|
||||
|
||||
@@ -5,11 +5,14 @@ FastAPI backend for the AITBC Trade Exchange
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
from fastapi import FastAPI, Depends, HTTPException, status
|
||||
from fastapi import FastAPI, Depends, HTTPException, status, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import desc, func, and_
|
||||
from sqlalchemy.orm import Session
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
from database import init_db, get_db_session
|
||||
from models import User, Order, Trade, Balance
|
||||
@@ -17,13 +20,59 @@ from models import User, Order, Trade, Balance
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(title="AITBC Trade Exchange API", version="1.0.0")
|
||||
|
||||
# In-memory session storage (use Redis in production)
|
||||
user_sessions = {}
|
||||
|
||||
def verify_session_token(token: str = Header(..., alias="Authorization")) -> int:
|
||||
"""Verify session token and return user_id"""
|
||||
# Remove "Bearer " prefix if present
|
||||
if token.startswith("Bearer "):
|
||||
token = token[7:]
|
||||
|
||||
if token not in user_sessions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token"
|
||||
)
|
||||
|
||||
session = user_sessions[token]
|
||||
|
||||
# Check if expired
|
||||
if int(time.time()) > session["expires_at"]:
|
||||
del user_sessions[token]
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token expired"
|
||||
)
|
||||
|
||||
return session["user_id"]
|
||||
|
||||
def optional_auth(token: Optional[str] = Header(None, alias="Authorization")) -> Optional[int]:
|
||||
"""Optional authentication - returns user_id if token is valid, None otherwise"""
|
||||
if not token:
|
||||
return None
|
||||
|
||||
try:
|
||||
return verify_session_token(token)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
# Type annotations for dependencies
|
||||
UserDep = Annotated[int, Depends(verify_session_token)]
|
||||
OptionalUserDep = Annotated[Optional[int], Depends(optional_auth)]
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=[
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8080",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:3003"
|
||||
],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"], # Allow all headers for auth tokens
|
||||
)
|
||||
|
||||
# Pydantic models
|
||||
@@ -110,6 +159,41 @@ def get_recent_trades(limit: int = 20, db: Session = Depends(get_db_session)):
|
||||
trades = db.query(Trade).order_by(desc(Trade.created_at)).limit(limit).all()
|
||||
return trades
|
||||
|
||||
@app.get("/api/orders", response_model=List[OrderResponse])
|
||||
def get_orders(
|
||||
status_filter: Optional[str] = None,
|
||||
user_only: bool = False,
|
||||
db: Session = Depends(get_db_session),
|
||||
user_id: OptionalUserDep = None
|
||||
):
|
||||
"""Get all orders with optional status filter"""
|
||||
query = db.query(Order)
|
||||
|
||||
# Filter by user if requested and authenticated
|
||||
if user_only and user_id:
|
||||
query = query.filter(Order.user_id == user_id)
|
||||
|
||||
if status_filter:
|
||||
query = query.filter(Order.status == status_filter.upper())
|
||||
|
||||
orders = query.order_by(Order.created_at.desc()).all()
|
||||
return orders
|
||||
|
||||
@app.get("/api/my/orders", response_model=List[OrderResponse])
|
||||
def get_my_orders(
|
||||
user_id: UserDep,
|
||||
status_filter: Optional[str] = None,
|
||||
db: Session = Depends(get_db_session)
|
||||
):
|
||||
"""Get current user's orders"""
|
||||
query = db.query(Order).filter(Order.user_id == user_id)
|
||||
|
||||
if status_filter:
|
||||
query = query.filter(Order.status == status_filter.upper())
|
||||
|
||||
orders = query.order_by(Order.created_at.desc()).all()
|
||||
return orders
|
||||
|
||||
@app.get("/api/orders/orderbook", response_model=OrderBookResponse)
|
||||
def get_orderbook(db: Session = Depends(get_db_session)):
|
||||
"""Get current order book"""
|
||||
@@ -127,7 +211,11 @@ def get_orderbook(db: Session = Depends(get_db_session)):
|
||||
return OrderBookResponse(buys=buys, sells=sells)
|
||||
|
||||
@app.post("/api/orders", response_model=OrderResponse)
|
||||
def create_order(order: OrderCreate, db: Session = Depends(get_db_session)):
|
||||
def create_order(
|
||||
order: OrderCreate,
|
||||
db: Session = Depends(get_db_session),
|
||||
user_id: UserDep
|
||||
):
|
||||
"""Create a new order"""
|
||||
|
||||
# Validate order type
|
||||
@@ -140,7 +228,7 @@ def create_order(order: OrderCreate, db: Session = Depends(get_db_session)):
|
||||
# Create order
|
||||
total = order.amount * order.price
|
||||
db_order = Order(
|
||||
user_id=1, # TODO: Get from authentication
|
||||
user_id=user_id, # Use authenticated user_id
|
||||
order_type=order.order_type,
|
||||
amount=order.amount,
|
||||
price=order.price,
|
||||
@@ -219,6 +307,45 @@ def try_match_order(order: Order, db: Session):
|
||||
|
||||
db.commit()
|
||||
|
||||
@app.post("/api/auth/login")
|
||||
def login_user(wallet_address: str, db: Session = Depends(get_db_session)):
|
||||
"""Login with wallet address"""
|
||||
# Find or create user
|
||||
user = db.query(User).filter(User.wallet_address == wallet_address).first()
|
||||
if not user:
|
||||
user = User(
|
||||
wallet_address=wallet_address,
|
||||
email=f"{wallet_address}@aitbc.local",
|
||||
is_active=True
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
# Create session token
|
||||
token_data = f"{user.id}:{int(time.time())}"
|
||||
token = hashlib.sha256(token_data.encode()).hexdigest()
|
||||
|
||||
# Store session
|
||||
user_sessions[token] = {
|
||||
"user_id": user.id,
|
||||
"created_at": int(time.time()),
|
||||
"expires_at": int(time.time()) + 86400 # 24 hours
|
||||
}
|
||||
|
||||
return {"token": token, "user_id": user.id}
|
||||
|
||||
@app.post("/api/auth/logout")
|
||||
def logout_user(token: str = Header(..., alias="Authorization")):
|
||||
"""Logout user"""
|
||||
if token.startswith("Bearer "):
|
||||
token = token[7:]
|
||||
|
||||
if token in user_sessions:
|
||||
del user_sessions[token]
|
||||
|
||||
return {"message": "Logged out successfully"}
|
||||
|
||||
@app.get("/api/health")
|
||||
def health_check():
|
||||
"""Health check endpoint"""
|
||||
|
||||
Reference in New Issue
Block a user