Update 2025-04-13_16:25:39
This commit is contained in:
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
AuthenticationBackend,
|
||||
AuthenticationError,
|
||||
UnauthenticatedUser,
|
||||
)
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ["http", "websocket"]:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
conn = HTTPConnection(scope)
|
||||
try:
|
||||
auth_result = await self.backend.authenticate(conn)
|
||||
except AuthenticationError as exc:
|
||||
response = self.on_error(conn, exc)
|
||||
if scope["type"] == "websocket":
|
||||
await send({"type": "websocket.close", "code": 1000})
|
||||
else:
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
if auth_result is None:
|
||||
auth_result = AuthCredentials(), UnauthenticatedUser()
|
||||
scope["auth"], scope["user"] = auth_result
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
@staticmethod
|
||||
def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
|
||||
return PlainTextResponse(str(exc), status_code=400)
|
Reference in New Issue
Block a user