From c5412b07ac69cad1d726c14800f84e7008ebb2e9 Mon Sep 17 00:00:00 2001 From: oib Date: Sat, 19 Jul 2025 10:49:16 +0200 Subject: [PATCH] Migrate from file-based to database-backed stream metadata storage - Add PublicStream model and migration - Update list_streams.py and upload.py to use database - Add import script for data migration - Remove public_streams.txt (replaced by database) - Fix quota sync between userquota and publicstream tables --- .../0df481ee920b_add_publicstream_model.py | 71 +++++++++ import_streams.py | 94 ++++++++++++ list_streams.py | 143 +++++++++--------- models.py | 37 +++++ upload.py | 41 +++++ 5 files changed, 313 insertions(+), 73 deletions(-) create mode 100644 alembic/versions/0df481ee920b_add_publicstream_model.py create mode 100644 import_streams.py diff --git a/alembic/versions/0df481ee920b_add_publicstream_model.py b/alembic/versions/0df481ee920b_add_publicstream_model.py new file mode 100644 index 0000000..0cf6db8 --- /dev/null +++ b/alembic/versions/0df481ee920b_add_publicstream_model.py @@ -0,0 +1,71 @@ +"""Add PublicStream model + +Revision ID: 0df481ee920b +Revises: f86c93c7a872 +Create Date: 2025-07-19 10:02:22.902696 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '0df481ee920b' +down_revision: Union[str, Sequence[str], None] = 'f86c93c7a872' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + # First create the new publicstream table + op.create_table('publicstream', + sa.Column('uid', sa.String(), nullable=False), + sa.Column('size', sa.Integer(), nullable=False), + sa.Column('mtime', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('uid') + ) + + # Drop the foreign key constraint first + op.drop_constraint('dbsession_user_id_fkey', 'dbsession', type_='foreignkey') + + # Then drop the unique constraint + op.drop_constraint(op.f('uq_user_username'), 'user', type_='unique') + + # Create the new index + op.create_index(op.f('ix_user_username'), 'user', ['username'], unique=True) + + # Recreate the foreign key constraint + op.create_foreign_key( + 'dbsession_user_id_fkey', 'dbsession', 'user', + ['user_id'], ['username'], ondelete='CASCADE' + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + # Drop the foreign key constraint first + op.drop_constraint('dbsession_user_id_fkey', 'dbsession', type_='foreignkey') + + # Drop the index + op.drop_index(op.f('ix_user_username'), table_name='user') + + # Recreate the unique constraint + op.create_unique_constraint(op.f('uq_user_username'), 'user', ['username']) + + # Recreate the foreign key constraint + op.create_foreign_key( + 'dbsession_user_id_fkey', 'dbsession', 'user', + ['user_id'], ['username'], ondelete='CASCADE' + ) + + # Drop the publicstream table + op.drop_table('publicstream') + # ### end Alembic commands ### diff --git a/import_streams.py b/import_streams.py new file mode 100644 index 0000000..133176b --- /dev/null +++ b/import_streams.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +""" +Script to import stream data from backup file into the publicstream table. +""" +import json +from datetime import datetime +from pathlib import Path +from sqlalchemy import create_engine, select +from sqlalchemy.orm import sessionmaker +from sqlmodel import Session +from models import PublicStream, User, UserQuota, DBSession, UploadLog +from database import engine + +# Database connection URL - using the same as in database.py +DATABASE_URL = "postgresql://d2s:kuTy4ZKs2VcjgDh6@localhost:5432/dictastream" + +def import_streams_from_backup(backup_file: str): + """Import stream data from backup file into the database.""" + # Set up database connection + SessionLocal = sessionmaker(bind=engine) + + with Session(engine) as session: + try: + # Read the backup file + with open(backup_file, 'r') as f: + for line in f: + line = line.strip() + if not line: + continue + + try: + # Parse the JSON data + stream_data = json.loads(line) + uid = stream_data.get('uid') + size = stream_data.get('size', 0) + mtime = stream_data.get('mtime', int(datetime.now().timestamp())) + + if not uid: + print(f"Skipping invalid entry (missing uid): {line}") + continue + + # Check if the stream already exists + existing = session.exec( + select(PublicStream).where(PublicStream.uid == uid) + ).first() + + now = datetime.utcnow() + + if existing: + # Update existing record + existing.size = size + existing.mtime = mtime + existing.updated_at = now + session.add(existing) + print(f"Updated stream: {uid}") + else: + # Create new record + stream = PublicStream( + uid=uid, + size=size, + mtime=mtime, + created_at=now, + updated_at=now + ) + session.add(stream) + print(f"Added stream: {uid}") + + # Commit after each record to ensure data integrity + session.commit() + + except json.JSONDecodeError as e: + print(f"Error parsing line: {line}") + print(f"Error: {e}") + session.rollback() + except Exception as e: + print(f"Error processing line: {line}") + print(f"Error: {e}") + session.rollback() + + print("Import completed successfully!") + + except Exception as e: + session.rollback() + print(f"Error during import: {e}") + raise + +if __name__ == "__main__": + backup_file = "public_streams.txt.backup" + if not Path(backup_file).exists(): + print(f"Error: Backup file '{backup_file}' not found.") + exit(1) + + print(f"Starting import from {backup_file}...") + import_streams_from_backup(backup_file) diff --git a/list_streams.py b/list_streams.py index 9e36366..3757ab7 100644 --- a/list_streams.py +++ b/list_streams.py @@ -1,18 +1,21 @@ # list_streams.py — FastAPI route to list all public streams (users with stream.opus) -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, Depends from fastapi.responses import StreamingResponse, Response +from sqlalchemy.orm import Session +from sqlalchemy import select +from models import PublicStream +from database import get_db from pathlib import Path import asyncio +import os +import json router = APIRouter() DATA_ROOT = Path("./data") @router.get("/streams-sse") -async def streams_sse(request: Request): - print(f"[SSE] New connection from {request.client.host}") - print(f"[SSE] Request headers: {dict(request.headers)}") - +async def streams_sse(request: Request, db: Session = Depends(get_db)): # Add CORS headers for SSE origin = request.headers.get('origin', '') allowed_origins = ["https://dicta2stream.net", "http://localhost:8000", "http://127.0.0.1:8000"] @@ -32,7 +35,6 @@ async def streams_sse(request: Request): # Handle preflight requests if request.method == "OPTIONS": - print("[SSE] Handling OPTIONS preflight request") headers.update({ "Access-Control-Allow-Methods": "GET, OPTIONS", "Access-Control-Allow-Headers": request.headers.get("access-control-request-headers", "*"), @@ -40,17 +42,16 @@ async def streams_sse(request: Request): }) return Response(status_code=204, headers=headers) - print("[SSE] Starting SSE stream") - async def event_wrapper(): try: - async for event in list_streams_sse(): + async for event in list_streams_sse(db): yield event except Exception as e: - print(f"[SSE] Error in event generator: {str(e)}") - import traceback - traceback.print_exc() - yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n" + # Only log errors if DEBUG is enabled + if os.getenv("DEBUG") == "1": + import traceback + traceback.print_exc() + yield f"data: {json.dumps({'error': True, 'message': 'An error occurred'})}\n\n" return StreamingResponse( event_wrapper(), @@ -58,75 +59,71 @@ async def streams_sse(request: Request): headers=headers ) -import json -import datetime - -async def list_streams_sse(): - print("[SSE] Starting stream generator") - txt_path = Path("./public_streams.txt") - - if not txt_path.exists(): - print(f"[SSE] No public_streams.txt found") - yield f"data: {json.dumps({'end': True})}\n\n" - return - +async def list_streams_sse(db): + """Stream public streams from the database as Server-Sent Events""" try: # Send initial ping - print("[SSE] Sending initial ping") yield ":ping\n\n" - # Read and send the file contents - with txt_path.open("r") as f: - for line in f: - line = line.strip() - if not line: - continue - - try: - # Parse the JSON to validate it - stream = json.loads(line) - print(f"[SSE] Sending stream data: {stream}") - - # Send the data as an SSE event - event = f"data: {json.dumps(stream)}\n\n" - yield event - - # Small delay to prevent overwhelming the client - await asyncio.sleep(0.1) - - except json.JSONDecodeError as e: - print(f"[SSE] JSON decode error: {e} in line: {line}") - continue - except Exception as e: - print(f"[SSE] Error processing line: {e}") - continue + # Query all public streams from the database + stmt = select(PublicStream).order_by(PublicStream.mtime.desc()) + result = db.execute(stmt) + streams = result.scalars().all() - print("[SSE] Sending end event") + if not streams: + yield f"data: {json.dumps({'end': True})}\n\n" + return + + # Send each stream as an SSE event + for stream in streams: + try: + stream_data = { + 'uid': stream.uid, + 'size': stream.size, + 'mtime': stream.mtime, + 'created_at': stream.created_at.isoformat() if stream.created_at else None, + 'updated_at': stream.updated_at.isoformat() if stream.updated_at else None + } + yield f"data: {json.dumps(stream_data)}\n\n" + # Small delay to prevent overwhelming the client + await asyncio.sleep(0.1) + except Exception as e: + if os.getenv("DEBUG") == "1": + import traceback + traceback.print_exc() + continue + + # Send end of stream marker yield f"data: {json.dumps({'end': True})}\n\n" except Exception as e: - print(f"[SSE] Error in stream generator: {str(e)}") - import traceback - traceback.print_exc() + if os.getenv("DEBUG") == "1": + import traceback + traceback.print_exc() yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n" - finally: - print("[SSE] Stream generator finished") + yield f"data: {json.dumps({'error': True, 'message': 'Stream generation failed'})}\n\n" -def list_streams(): - txt_path = Path("./public_streams.txt") - if not txt_path.exists(): - return {"streams": []} +def list_streams(db: Session = Depends(get_db)): + """List all public streams from the database""" try: - streams = [] - with txt_path.open("r") as f: - for line in f: - line = line.strip() - if not line: - continue - try: - streams.append(json.loads(line)) - except Exception: - continue # skip malformed lines - return {"streams": streams} - except Exception: + stmt = select(PublicStream).order_by(PublicStream.mtime.desc()) + result = db.execute(stmt) + streams = result.scalars().all() + + return { + "streams": [ + { + 'uid': stream.uid, + 'size': stream.size, + 'mtime': stream.mtime, + 'created_at': stream.created_at.isoformat() if stream.created_at else None, + 'updated_at': stream.updated_at.isoformat() if stream.updated_at else None + } + for stream in streams + ] + } + except Exception as e: + if os.getenv("DEBUG") == "1": + import traceback + traceback.print_exc() return {"streams": []} diff --git a/models.py b/models.py index b28dc8b..50061f2 100644 --- a/models.py +++ b/models.py @@ -40,8 +40,45 @@ class DBSession(SQLModel, table=True): last_activity: datetime = Field(default_factory=datetime.utcnow) +class PublicStream(SQLModel, table=True): + """Stores public stream metadata for all users""" + uid: str = Field(primary_key=True) + size: int = 0 + mtime: int = Field(default_factory=lambda: int(datetime.utcnow().timestamp())) + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + def get_user_by_uid(uid: str) -> Optional[User]: with Session(engine) as session: statement = select(User).where(User.username == uid) result = session.exec(statement).first() return result + + +def verify_session(db: Session, token: str) -> DBSession: + """Verify a session token and return the session if valid""" + from datetime import datetime + + # Find the session + session = db.exec( + select(DBSession) + .where(DBSession.token == token) + .where(DBSession.is_active == True) # noqa: E712 + .where(DBSession.expires_at > datetime.utcnow()) + ).first() + + if not session: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired session", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Update last activity + session.last_activity = datetime.utcnow() + db.add(session) + db.commit() + db.refresh(session) + + return session diff --git a/upload.py b/upload.py index 0878f49..b9742f2 100644 --- a/upload.py +++ b/upload.py @@ -5,6 +5,8 @@ from slowapi import Limiter from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from pathlib import Path +import json +from datetime import datetime from convert_to_opus import convert_to_opus from models import UploadLog, UserQuota, User from sqlalchemy import select @@ -115,6 +117,9 @@ async def upload(request: Request, db = Depends(get_db), uid: str = Form(...), f db.add(quota) quota.storage_bytes += size db.commit() + + # Update public streams list + update_public_streams(uid, quota.storage_bytes) return { "filename": file.filename, @@ -135,3 +140,39 @@ async def upload(request: Request, db = Depends(get_db), uid: str = Form(...), f except Exception: pass return {"detail": f"Server error: {type(e).__name__}: {str(e)}"} + + +def update_public_streams(uid: str, storage_bytes: int, db = Depends(get_db)): + """Update the public streams list in the database with the latest user upload info""" + try: + from models import PublicStream + + # Get or create the public stream record + public_stream = db.get(PublicStream, uid) + current_time = datetime.utcnow() + + if public_stream is None: + # Create a new record if it doesn't exist + public_stream = PublicStream( + uid=uid, + size=storage_bytes, + mtime=int(current_time.timestamp()), + created_at=current_time, + updated_at=current_time + ) + db.add(public_stream) + else: + # Update existing record + public_stream.size = storage_bytes + public_stream.mtime = int(current_time.timestamp()) + public_stream.updated_at = current_time + + db.commit() + db.refresh(public_stream) + + except Exception as e: + db.rollback() + import traceback + print(f"Error updating public streams in database: {e}") + print(traceback.format_exc()) + raise