96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
"""Authentication middleware and utilities for dicta2stream"""
|
|
from fastapi import Request, HTTPException, Depends, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from sqlmodel import Session, select
|
|
from typing import Optional
|
|
|
|
from models import User, Session as DBSession, verify_session
|
|
from database import get_db
|
|
|
|
security = HTTPBearer()
|
|
|
|
def get_current_user(
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
) -> User:
|
|
"""Dependency to get the current authenticated user"""
|
|
token = credentials.credentials
|
|
|
|
# Use the database session context manager
|
|
with get_db() as db:
|
|
db_session = verify_session(db, token)
|
|
|
|
if not db_session:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired session",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# Get the user from the session using query interface
|
|
user = db.query(User).filter(User.email == db_session.uid).first()
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# Attach the session to the request state for later use
|
|
request.state.session = db_session
|
|
return user
|
|
|
|
|
|
def get_optional_user(
|
|
request: Request,
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security, use_cache=False)
|
|
) -> Optional[User]:
|
|
"""Dependency that returns the current user if authenticated, None otherwise"""
|
|
if not credentials:
|
|
return None
|
|
|
|
try:
|
|
# get_current_user now handles its own database session
|
|
return get_current_user(request, credentials)
|
|
except HTTPException:
|
|
return None
|
|
|
|
|
|
def create_session(user: User, request: Request) -> DBSession:
|
|
"""Create a new session for the user (valid for 24 hours)"""
|
|
import secrets
|
|
from datetime import datetime, timedelta
|
|
|
|
user_agent = request.headers.get("user-agent", "")
|
|
ip_address = request.client.host if request.client else "0.0.0.0"
|
|
|
|
# Create session token and set 24-hour expiry
|
|
session_token = secrets.token_urlsafe(32)
|
|
expires_at = datetime.utcnow() + timedelta(hours=24)
|
|
|
|
# Create the session object
|
|
session = DBSession(
|
|
token=session_token,
|
|
user_id=user.email,
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
expires_at=expires_at,
|
|
is_active=True
|
|
)
|
|
|
|
# Use the database session context manager
|
|
with get_db() as db:
|
|
try:
|
|
db.add(session)
|
|
db.commit()
|
|
db.refresh(session) # Ensure we have the latest data
|
|
return session
|
|
except Exception as e:
|
|
db.rollback()
|
|
# Debug messages disabled
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to create session"
|
|
)
|