Migrate from file-based to database-backed stream metadata storage

- Add PublicStream model and migration
- Update list_streams.py and upload.py to use database
- Add import script for data migration
- Remove public_streams.txt (replaced by database)
- Fix quota sync between userquota and publicstream tables
This commit is contained in:
oib
2025-07-19 10:49:16 +02:00
parent 402e920bc6
commit c5412b07ac
5 changed files with 313 additions and 73 deletions

View File

@ -0,0 +1,71 @@
"""Add PublicStream model
Revision ID: 0df481ee920b
Revises: f86c93c7a872
Create Date: 2025-07-19 10:02:22.902696
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '0df481ee920b'
down_revision: Union[str, Sequence[str], None] = 'f86c93c7a872'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
# First create the new publicstream table
op.create_table('publicstream',
sa.Column('uid', sa.String(), nullable=False),
sa.Column('size', sa.Integer(), nullable=False),
sa.Column('mtime', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('uid')
)
# Drop the foreign key constraint first
op.drop_constraint('dbsession_user_id_fkey', 'dbsession', type_='foreignkey')
# Then drop the unique constraint
op.drop_constraint(op.f('uq_user_username'), 'user', type_='unique')
# Create the new index
op.create_index(op.f('ix_user_username'), 'user', ['username'], unique=True)
# Recreate the foreign key constraint
op.create_foreign_key(
'dbsession_user_id_fkey', 'dbsession', 'user',
['user_id'], ['username'], ondelete='CASCADE'
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
# Drop the foreign key constraint first
op.drop_constraint('dbsession_user_id_fkey', 'dbsession', type_='foreignkey')
# Drop the index
op.drop_index(op.f('ix_user_username'), table_name='user')
# Recreate the unique constraint
op.create_unique_constraint(op.f('uq_user_username'), 'user', ['username'])
# Recreate the foreign key constraint
op.create_foreign_key(
'dbsession_user_id_fkey', 'dbsession', 'user',
['user_id'], ['username'], ondelete='CASCADE'
)
# Drop the publicstream table
op.drop_table('publicstream')
# ### end Alembic commands ###

94
import_streams.py Normal file
View File

@ -0,0 +1,94 @@
#!/usr/bin/env python3
"""
Script to import stream data from backup file into the publicstream table.
"""
import json
from datetime import datetime
from pathlib import Path
from sqlalchemy import create_engine, select
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session
from models import PublicStream, User, UserQuota, DBSession, UploadLog
from database import engine
# Database connection URL - using the same as in database.py
DATABASE_URL = "postgresql://d2s:kuTy4ZKs2VcjgDh6@localhost:5432/dictastream"
def import_streams_from_backup(backup_file: str):
"""Import stream data from backup file into the database."""
# Set up database connection
SessionLocal = sessionmaker(bind=engine)
with Session(engine) as session:
try:
# Read the backup file
with open(backup_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
try:
# Parse the JSON data
stream_data = json.loads(line)
uid = stream_data.get('uid')
size = stream_data.get('size', 0)
mtime = stream_data.get('mtime', int(datetime.now().timestamp()))
if not uid:
print(f"Skipping invalid entry (missing uid): {line}")
continue
# Check if the stream already exists
existing = session.exec(
select(PublicStream).where(PublicStream.uid == uid)
).first()
now = datetime.utcnow()
if existing:
# Update existing record
existing.size = size
existing.mtime = mtime
existing.updated_at = now
session.add(existing)
print(f"Updated stream: {uid}")
else:
# Create new record
stream = PublicStream(
uid=uid,
size=size,
mtime=mtime,
created_at=now,
updated_at=now
)
session.add(stream)
print(f"Added stream: {uid}")
# Commit after each record to ensure data integrity
session.commit()
except json.JSONDecodeError as e:
print(f"Error parsing line: {line}")
print(f"Error: {e}")
session.rollback()
except Exception as e:
print(f"Error processing line: {line}")
print(f"Error: {e}")
session.rollback()
print("Import completed successfully!")
except Exception as e:
session.rollback()
print(f"Error during import: {e}")
raise
if __name__ == "__main__":
backup_file = "public_streams.txt.backup"
if not Path(backup_file).exists():
print(f"Error: Backup file '{backup_file}' not found.")
exit(1)
print(f"Starting import from {backup_file}...")
import_streams_from_backup(backup_file)

View File

@ -1,18 +1,21 @@
# list_streams.py — FastAPI route to list all public streams (users with stream.opus)
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Depends
from fastapi.responses import StreamingResponse, Response
from sqlalchemy.orm import Session
from sqlalchemy import select
from models import PublicStream
from database import get_db
from pathlib import Path
import asyncio
import os
import json
router = APIRouter()
DATA_ROOT = Path("./data")
@router.get("/streams-sse")
async def streams_sse(request: Request):
print(f"[SSE] New connection from {request.client.host}")
print(f"[SSE] Request headers: {dict(request.headers)}")
async def streams_sse(request: Request, db: Session = Depends(get_db)):
# Add CORS headers for SSE
origin = request.headers.get('origin', '')
allowed_origins = ["https://dicta2stream.net", "http://localhost:8000", "http://127.0.0.1:8000"]
@ -32,7 +35,6 @@ async def streams_sse(request: Request):
# Handle preflight requests
if request.method == "OPTIONS":
print("[SSE] Handling OPTIONS preflight request")
headers.update({
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": request.headers.get("access-control-request-headers", "*"),
@ -40,17 +42,16 @@ async def streams_sse(request: Request):
})
return Response(status_code=204, headers=headers)
print("[SSE] Starting SSE stream")
async def event_wrapper():
try:
async for event in list_streams_sse():
async for event in list_streams_sse(db):
yield event
except Exception as e:
print(f"[SSE] Error in event generator: {str(e)}")
import traceback
traceback.print_exc()
yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n"
# Only log errors if DEBUG is enabled
if os.getenv("DEBUG") == "1":
import traceback
traceback.print_exc()
yield f"data: {json.dumps({'error': True, 'message': 'An error occurred'})}\n\n"
return StreamingResponse(
event_wrapper(),
@ -58,75 +59,71 @@ async def streams_sse(request: Request):
headers=headers
)
import json
import datetime
async def list_streams_sse():
print("[SSE] Starting stream generator")
txt_path = Path("./public_streams.txt")
if not txt_path.exists():
print(f"[SSE] No public_streams.txt found")
yield f"data: {json.dumps({'end': True})}\n\n"
return
async def list_streams_sse(db):
"""Stream public streams from the database as Server-Sent Events"""
try:
# Send initial ping
print("[SSE] Sending initial ping")
yield ":ping\n\n"
# Read and send the file contents
with txt_path.open("r") as f:
for line in f:
line = line.strip()
if not line:
continue
# Query all public streams from the database
stmt = select(PublicStream).order_by(PublicStream.mtime.desc())
result = db.execute(stmt)
streams = result.scalars().all()
try:
# Parse the JSON to validate it
stream = json.loads(line)
print(f"[SSE] Sending stream data: {stream}")
if not streams:
yield f"data: {json.dumps({'end': True})}\n\n"
return
# Send the data as an SSE event
event = f"data: {json.dumps(stream)}\n\n"
yield event
# Send each stream as an SSE event
for stream in streams:
try:
stream_data = {
'uid': stream.uid,
'size': stream.size,
'mtime': stream.mtime,
'created_at': stream.created_at.isoformat() if stream.created_at else None,
'updated_at': stream.updated_at.isoformat() if stream.updated_at else None
}
yield f"data: {json.dumps(stream_data)}\n\n"
# Small delay to prevent overwhelming the client
await asyncio.sleep(0.1)
except Exception as e:
if os.getenv("DEBUG") == "1":
import traceback
traceback.print_exc()
continue
# Small delay to prevent overwhelming the client
await asyncio.sleep(0.1)
except json.JSONDecodeError as e:
print(f"[SSE] JSON decode error: {e} in line: {line}")
continue
except Exception as e:
print(f"[SSE] Error processing line: {e}")
continue
print("[SSE] Sending end event")
# Send end of stream marker
yield f"data: {json.dumps({'end': True})}\n\n"
except Exception as e:
print(f"[SSE] Error in stream generator: {str(e)}")
import traceback
traceback.print_exc()
if os.getenv("DEBUG") == "1":
import traceback
traceback.print_exc()
yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n"
finally:
print("[SSE] Stream generator finished")
yield f"data: {json.dumps({'error': True, 'message': 'Stream generation failed'})}\n\n"
def list_streams():
txt_path = Path("./public_streams.txt")
if not txt_path.exists():
return {"streams": []}
def list_streams(db: Session = Depends(get_db)):
"""List all public streams from the database"""
try:
streams = []
with txt_path.open("r") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
streams.append(json.loads(line))
except Exception:
continue # skip malformed lines
return {"streams": streams}
except Exception:
stmt = select(PublicStream).order_by(PublicStream.mtime.desc())
result = db.execute(stmt)
streams = result.scalars().all()
return {
"streams": [
{
'uid': stream.uid,
'size': stream.size,
'mtime': stream.mtime,
'created_at': stream.created_at.isoformat() if stream.created_at else None,
'updated_at': stream.updated_at.isoformat() if stream.updated_at else None
}
for stream in streams
]
}
except Exception as e:
if os.getenv("DEBUG") == "1":
import traceback
traceback.print_exc()
return {"streams": []}

View File

@ -40,8 +40,45 @@ class DBSession(SQLModel, table=True):
last_activity: datetime = Field(default_factory=datetime.utcnow)
class PublicStream(SQLModel, table=True):
"""Stores public stream metadata for all users"""
uid: str = Field(primary_key=True)
size: int = 0
mtime: int = Field(default_factory=lambda: int(datetime.utcnow().timestamp()))
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
def get_user_by_uid(uid: str) -> Optional[User]:
with Session(engine) as session:
statement = select(User).where(User.username == uid)
result = session.exec(statement).first()
return result
def verify_session(db: Session, token: str) -> DBSession:
"""Verify a session token and return the session if valid"""
from datetime import datetime
# Find the session
session = db.exec(
select(DBSession)
.where(DBSession.token == token)
.where(DBSession.is_active == True) # noqa: E712
.where(DBSession.expires_at > datetime.utcnow())
).first()
if not session:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired session",
headers={"WWW-Authenticate": "Bearer"},
)
# Update last activity
session.last_activity = datetime.utcnow()
db.add(session)
db.commit()
db.refresh(session)
return session

View File

@ -5,6 +5,8 @@ from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from pathlib import Path
import json
from datetime import datetime
from convert_to_opus import convert_to_opus
from models import UploadLog, UserQuota, User
from sqlalchemy import select
@ -116,6 +118,9 @@ async def upload(request: Request, db = Depends(get_db), uid: str = Form(...), f
quota.storage_bytes += size
db.commit()
# Update public streams list
update_public_streams(uid, quota.storage_bytes)
return {
"filename": file.filename,
"original_size": round(original_size / 1024, 1),
@ -135,3 +140,39 @@ async def upload(request: Request, db = Depends(get_db), uid: str = Form(...), f
except Exception:
pass
return {"detail": f"Server error: {type(e).__name__}: {str(e)}"}
def update_public_streams(uid: str, storage_bytes: int, db = Depends(get_db)):
"""Update the public streams list in the database with the latest user upload info"""
try:
from models import PublicStream
# Get or create the public stream record
public_stream = db.get(PublicStream, uid)
current_time = datetime.utcnow()
if public_stream is None:
# Create a new record if it doesn't exist
public_stream = PublicStream(
uid=uid,
size=storage_bytes,
mtime=int(current_time.timestamp()),
created_at=current_time,
updated_at=current_time
)
db.add(public_stream)
else:
# Update existing record
public_stream.size = storage_bytes
public_stream.mtime = int(current_time.timestamp())
public_stream.updated_at = current_time
db.commit()
db.refresh(public_stream)
except Exception as e:
db.rollback()
import traceback
print(f"Error updating public streams in database: {e}")
print(traceback.format_exc())
raise