Migrate from file-based to database-backed stream metadata storage
- Add PublicStream model and migration - Update list_streams.py and upload.py to use database - Add import script for data migration - Remove public_streams.txt (replaced by database) - Fix quota sync between userquota and publicstream tables
This commit is contained in:
71
alembic/versions/0df481ee920b_add_publicstream_model.py
Normal file
71
alembic/versions/0df481ee920b_add_publicstream_model.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
"""Add PublicStream model
|
||||||
|
|
||||||
|
Revision ID: 0df481ee920b
|
||||||
|
Revises: f86c93c7a872
|
||||||
|
Create Date: 2025-07-19 10:02:22.902696
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '0df481ee920b'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = 'f86c93c7a872'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# First create the new publicstream table
|
||||||
|
op.create_table('publicstream',
|
||||||
|
sa.Column('uid', sa.String(), nullable=False),
|
||||||
|
sa.Column('size', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('mtime', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('uid')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop the foreign key constraint first
|
||||||
|
op.drop_constraint('dbsession_user_id_fkey', 'dbsession', type_='foreignkey')
|
||||||
|
|
||||||
|
# Then drop the unique constraint
|
||||||
|
op.drop_constraint(op.f('uq_user_username'), 'user', type_='unique')
|
||||||
|
|
||||||
|
# Create the new index
|
||||||
|
op.create_index(op.f('ix_user_username'), 'user', ['username'], unique=True)
|
||||||
|
|
||||||
|
# Recreate the foreign key constraint
|
||||||
|
op.create_foreign_key(
|
||||||
|
'dbsession_user_id_fkey', 'dbsession', 'user',
|
||||||
|
['user_id'], ['username'], ondelete='CASCADE'
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# Drop the foreign key constraint first
|
||||||
|
op.drop_constraint('dbsession_user_id_fkey', 'dbsession', type_='foreignkey')
|
||||||
|
|
||||||
|
# Drop the index
|
||||||
|
op.drop_index(op.f('ix_user_username'), table_name='user')
|
||||||
|
|
||||||
|
# Recreate the unique constraint
|
||||||
|
op.create_unique_constraint(op.f('uq_user_username'), 'user', ['username'])
|
||||||
|
|
||||||
|
# Recreate the foreign key constraint
|
||||||
|
op.create_foreign_key(
|
||||||
|
'dbsession_user_id_fkey', 'dbsession', 'user',
|
||||||
|
['user_id'], ['username'], ondelete='CASCADE'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop the publicstream table
|
||||||
|
op.drop_table('publicstream')
|
||||||
|
# ### end Alembic commands ###
|
94
import_streams.py
Normal file
94
import_streams.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Script to import stream data from backup file into the publicstream table.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from sqlalchemy import create_engine, select
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlmodel import Session
|
||||||
|
from models import PublicStream, User, UserQuota, DBSession, UploadLog
|
||||||
|
from database import engine
|
||||||
|
|
||||||
|
# Database connection URL - using the same as in database.py
|
||||||
|
DATABASE_URL = "postgresql://d2s:kuTy4ZKs2VcjgDh6@localhost:5432/dictastream"
|
||||||
|
|
||||||
|
def import_streams_from_backup(backup_file: str):
|
||||||
|
"""Import stream data from backup file into the database."""
|
||||||
|
# Set up database connection
|
||||||
|
SessionLocal = sessionmaker(bind=engine)
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
try:
|
||||||
|
# Read the backup file
|
||||||
|
with open(backup_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse the JSON data
|
||||||
|
stream_data = json.loads(line)
|
||||||
|
uid = stream_data.get('uid')
|
||||||
|
size = stream_data.get('size', 0)
|
||||||
|
mtime = stream_data.get('mtime', int(datetime.now().timestamp()))
|
||||||
|
|
||||||
|
if not uid:
|
||||||
|
print(f"Skipping invalid entry (missing uid): {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if the stream already exists
|
||||||
|
existing = session.exec(
|
||||||
|
select(PublicStream).where(PublicStream.uid == uid)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# Update existing record
|
||||||
|
existing.size = size
|
||||||
|
existing.mtime = mtime
|
||||||
|
existing.updated_at = now
|
||||||
|
session.add(existing)
|
||||||
|
print(f"Updated stream: {uid}")
|
||||||
|
else:
|
||||||
|
# Create new record
|
||||||
|
stream = PublicStream(
|
||||||
|
uid=uid,
|
||||||
|
size=size,
|
||||||
|
mtime=mtime,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now
|
||||||
|
)
|
||||||
|
session.add(stream)
|
||||||
|
print(f"Added stream: {uid}")
|
||||||
|
|
||||||
|
# Commit after each record to ensure data integrity
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error parsing line: {line}")
|
||||||
|
print(f"Error: {e}")
|
||||||
|
session.rollback()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing line: {line}")
|
||||||
|
print(f"Error: {e}")
|
||||||
|
session.rollback()
|
||||||
|
|
||||||
|
print("Import completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
print(f"Error during import: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
backup_file = "public_streams.txt.backup"
|
||||||
|
if not Path(backup_file).exists():
|
||||||
|
print(f"Error: Backup file '{backup_file}' not found.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print(f"Starting import from {backup_file}...")
|
||||||
|
import_streams_from_backup(backup_file)
|
143
list_streams.py
143
list_streams.py
@ -1,18 +1,21 @@
|
|||||||
# list_streams.py — FastAPI route to list all public streams (users with stream.opus)
|
# list_streams.py — FastAPI route to list all public streams (users with stream.opus)
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request, Depends
|
||||||
from fastapi.responses import StreamingResponse, Response
|
from fastapi.responses import StreamingResponse, Response
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import select
|
||||||
|
from models import PublicStream
|
||||||
|
from database import get_db
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
DATA_ROOT = Path("./data")
|
DATA_ROOT = Path("./data")
|
||||||
|
|
||||||
@router.get("/streams-sse")
|
@router.get("/streams-sse")
|
||||||
async def streams_sse(request: Request):
|
async def streams_sse(request: Request, db: Session = Depends(get_db)):
|
||||||
print(f"[SSE] New connection from {request.client.host}")
|
|
||||||
print(f"[SSE] Request headers: {dict(request.headers)}")
|
|
||||||
|
|
||||||
# Add CORS headers for SSE
|
# Add CORS headers for SSE
|
||||||
origin = request.headers.get('origin', '')
|
origin = request.headers.get('origin', '')
|
||||||
allowed_origins = ["https://dicta2stream.net", "http://localhost:8000", "http://127.0.0.1:8000"]
|
allowed_origins = ["https://dicta2stream.net", "http://localhost:8000", "http://127.0.0.1:8000"]
|
||||||
@ -32,7 +35,6 @@ async def streams_sse(request: Request):
|
|||||||
|
|
||||||
# Handle preflight requests
|
# Handle preflight requests
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
print("[SSE] Handling OPTIONS preflight request")
|
|
||||||
headers.update({
|
headers.update({
|
||||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||||
"Access-Control-Allow-Headers": request.headers.get("access-control-request-headers", "*"),
|
"Access-Control-Allow-Headers": request.headers.get("access-control-request-headers", "*"),
|
||||||
@ -40,17 +42,16 @@ async def streams_sse(request: Request):
|
|||||||
})
|
})
|
||||||
return Response(status_code=204, headers=headers)
|
return Response(status_code=204, headers=headers)
|
||||||
|
|
||||||
print("[SSE] Starting SSE stream")
|
|
||||||
|
|
||||||
async def event_wrapper():
|
async def event_wrapper():
|
||||||
try:
|
try:
|
||||||
async for event in list_streams_sse():
|
async for event in list_streams_sse(db):
|
||||||
yield event
|
yield event
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[SSE] Error in event generator: {str(e)}")
|
# Only log errors if DEBUG is enabled
|
||||||
import traceback
|
if os.getenv("DEBUG") == "1":
|
||||||
traceback.print_exc()
|
import traceback
|
||||||
yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n"
|
traceback.print_exc()
|
||||||
|
yield f"data: {json.dumps({'error': True, 'message': 'An error occurred'})}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_wrapper(),
|
event_wrapper(),
|
||||||
@ -58,75 +59,71 @@ async def streams_sse(request: Request):
|
|||||||
headers=headers
|
headers=headers
|
||||||
)
|
)
|
||||||
|
|
||||||
import json
|
async def list_streams_sse(db):
|
||||||
import datetime
|
"""Stream public streams from the database as Server-Sent Events"""
|
||||||
|
|
||||||
async def list_streams_sse():
|
|
||||||
print("[SSE] Starting stream generator")
|
|
||||||
txt_path = Path("./public_streams.txt")
|
|
||||||
|
|
||||||
if not txt_path.exists():
|
|
||||||
print(f"[SSE] No public_streams.txt found")
|
|
||||||
yield f"data: {json.dumps({'end': True})}\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Send initial ping
|
# Send initial ping
|
||||||
print("[SSE] Sending initial ping")
|
|
||||||
yield ":ping\n\n"
|
yield ":ping\n\n"
|
||||||
|
|
||||||
# Read and send the file contents
|
# Query all public streams from the database
|
||||||
with txt_path.open("r") as f:
|
stmt = select(PublicStream).order_by(PublicStream.mtime.desc())
|
||||||
for line in f:
|
result = db.execute(stmt)
|
||||||
line = line.strip()
|
streams = result.scalars().all()
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Parse the JSON to validate it
|
|
||||||
stream = json.loads(line)
|
|
||||||
print(f"[SSE] Sending stream data: {stream}")
|
|
||||||
|
|
||||||
# Send the data as an SSE event
|
|
||||||
event = f"data: {json.dumps(stream)}\n\n"
|
|
||||||
yield event
|
|
||||||
|
|
||||||
# Small delay to prevent overwhelming the client
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
print(f"[SSE] JSON decode error: {e} in line: {line}")
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[SSE] Error processing line: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print("[SSE] Sending end event")
|
if not streams:
|
||||||
|
yield f"data: {json.dumps({'end': True})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Send each stream as an SSE event
|
||||||
|
for stream in streams:
|
||||||
|
try:
|
||||||
|
stream_data = {
|
||||||
|
'uid': stream.uid,
|
||||||
|
'size': stream.size,
|
||||||
|
'mtime': stream.mtime,
|
||||||
|
'created_at': stream.created_at.isoformat() if stream.created_at else None,
|
||||||
|
'updated_at': stream.updated_at.isoformat() if stream.updated_at else None
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(stream_data)}\n\n"
|
||||||
|
# Small delay to prevent overwhelming the client
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except Exception as e:
|
||||||
|
if os.getenv("DEBUG") == "1":
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Send end of stream marker
|
||||||
yield f"data: {json.dumps({'end': True})}\n\n"
|
yield f"data: {json.dumps({'end': True})}\n\n"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[SSE] Error in stream generator: {str(e)}")
|
if os.getenv("DEBUG") == "1":
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n"
|
yield f"data: {json.dumps({'error': True, 'message': str(e)})}\n\n"
|
||||||
finally:
|
yield f"data: {json.dumps({'error': True, 'message': 'Stream generation failed'})}\n\n"
|
||||||
print("[SSE] Stream generator finished")
|
|
||||||
|
|
||||||
def list_streams():
|
def list_streams(db: Session = Depends(get_db)):
|
||||||
txt_path = Path("./public_streams.txt")
|
"""List all public streams from the database"""
|
||||||
if not txt_path.exists():
|
|
||||||
return {"streams": []}
|
|
||||||
try:
|
try:
|
||||||
streams = []
|
stmt = select(PublicStream).order_by(PublicStream.mtime.desc())
|
||||||
with txt_path.open("r") as f:
|
result = db.execute(stmt)
|
||||||
for line in f:
|
streams = result.scalars().all()
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
return {
|
||||||
continue
|
"streams": [
|
||||||
try:
|
{
|
||||||
streams.append(json.loads(line))
|
'uid': stream.uid,
|
||||||
except Exception:
|
'size': stream.size,
|
||||||
continue # skip malformed lines
|
'mtime': stream.mtime,
|
||||||
return {"streams": streams}
|
'created_at': stream.created_at.isoformat() if stream.created_at else None,
|
||||||
except Exception:
|
'updated_at': stream.updated_at.isoformat() if stream.updated_at else None
|
||||||
|
}
|
||||||
|
for stream in streams
|
||||||
|
]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
if os.getenv("DEBUG") == "1":
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
return {"streams": []}
|
return {"streams": []}
|
||||||
|
37
models.py
37
models.py
@ -40,8 +40,45 @@ class DBSession(SQLModel, table=True):
|
|||||||
last_activity: datetime = Field(default_factory=datetime.utcnow)
|
last_activity: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class PublicStream(SQLModel, table=True):
|
||||||
|
"""Stores public stream metadata for all users"""
|
||||||
|
uid: str = Field(primary_key=True)
|
||||||
|
size: int = 0
|
||||||
|
mtime: int = Field(default_factory=lambda: int(datetime.utcnow().timestamp()))
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_uid(uid: str) -> Optional[User]:
|
def get_user_by_uid(uid: str) -> Optional[User]:
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
statement = select(User).where(User.username == uid)
|
statement = select(User).where(User.username == uid)
|
||||||
result = session.exec(statement).first()
|
result = session.exec(statement).first()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def verify_session(db: Session, token: str) -> DBSession:
|
||||||
|
"""Verify a session token and return the session if valid"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Find the session
|
||||||
|
session = db.exec(
|
||||||
|
select(DBSession)
|
||||||
|
.where(DBSession.token == token)
|
||||||
|
.where(DBSession.is_active == True) # noqa: E712
|
||||||
|
.where(DBSession.expires_at > datetime.utcnow())
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired session",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update last activity
|
||||||
|
session.last_activity = datetime.utcnow()
|
||||||
|
db.add(session)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(session)
|
||||||
|
|
||||||
|
return session
|
||||||
|
41
upload.py
41
upload.py
@ -5,6 +5,8 @@ from slowapi import Limiter
|
|||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
from convert_to_opus import convert_to_opus
|
from convert_to_opus import convert_to_opus
|
||||||
from models import UploadLog, UserQuota, User
|
from models import UploadLog, UserQuota, User
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@ -115,6 +117,9 @@ async def upload(request: Request, db = Depends(get_db), uid: str = Form(...), f
|
|||||||
db.add(quota)
|
db.add(quota)
|
||||||
quota.storage_bytes += size
|
quota.storage_bytes += size
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
# Update public streams list
|
||||||
|
update_public_streams(uid, quota.storage_bytes)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"filename": file.filename,
|
"filename": file.filename,
|
||||||
@ -135,3 +140,39 @@ async def upload(request: Request, db = Depends(get_db), uid: str = Form(...), f
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return {"detail": f"Server error: {type(e).__name__}: {str(e)}"}
|
return {"detail": f"Server error: {type(e).__name__}: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
|
def update_public_streams(uid: str, storage_bytes: int, db = Depends(get_db)):
|
||||||
|
"""Update the public streams list in the database with the latest user upload info"""
|
||||||
|
try:
|
||||||
|
from models import PublicStream
|
||||||
|
|
||||||
|
# Get or create the public stream record
|
||||||
|
public_stream = db.get(PublicStream, uid)
|
||||||
|
current_time = datetime.utcnow()
|
||||||
|
|
||||||
|
if public_stream is None:
|
||||||
|
# Create a new record if it doesn't exist
|
||||||
|
public_stream = PublicStream(
|
||||||
|
uid=uid,
|
||||||
|
size=storage_bytes,
|
||||||
|
mtime=int(current_time.timestamp()),
|
||||||
|
created_at=current_time,
|
||||||
|
updated_at=current_time
|
||||||
|
)
|
||||||
|
db.add(public_stream)
|
||||||
|
else:
|
||||||
|
# Update existing record
|
||||||
|
public_stream.size = storage_bytes
|
||||||
|
public_stream.mtime = int(current_time.timestamp())
|
||||||
|
public_stream.updated_at = current_time
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(public_stream)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
import traceback
|
||||||
|
print(f"Error updating public streams in database: {e}")
|
||||||
|
print(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
Reference in New Issue
Block a user