Update 2025-05-21_08:58:06
This commit is contained in:
180
main.py
180
main.py
@ -1,18 +1,24 @@
|
||||
# main.py — FastAPI backend entrypoint for dicta2stream
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, HTMLResponse
|
||||
from fastapi import FastAPI, Request, Response, status, Form, UploadFile, File, Depends
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
import os
|
||||
import subprocess
|
||||
from log import log_violation
|
||||
from models import get_user_by_uid
|
||||
|
||||
from sqlmodel import Session, SQLModel, select
|
||||
import io
|
||||
import traceback
|
||||
import shutil
|
||||
import mimetypes
|
||||
from typing import Optional
|
||||
from models import User, UploadLog
|
||||
from sqlmodel import Session, select, SQLModel
|
||||
from database import get_db, engine
|
||||
from models import User, UserQuota
|
||||
|
||||
from fastapi import Depends
|
||||
from log import log_violation
|
||||
import secrets
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
@ -31,16 +37,72 @@ from fastapi.exception_handlers import RequestValidationError
|
||||
from fastapi.exceptions import HTTPException as FastAPIHTTPException
|
||||
|
||||
app = FastAPI(debug=debug_mode)
|
||||
|
||||
# --- CORS Middleware for SSE and API access ---
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["https://dicta2stream.net", "http://localhost:8000", "http://127.0.0.1:8000"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import os
|
||||
if not os.path.exists("data"):
|
||||
os.makedirs("data")
|
||||
app.mount("/audio", StaticFiles(directory="data"), name="audio")
|
||||
# Secure audio file serving endpoint (replaces static mount)
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import Security
|
||||
|
||||
def get_current_user(request: Request, db: Session = Depends(get_db)):
|
||||
# Use your existing session/cookie/token mechanism here
|
||||
uid = request.headers.get("x-uid") or request.query_params.get("uid") or request.cookies.get("uid")
|
||||
if not uid:
|
||||
raise HTTPException(status_code=403, detail="Not authenticated")
|
||||
user = get_user_by_uid(uid)
|
||||
if not user or not user.confirmed:
|
||||
raise HTTPException(status_code=403, detail="Invalid user")
|
||||
return user
|
||||
|
||||
from range_response import range_response
|
||||
|
||||
@app.get("/audio/{uid}/{filename}")
|
||||
def get_audio(uid: str, filename: str, request: Request, db: Session = Depends(get_db)):
|
||||
# Allow public access ONLY to stream.opus
|
||||
user_dir = os.path.join("data", uid)
|
||||
file_path = os.path.join(user_dir, filename)
|
||||
real_user_dir = os.path.realpath(user_dir)
|
||||
real_file_path = os.path.realpath(file_path)
|
||||
if not real_file_path.startswith(real_user_dir):
|
||||
raise HTTPException(status_code=403, detail="Path traversal detected")
|
||||
if not os.path.isfile(real_file_path):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
if filename == "stream.opus":
|
||||
# Use range_response for browser seeking support
|
||||
return range_response(request, real_file_path, content_type="audio/ogg")
|
||||
# Otherwise, require authentication and owner check
|
||||
try:
|
||||
from fastapi import Security
|
||||
current_user = get_current_user(request, db)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=403, detail="Not allowed")
|
||||
if uid != current_user.username:
|
||||
raise HTTPException(status_code=403, detail="Not allowed")
|
||||
return FileResponse(real_file_path, media_type="audio/ogg")
|
||||
|
||||
if debug_mode:
|
||||
print("[DEBUG] FastAPI running in debug mode.")
|
||||
|
||||
# Global error handler to always return JSON
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from models import get_user_by_uid, UserQuota
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
|
||||
return JSONResponse(status_code=429, content={"detail": "Rate limit exceeded. Please try again later."})
|
||||
|
||||
@app.exception_handler(FastAPIHTTPException)
|
||||
async def http_exception_handler(request: FastAPIRequest, exc: FastAPIHTTPException):
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
@ -57,20 +119,64 @@ async def generic_exception_handler(request: FastAPIRequest, exc: Exception):
|
||||
from register import router as register_router
|
||||
from magic import router as magic_router
|
||||
from upload import router as upload_router
|
||||
from redirect import router as redirect_router
|
||||
from streams import router as streams_router
|
||||
from list_user_files import router as list_user_files_router
|
||||
|
||||
app.include_router(streams_router)
|
||||
|
||||
from list_streams import router as list_streams_router
|
||||
|
||||
app.include_router(register_router)
|
||||
app.include_router(magic_router)
|
||||
app.include_router(upload_router)
|
||||
app.include_router(redirect_router)
|
||||
app.include_router(list_user_files_router)
|
||||
app.include_router(list_streams_router)
|
||||
|
||||
# Serve static files
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
@app.post("/log-client")
|
||||
async def log_client(request: Request):
|
||||
try:
|
||||
data = await request.json()
|
||||
msg = data.get("msg", "")
|
||||
ip = request.client.host
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "log")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_path = os.path.join(log_dir, "debug.log")
|
||||
log_entry = f"[{timestamp}] IP={ip} MSG={msg}\n"
|
||||
with open(log_path, "a") as f:
|
||||
f.write(log_entry)
|
||||
if os.getenv("DEBUG", "0") in ("1", "true", "True"):
|
||||
print(f"[CLIENT-DEBUG] {log_entry.strip()}")
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
# Enhanced error logging
|
||||
import sys
|
||||
import traceback
|
||||
error_log_dir = os.path.join(os.path.dirname(__file__), "log")
|
||||
os.makedirs(error_log_dir, exist_ok=True)
|
||||
error_log_path = os.path.join(error_log_dir, "debug-errors.log")
|
||||
tb = traceback.format_exc()
|
||||
try:
|
||||
req_body = await request.body()
|
||||
except Exception:
|
||||
req_body = b"<failed to read body>"
|
||||
error_entry = (
|
||||
f"[{datetime.utcnow().isoformat()}] /log-client ERROR: {type(e).__name__}: {e}\n"
|
||||
f"Request IP: {getattr(request.client, 'host', None)}\n"
|
||||
f"Request body: {req_body}\n"
|
||||
f"Traceback:\n{tb}\n"
|
||||
)
|
||||
try:
|
||||
with open(error_log_path, "a") as ef:
|
||||
ef.write(error_entry)
|
||||
except Exception as ef_exc:
|
||||
print(f"[CLIENT-DEBUG-ERROR] Failed to write error log: {ef_exc}", file=sys.stderr)
|
||||
print(error_entry, file=sys.stderr)
|
||||
return {"status": "error", "detail": str(e)}
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def serve_index():
|
||||
with open("static/index.html") as f:
|
||||
@ -116,9 +222,6 @@ def debug(request: Request):
|
||||
"headers": dict(request.headers),
|
||||
}
|
||||
|
||||
STREAM_DIR = "/srv/streams"
|
||||
ICECAST_BASE_URL = "https://dicta2stream.net/stream/"
|
||||
ICECAST_MOUNT_PREFIX = "user-"
|
||||
MAX_QUOTA_BYTES = 100 * 1024 * 1024
|
||||
|
||||
@app.post("/delete-account")
|
||||
@ -142,47 +245,15 @@ async def delete_account(data: dict, request: Request, db: Session = Depends(get
|
||||
db.commit()
|
||||
|
||||
import shutil
|
||||
user_dir = os.path.join(STREAM_DIR, user.username)
|
||||
# Only allow deletion within STREAM_DIR
|
||||
user_dir = os.path.join('data', user.username)
|
||||
real_user_dir = os.path.realpath(user_dir)
|
||||
if not real_user_dir.startswith(os.path.realpath(STREAM_DIR)):
|
||||
if not real_user_dir.startswith(os.path.realpath('data')):
|
||||
raise HTTPException(status_code=400, detail="Invalid user directory")
|
||||
if os.path.exists(real_user_dir):
|
||||
shutil.rmtree(real_user_dir, ignore_errors=True)
|
||||
|
||||
return {"message": "User deleted"}
|
||||
|
||||
@app.post("/upload")
|
||||
async def upload_audio(
|
||||
request: Request,
|
||||
uid: str = Form(...),
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
ip = request.client.host
|
||||
user = get_user_by_uid(uid)
|
||||
if not user:
|
||||
log_violation("UPLOAD", ip, uid, "UID not found")
|
||||
raise HTTPException(status_code=403, detail="Invalid user ID")
|
||||
|
||||
if user.ip != ip:
|
||||
log_violation("UPLOAD", ip, uid, "UID/IP mismatch")
|
||||
raise HTTPException(status_code=403, detail="Device/IP mismatch")
|
||||
|
||||
user_dir = os.path.join(STREAM_DIR, user.username)
|
||||
os.makedirs(user_dir, exist_ok=True)
|
||||
raw_path = os.path.join(user_dir, "upload.wav")
|
||||
final_path = os.path.join(user_dir, "stream.opus")
|
||||
|
||||
with open(raw_path, "wb") as out:
|
||||
content = await file.read()
|
||||
out.write(content)
|
||||
|
||||
usage = subprocess.check_output(["du", "-sb", user_dir]).split()[0]
|
||||
if int(usage) > MAX_QUOTA_BYTES:
|
||||
os.remove(raw_path)
|
||||
log_violation("UPLOAD", ip, uid, "Quota exceeded")
|
||||
raise HTTPException(status_code=403, detail="Quota exceeded")
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
# from detect_content_type_whisper_ollama import detect_content_type_whisper_ollama # Broken import: module not found
|
||||
content_type = None
|
||||
@ -214,8 +285,7 @@ async def upload_audio(
|
||||
except Exception as e:
|
||||
log_violation("QUOTA", ip, uid, f"Quota update failed: {e}")
|
||||
|
||||
stream_url = f"{ICECAST_BASE_URL}{ICECAST_MOUNT_PREFIX}{user.username}.opus"
|
||||
return {"stream_url": stream_url}
|
||||
return {}
|
||||
|
||||
@app.delete("/uploads/{uid}/{filename}")
|
||||
def delete_file(uid: str, filename: str, request: Request, db: Session = Depends(get_db)):
|
||||
@ -227,7 +297,7 @@ def delete_file(uid: str, filename: str, request: Request, db: Session = Depends
|
||||
if user.ip != ip:
|
||||
raise HTTPException(status_code=403, detail="Device/IP mismatch")
|
||||
|
||||
user_dir = os.path.join(STREAM_DIR, user.username)
|
||||
user_dir = os.path.join('data', user.username)
|
||||
target_path = os.path.join(user_dir, filename)
|
||||
# Prevent path traversal attacks
|
||||
real_target_path = os.path.realpath(target_path)
|
||||
@ -268,7 +338,7 @@ def get_me(uid: str, request: Request, db: Session = Depends(get_db)):
|
||||
if not user or user.ip != ip:
|
||||
raise HTTPException(status_code=403, detail="Unauthorized access")
|
||||
|
||||
user_dir = os.path.join(STREAM_DIR, user.username)
|
||||
user_dir = os.path.join('data', user.username)
|
||||
files = []
|
||||
if os.path.exists(user_dir):
|
||||
for f in os.listdir(user_dir):
|
||||
@ -280,7 +350,7 @@ def get_me(uid: str, request: Request, db: Session = Depends(get_db)):
|
||||
quota_mb = round(q.storage_bytes / (1024 * 1024), 2) if q else 0
|
||||
|
||||
return {
|
||||
"stream_url": f"{ICECAST_BASE_URL}{ICECAST_MOUNT_PREFIX}{user.username}.opus",
|
||||
|
||||
"files": files,
|
||||
"quota": quota_mb
|
||||
}
|
||||
|
Reference in New Issue
Block a user