Migrate from file-based to database-backed stream metadata storage

- Add PublicStream model and migration
- Update list_streams.py and upload.py to use database
- Add import script for data migration
- Remove public_streams.txt (replaced by database)
- Fix quota sync between userquota and publicstream tables
This commit is contained in:
oib
2025-07-19 10:49:16 +02:00
parent 402e920bc6
commit c5412b07ac
5 changed files with 313 additions and 73 deletions

View File

@ -0,0 +1,71 @@
"""Add PublicStream model
Revision ID: 0df481ee920b
Revises: f86c93c7a872
Create Date: 2025-07-19 10:02:22.902696
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '0df481ee920b'
down_revision: Union[str, Sequence[str], None] = 'f86c93c7a872'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
# First create the new publicstream table
op.create_table('publicstream',
sa.Column('uid', sa.String(), nullable=False),
sa.Column('size', sa.Integer(), nullable=False),
sa.Column('mtime', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('uid')
)
# Drop the foreign key constraint first
op.drop_constraint('dbsession_user_id_fkey', 'dbsession', type_='foreignkey')
# Then drop the unique constraint
op.drop_constraint(op.f('uq_user_username'), 'user', type_='unique')
# Create the new index
op.create_index(op.f('ix_user_username'), 'user', ['username'], unique=True)
# Recreate the foreign key constraint
op.create_foreign_key(
'dbsession_user_id_fkey', 'dbsession', 'user',
['user_id'], ['username'], ondelete='CASCADE'
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
# Drop the foreign key constraint first
op.drop_constraint('dbsession_user_id_fkey', 'dbsession', type_='foreignkey')
# Drop the index
op.drop_index(op.f('ix_user_username'), table_name='user')
# Recreate the unique constraint
op.create_unique_constraint(op.f('uq_user_username'), 'user', ['username'])
# Recreate the foreign key constraint
op.create_foreign_key(
'dbsession_user_id_fkey', 'dbsession', 'user',
['user_id'], ['username'], ondelete='CASCADE'
)
# Drop the publicstream table
op.drop_table('publicstream')
# ### end Alembic commands ###

94
import_streams.py Normal file
View File

@ -0,0 +1,94 @@
#!/usr/bin/env python3
"""
Script to import stream data from backup file into the publicstream table.
"""
import json
from datetime import datetime
from pathlib import Path
from sqlalchemy import create_engine, select
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session
from models import PublicStream, User, UserQuota, DBSession, UploadLog
from database import engine
# Database connection URL - using the same as in database.py
DATABASE_URL = "postgresql://d2s:kuTy4ZKs2VcjgDh6@localhost:5432/dictastream"
def import_streams_from_backup(backup_file: str):
"""Import stream data from backup file into the database."""
# Set up database connection
SessionLocal = sessionmaker(bind=engine)
with Session(engine) as session:
try:
# Read the backup file
with open(backup_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
try:
# Parse the JSON data
stream_data = json.loads(line)
uid = stream_data.get('uid')
size = stream_data.get('size', 0)
mtime = stream_data.get('mtime', int(datetime.now().timestamp()))
if not uid:
print(f"Skipping invalid entry (missing uid): {line}")
continue
# Check if the stream already exists
existing = session.exec(
select(PublicStream).where(PublicStream.uid == uid)
).first()
now = datetime.utcnow()
if existing:
# Update existing record
existing.size = size
existing.mtime = mtime
existing.updated_at = now
session.add(existing)
print(f"Updated stream: {uid}")
else:
# Create new record
stream = PublicStream(
uid=uid,
size=size,
mtime=mtime,
created_at=now,
updated_at=now
)
session.add(stream)
print(f"Added stream: {uid}")
# Commit after each record to ensure data integrity
session.commit()
except json.JSONDecodeError as e:
print(f"Error parsing line: {line}")
print(f"Error: {e}")
session.rollback()
except Exception as e:
print(f"Error processing line: {line}")
print(f"Error: {e}")
session.rollback()
print("Import completed successfully!")
except Exception as e:
session.rollback()
print(f"Error during import: {e}")
raise
if __name__ == "__main__":
backup_file = "public_streams.txt.backup"
if not Path(backup_file).exists():
print(f"Error: Backup file '{backup_file}' not found.")
exit(1)
print(f"Starting import from {backup_file}...")
import_streams_from_backup(backup_file)

View File

@ -1,18 +1,21 @@
# list_streams.py — FastAPI route to list all public streams (users with stream.opus) # list_streams.py — FastAPI route to list all public streams (users with stream.opus)
from fastapi import APIRouter, Request from fastapi import APIRouter, Request, Depends
from fastapi.responses import StreamingResponse, Response from fastapi.responses import StreamingResponse, Response
from sqlalchemy.orm import Session
from sqlalchemy import select
from models import PublicStream
from database import get_db
from pathlib import Path from pathlib import Path
import asyncio import asyncio
import os
import json
router = APIRouter() router = APIRouter()
DATA_ROOT = Path("./data") DATA_ROOT = Path("./data")
@router.get("/streams-sse") @router.get("/streams-sse")
async def streams_sse(request: Request): async def streams_sse(request: Request, db: Session = Depends(get_db)):
print(f"[SSE] New connection from {request.client.host}")
print(f"[SSE] Request headers: {dict(request.headers)}")
# Add CORS headers for SSE # Add CORS headers for SSE
origin = request.headers.get('origin', '') origin = request.headers.get('origin', '')
allowed_origins = ["https://dicta2stream.net", "http://localhost:8000", "http://127.0.0.1:8000"] allowed_origins = ["https://dicta2stream.net", "http://localhost:8000", "http://127.0.0.1:8000"]
@ -32,7 +35,6 @@ async def streams_sse(request: Request):
# Handle preflight requests # Handle preflight requests
if request.method == "OPTIONS": if request.method == "OPTIONS":
print("[SSE] Handling OPTIONS preflight request")
headers.update({ headers.update({
"Access-Control-Allow-Methods": "GET, OPTIONS", "Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": request.headers.get("access-control-request-headers", "*"), "Access-Control-Allow-Headers": request.headers.get("access-control-request-headers", "*"),
@ -40,17 +42,16 @@ async def streams_sse(request: Request):
}) })
return Response(status_code=204, headers=headers) return Response(status_code=204, headers=headers)
print("[SSE] Starting SSE stream")
async def event_wrapper(): async def event_wrapper():
try: try:
async for event in list_streams_sse(): async for event in list_streams_sse(db):
yield event yield event
except Exception as e: except Exception as e:
print(f"[SSE] Error in event generator: {str(e)}") # Only log errors if DEBUG is enabled
if os.getenv("DEBUG") == "1":
import traceback import traceback
traceback.print_exc() traceback.print_exc()
yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n" yield f"data: {json.dumps({'error': True, 'message': 'An error occurred'})}\n\n"
return StreamingResponse( return StreamingResponse(
event_wrapper(), event_wrapper(),
@ -58,75 +59,71 @@ async def streams_sse(request: Request):
headers=headers headers=headers
) )
import json async def list_streams_sse(db):
import datetime """Stream public streams from the database as Server-Sent Events"""
try:
# Send initial ping
yield ":ping\n\n"
async def list_streams_sse(): # Query all public streams from the database
print("[SSE] Starting stream generator") stmt = select(PublicStream).order_by(PublicStream.mtime.desc())
txt_path = Path("./public_streams.txt") result = db.execute(stmt)
streams = result.scalars().all()
if not txt_path.exists(): if not streams:
print(f"[SSE] No public_streams.txt found")
yield f"data: {json.dumps({'end': True})}\n\n" yield f"data: {json.dumps({'end': True})}\n\n"
return return
# Send each stream as an SSE event
for stream in streams:
try: try:
# Send initial ping stream_data = {
print("[SSE] Sending initial ping") 'uid': stream.uid,
yield ":ping\n\n" 'size': stream.size,
'mtime': stream.mtime,
# Read and send the file contents 'created_at': stream.created_at.isoformat() if stream.created_at else None,
with txt_path.open("r") as f: 'updated_at': stream.updated_at.isoformat() if stream.updated_at else None
for line in f: }
line = line.strip() yield f"data: {json.dumps(stream_data)}\n\n"
if not line:
continue
try:
# Parse the JSON to validate it
stream = json.loads(line)
print(f"[SSE] Sending stream data: {stream}")
# Send the data as an SSE event
event = f"data: {json.dumps(stream)}\n\n"
yield event
# Small delay to prevent overwhelming the client # Small delay to prevent overwhelming the client
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
except json.JSONDecodeError as e:
print(f"[SSE] JSON decode error: {e} in line: {line}")
continue
except Exception as e: except Exception as e:
print(f"[SSE] Error processing line: {e}") if os.getenv("DEBUG") == "1":
import traceback
traceback.print_exc()
continue continue
print("[SSE] Sending end event") # Send end of stream marker
yield f"data: {json.dumps({'end': True})}\n\n" yield f"data: {json.dumps({'end': True})}\n\n"
except Exception as e: except Exception as e:
print(f"[SSE] Error in stream generator: {str(e)}") if os.getenv("DEBUG") == "1":
import traceback import traceback
traceback.print_exc() traceback.print_exc()
yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n" yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n"
finally: yield f"data: {json.dumps({'error': True, 'message': 'Stream generation failed'})}\n\n"
print("[SSE] Stream generator finished")
def list_streams(): def list_streams(db: Session = Depends(get_db)):
txt_path = Path("./public_streams.txt") """List all public streams from the database"""
if not txt_path.exists():
return {"streams": []}
try: try:
streams = [] stmt = select(PublicStream).order_by(PublicStream.mtime.desc())
with txt_path.open("r") as f: result = db.execute(stmt)
for line in f: streams = result.scalars().all()
line = line.strip()
if not line: return {
continue "streams": [
try: {
streams.append(json.loads(line)) 'uid': stream.uid,
except Exception: 'size': stream.size,
continue # skip malformed lines 'mtime': stream.mtime,
return {"streams": streams} 'created_at': stream.created_at.isoformat() if stream.created_at else None,
except Exception: 'updated_at': stream.updated_at.isoformat() if stream.updated_at else None
}
for stream in streams
]
}
except Exception as e:
if os.getenv("DEBUG") == "1":
import traceback
traceback.print_exc()
return {"streams": []} return {"streams": []}

View File

@ -40,8 +40,45 @@ class DBSession(SQLModel, table=True):
last_activity: datetime = Field(default_factory=datetime.utcnow) last_activity: datetime = Field(default_factory=datetime.utcnow)
class PublicStream(SQLModel, table=True):
"""Stores public stream metadata for all users"""
uid: str = Field(primary_key=True)
size: int = 0
mtime: int = Field(default_factory=lambda: int(datetime.utcnow().timestamp()))
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
def get_user_by_uid(uid: str) -> Optional[User]: def get_user_by_uid(uid: str) -> Optional[User]:
with Session(engine) as session: with Session(engine) as session:
statement = select(User).where(User.username == uid) statement = select(User).where(User.username == uid)
result = session.exec(statement).first() result = session.exec(statement).first()
return result return result
def verify_session(db: Session, token: str) -> DBSession:
"""Verify a session token and return the session if valid"""
from datetime import datetime
# Find the session
session = db.exec(
select(DBSession)
.where(DBSession.token == token)
.where(DBSession.is_active == True) # noqa: E712
.where(DBSession.expires_at > datetime.utcnow())
).first()
if not session:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired session",
headers={"WWW-Authenticate": "Bearer"},
)
# Update last activity
session.last_activity = datetime.utcnow()
db.add(session)
db.commit()
db.refresh(session)
return session

View File

@ -5,6 +5,8 @@ from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from pathlib import Path from pathlib import Path
import json
from datetime import datetime
from convert_to_opus import convert_to_opus from convert_to_opus import convert_to_opus
from models import UploadLog, UserQuota, User from models import UploadLog, UserQuota, User
from sqlalchemy import select from sqlalchemy import select
@ -116,6 +118,9 @@ async def upload(request: Request, db = Depends(get_db), uid: str = Form(...), f
quota.storage_bytes += size quota.storage_bytes += size
db.commit() db.commit()
# Update public streams list
update_public_streams(uid, quota.storage_bytes)
return { return {
"filename": file.filename, "filename": file.filename,
"original_size": round(original_size / 1024, 1), "original_size": round(original_size / 1024, 1),
@ -135,3 +140,39 @@ async def upload(request: Request, db = Depends(get_db), uid: str = Form(...), f
except Exception: except Exception:
pass pass
return {"detail": f"Server error: {type(e).__name__}: {str(e)}"} return {"detail": f"Server error: {type(e).__name__}: {str(e)}"}
def update_public_streams(uid: str, storage_bytes: int, db = Depends(get_db)):
"""Update the public streams list in the database with the latest user upload info"""
try:
from models import PublicStream
# Get or create the public stream record
public_stream = db.get(PublicStream, uid)
current_time = datetime.utcnow()
if public_stream is None:
# Create a new record if it doesn't exist
public_stream = PublicStream(
uid=uid,
size=storage_bytes,
mtime=int(current_time.timestamp()),
created_at=current_time,
updated_at=current_time
)
db.add(public_stream)
else:
# Update existing record
public_stream.size = storage_bytes
public_stream.mtime = int(current_time.timestamp())
public_stream.updated_at = current_time
db.commit()
db.refresh(public_stream)
except Exception as e:
db.rollback()
import traceback
print(f"Error updating public streams in database: {e}")
print(traceback.format_exc())
raise