Files
2025-04-24 11:44:23 +02:00

287 lines
9.7 KiB
Python

# main.py — FastAPI backend entrypoint for dicta2stream
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
import os
import subprocess
from log import log_violation
from models import get_user_by_uid
from sqlmodel import Session, SQLModel, select
from database import get_db, engine
from models import User, UserQuota
from fastapi import Depends
from dotenv import load_dotenv
load_dotenv()
# Ensure all tables exist at startup
SQLModel.metadata.create_all(engine)
ADMIN_SECRET = os.getenv("ADMIN_SECRET")
import os
debug_mode = os.getenv("DEBUG", "0") in ("1", "true", "True")
from fastapi.responses import JSONResponse
from fastapi.requests import Request as FastAPIRequest
from fastapi.exception_handlers import RequestValidationError
from fastapi.exceptions import HTTPException as FastAPIHTTPException
app = FastAPI(debug=debug_mode)
from fastapi.staticfiles import StaticFiles
import os
if not os.path.exists("data"):
os.makedirs("data")
app.mount("/audio", StaticFiles(directory="data"), name="audio")
if debug_mode:
print("[DEBUG] FastAPI running in debug mode.")
# Global error handler to always return JSON
@app.exception_handler(FastAPIHTTPException)
async def http_exception_handler(request: FastAPIRequest, exc: FastAPIHTTPException):
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: FastAPIRequest, exc: RequestValidationError):
return JSONResponse(status_code=422, content={"detail": exc.errors()})
@app.exception_handler(Exception)
async def generic_exception_handler(request: FastAPIRequest, exc: Exception):
return JSONResponse(status_code=500, content={"detail": str(exc)})
# include routers from submodules
from register import router as register_router
from magic import router as magic_router
from upload import router as upload_router
from redirect import router as redirect_router
from list_user_files import router as list_user_files_router
from list_streams import router as list_streams_router
app.include_router(register_router)
app.include_router(magic_router)
app.include_router(upload_router)
app.include_router(redirect_router)
app.include_router(list_user_files_router)
app.include_router(list_streams_router)
# Serve static files
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/", response_class=HTMLResponse)
def serve_index():
with open("static/index.html") as f:
return f.read()
@app.get("/me", response_class=HTMLResponse)
def serve_me():
with open("static/index.html") as f:
return f.read()
@app.get("/admin/stats")
def admin_stats(request: Request, db: Session = Depends(get_db)):
from sqlmodel import select
users = db.exec(select(User)).all()
users_count = len(users)
total_quota = db.exec(select(UserQuota)).all()
total_quota_sum = sum(q.storage_bytes for q in total_quota)
violations_log = 0
try:
with open("log.txt") as f:
violations_log = sum(1 for _ in f)
except FileNotFoundError:
pass
secret = request.headers.get("x-admin-secret")
if secret != ADMIN_SECRET:
raise HTTPException(status_code=403, detail="Forbidden")
return {
"total_users": users_count,
"total_quota_mb": round(total_quota_sum / (1024 * 1024), 2),
"violation_log_entries": violations_log
}
@app.get("/status")
def status():
return {"status": "ok"}
@app.get("/debug")
def debug(request: Request):
return {
"ip": request.client.host,
"headers": dict(request.headers),
}
STREAM_DIR = "/srv/streams"
ICECAST_BASE_URL = "https://dicta2stream.net/stream/"
ICECAST_MOUNT_PREFIX = "user-"
MAX_QUOTA_BYTES = 100 * 1024 * 1024
@app.post("/delete-account")
async def delete_account(data: dict, request: Request, db: Session = Depends(get_db)):
uid = data.get("uid")
if not uid:
raise HTTPException(status_code=400, detail="Missing UID")
ip = request.client.host
user = get_user_by_uid(uid)
if not user or user.ip != ip:
raise HTTPException(status_code=403, detail="Unauthorized")
# Delete user quota and user using ORM
quota = db.get(UserQuota, uid)
if quota:
db.delete(quota)
user_obj = db.get(User, user.email)
if user_obj:
db.delete(user_obj)
db.commit()
import shutil
user_dir = os.path.join(STREAM_DIR, user.username)
# Only allow deletion within STREAM_DIR
real_user_dir = os.path.realpath(user_dir)
if not real_user_dir.startswith(os.path.realpath(STREAM_DIR)):
raise HTTPException(status_code=400, detail="Invalid user directory")
if os.path.exists(real_user_dir):
shutil.rmtree(real_user_dir, ignore_errors=True)
return {"message": "User deleted"}
@app.post("/upload")
async def upload_audio(
request: Request,
uid: str = Form(...),
file: UploadFile = File(...)
):
ip = request.client.host
user = get_user_by_uid(uid)
if not user:
log_violation("UPLOAD", ip, uid, "UID not found")
raise HTTPException(status_code=403, detail="Invalid user ID")
if user.ip != ip:
log_violation("UPLOAD", ip, uid, "UID/IP mismatch")
raise HTTPException(status_code=403, detail="Device/IP mismatch")
user_dir = os.path.join(STREAM_DIR, user.username)
os.makedirs(user_dir, exist_ok=True)
raw_path = os.path.join(user_dir, "upload.wav")
final_path = os.path.join(user_dir, "stream.opus")
with open(raw_path, "wb") as out:
content = await file.read()
out.write(content)
usage = subprocess.check_output(["du", "-sb", user_dir]).split()[0]
if int(usage) > MAX_QUOTA_BYTES:
os.remove(raw_path)
log_violation("UPLOAD", ip, uid, "Quota exceeded")
raise HTTPException(status_code=403, detail="Quota exceeded")
from fastapi.concurrency import run_in_threadpool
# from detect_content_type_whisper_ollama import detect_content_type_whisper_ollama # Broken import: module not found
content_type = None
if content_type in ["music", "singing"]:
os.remove(raw_path)
log_violation("UPLOAD", ip, uid, f"Rejected content: {content_type}")
return JSONResponse(status_code=403, content={"error": f"{content_type.capitalize()} uploads are not allowed."})
try:
subprocess.run([
"ffmpeg", "-y", "-i", raw_path,
"-ac", "1", "-ar", "48000",
"-c:a", "libopus", "-b:a", "60k",
final_path
], check=True)
except subprocess.CalledProcessError as e:
os.remove(raw_path)
log_violation("FFMPEG", ip, uid, f"ffmpeg failed: {e}")
raise HTTPException(status_code=500, detail="Encoding failed")
os.remove(raw_path)
try:
actual_bytes = int(subprocess.check_output(["du", "-sb", user_dir]).split()[0])
q = db.get(UserQuota, uid)
if q:
q.storage_bytes = actual_bytes
db.add(q)
db.commit()
except Exception as e:
log_violation("QUOTA", ip, uid, f"Quota update failed: {e}")
stream_url = f"{ICECAST_BASE_URL}{ICECAST_MOUNT_PREFIX}{user.username}.opus"
return {"stream_url": stream_url}
@app.delete("/uploads/{uid}/{filename}")
def delete_file(uid: str, filename: str, request: Request, db: Session = Depends(get_db)):
user = get_user_by_uid(uid)
if not user:
raise HTTPException(status_code=403, detail="Invalid user ID")
ip = request.client.host
if user.ip != ip:
raise HTTPException(status_code=403, detail="Device/IP mismatch")
user_dir = os.path.join(STREAM_DIR, user.username)
target_path = os.path.join(user_dir, filename)
# Prevent path traversal attacks
real_target_path = os.path.realpath(target_path)
real_user_dir = os.path.realpath(user_dir)
if not real_target_path.startswith(real_user_dir + os.sep):
raise HTTPException(status_code=403, detail="Invalid path")
if not os.path.isfile(real_target_path):
raise HTTPException(status_code=404, detail="File not found")
os.remove(real_target_path)
log_violation("DELETE", ip, uid, f"Deleted {filename}")
subprocess.run(["/root/scripts/refresh_user_playlist.sh", user.username])
try:
actual_bytes = int(subprocess.check_output(["du", "-sb", user_dir]).split()[0])
q = db.get(UserQuota, uid)
if q:
q.storage_bytes = actual_bytes
db.add(q)
db.commit()
except Exception as e:
log_violation("QUOTA", ip, uid, f"Quota update after delete failed: {e}")
return {"status": "deleted"}
@app.get("/confirm/{uid}")
def confirm_user(uid: str, request: Request):
ip = request.client.host
user = get_user_by_uid(uid)
if not user or user.ip != ip:
raise HTTPException(status_code=403, detail="Unauthorized")
return {"username": user.username, "email": user.email}
@app.get("/me/{uid}")
def get_me(uid: str, request: Request, db: Session = Depends(get_db)):
ip = request.client.host
user = get_user_by_uid(uid)
if not user or user.ip != ip:
raise HTTPException(status_code=403, detail="Unauthorized access")
user_dir = os.path.join(STREAM_DIR, user.username)
files = []
if os.path.exists(user_dir):
for f in os.listdir(user_dir):
path = os.path.join(user_dir, f)
if os.path.isfile(path):
files.append({"name": f, "size": os.path.getsize(path)})
q = db.get(UserQuota, uid)
quota_mb = round(q.storage_bytes / (1024 * 1024), 2) if q else 0
return {
"stream_url": f"{ICECAST_BASE_URL}{ICECAST_MOUNT_PREFIX}{user.username}.opus",
"files": files,
"quota": quota_mb
}