Update 2025-04-13_16:26:34
This commit is contained in:
1
venv/lib/python3.11/site-packages/starlette/__init__.py
Normal file
1
venv/lib/python3.11/site-packages/starlette/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
__version__ = "0.46.1"
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
ExceptionHandlers = dict[typing.Any, ExceptionHandler]
|
||||
StatusHandlers = dict[int, ExceptionHandler]
|
||||
|
||||
|
||||
def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
|
||||
for cls in type(exc).__mro__:
|
||||
if cls in exc_handlers:
|
||||
return exc_handlers[cls]
|
||||
return None
|
||||
|
||||
|
||||
def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
|
||||
exception_handlers: ExceptionHandlers
|
||||
status_handlers: StatusHandlers
|
||||
try:
|
||||
exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
|
||||
except KeyError:
|
||||
exception_handlers, status_handlers = {}, {}
|
||||
|
||||
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
response_started = False
|
||||
|
||||
async def sender(message: Message) -> None:
|
||||
nonlocal response_started
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await app(scope, receive, sender)
|
||||
except Exception as exc:
|
||||
handler = None
|
||||
|
||||
if isinstance(exc, HTTPException):
|
||||
handler = status_handlers.get(exc.status_code)
|
||||
|
||||
if handler is None:
|
||||
handler = _lookup_exception_handler(exception_handlers, exc)
|
||||
|
||||
if handler is None:
|
||||
raise exc
|
||||
|
||||
if response_started:
|
||||
raise RuntimeError("Caught handled exception, but response already started.") from exc
|
||||
|
||||
if is_async_callable(handler):
|
||||
response = await handler(conn, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, conn, exc) # type: ignore
|
||||
if response is not None:
|
||||
await response(scope, receive, sender)
|
||||
|
||||
return wrapped_app
|
100
venv/lib/python3.11/site-packages/starlette/_utils.py
Normal file
100
venv/lib/python3.11/site-packages/starlette/_utils.py
Normal file
@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
|
||||
from starlette.types import Scope
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import TypeGuard
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
has_exceptiongroups = True
|
||||
if sys.version_info < (3, 11): # pragma: no cover
|
||||
try:
|
||||
from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
|
||||
except ImportError:
|
||||
has_exceptiongroups = False
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
|
||||
|
||||
|
||||
@typing.overload
|
||||
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ...
|
||||
|
||||
|
||||
def is_async_callable(obj: typing.Any) -> typing.Any:
|
||||
while isinstance(obj, functools.partial):
|
||||
obj = obj.func
|
||||
|
||||
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__))
|
||||
|
||||
|
||||
T_co = typing.TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
|
||||
|
||||
|
||||
class SupportsAsyncClose(typing.Protocol):
|
||||
async def close(self) -> None: ... # pragma: no cover
|
||||
|
||||
|
||||
SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
|
||||
|
||||
|
||||
class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
|
||||
__slots__ = ("aw", "entered")
|
||||
|
||||
def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
|
||||
self.aw = aw
|
||||
|
||||
def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
|
||||
return self.aw.__await__()
|
||||
|
||||
async def __aenter__(self) -> SupportsAsyncCloseType:
|
||||
self.entered = await self.aw
|
||||
return self.entered
|
||||
|
||||
async def __aexit__(self, *args: typing.Any) -> None | bool:
|
||||
await self.entered.close()
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def collapse_excgroups() -> typing.Generator[None, None, None]:
|
||||
try:
|
||||
yield
|
||||
except BaseException as exc:
|
||||
if has_exceptiongroups: # pragma: no cover
|
||||
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
|
||||
exc = exc.exceptions[0]
|
||||
|
||||
raise exc
|
||||
|
||||
|
||||
def get_route_path(scope: Scope) -> str:
|
||||
path: str = scope["path"]
|
||||
root_path = scope.get("root_path", "")
|
||||
if not root_path:
|
||||
return path
|
||||
|
||||
if not path.startswith(root_path):
|
||||
return path
|
||||
|
||||
if path == root_path:
|
||||
return ""
|
||||
|
||||
if path[len(root_path)] == "/":
|
||||
return path[len(root_path) :]
|
||||
|
||||
return path
|
249
venv/lib/python3.11/site-packages/starlette/applications.py
Normal file
249
venv/lib/python3.11/site-packages/starlette/applications.py
Normal file
@ -0,0 +1,249 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from starlette.datastructures import State, URLPath
|
||||
from starlette.middleware import Middleware, _MiddlewareFactory
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.errors import ServerErrorMiddleware
|
||||
from starlette.middleware.exceptions import ExceptionMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import BaseRoute, Router
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
AppType = typing.TypeVar("AppType", bound="Starlette")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class Starlette:
|
||||
"""Creates an Starlette application."""
|
||||
|
||||
def __init__(
|
||||
self: AppType,
|
||||
debug: bool = False,
|
||||
routes: typing.Sequence[BaseRoute] | None = None,
|
||||
middleware: typing.Sequence[Middleware] | None = None,
|
||||
exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None,
|
||||
on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
|
||||
on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
|
||||
lifespan: Lifespan[AppType] | None = None,
|
||||
) -> None:
|
||||
"""Initializes the application.
|
||||
|
||||
Parameters:
|
||||
debug: Boolean indicating if debug tracebacks should be returned on errors.
|
||||
routes: A list of routes to serve incoming HTTP and WebSocket requests.
|
||||
middleware: A list of middleware to run for every request. A starlette
|
||||
application will always automatically include two middleware classes.
|
||||
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
|
||||
any uncaught errors occurring anywhere in the entire stack.
|
||||
`ExceptionMiddleware` is added as the very innermost middleware, to deal
|
||||
with handled exception cases occurring in the routing or endpoints.
|
||||
exception_handlers: A mapping of either integer status codes,
|
||||
or exception class types onto callables which handle the exceptions.
|
||||
Exception handler callables should be of the form
|
||||
`handler(request, exc) -> response` and may be either standard functions, or
|
||||
async functions.
|
||||
on_startup: A list of callables to run on application startup.
|
||||
Startup handler callables do not take any arguments, and may be either
|
||||
standard functions, or async functions.
|
||||
on_shutdown: A list of callables to run on application shutdown.
|
||||
Shutdown handler callables do not take any arguments, and may be either
|
||||
standard functions, or async functions.
|
||||
lifespan: A lifespan context function, which can be used to perform
|
||||
startup and shutdown tasks. This is a newer style that replaces the
|
||||
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
|
||||
"""
|
||||
# The lifespan context function is a newer style that replaces
|
||||
# on_startup / on_shutdown handlers. Use one or the other, not both.
|
||||
assert lifespan is None or (on_startup is None and on_shutdown is None), (
|
||||
"Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
|
||||
)
|
||||
|
||||
self.debug = debug
|
||||
self.state = State()
|
||||
self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
|
||||
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
|
||||
self.user_middleware = [] if middleware is None else list(middleware)
|
||||
self.middleware_stack: ASGIApp | None = None
|
||||
|
||||
def build_middleware_stack(self) -> ASGIApp:
|
||||
debug = self.debug
|
||||
error_handler = None
|
||||
exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {}
|
||||
|
||||
for key, value in self.exception_handlers.items():
|
||||
if key in (500, Exception):
|
||||
error_handler = value
|
||||
else:
|
||||
exception_handlers[key] = value
|
||||
|
||||
middleware = (
|
||||
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
|
||||
+ self.user_middleware
|
||||
+ [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
|
||||
)
|
||||
|
||||
app = self.router
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
app = cls(app, *args, **kwargs)
|
||||
return app
|
||||
|
||||
@property
|
||||
def routes(self) -> list[BaseRoute]:
|
||||
return self.router.routes
|
||||
|
||||
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
|
||||
return self.router.url_path_for(name, **path_params)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
scope["app"] = self
|
||||
if self.middleware_stack is None:
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
await self.middleware_stack(scope, receive, send)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
|
||||
return self.router.on_event(event_type) # pragma: no cover
|
||||
|
||||
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
|
||||
self.router.mount(path, app=app, name=name) # pragma: no cover
|
||||
|
||||
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
|
||||
self.router.host(host, app=app, name=name) # pragma: no cover
|
||||
|
||||
def add_middleware(
|
||||
self,
|
||||
middleware_class: _MiddlewareFactory[P],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
if self.middleware_stack is not None: # pragma: no cover
|
||||
raise RuntimeError("Cannot add middleware after an application has started")
|
||||
self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: ExceptionHandler,
|
||||
) -> None: # pragma: no cover
|
||||
self.exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
def add_event_handler(
|
||||
self,
|
||||
event_type: str,
|
||||
func: typing.Callable, # type: ignore[type-arg]
|
||||
) -> None: # pragma: no cover
|
||||
self.router.add_event_handler(event_type, func)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
path: str,
|
||||
route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
|
||||
methods: list[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None: # pragma: no cover
|
||||
self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
|
||||
|
||||
def add_websocket_route(
|
||||
self,
|
||||
path: str,
|
||||
route: typing.Callable[[WebSocket], typing.Awaitable[None]],
|
||||
name: str | None = None,
|
||||
) -> None: # pragma: no cover
|
||||
self.router.add_websocket_route(path, route, name=name)
|
||||
|
||||
def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_exception_handler(exc_class_or_status_code, func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def route(
|
||||
self,
|
||||
path: str,
|
||||
methods: list[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> typing.Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
|
||||
>>> routes = [Route(path, endpoint=...), ...]
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `route` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/routing/ for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.router.add_route(
|
||||
path,
|
||||
func,
|
||||
methods=methods,
|
||||
name=name,
|
||||
include_in_schema=include_in_schema,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
|
||||
>>> routes = [WebSocketRoute(path, endpoint=...), ...]
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.router.add_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
|
||||
>>> middleware = [Middleware(...), ...]
|
||||
>>> app = Starlette(middleware=middleware)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert middleware_type == "http", 'Currently only middleware("http") is supported.'
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
|
||||
return func
|
||||
|
||||
return decorator
|
147
venv/lib/python3.11/site-packages/starlette/authentication.py
Normal file
147
venv/lib/python3.11/site-packages/starlette/authentication.py
Normal file
@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
import typing
|
||||
from urllib.parse import urlencode
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
|
||||
for scope in scopes:
|
||||
if scope not in conn.auth.scopes:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def requires(
|
||||
scopes: str | typing.Sequence[str],
|
||||
status_code: int = 403,
|
||||
redirect: str | None = None,
|
||||
) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
|
||||
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
||||
|
||||
def decorator(
|
||||
func: typing.Callable[_P, typing.Any],
|
||||
) -> typing.Callable[_P, typing.Any]:
|
||||
sig = inspect.signature(func)
|
||||
for idx, parameter in enumerate(sig.parameters.values()):
|
||||
if parameter.name == "request" or parameter.name == "websocket":
|
||||
type_ = parameter.name
|
||||
break
|
||||
else:
|
||||
raise Exception(f'No "request" or "websocket" argument on function "{func}"')
|
||||
|
||||
if type_ == "websocket":
|
||||
# Handle websocket functions. (Always async)
|
||||
@functools.wraps(func)
|
||||
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(websocket, WebSocket)
|
||||
|
||||
if not has_required_scope(websocket, scopes_list):
|
||||
await websocket.close()
|
||||
else:
|
||||
await func(*args, **kwargs)
|
||||
|
||||
return websocket_wrapper
|
||||
|
||||
elif is_async_callable(func):
|
||||
# Handle async request/response functions.
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
|
||||
request = kwargs.get("request", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
orig_request_qparam = urlencode({"next": str(request.url)})
|
||||
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
|
||||
return RedirectResponse(url=next_url, status_code=303)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
# Handle sync request/response functions.
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
|
||||
request = kwargs.get("request", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
orig_request_qparam = urlencode({"next": str(request.url)})
|
||||
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
|
||||
return RedirectResponse(url=next_url, status_code=303)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationBackend:
|
||||
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class AuthCredentials:
|
||||
def __init__(self, scopes: typing.Sequence[str] | None = None):
|
||||
self.scopes = [] if scopes is None else list(scopes)
|
||||
|
||||
|
||||
class BaseUser:
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
@property
|
||||
def identity(self) -> str:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class SimpleUser(BaseUser):
|
||||
def __init__(self, username: str) -> None:
|
||||
self.username = username
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self.username
|
||||
|
||||
|
||||
class UnauthenticatedUser(BaseUser):
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return ""
|
41
venv/lib/python3.11/site-packages/starlette/background.py
Normal file
41
venv/lib/python3.11/site-packages/starlette/background.py
Normal file
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class BackgroundTask:
|
||||
def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.is_async = is_async_callable(func)
|
||||
|
||||
async def __call__(self) -> None:
|
||||
if self.is_async:
|
||||
await self.func(*self.args, **self.kwargs)
|
||||
else:
|
||||
await run_in_threadpool(self.func, *self.args, **self.kwargs)
|
||||
|
||||
|
||||
class BackgroundTasks(BackgroundTask):
|
||||
def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
|
||||
self.tasks = list(tasks) if tasks else []
|
||||
|
||||
def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
task = BackgroundTask(func, *args, **kwargs)
|
||||
self.tasks.append(task)
|
||||
|
||||
async def __call__(self) -> None:
|
||||
for task in self.tasks:
|
||||
await task()
|
62
venv/lib/python3.11/site-packages/starlette/concurrency.py
Normal file
62
venv/lib/python3.11/site-packages/starlette/concurrency.py
Normal file
@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
import anyio.to_thread
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"run_until_first_complete is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg]
|
||||
await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
|
||||
for func, kwargs in args:
|
||||
task_group.start_soon(run, functools.partial(func, **kwargs))
|
||||
|
||||
|
||||
async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
func = functools.partial(func, *args, **kwargs)
|
||||
return await anyio.to_thread.run_sync(func)
|
||||
|
||||
|
||||
class _StopIteration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _next(iterator: typing.Iterator[T]) -> T:
|
||||
# We can't raise `StopIteration` from within the threadpool iterator
|
||||
# and catch it outside that context, so we coerce them into a different
|
||||
# exception type.
|
||||
try:
|
||||
return next(iterator)
|
||||
except StopIteration:
|
||||
raise _StopIteration
|
||||
|
||||
|
||||
async def iterate_in_threadpool(
|
||||
iterator: typing.Iterable[T],
|
||||
) -> typing.AsyncIterator[T]:
|
||||
as_iterator = iter(iterator)
|
||||
while True:
|
||||
try:
|
||||
yield await anyio.to_thread.run_sync(_next, as_iterator)
|
||||
except _StopIteration:
|
||||
break
|
138
venv/lib/python3.11/site-packages/starlette/config.py
Normal file
138
venv/lib/python3.11/site-packages/starlette/config.py
Normal file
@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class undefined:
|
||||
pass
|
||||
|
||||
|
||||
class EnvironError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Environ(typing.MutableMapping[str, str]):
|
||||
def __init__(self, environ: typing.MutableMapping[str, str] = os.environ):
|
||||
self._environ = environ
|
||||
self._has_been_read: set[str] = set()
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
self._has_been_read.add(key)
|
||||
return self._environ.__getitem__(key)
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
|
||||
self._environ.__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
|
||||
self._environ.__delitem__(key)
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
return iter(self._environ)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._environ)
|
||||
|
||||
|
||||
environ = Environ()
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(
|
||||
self,
|
||||
env_file: str | Path | None = None,
|
||||
environ: typing.Mapping[str, str] = environ,
|
||||
env_prefix: str = "",
|
||||
) -> None:
|
||||
self.environ = environ
|
||||
self.env_prefix = env_prefix
|
||||
self.file_values: dict[str, str] = {}
|
||||
if env_file is not None:
|
||||
if not os.path.isfile(env_file):
|
||||
warnings.warn(f"Config file '{env_file}' not found.")
|
||||
else:
|
||||
self.file_values = self._read_file(env_file)
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, *, default: None) -> str | None: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(
|
||||
self,
|
||||
key: str,
|
||||
cast: typing.Callable[[typing.Any], T] = ...,
|
||||
default: typing.Any = ...,
|
||||
) -> T: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
key: str,
|
||||
cast: typing.Callable[[typing.Any], typing.Any] | None = None,
|
||||
default: typing.Any = undefined,
|
||||
) -> typing.Any:
|
||||
return self.get(key, cast, default)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
cast: typing.Callable[[typing.Any], typing.Any] | None = None,
|
||||
default: typing.Any = undefined,
|
||||
) -> typing.Any:
|
||||
key = self.env_prefix + key
|
||||
if key in self.environ:
|
||||
value = self.environ[key]
|
||||
return self._perform_cast(key, value, cast)
|
||||
if key in self.file_values:
|
||||
value = self.file_values[key]
|
||||
return self._perform_cast(key, value, cast)
|
||||
if default is not undefined:
|
||||
return self._perform_cast(key, default, cast)
|
||||
raise KeyError(f"Config '{key}' is missing, and has no default.")
|
||||
|
||||
def _read_file(self, file_name: str | Path) -> dict[str, str]:
|
||||
file_values: dict[str, str] = {}
|
||||
with open(file_name) as input_file:
|
||||
for line in input_file.readlines():
|
||||
line = line.strip()
|
||||
if "=" in line and not line.startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip().strip("\"'")
|
||||
file_values[key] = value
|
||||
return file_values
|
||||
|
||||
def _perform_cast(
|
||||
self,
|
||||
key: str,
|
||||
value: typing.Any,
|
||||
cast: typing.Callable[[typing.Any], typing.Any] | None = None,
|
||||
) -> typing.Any:
|
||||
if cast is None or value is None:
|
||||
return value
|
||||
elif cast is bool and isinstance(value, str):
|
||||
mapping = {"true": True, "1": True, "false": False, "0": False}
|
||||
value = value.lower()
|
||||
if value not in mapping:
|
||||
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
|
||||
return mapping[value]
|
||||
try:
|
||||
return cast(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")
|
89
venv/lib/python3.11/site-packages/starlette/convertors.py
Normal file
89
venv/lib/python3.11/site-packages/starlette/convertors.py
Normal file
@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
class Convertor(typing.Generic[T]):
|
||||
regex: typing.ClassVar[str] = ""
|
||||
|
||||
def convert(self, value: str) -> T:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def to_string(self, value: T) -> str:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class StringConvertor(Convertor[str]):
|
||||
regex = "[^/]+"
|
||||
|
||||
def convert(self, value: str) -> str:
|
||||
return value
|
||||
|
||||
def to_string(self, value: str) -> str:
|
||||
value = str(value)
|
||||
assert "/" not in value, "May not contain path separators"
|
||||
assert value, "Must not be empty"
|
||||
return value
|
||||
|
||||
|
||||
class PathConvertor(Convertor[str]):
|
||||
regex = ".*"
|
||||
|
||||
def convert(self, value: str) -> str:
|
||||
return str(value)
|
||||
|
||||
def to_string(self, value: str) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
class IntegerConvertor(Convertor[int]):
|
||||
regex = "[0-9]+"
|
||||
|
||||
def convert(self, value: str) -> int:
|
||||
return int(value)
|
||||
|
||||
def to_string(self, value: int) -> str:
|
||||
value = int(value)
|
||||
assert value >= 0, "Negative integers are not supported"
|
||||
return str(value)
|
||||
|
||||
|
||||
class FloatConvertor(Convertor[float]):
|
||||
regex = r"[0-9]+(\.[0-9]+)?"
|
||||
|
||||
def convert(self, value: str) -> float:
|
||||
return float(value)
|
||||
|
||||
def to_string(self, value: float) -> str:
|
||||
value = float(value)
|
||||
assert value >= 0.0, "Negative floats are not supported"
|
||||
assert not math.isnan(value), "NaN values are not supported"
|
||||
assert not math.isinf(value), "Infinite values are not supported"
|
||||
return ("%0.20f" % value).rstrip("0").rstrip(".")
|
||||
|
||||
|
||||
class UUIDConvertor(Convertor[uuid.UUID]):
|
||||
regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}"
|
||||
|
||||
def convert(self, value: str) -> uuid.UUID:
|
||||
return uuid.UUID(value)
|
||||
|
||||
def to_string(self, value: uuid.UUID) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = {
|
||||
"str": StringConvertor(),
|
||||
"path": PathConvertor(),
|
||||
"int": IntegerConvertor(),
|
||||
"float": FloatConvertor(),
|
||||
"uuid": UUIDConvertor(),
|
||||
}
|
||||
|
||||
|
||||
def register_url_convertor(key: str, convertor: Convertor[typing.Any]) -> None:
|
||||
CONVERTOR_TYPES[key] = convertor
|
674
venv/lib/python3.11/site-packages/starlette/datastructures.py
Normal file
674
venv/lib/python3.11/site-packages/starlette/datastructures.py
Normal file
@ -0,0 +1,674 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from shlex import shlex
|
||||
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.types import Scope
|
||||
|
||||
|
||||
class Address(typing.NamedTuple):
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
_KeyType = typing.TypeVar("_KeyType")
|
||||
# Mapping keys are invariant but their values are covariant since
|
||||
# you can only read them
|
||||
# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
|
||||
_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
|
||||
|
||||
|
||||
class URL:
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "",
|
||||
scope: Scope | None = None,
|
||||
**components: typing.Any,
|
||||
) -> None:
|
||||
if scope is not None:
|
||||
assert not url, 'Cannot set both "url" and "scope".'
|
||||
assert not components, 'Cannot set both "scope" and "**components".'
|
||||
scheme = scope.get("scheme", "http")
|
||||
server = scope.get("server", None)
|
||||
path = scope["path"]
|
||||
query_string = scope.get("query_string", b"")
|
||||
|
||||
host_header = None
|
||||
for key, value in scope["headers"]:
|
||||
if key == b"host":
|
||||
host_header = value.decode("latin-1")
|
||||
break
|
||||
|
||||
if host_header is not None:
|
||||
url = f"{scheme}://{host_header}{path}"
|
||||
elif server is None:
|
||||
url = path
|
||||
else:
|
||||
host, port = server
|
||||
default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
|
||||
if port == default_port:
|
||||
url = f"{scheme}://{host}{path}"
|
||||
else:
|
||||
url = f"{scheme}://{host}:{port}{path}"
|
||||
|
||||
if query_string:
|
||||
url += "?" + query_string.decode()
|
||||
elif components:
|
||||
assert not url, 'Cannot set both "url" and "**components".'
|
||||
url = URL("").replace(**components).components.geturl()
|
||||
|
||||
self._url = url
|
||||
|
||||
@property
|
||||
def components(self) -> SplitResult:
|
||||
if not hasattr(self, "_components"):
|
||||
self._components = urlsplit(self._url)
|
||||
return self._components
|
||||
|
||||
@property
|
||||
def scheme(self) -> str:
|
||||
return self.components.scheme
|
||||
|
||||
@property
|
||||
def netloc(self) -> str:
|
||||
return self.components.netloc
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self.components.path
|
||||
|
||||
@property
|
||||
def query(self) -> str:
|
||||
return self.components.query
|
||||
|
||||
@property
|
||||
def fragment(self) -> str:
|
||||
return self.components.fragment
|
||||
|
||||
@property
|
||||
def username(self) -> None | str:
|
||||
return self.components.username
|
||||
|
||||
@property
|
||||
def password(self) -> None | str:
|
||||
return self.components.password
|
||||
|
||||
@property
|
||||
def hostname(self) -> None | str:
|
||||
return self.components.hostname
|
||||
|
||||
@property
|
||||
def port(self) -> int | None:
|
||||
return self.components.port
|
||||
|
||||
@property
|
||||
def is_secure(self) -> bool:
|
||||
return self.scheme in ("https", "wss")
|
||||
|
||||
def replace(self, **kwargs: typing.Any) -> URL:
|
||||
if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
|
||||
hostname = kwargs.pop("hostname", None)
|
||||
port = kwargs.pop("port", self.port)
|
||||
username = kwargs.pop("username", self.username)
|
||||
password = kwargs.pop("password", self.password)
|
||||
|
||||
if hostname is None:
|
||||
netloc = self.netloc
|
||||
_, _, hostname = netloc.rpartition("@")
|
||||
|
||||
if hostname[-1] != "]":
|
||||
hostname = hostname.rsplit(":", 1)[0]
|
||||
|
||||
netloc = hostname
|
||||
if port is not None:
|
||||
netloc += f":{port}"
|
||||
if username is not None:
|
||||
userpass = username
|
||||
if password is not None:
|
||||
userpass += f":{password}"
|
||||
netloc = f"{userpass}@{netloc}"
|
||||
|
||||
kwargs["netloc"] = netloc
|
||||
|
||||
components = self.components._replace(**kwargs)
|
||||
return self.__class__(components.geturl())
|
||||
|
||||
def include_query_params(self, **kwargs: typing.Any) -> URL:
|
||||
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
|
||||
params.update({str(key): str(value) for key, value in kwargs.items()})
|
||||
query = urlencode(params.multi_items())
|
||||
return self.replace(query=query)
|
||||
|
||||
def replace_query_params(self, **kwargs: typing.Any) -> URL:
|
||||
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
|
||||
return self.replace(query=query)
|
||||
|
||||
def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
|
||||
for key in keys:
|
||||
params.pop(key, None)
|
||||
query = urlencode(params.multi_items())
|
||||
return self.replace(query=query)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return str(self) == str(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._url
|
||||
|
||||
def __repr__(self) -> str:
|
||||
url = str(self)
|
||||
if self.password:
|
||||
url = str(self.replace(password="********"))
|
||||
return f"{self.__class__.__name__}({repr(url)})"
|
||||
|
||||
|
||||
class URLPath(str):
|
||||
"""
|
||||
A URL path string that may also hold an associated protocol and/or host.
|
||||
Used by the routing to return `url_path_for` matches.
|
||||
"""
|
||||
|
||||
def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
|
||||
assert protocol in ("http", "websocket", "")
|
||||
return str.__new__(cls, path)
|
||||
|
||||
def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
|
||||
self.protocol = protocol
|
||||
self.host = host
|
||||
|
||||
def make_absolute_url(self, base_url: str | URL) -> URL:
|
||||
if isinstance(base_url, str):
|
||||
base_url = URL(base_url)
|
||||
if self.protocol:
|
||||
scheme = {
|
||||
"http": {True: "https", False: "http"},
|
||||
"websocket": {True: "wss", False: "ws"},
|
||||
}[self.protocol][base_url.is_secure]
|
||||
else:
|
||||
scheme = base_url.scheme
|
||||
|
||||
netloc = self.host or base_url.netloc
|
||||
path = base_url.path.rstrip("/") + str(self)
|
||||
return URL(scheme=scheme, netloc=netloc, path=path)
|
||||
|
||||
|
||||
class Secret:
|
||||
"""
|
||||
Holds a string value that should not be revealed in tracebacks etc.
|
||||
You should cast the value to `str` at the point it is required.
|
||||
"""
|
||||
|
||||
def __init__(self, value: str):
|
||||
self._value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}('**********')"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._value
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._value)
|
||||
|
||||
|
||||
class CommaSeparatedStrings(typing.Sequence[str]):
|
||||
def __init__(self, value: str | typing.Sequence[str]):
|
||||
if isinstance(value, str):
|
||||
splitter = shlex(value, posix=True)
|
||||
splitter.whitespace = ","
|
||||
splitter.whitespace_split = True
|
||||
self._items = [item.strip() for item in splitter]
|
||||
else:
|
||||
self._items = list(value)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._items)
|
||||
|
||||
def __getitem__(self, index: int | slice) -> typing.Any:
|
||||
return self._items[index]
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
return iter(self._items)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
items = [item for item in self]
|
||||
return f"{class_name}({items!r})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ", ".join(repr(item) for item in self)
|
||||
|
||||
|
||||
class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
|
||||
_dict: dict[_KeyType, _CovariantValueType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: ImmutableMultiDict[_KeyType, _CovariantValueType]
|
||||
| typing.Mapping[_KeyType, _CovariantValueType]
|
||||
| typing.Iterable[tuple[_KeyType, _CovariantValueType]],
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
assert len(args) < 2, "Too many arguments."
|
||||
|
||||
value: typing.Any = args[0] if args else []
|
||||
if kwargs:
|
||||
value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
|
||||
|
||||
if not value:
|
||||
_items: list[tuple[typing.Any, typing.Any]] = []
|
||||
elif hasattr(value, "multi_items"):
|
||||
value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
|
||||
_items = list(value.multi_items())
|
||||
elif hasattr(value, "items"):
|
||||
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
|
||||
_items = list(value.items())
|
||||
else:
|
||||
value = typing.cast("list[tuple[typing.Any, typing.Any]]", value)
|
||||
_items = list(value)
|
||||
|
||||
self._dict = {k: v for k, v in _items}
|
||||
self._list = _items
|
||||
|
||||
def getlist(self, key: typing.Any) -> list[_CovariantValueType]:
|
||||
return [item_value for item_key, item_value in self._list if item_key == key]
|
||||
|
||||
def keys(self) -> typing.KeysView[_KeyType]:
|
||||
return self._dict.keys()
|
||||
|
||||
def values(self) -> typing.ValuesView[_CovariantValueType]:
|
||||
return self._dict.values()
|
||||
|
||||
def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
|
||||
return self._dict.items()
|
||||
|
||||
def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
|
||||
return list(self._list)
|
||||
|
||||
def __getitem__(self, key: _KeyType) -> _CovariantValueType:
|
||||
return self._dict[key]
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
return key in self._dict
|
||||
|
||||
def __iter__(self) -> typing.Iterator[_KeyType]:
|
||||
return iter(self.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dict)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
items = self.multi_items()
|
||||
return f"{class_name}({items!r})"
|
||||
|
||||
|
||||
class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
|
||||
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
|
||||
self.setlist(key, [value])
|
||||
|
||||
def __delitem__(self, key: typing.Any) -> None:
|
||||
self._list = [(k, v) for k, v in self._list if k != key]
|
||||
del self._dict[key]
|
||||
|
||||
def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
|
||||
self._list = [(k, v) for k, v in self._list if k != key]
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
def popitem(self) -> tuple[typing.Any, typing.Any]:
|
||||
key, value = self._dict.popitem()
|
||||
self._list = [(k, v) for k, v in self._list if k != key]
|
||||
return key, value
|
||||
|
||||
def poplist(self, key: typing.Any) -> list[typing.Any]:
|
||||
values = [v for k, v in self._list if k == key]
|
||||
self.pop(key)
|
||||
return values
|
||||
|
||||
def clear(self) -> None:
|
||||
self._dict.clear()
|
||||
self._list.clear()
|
||||
|
||||
def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
|
||||
if key not in self:
|
||||
self._dict[key] = default
|
||||
self._list.append((key, default))
|
||||
|
||||
return self[key]
|
||||
|
||||
def setlist(self, key: typing.Any, values: list[typing.Any]) -> None:
|
||||
if not values:
|
||||
self.pop(key, None)
|
||||
else:
|
||||
existing_items = [(k, v) for (k, v) in self._list if k != key]
|
||||
self._list = existing_items + [(key, value) for value in values]
|
||||
self._dict[key] = values[-1]
|
||||
|
||||
def append(self, key: typing.Any, value: typing.Any) -> None:
|
||||
self._list.append((key, value))
|
||||
self._dict[key] = value
|
||||
|
||||
def update(
|
||||
self,
|
||||
*args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
value = MultiDict(*args, **kwargs)
|
||||
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
|
||||
self._list = existing_items + value.multi_items()
|
||||
self._dict.update(value)
|
||||
|
||||
|
||||
class QueryParams(ImmutableMultiDict[str, str]):
|
||||
"""
|
||||
An immutable multidict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: ImmutableMultiDict[typing.Any, typing.Any]
|
||||
| typing.Mapping[typing.Any, typing.Any]
|
||||
| list[tuple[typing.Any, typing.Any]]
|
||||
| str
|
||||
| bytes,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
assert len(args) < 2, "Too many arguments."
|
||||
|
||||
value = args[0] if args else []
|
||||
|
||||
if isinstance(value, str):
|
||||
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
|
||||
elif isinstance(value, bytes):
|
||||
super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs) # type: ignore[arg-type]
|
||||
self._list = [(str(k), str(v)) for k, v in self._list]
|
||||
self._dict = {str(k): str(v) for k, v in self._dict.items()}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return urlencode(self._list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
query_string = str(self)
|
||||
return f"{class_name}({query_string!r})"
|
||||
|
||||
|
||||
class UploadFile:
|
||||
"""
|
||||
An uploaded file included as part of the request data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file: typing.BinaryIO,
|
||||
*,
|
||||
size: int | None = None,
|
||||
filename: str | None = None,
|
||||
headers: Headers | None = None,
|
||||
) -> None:
|
||||
self.filename = filename
|
||||
self.file = file
|
||||
self.size = size
|
||||
self.headers = headers or Headers()
|
||||
|
||||
@property
|
||||
def content_type(self) -> str | None:
|
||||
return self.headers.get("content-type", None)
|
||||
|
||||
@property
|
||||
def _in_memory(self) -> bool:
|
||||
# check for SpooledTemporaryFile._rolled
|
||||
rolled_to_disk = getattr(self.file, "_rolled", True)
|
||||
return not rolled_to_disk
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
if self.size is not None:
|
||||
self.size += len(data)
|
||||
|
||||
if self._in_memory:
|
||||
self.file.write(data)
|
||||
else:
|
||||
await run_in_threadpool(self.file.write, data)
|
||||
|
||||
async def read(self, size: int = -1) -> bytes:
|
||||
if self._in_memory:
|
||||
return self.file.read(size)
|
||||
return await run_in_threadpool(self.file.read, size)
|
||||
|
||||
async def seek(self, offset: int) -> None:
|
||||
if self._in_memory:
|
||||
self.file.seek(offset)
|
||||
else:
|
||||
await run_in_threadpool(self.file.seek, offset)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._in_memory:
|
||||
self.file.close()
|
||||
else:
|
||||
await run_in_threadpool(self.file.close)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
|
||||
|
||||
|
||||
class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
|
||||
"""
|
||||
An immutable multidict, containing both file uploads and text input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
|
||||
**kwargs: str | UploadFile,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def close(self) -> None:
|
||||
for key, value in self.multi_items():
|
||||
if isinstance(value, UploadFile):
|
||||
await value.close()
|
||||
|
||||
|
||||
class Headers(typing.Mapping[str, str]):
|
||||
"""
|
||||
An immutable, case-insensitive multidict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
raw: list[tuple[bytes, bytes]] | None = None,
|
||||
scope: typing.MutableMapping[str, typing.Any] | None = None,
|
||||
) -> None:
|
||||
self._list: list[tuple[bytes, bytes]] = []
|
||||
if headers is not None:
|
||||
assert raw is None, 'Cannot set both "headers" and "raw".'
|
||||
assert scope is None, 'Cannot set both "headers" and "scope".'
|
||||
self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
|
||||
elif raw is not None:
|
||||
assert scope is None, 'Cannot set both "raw" and "scope".'
|
||||
self._list = raw
|
||||
elif scope is not None:
|
||||
# scope["headers"] isn't necessarily a list
|
||||
# it might be a tuple or other iterable
|
||||
self._list = scope["headers"] = list(scope["headers"])
|
||||
|
||||
@property
|
||||
def raw(self) -> list[tuple[bytes, bytes]]:
|
||||
return list(self._list)
|
||||
|
||||
def keys(self) -> list[str]: # type: ignore[override]
|
||||
return [key.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def values(self) -> list[str]: # type: ignore[override]
|
||||
return [value.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
|
||||
return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
|
||||
|
||||
def getlist(self, key: str) -> list[str]:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
|
||||
|
||||
def mutablecopy(self) -> MutableHeaders:
|
||||
return MutableHeaders(raw=self._list[:])
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return header_value.decode("latin-1")
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||||
return iter(self.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._list)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
if not isinstance(other, Headers):
|
||||
return False
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
as_dict = dict(self.items())
|
||||
if len(as_dict) == len(self):
|
||||
return f"{class_name}({as_dict!r})"
|
||||
return f"{class_name}(raw={self.raw!r})"
|
||||
|
||||
|
||||
class MutableHeaders(Headers):
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set the header `key` to `value`, removing any duplicate entries.
|
||||
Retains insertion order.
|
||||
"""
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
found_indexes: list[int] = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
found_indexes.append(idx)
|
||||
|
||||
for idx in reversed(found_indexes[1:]):
|
||||
del self._list[idx]
|
||||
|
||||
if found_indexes:
|
||||
idx = found_indexes[0]
|
||||
self._list[idx] = (set_key, set_value)
|
||||
else:
|
||||
self._list.append((set_key, set_value))
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""
|
||||
Remove the header `key`.
|
||||
"""
|
||||
del_key = key.lower().encode("latin-1")
|
||||
|
||||
pop_indexes: list[int] = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == del_key:
|
||||
pop_indexes.append(idx)
|
||||
|
||||
for idx in reversed(pop_indexes):
|
||||
del self._list[idx]
|
||||
|
||||
def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
|
||||
if not isinstance(other, typing.Mapping):
|
||||
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
|
||||
self.update(other)
|
||||
return self
|
||||
|
||||
def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
|
||||
if not isinstance(other, typing.Mapping):
|
||||
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
|
||||
new = self.mutablecopy()
|
||||
new.update(other)
|
||||
return new
|
||||
|
||||
@property
|
||||
def raw(self) -> list[tuple[bytes, bytes]]:
|
||||
return self._list
|
||||
|
||||
def setdefault(self, key: str, value: str) -> str:
|
||||
"""
|
||||
If the header `key` does not exist, then set it to `value`.
|
||||
Returns the header value.
|
||||
"""
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
return item_value.decode("latin-1")
|
||||
self._list.append((set_key, set_value))
|
||||
return value
|
||||
|
||||
def update(self, other: typing.Mapping[str, str]) -> None:
|
||||
for key, val in other.items():
|
||||
self[key] = val
|
||||
|
||||
def append(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Append a header, preserving any duplicate entries.
|
||||
"""
|
||||
append_key = key.lower().encode("latin-1")
|
||||
append_value = value.encode("latin-1")
|
||||
self._list.append((append_key, append_value))
|
||||
|
||||
def add_vary_header(self, vary: str) -> None:
|
||||
existing = self.get("vary")
|
||||
if existing is not None:
|
||||
vary = ", ".join([existing, vary])
|
||||
self["vary"] = vary
|
||||
|
||||
|
||||
class State:
|
||||
"""
|
||||
An object that can be used to store arbitrary state.
|
||||
|
||||
Used for `request.state` and `app.state`.
|
||||
"""
|
||||
|
||||
_state: dict[str, typing.Any]
|
||||
|
||||
def __init__(self, state: dict[str, typing.Any] | None = None):
|
||||
if state is None:
|
||||
state = {}
|
||||
super().__setattr__("_state", state)
|
||||
|
||||
def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
|
||||
self._state[key] = value
|
||||
|
||||
def __getattr__(self, key: typing.Any) -> typing.Any:
|
||||
try:
|
||||
return self._state[key]
|
||||
except KeyError:
|
||||
message = "'{}' object has no attribute '{}'"
|
||||
raise AttributeError(message.format(self.__class__.__name__, key))
|
||||
|
||||
def __delattr__(self, key: typing.Any) -> None:
|
||||
del self._state[key]
|
122
venv/lib/python3.11/site-packages/starlette/endpoints.py
Normal file
122
venv/lib/python3.11/site-packages/starlette/endpoints.py
Normal file
@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
|
||||
from starlette import status
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
class HTTPEndpoint:
|
||||
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] == "http"
|
||||
self.scope = scope
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
self._allowed_methods = [
|
||||
method
|
||||
for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
|
||||
if getattr(self, method.lower(), None) is not None
|
||||
]
|
||||
|
||||
def __await__(self) -> typing.Generator[typing.Any, None, None]:
|
||||
return self.dispatch().__await__()
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
request = Request(self.scope, receive=self.receive)
|
||||
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
|
||||
|
||||
handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
|
||||
is_async = is_async_callable(handler)
|
||||
if is_async:
|
||||
response = await handler(request)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, request)
|
||||
await response(self.scope, self.receive, self.send)
|
||||
|
||||
async def method_not_allowed(self, request: Request) -> Response:
|
||||
# If we're running inside a starlette application then raise an
|
||||
# exception, so that the configurable exception handler can deal with
|
||||
# returning the response. For plain ASGI apps, just return the response.
|
||||
headers = {"Allow": ", ".join(self._allowed_methods)}
|
||||
if "app" in self.scope:
|
||||
raise HTTPException(status_code=405, headers=headers)
|
||||
return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
|
||||
|
||||
|
||||
class WebSocketEndpoint:
|
||||
encoding: str | None = None # May be "text", "bytes", or "json".
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] == "websocket"
|
||||
self.scope = scope
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
|
||||
def __await__(self) -> typing.Generator[typing.Any, None, None]:
|
||||
return self.dispatch().__await__()
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
|
||||
await self.on_connect(websocket)
|
||||
|
||||
close_code = status.WS_1000_NORMAL_CLOSURE
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
if message["type"] == "websocket.receive":
|
||||
data = await self.decode(websocket, message)
|
||||
await self.on_receive(websocket, data)
|
||||
elif message["type"] == "websocket.disconnect": # pragma: no branch
|
||||
close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
|
||||
break
|
||||
except Exception as exc:
|
||||
close_code = status.WS_1011_INTERNAL_ERROR
|
||||
raise exc
|
||||
finally:
|
||||
await self.on_disconnect(websocket, close_code)
|
||||
|
||||
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
|
||||
if self.encoding == "text":
|
||||
if "text" not in message:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
raise RuntimeError("Expected text websocket messages, but got bytes")
|
||||
return message["text"]
|
||||
|
||||
elif self.encoding == "bytes":
|
||||
if "bytes" not in message:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
raise RuntimeError("Expected bytes websocket messages, but got text")
|
||||
return message["bytes"]
|
||||
|
||||
elif self.encoding == "json":
|
||||
if message.get("text") is not None:
|
||||
text = message["text"]
|
||||
else:
|
||||
text = message["bytes"].decode("utf-8")
|
||||
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.decoder.JSONDecodeError:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
raise RuntimeError("Malformed JSON data received.")
|
||||
|
||||
assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
|
||||
return message["text"] if message.get("text") else message["bytes"]
|
||||
|
||||
async def on_connect(self, websocket: WebSocket) -> None:
|
||||
"""Override to handle an incoming websocket connection"""
|
||||
await websocket.accept()
|
||||
|
||||
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
|
||||
"""Override to handle an incoming websocket message"""
|
||||
|
||||
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
|
||||
"""Override to handle a disconnecting websocket"""
|
33
venv/lib/python3.11/site-packages/starlette/exceptions.py
Normal file
33
venv/lib/python3.11/site-packages/starlette/exceptions.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import http
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
class HTTPException(Exception):
|
||||
def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None:
|
||||
if detail is None:
|
||||
detail = http.HTTPStatus(status_code).phrase
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.status_code}: {self.detail}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
|
||||
|
||||
|
||||
class WebSocketException(Exception):
|
||||
def __init__(self, code: int, reason: str | None = None) -> None:
|
||||
self.code = code
|
||||
self.reason = reason or ""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.code}: {self.reason}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"
|
275
venv/lib/python3.11/site-packages/starlette/formparsers.py
Normal file
275
venv/lib/python3.11/site-packages/starlette/formparsers.py
Normal file
@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from tempfile import SpooledTemporaryFile
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from starlette.datastructures import FormData, Headers, UploadFile
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import python_multipart as multipart
|
||||
from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
|
||||
else:
|
||||
try:
|
||||
try:
|
||||
import python_multipart as multipart
|
||||
from python_multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import multipart
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
multipart = None
|
||||
parse_options_header = None
|
||||
|
||||
|
||||
class FormMessage(Enum):
|
||||
FIELD_START = 1
|
||||
FIELD_NAME = 2
|
||||
FIELD_DATA = 3
|
||||
FIELD_END = 4
|
||||
END = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultipartPart:
|
||||
content_disposition: bytes | None = None
|
||||
field_name: str = ""
|
||||
data: bytearray = field(default_factory=bytearray)
|
||||
file: UploadFile | None = None
|
||||
item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
|
||||
|
||||
|
||||
def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
|
||||
try:
|
||||
return src.decode(codec)
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
return src.decode("latin-1")
|
||||
|
||||
|
||||
class MultiPartException(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
|
||||
|
||||
class FormParser:
|
||||
def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
|
||||
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
|
||||
self.headers = headers
|
||||
self.stream = stream
|
||||
self.messages: list[tuple[FormMessage, bytes]] = []
|
||||
|
||||
def on_field_start(self) -> None:
|
||||
message = (FormMessage.FIELD_START, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
def on_field_name(self, data: bytes, start: int, end: int) -> None:
|
||||
message = (FormMessage.FIELD_NAME, data[start:end])
|
||||
self.messages.append(message)
|
||||
|
||||
def on_field_data(self, data: bytes, start: int, end: int) -> None:
|
||||
message = (FormMessage.FIELD_DATA, data[start:end])
|
||||
self.messages.append(message)
|
||||
|
||||
def on_field_end(self) -> None:
|
||||
message = (FormMessage.FIELD_END, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
def on_end(self) -> None:
|
||||
message = (FormMessage.END, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
async def parse(self) -> FormData:
|
||||
# Callbacks dictionary.
|
||||
callbacks: QuerystringCallbacks = {
|
||||
"on_field_start": self.on_field_start,
|
||||
"on_field_name": self.on_field_name,
|
||||
"on_field_data": self.on_field_data,
|
||||
"on_field_end": self.on_field_end,
|
||||
"on_end": self.on_end,
|
||||
}
|
||||
|
||||
# Create the parser.
|
||||
parser = multipart.QuerystringParser(callbacks)
|
||||
field_name = b""
|
||||
field_value = b""
|
||||
|
||||
items: list[tuple[str, str | UploadFile]] = []
|
||||
|
||||
# Feed the parser with data from the request.
|
||||
async for chunk in self.stream:
|
||||
if chunk:
|
||||
parser.write(chunk)
|
||||
else:
|
||||
parser.finalize()
|
||||
messages = list(self.messages)
|
||||
self.messages.clear()
|
||||
for message_type, message_bytes in messages:
|
||||
if message_type == FormMessage.FIELD_START:
|
||||
field_name = b""
|
||||
field_value = b""
|
||||
elif message_type == FormMessage.FIELD_NAME:
|
||||
field_name += message_bytes
|
||||
elif message_type == FormMessage.FIELD_DATA:
|
||||
field_value += message_bytes
|
||||
elif message_type == FormMessage.FIELD_END:
|
||||
name = unquote_plus(field_name.decode("latin-1"))
|
||||
value = unquote_plus(field_value.decode("latin-1"))
|
||||
items.append((name, value))
|
||||
|
||||
return FormData(items)
|
||||
|
||||
|
||||
class MultiPartParser:
|
||||
spool_max_size = 1024 * 1024 # 1MB
|
||||
"""The maximum size of the spooled temporary file used to store file data."""
|
||||
max_part_size = 1024 * 1024 # 1MB
|
||||
"""The maximum size of a part in the multipart request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: Headers,
|
||||
stream: typing.AsyncGenerator[bytes, None],
|
||||
*,
|
||||
max_files: int | float = 1000,
|
||||
max_fields: int | float = 1000,
|
||||
max_part_size: int = 1024 * 1024, # 1MB
|
||||
) -> None:
|
||||
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
|
||||
self.headers = headers
|
||||
self.stream = stream
|
||||
self.max_files = max_files
|
||||
self.max_fields = max_fields
|
||||
self.items: list[tuple[str, str | UploadFile]] = []
|
||||
self._current_files = 0
|
||||
self._current_fields = 0
|
||||
self._current_partial_header_name: bytes = b""
|
||||
self._current_partial_header_value: bytes = b""
|
||||
self._current_part = MultipartPart()
|
||||
self._charset = ""
|
||||
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
|
||||
self._file_parts_to_finish: list[MultipartPart] = []
|
||||
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
|
||||
self.max_part_size = max_part_size
|
||||
|
||||
def on_part_begin(self) -> None:
|
||||
self._current_part = MultipartPart()
|
||||
|
||||
def on_part_data(self, data: bytes, start: int, end: int) -> None:
|
||||
message_bytes = data[start:end]
|
||||
if self._current_part.file is None:
|
||||
if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
|
||||
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
|
||||
self._current_part.data.extend(message_bytes)
|
||||
else:
|
||||
self._file_parts_to_write.append((self._current_part, message_bytes))
|
||||
|
||||
def on_part_end(self) -> None:
|
||||
if self._current_part.file is None:
|
||||
self.items.append(
|
||||
(
|
||||
self._current_part.field_name,
|
||||
_user_safe_decode(self._current_part.data, self._charset),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._file_parts_to_finish.append(self._current_part)
|
||||
# The file can be added to the items right now even though it's not
|
||||
# finished yet, because it will be finished in the `parse()` method, before
|
||||
# self.items is used in the return value.
|
||||
self.items.append((self._current_part.field_name, self._current_part.file))
|
||||
|
||||
def on_header_field(self, data: bytes, start: int, end: int) -> None:
|
||||
self._current_partial_header_name += data[start:end]
|
||||
|
||||
def on_header_value(self, data: bytes, start: int, end: int) -> None:
|
||||
self._current_partial_header_value += data[start:end]
|
||||
|
||||
def on_header_end(self) -> None:
|
||||
field = self._current_partial_header_name.lower()
|
||||
if field == b"content-disposition":
|
||||
self._current_part.content_disposition = self._current_partial_header_value
|
||||
self._current_part.item_headers.append((field, self._current_partial_header_value))
|
||||
self._current_partial_header_name = b""
|
||||
self._current_partial_header_value = b""
|
||||
|
||||
def on_headers_finished(self) -> None:
|
||||
disposition, options = parse_options_header(self._current_part.content_disposition)
|
||||
try:
|
||||
self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
|
||||
except KeyError:
|
||||
raise MultiPartException('The Content-Disposition header field "name" must be provided.')
|
||||
if b"filename" in options:
|
||||
self._current_files += 1
|
||||
if self._current_files > self.max_files:
|
||||
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
|
||||
filename = _user_safe_decode(options[b"filename"], self._charset)
|
||||
tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
|
||||
self._files_to_close_on_error.append(tempfile)
|
||||
self._current_part.file = UploadFile(
|
||||
file=tempfile, # type: ignore[arg-type]
|
||||
size=0,
|
||||
filename=filename,
|
||||
headers=Headers(raw=self._current_part.item_headers),
|
||||
)
|
||||
else:
|
||||
self._current_fields += 1
|
||||
if self._current_fields > self.max_fields:
|
||||
raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
|
||||
self._current_part.file = None
|
||||
|
||||
def on_end(self) -> None:
|
||||
pass
|
||||
|
||||
async def parse(self) -> FormData:
|
||||
# Parse the Content-Type header to get the multipart boundary.
|
||||
_, params = parse_options_header(self.headers["Content-Type"])
|
||||
charset = params.get(b"charset", "utf-8")
|
||||
if isinstance(charset, bytes):
|
||||
charset = charset.decode("latin-1")
|
||||
self._charset = charset
|
||||
try:
|
||||
boundary = params[b"boundary"]
|
||||
except KeyError:
|
||||
raise MultiPartException("Missing boundary in multipart.")
|
||||
|
||||
# Callbacks dictionary.
|
||||
callbacks: MultipartCallbacks = {
|
||||
"on_part_begin": self.on_part_begin,
|
||||
"on_part_data": self.on_part_data,
|
||||
"on_part_end": self.on_part_end,
|
||||
"on_header_field": self.on_header_field,
|
||||
"on_header_value": self.on_header_value,
|
||||
"on_header_end": self.on_header_end,
|
||||
"on_headers_finished": self.on_headers_finished,
|
||||
"on_end": self.on_end,
|
||||
}
|
||||
|
||||
# Create the parser.
|
||||
parser = multipart.MultipartParser(boundary, callbacks)
|
||||
try:
|
||||
# Feed the parser with data from the request.
|
||||
async for chunk in self.stream:
|
||||
parser.write(chunk)
|
||||
# Write file data, it needs to use await with the UploadFile methods
|
||||
# that call the corresponding file methods *in a threadpool*,
|
||||
# otherwise, if they were called directly in the callback methods above
|
||||
# (regular, non-async functions), that would block the event loop in
|
||||
# the main thread.
|
||||
for part, data in self._file_parts_to_write:
|
||||
assert part.file # for type checkers
|
||||
await part.file.write(data)
|
||||
for part in self._file_parts_to_finish:
|
||||
assert part.file # for type checkers
|
||||
await part.file.seek(0)
|
||||
self._file_parts_to_write.clear()
|
||||
self._file_parts_to_finish.clear()
|
||||
except MultiPartException as exc:
|
||||
# Close all the files if there was an error.
|
||||
for file in self._files_to_close_on_error:
|
||||
file.close()
|
||||
raise exc
|
||||
|
||||
parser.finalize()
|
||||
return FormData(self.items)
|
@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Protocol
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class _MiddlewareFactory(Protocol[P]):
|
||||
def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover
|
||||
|
||||
|
||||
class Middleware:
|
||||
def __init__(
|
||||
self,
|
||||
cls: _MiddlewareFactory[P],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
self.cls = cls
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
as_tuple = (self.cls, self.args, self.kwargs)
|
||||
return iter(as_tuple)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
args_strings = [f"{value!r}" for value in self.args]
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
|
||||
name = getattr(self.cls, "__name__", "")
|
||||
args_repr = ", ".join([name] + args_strings + option_strings)
|
||||
return f"{class_name}({args_repr})"
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
AuthenticationBackend,
|
||||
AuthenticationError,
|
||||
UnauthenticatedUser,
|
||||
)
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ["http", "websocket"]:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
conn = HTTPConnection(scope)
|
||||
try:
|
||||
auth_result = await self.backend.authenticate(conn)
|
||||
except AuthenticationError as exc:
|
||||
response = self.on_error(conn, exc)
|
||||
if scope["type"] == "websocket":
|
||||
await send({"type": "websocket.close", "code": 1000})
|
||||
else:
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
if auth_result is None:
|
||||
auth_result = AuthCredentials(), UnauthenticatedUser()
|
||||
scope["auth"], scope["user"] = auth_result
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
@staticmethod
|
||||
def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
|
||||
return PlainTextResponse(str(exc), status_code=400)
|
220
venv/lib/python3.11/site-packages/starlette/middleware/base.py
Normal file
220
venv/lib/python3.11/site-packages/starlette/middleware/base.py
Normal file
@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
|
||||
from starlette._utils import collapse_excgroups
|
||||
from starlette.requests import ClientDisconnect, Request
|
||||
from starlette.responses import AsyncContentStream, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||||
DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
class _CachedRequest(Request):
|
||||
"""
|
||||
If the user calls Request.body() from their dispatch function
|
||||
we cache the entire request body in memory and pass that to downstream middlewares,
|
||||
but if they call Request.stream() then all we do is send an
|
||||
empty body so that downstream things don't hang forever.
|
||||
"""
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive):
|
||||
super().__init__(scope, receive)
|
||||
self._wrapped_rcv_disconnected = False
|
||||
self._wrapped_rcv_consumed = False
|
||||
self._wrapped_rc_stream = self.stream()
|
||||
|
||||
async def wrapped_receive(self) -> Message:
|
||||
# wrapped_rcv state 1: disconnected
|
||||
if self._wrapped_rcv_disconnected:
|
||||
# we've already sent a disconnect to the downstream app
|
||||
# we don't need to wait to get another one
|
||||
# (although most ASGI servers will just keep sending it)
|
||||
return {"type": "http.disconnect"}
|
||||
# wrapped_rcv state 1: consumed but not yet disconnected
|
||||
if self._wrapped_rcv_consumed:
|
||||
# since the downstream app has consumed us all that is left
|
||||
# is to send it a disconnect
|
||||
if self._is_disconnected:
|
||||
# the middleware has already seen the disconnect
|
||||
# since we know the client is disconnected no need to wait
|
||||
# for the message
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
# we don't know yet if the client is disconnected or not
|
||||
# so we'll wait until we get that message
|
||||
msg = await self.receive()
|
||||
if msg["type"] != "http.disconnect": # pragma: no cover
|
||||
# at this point a disconnect is all that we should be receiving
|
||||
# if we get something else, things went wrong somewhere
|
||||
raise RuntimeError(f"Unexpected message received: {msg['type']}")
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return msg
|
||||
|
||||
# wrapped_rcv state 3: not yet consumed
|
||||
if getattr(self, "_body", None) is not None:
|
||||
# body() was called, we return it even if the client disconnected
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": self._body,
|
||||
"more_body": False,
|
||||
}
|
||||
elif self._stream_consumed:
|
||||
# stream() was called to completion
|
||||
# return an empty body so that downstream apps don't hang
|
||||
# waiting for a disconnect
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
else:
|
||||
# body() was never called and stream() wasn't consumed
|
||||
try:
|
||||
stream = self.stream()
|
||||
chunk = await stream.__anext__()
|
||||
self._wrapped_rcv_consumed = self._stream_consumed
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": not self._stream_consumed,
|
||||
}
|
||||
except ClientDisconnect:
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
|
||||
class BaseHTTPMiddleware:
|
||||
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
|
||||
self.app = app
|
||||
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = _CachedRequest(scope, receive)
|
||||
wrapped_receive = request.wrapped_receive
|
||||
response_sent = anyio.Event()
|
||||
app_exc: Exception | None = None
|
||||
|
||||
async def call_next(request: Request) -> Response:
|
||||
async def receive_or_disconnect() -> Message:
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
|
||||
result = await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
return result
|
||||
|
||||
task_group.start_soon(wrap, response_sent.wait)
|
||||
message = await wrap(wrapped_receive)
|
||||
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
return message
|
||||
|
||||
async def send_no_error(message: Message) -> None:
|
||||
try:
|
||||
await send_stream.send(message)
|
||||
except anyio.BrokenResourceError:
|
||||
# recv_stream has been closed, i.e. response_sent has been set.
|
||||
return
|
||||
|
||||
async def coro() -> None:
|
||||
nonlocal app_exc
|
||||
|
||||
with send_stream:
|
||||
try:
|
||||
await self.app(scope, receive_or_disconnect, send_no_error)
|
||||
except Exception as exc:
|
||||
app_exc = exc
|
||||
|
||||
task_group.start_soon(coro)
|
||||
|
||||
try:
|
||||
message = await recv_stream.receive()
|
||||
info = message.get("info", None)
|
||||
if message["type"] == "http.response.debug" and info is not None:
|
||||
message = await recv_stream.receive()
|
||||
except anyio.EndOfStream:
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
raise RuntimeError("No response returned.")
|
||||
|
||||
assert message["type"] == "http.response.start"
|
||||
|
||||
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||||
async for message in recv_stream:
|
||||
assert message["type"] == "http.response.body"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
|
||||
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
|
||||
response.raw_headers = message["headers"]
|
||||
return response
|
||||
|
||||
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
|
||||
send_stream, recv_stream = streams
|
||||
with recv_stream, send_stream, collapse_excgroups():
|
||||
async with anyio.create_task_group() as task_group:
|
||||
response = await self.dispatch_func(request, call_next)
|
||||
await response(scope, wrapped_receive, send)
|
||||
response_sent.set()
|
||||
recv_stream.close()
|
||||
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class _StreamingResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
content: AsyncContentStream,
|
||||
status_code: int = 200,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
info: typing.Mapping[str, typing.Any] | None = None,
|
||||
) -> None:
|
||||
self.info = info
|
||||
self.body_iterator = content
|
||||
self.status_code = status_code
|
||||
self.media_type = media_type
|
||||
self.init_headers(headers)
|
||||
self.background = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.info is not None:
|
||||
await send({"type": "http.response.debug", "info": self.info})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
|
||||
async for chunk in self.body_iterator:
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
if self.background:
|
||||
await self.background()
|
172
venv/lib/python3.11/site-packages/starlette/middleware/cors.py
Normal file
172
venv/lib/python3.11/site-packages/starlette/middleware/cors.py
Normal file
@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
|
||||
SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
|
||||
|
||||
|
||||
class CORSMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allow_origins: typing.Sequence[str] = (),
|
||||
allow_methods: typing.Sequence[str] = ("GET",),
|
||||
allow_headers: typing.Sequence[str] = (),
|
||||
allow_credentials: bool = False,
|
||||
allow_origin_regex: str | None = None,
|
||||
expose_headers: typing.Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
) -> None:
|
||||
if "*" in allow_methods:
|
||||
allow_methods = ALL_METHODS
|
||||
|
||||
compiled_allow_origin_regex = None
|
||||
if allow_origin_regex is not None:
|
||||
compiled_allow_origin_regex = re.compile(allow_origin_regex)
|
||||
|
||||
allow_all_origins = "*" in allow_origins
|
||||
allow_all_headers = "*" in allow_headers
|
||||
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
|
||||
|
||||
simple_headers = {}
|
||||
if allow_all_origins:
|
||||
simple_headers["Access-Control-Allow-Origin"] = "*"
|
||||
if allow_credentials:
|
||||
simple_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
if expose_headers:
|
||||
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
|
||||
|
||||
preflight_headers = {}
|
||||
if preflight_explicit_allow_origin:
|
||||
# The origin value will be set in preflight_response() if it is allowed.
|
||||
preflight_headers["Vary"] = "Origin"
|
||||
else:
|
||||
preflight_headers["Access-Control-Allow-Origin"] = "*"
|
||||
preflight_headers.update(
|
||||
{
|
||||
"Access-Control-Allow-Methods": ", ".join(allow_methods),
|
||||
"Access-Control-Max-Age": str(max_age),
|
||||
}
|
||||
)
|
||||
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
|
||||
if allow_headers and not allow_all_headers:
|
||||
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
|
||||
if allow_credentials:
|
||||
preflight_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
self.app = app
|
||||
self.allow_origins = allow_origins
|
||||
self.allow_methods = allow_methods
|
||||
self.allow_headers = [h.lower() for h in allow_headers]
|
||||
self.allow_all_origins = allow_all_origins
|
||||
self.allow_all_headers = allow_all_headers
|
||||
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
|
||||
self.allow_origin_regex = compiled_allow_origin_regex
|
||||
self.simple_headers = simple_headers
|
||||
self.preflight_headers = preflight_headers
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http": # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
method = scope["method"]
|
||||
headers = Headers(scope=scope)
|
||||
origin = headers.get("origin")
|
||||
|
||||
if origin is None:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
if method == "OPTIONS" and "access-control-request-method" in headers:
|
||||
response = self.preflight_response(request_headers=headers)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.simple_response(scope, receive, send, request_headers=headers)
|
||||
|
||||
def is_allowed_origin(self, origin: str) -> bool:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
|
||||
return True
|
||||
|
||||
return origin in self.allow_origins
|
||||
|
||||
def preflight_response(self, request_headers: Headers) -> Response:
|
||||
requested_origin = request_headers["origin"]
|
||||
requested_method = request_headers["access-control-request-method"]
|
||||
requested_headers = request_headers.get("access-control-request-headers")
|
||||
|
||||
headers = dict(self.preflight_headers)
|
||||
failures = []
|
||||
|
||||
if self.is_allowed_origin(origin=requested_origin):
|
||||
if self.preflight_explicit_allow_origin:
|
||||
# The "else" case is already accounted for in self.preflight_headers
|
||||
# and the value would be "*".
|
||||
headers["Access-Control-Allow-Origin"] = requested_origin
|
||||
else:
|
||||
failures.append("origin")
|
||||
|
||||
if requested_method not in self.allow_methods:
|
||||
failures.append("method")
|
||||
|
||||
# If we allow all headers, then we have to mirror back any requested
|
||||
# headers in the response.
|
||||
if self.allow_all_headers and requested_headers is not None:
|
||||
headers["Access-Control-Allow-Headers"] = requested_headers
|
||||
elif requested_headers is not None:
|
||||
for header in [h.lower() for h in requested_headers.split(",")]:
|
||||
if header.strip() not in self.allow_headers:
|
||||
failures.append("headers")
|
||||
break
|
||||
|
||||
# We don't strictly need to use 400 responses here, since its up to
|
||||
# the browser to enforce the CORS policy, but its more informative
|
||||
# if we do.
|
||||
if failures:
|
||||
failure_text = "Disallowed CORS " + ", ".join(failures)
|
||||
return PlainTextResponse(failure_text, status_code=400, headers=headers)
|
||||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(scope=message)
|
||||
headers.update(self.simple_headers)
|
||||
origin = request_headers["Origin"]
|
||||
has_cookie = "cookie" in request_headers
|
||||
|
||||
# If request includes any cookie headers, then we must respond
|
||||
# with the specific origin instead of '*'.
|
||||
if self.allow_all_origins and has_cookie:
|
||||
self.allow_explicit_origin(headers, origin)
|
||||
|
||||
# If we only allow specific origins, then we have to mirror back
|
||||
# the Origin header in the response.
|
||||
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
||||
self.allow_explicit_origin(headers, origin)
|
||||
|
||||
await send(message)
|
||||
|
||||
@staticmethod
|
||||
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
headers.add_vary_header("Origin")
|
260
venv/lib/python3.11/site-packages/starlette/middleware/errors.py
Normal file
260
venv/lib/python3.11/site-packages/starlette/middleware/errors.py
Normal file
@ -0,0 +1,260 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
STYLES = """
|
||||
p {
|
||||
color: #211c1c;
|
||||
}
|
||||
.traceback-container {
|
||||
border: 1px solid #038BB8;
|
||||
}
|
||||
.traceback-title {
|
||||
background-color: #038BB8;
|
||||
color: lemonchiffon;
|
||||
padding: 12px;
|
||||
font-size: 20px;
|
||||
margin-top: 0px;
|
||||
}
|
||||
.frame-line {
|
||||
padding-left: 10px;
|
||||
font-family: monospace;
|
||||
}
|
||||
.frame-filename {
|
||||
font-family: monospace;
|
||||
}
|
||||
.center-line {
|
||||
background-color: #038BB8;
|
||||
color: #f9f6e1;
|
||||
padding: 5px 0px 5px 5px;
|
||||
}
|
||||
.lineno {
|
||||
margin-right: 5px;
|
||||
}
|
||||
.frame-title {
|
||||
font-weight: unset;
|
||||
padding: 10px 10px 10px 10px;
|
||||
background-color: #E4F4FD;
|
||||
margin-right: 10px;
|
||||
color: #191f21;
|
||||
font-size: 17px;
|
||||
border: 1px solid #c7dce8;
|
||||
}
|
||||
.collapse-btn {
|
||||
float: right;
|
||||
padding: 0px 5px 1px 5px;
|
||||
border: solid 1px #96aebb;
|
||||
cursor: pointer;
|
||||
}
|
||||
.collapsed {
|
||||
display: none;
|
||||
}
|
||||
.source-code {
|
||||
font-family: courier;
|
||||
font-size: small;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
"""
|
||||
|
||||
JS = """
|
||||
<script type="text/javascript">
|
||||
function collapse(element){
|
||||
const frameId = element.getAttribute("data-frame-id");
|
||||
const frame = document.getElementById(frameId);
|
||||
|
||||
if (frame.classList.contains("collapsed")){
|
||||
element.innerHTML = "‒";
|
||||
frame.classList.remove("collapsed");
|
||||
} else {
|
||||
element.innerHTML = "+";
|
||||
frame.classList.add("collapsed");
|
||||
}
|
||||
}
|
||||
</script>
|
||||
"""
|
||||
|
||||
TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<style type='text/css'>
|
||||
{styles}
|
||||
</style>
|
||||
<title>Starlette Debugger</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>500 Server Error</h1>
|
||||
<h2>{error}</h2>
|
||||
<div class="traceback-container">
|
||||
<p class="traceback-title">Traceback</p>
|
||||
<div>{exc_html}</div>
|
||||
</div>
|
||||
{js}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
FRAME_TEMPLATE = """
|
||||
<div>
|
||||
<p class="frame-title">File <span class="frame-filename">{frame_filename}</span>,
|
||||
line <i>{frame_lineno}</i>,
|
||||
in <b>{frame_name}</b>
|
||||
<span class="collapse-btn" data-frame-id="{frame_filename}-{frame_lineno}" onclick="collapse(this)">{collapse_button}</span>
|
||||
</p>
|
||||
<div id="{frame_filename}-{frame_lineno}" class="source-code {collapsed}">{code_context}</div>
|
||||
</div>
|
||||
""" # noqa: E501
|
||||
|
||||
LINE = """
|
||||
<p><span class="frame-line">
|
||||
<span class="lineno">{lineno}.</span> {line}</span></p>
|
||||
"""
|
||||
|
||||
CENTER_LINE = """
|
||||
<p class="center-line"><span class="frame-line center-line">
|
||||
<span class="lineno">{lineno}.</span> {line}</span></p>
|
||||
"""
|
||||
|
||||
|
||||
class ServerErrorMiddleware:
|
||||
"""
|
||||
Handles returning 500 responses when a server error occurs.
|
||||
|
||||
If 'debug' is set, then traceback responses will be returned,
|
||||
otherwise the designated 'handler' will be called.
|
||||
|
||||
This middleware class should generally be used to wrap *everything*
|
||||
else up, so that unhandled exceptions anywhere in the stack
|
||||
always result in an appropriate 500 response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handler: typing.Callable[[Request, Exception], typing.Any] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.handler = handler
|
||||
self.debug = debug
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response_started = False
|
||||
|
||||
async def _send(message: Message) -> None:
|
||||
nonlocal response_started, send
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, _send)
|
||||
except Exception as exc:
|
||||
request = Request(scope)
|
||||
if self.debug:
|
||||
# In debug mode, return traceback responses.
|
||||
response = self.debug_response(request, exc)
|
||||
elif self.handler is None:
|
||||
# Use our default 500 error handler.
|
||||
response = self.error_response(request, exc)
|
||||
else:
|
||||
# Use an installed 500 error handler.
|
||||
if is_async_callable(self.handler):
|
||||
response = await self.handler(request, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(self.handler, request, exc)
|
||||
|
||||
if not response_started:
|
||||
await response(scope, receive, send)
|
||||
|
||||
# We always continue to raise the exception.
|
||||
# This allows servers to log the error, or allows test clients
|
||||
# to optionally raise the error within the test case.
|
||||
raise exc
|
||||
|
||||
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
|
||||
values = {
|
||||
# HTML escape - line could contain < or >
|
||||
"line": html.escape(line).replace(" ", " "),
|
||||
"lineno": (frame_lineno - frame_index) + index,
|
||||
}
|
||||
|
||||
if index != frame_index:
|
||||
return LINE.format(**values)
|
||||
return CENTER_LINE.format(**values)
|
||||
|
||||
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
|
||||
code_context = "".join(
|
||||
self.format_line(
|
||||
index,
|
||||
line,
|
||||
frame.lineno,
|
||||
frame.index, # type: ignore[arg-type]
|
||||
)
|
||||
for index, line in enumerate(frame.code_context or [])
|
||||
)
|
||||
|
||||
values = {
|
||||
# HTML escape - filename could contain < or >, especially if it's a virtual
|
||||
# file e.g. <stdin> in the REPL
|
||||
"frame_filename": html.escape(frame.filename),
|
||||
"frame_lineno": frame.lineno,
|
||||
# HTML escape - if you try very hard it's possible to name a function with <
|
||||
# or >
|
||||
"frame_name": html.escape(frame.function),
|
||||
"code_context": code_context,
|
||||
"collapsed": "collapsed" if is_collapsed else "",
|
||||
"collapse_button": "+" if is_collapsed else "‒",
|
||||
}
|
||||
return FRAME_TEMPLATE.format(**values)
|
||||
|
||||
def generate_html(self, exc: Exception, limit: int = 7) -> str:
|
||||
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
|
||||
|
||||
exc_html = ""
|
||||
is_collapsed = False
|
||||
exc_traceback = exc.__traceback__
|
||||
if exc_traceback is not None:
|
||||
frames = inspect.getinnerframes(exc_traceback, limit)
|
||||
for frame in reversed(frames):
|
||||
exc_html += self.generate_frame_html(frame, is_collapsed)
|
||||
is_collapsed = True
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type_str
|
||||
else: # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type.__name__
|
||||
|
||||
# escape error class and text
|
||||
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
|
||||
|
||||
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
|
||||
|
||||
def generate_plain_text(self, exc: Exception) -> str:
|
||||
return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
|
||||
|
||||
def debug_response(self, request: Request, exc: Exception) -> Response:
|
||||
accept = request.headers.get("accept", "")
|
||||
|
||||
if "text/html" in accept:
|
||||
content = self.generate_html(exc)
|
||||
return HTMLResponse(content, status_code=500)
|
||||
content = self.generate_plain_text(exc)
|
||||
return PlainTextResponse(content, status_code=500)
|
||||
|
||||
def error_response(self, request: Request, exc: Exception) -> Response:
|
||||
return PlainTextResponse("Internal Server Error", status_code=500)
|
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette._exception_handler import (
|
||||
ExceptionHandlers,
|
||||
StatusHandlers,
|
||||
wrap_app_handling_exceptions,
|
||||
)
|
||||
from starlette.exceptions import HTTPException, WebSocketException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
class ExceptionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
|
||||
self._status_handlers: StatusHandlers = {}
|
||||
self._exception_handlers: ExceptionHandlers = {
|
||||
HTTPException: self.http_exception,
|
||||
WebSocketException: self.websocket_exception,
|
||||
}
|
||||
if handlers is not None: # pragma: no branch
|
||||
for key, value in handlers.items():
|
||||
self.add_exception_handler(key, value)
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: typing.Callable[[Request, Exception], Response],
|
||||
) -> None:
|
||||
if isinstance(exc_class_or_status_code, int):
|
||||
self._status_handlers[exc_class_or_status_code] = handler
|
||||
else:
|
||||
assert issubclass(exc_class_or_status_code, Exception)
|
||||
self._exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
scope["starlette.exception_handlers"] = (
|
||||
self._exception_handlers,
|
||||
self._status_handlers,
|
||||
)
|
||||
|
||||
conn: Request | WebSocket
|
||||
if scope["type"] == "http":
|
||||
conn = Request(scope, receive, send)
|
||||
else:
|
||||
conn = WebSocket(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
|
||||
|
||||
def http_exception(self, request: Request, exc: Exception) -> Response:
|
||||
assert isinstance(exc, HTTPException)
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
|
||||
|
||||
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
|
||||
assert isinstance(exc, WebSocketException)
|
||||
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
|
141
venv/lib/python3.11/site-packages/starlette/middleware/gzip.py
Normal file
141
venv/lib/python3.11/site-packages/starlette/middleware/gzip.py
Normal file
@ -0,0 +1,141 @@
|
||||
import gzip
|
||||
import io
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
|
||||
|
||||
|
||||
class GZipMiddleware:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.compresslevel = compresslevel
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http": # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
responder: ASGIApp
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
|
||||
else:
|
||||
responder = IdentityResponder(self.app, self.minimum_size)
|
||||
|
||||
await responder(scope, receive, send)
|
||||
|
||||
|
||||
class IdentityResponder:
|
||||
content_encoding: str
|
||||
|
||||
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.send: Send = unattached_send
|
||||
self.initial_message: Message = {}
|
||||
self.started = False
|
||||
self.content_encoding_set = False
|
||||
self.content_type_is_excluded = False
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.send = send
|
||||
await self.app(scope, receive, self.send_with_compression)
|
||||
|
||||
async def send_with_compression(self, message: Message) -> None:
|
||||
message_type = message["type"]
|
||||
if message_type == "http.response.start":
|
||||
# Don't send the initial message until we've determined how to
|
||||
# modify the outgoing headers correctly.
|
||||
self.initial_message = message
|
||||
headers = Headers(raw=self.initial_message["headers"])
|
||||
self.content_encoding_set = "content-encoding" in headers
|
||||
self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
|
||||
elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
|
||||
if not self.started:
|
||||
self.started = True
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif message_type == "http.response.body" and not self.started:
|
||||
self.started = True
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if len(body) < self.minimum_size and not more_body:
|
||||
# Don't apply compression to small outgoing responses.
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif not more_body:
|
||||
# Standard response.
|
||||
body = self.apply_compression(body, more_body=False)
|
||||
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
if body != message["body"]:
|
||||
headers["Content-Encoding"] = self.content_encoding
|
||||
headers["Content-Length"] = str(len(body))
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
else:
|
||||
# Initial body in streaming response.
|
||||
body = self.apply_compression(body, more_body=True)
|
||||
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
if body != message["body"]:
|
||||
headers["Content-Encoding"] = self.content_encoding
|
||||
del headers["Content-Length"]
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif message_type == "http.response.body": # pragma: no branch
|
||||
# Remaining body in streaming response.
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
message["body"] = self.apply_compression(body, more_body=more_body)
|
||||
|
||||
await self.send(message)
|
||||
|
||||
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
|
||||
"""Apply compression on the response body.
|
||||
|
||||
If more_body is False, any compression file should be closed. If it
|
||||
isn't, it won't be closed automatically until all background tasks
|
||||
complete.
|
||||
"""
|
||||
return body
|
||||
|
||||
|
||||
class GZipResponder(IdentityResponder):
|
||||
content_encoding = "gzip"
|
||||
|
||||
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
|
||||
super().__init__(app, minimum_size)
|
||||
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
with self.gzip_buffer, self.gzip_file:
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
body = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
return body
|
||||
|
||||
|
||||
async def unattached_send(message: Message) -> typing.NoReturn:
|
||||
raise RuntimeError("send awaitable not set") # pragma: no cover
|
@ -0,0 +1,19 @@
|
||||
from starlette.datastructures import URL
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class HTTPSRedirectMiddleware:
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"):
|
||||
url = URL(scope=scope)
|
||||
redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme]
|
||||
netloc = url.hostname if url.port in (80, 443) else url.netloc
|
||||
url = url.replace(scheme=redirect_scheme, netloc=netloc)
|
||||
response = RedirectResponse(url, status_code=307)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
import itsdangerous
|
||||
from itsdangerous.exc import BadSignature
|
||||
|
||||
from starlette.datastructures import MutableHeaders, Secret
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class SessionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
secret_key: str | Secret,
|
||||
session_cookie: str = "session",
|
||||
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
path: str = "/",
|
||||
same_site: typing.Literal["lax", "strict", "none"] = "lax",
|
||||
https_only: bool = False,
|
||||
domain: str | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.signer = itsdangerous.TimestampSigner(str(secret_key))
|
||||
self.session_cookie = session_cookie
|
||||
self.max_age = max_age
|
||||
self.path = path
|
||||
self.security_flags = "httponly; samesite=" + same_site
|
||||
if https_only: # Secure flag can be used with HTTPS only
|
||||
self.security_flags += "; secure"
|
||||
if domain is not None:
|
||||
self.security_flags += f"; domain={domain}"
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"): # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
connection = HTTPConnection(scope)
|
||||
initial_session_was_empty = True
|
||||
|
||||
if self.session_cookie in connection.cookies:
|
||||
data = connection.cookies[self.session_cookie].encode("utf-8")
|
||||
try:
|
||||
data = self.signer.unsign(data, max_age=self.max_age)
|
||||
scope["session"] = json.loads(b64decode(data))
|
||||
initial_session_was_empty = False
|
||||
except BadSignature:
|
||||
scope["session"] = {}
|
||||
else:
|
||||
scope["session"] = {}
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
if scope["session"]:
|
||||
# We have session data to persist.
|
||||
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
|
||||
data = self.signer.sign(data)
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data=data.decode("utf-8"),
|
||||
path=self.path,
|
||||
max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
|
||||
security_flags=self.security_flags,
|
||||
)
|
||||
headers.append("Set-Cookie", header_value)
|
||||
elif not initial_session_was_empty:
|
||||
# The session has been cleared.
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data="null",
|
||||
path=self.path,
|
||||
expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
|
||||
security_flags=self.security_flags,
|
||||
)
|
||||
headers.append("Set-Cookie", header_value)
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
|
||||
|
||||
|
||||
class TrustedHostMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_hosts: typing.Sequence[str] | None = None,
|
||||
www_redirect: bool = True,
|
||||
) -> None:
|
||||
if allowed_hosts is None:
|
||||
allowed_hosts = ["*"]
|
||||
|
||||
for pattern in allowed_hosts:
|
||||
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
|
||||
if pattern.startswith("*") and pattern != "*":
|
||||
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
|
||||
self.app = app
|
||||
self.allowed_hosts = list(allowed_hosts)
|
||||
self.allow_any = "*" in allowed_hosts
|
||||
self.www_redirect = www_redirect
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.allow_any or scope["type"] not in (
|
||||
"http",
|
||||
"websocket",
|
||||
): # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
is_valid_host = False
|
||||
found_www_redirect = False
|
||||
for pattern in self.allowed_hosts:
|
||||
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
|
||||
is_valid_host = True
|
||||
break
|
||||
elif "www." + host == pattern:
|
||||
found_www_redirect = True
|
||||
|
||||
if is_valid_host:
|
||||
await self.app(scope, receive, send)
|
||||
else:
|
||||
response: Response
|
||||
if found_www_redirect and self.www_redirect:
|
||||
url = URL(scope=scope)
|
||||
redirect_url = url.replace(netloc="www." + url.netloc)
|
||||
response = RedirectResponse(url=str(redirect_url))
|
||||
else:
|
||||
response = PlainTextResponse("Invalid host header", status_code=400)
|
||||
await response(scope, receive, send)
|
152
venv/lib/python3.11/site-packages/starlette/middleware/wsgi.py
Normal file
152
venv/lib/python3.11/site-packages/starlette/middleware/wsgi.py
Normal file
@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||||
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
warnings.warn(
|
||||
"starlette.middleware.wsgi is deprecated and will be removed in a future release. "
|
||||
"Please refer to https://github.com/abersheeran/a2wsgi as a replacement.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]:
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
"""
|
||||
|
||||
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
|
||||
path_info = scope["path"].encode("utf8").decode("latin1")
|
||||
if path_info.startswith(script_name):
|
||||
path_info = path_info[len(script_name) :]
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": script_name,
|
||||
"PATH_INFO": path_info,
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": scope.get("scheme", "http"),
|
||||
"wsgi.input": io.BytesIO(body),
|
||||
"wsgi.errors": sys.stdout,
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
}
|
||||
|
||||
# Get server name and port - required in WSGI, not in ASGI
|
||||
server = scope.get("server") or ("localhost", 80)
|
||||
environ["SERVER_NAME"] = server[0]
|
||||
environ["SERVER_PORT"] = server[1]
|
||||
|
||||
# Get client IP address
|
||||
if scope.get("client"):
|
||||
environ["REMOTE_ADDR"] = scope["client"][0]
|
||||
|
||||
# Go through headers and make them into environ entries
|
||||
for name, value in scope.get("headers", []):
|
||||
name = name.decode("latin1")
|
||||
if name == "content-length":
|
||||
corrected_name = "CONTENT_LENGTH"
|
||||
elif name == "content-type":
|
||||
corrected_name = "CONTENT_TYPE"
|
||||
else:
|
||||
corrected_name = f"HTTP_{name}".upper().replace("-", "_")
|
||||
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in
|
||||
# case
|
||||
value = value.decode("latin1")
|
||||
if corrected_name in environ:
|
||||
value = environ[corrected_name] + "," + value
|
||||
environ[corrected_name] = value
|
||||
return environ
|
||||
|
||||
|
||||
class WSGIMiddleware:
|
||||
def __init__(self, app: typing.Callable[..., typing.Any]) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] == "http"
|
||||
responder = WSGIResponder(self.app, scope)
|
||||
await responder(receive, send)
|
||||
|
||||
|
||||
class WSGIResponder:
|
||||
stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
|
||||
stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
|
||||
|
||||
def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None:
|
||||
self.app = app
|
||||
self.scope = scope
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
|
||||
self.response_started = False
|
||||
self.exc_info: typing.Any = None
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
body = b""
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await receive()
|
||||
body += message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
environ = build_environ(self.scope, body)
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
task_group.start_soon(self.sender, send)
|
||||
async with self.stream_send:
|
||||
await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
|
||||
if self.exc_info is not None:
|
||||
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
|
||||
|
||||
async def sender(self, send: Send) -> None:
|
||||
async with self.stream_receive:
|
||||
async for message in self.stream_receive:
|
||||
await send(message)
|
||||
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: list[tuple[str, str]],
|
||||
exc_info: typing.Any = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started: # pragma: no branch
|
||||
self.response_started = True
|
||||
status_code_string, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_string)
|
||||
headers = [
|
||||
(name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
|
||||
for name, value in response_headers
|
||||
]
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send,
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
def wsgi(
|
||||
self,
|
||||
environ: dict[str, typing.Any],
|
||||
start_response: typing.Callable[..., typing.Any],
|
||||
) -> None:
|
||||
for chunk in self.app(environ, start_response):
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send,
|
||||
{"type": "http.response.body", "body": chunk, "more_body": True},
|
||||
)
|
||||
|
||||
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
|
322
venv/lib/python3.11/site-packages/starlette/requests.py
Normal file
322
venv/lib/python3.11/site-packages/starlette/requests.py
Normal file
@ -0,0 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from http import cookies as http_cookies
|
||||
|
||||
import anyio
|
||||
|
||||
from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
|
||||
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from python_multipart.multipart import parse_options_header
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Router
|
||||
else:
|
||||
try:
|
||||
try:
|
||||
from python_multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
parse_options_header = None
|
||||
|
||||
|
||||
SERVER_PUSH_HEADERS_TO_COPY = {
|
||||
"accept",
|
||||
"accept-encoding",
|
||||
"accept-language",
|
||||
"cache-control",
|
||||
"user-agent",
|
||||
}
|
||||
|
||||
|
||||
def cookie_parser(cookie_string: str) -> dict[str, str]:
|
||||
"""
|
||||
This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
|
||||
|
||||
It attempts to mimic browser cookie parsing behavior: browsers and web servers
|
||||
frequently disregard the spec (RFC 6265) when setting and reading cookies,
|
||||
so we attempt to suit the common scenarios here.
|
||||
|
||||
This function has been adapted from Django 3.1.0.
|
||||
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
|
||||
on an outdated spec and will fail on lots of input we want to support
|
||||
"""
|
||||
cookie_dict: dict[str, str] = {}
|
||||
for chunk in cookie_string.split(";"):
|
||||
if "=" in chunk:
|
||||
key, val = chunk.split("=", 1)
|
||||
else:
|
||||
# Assume an empty name per
|
||||
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
|
||||
key, val = "", chunk
|
||||
key, val = key.strip(), val.strip()
|
||||
if key or val:
|
||||
# unquote using Python's algorithm.
|
||||
cookie_dict[key] = http_cookies._unquote(val)
|
||||
return cookie_dict
|
||||
|
||||
|
||||
class ClientDisconnect(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
"""
|
||||
A base class for incoming HTTP connections, that is used to provide
|
||||
any functionality that is common to both `Request` and `WebSocket`.
|
||||
"""
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
|
||||
assert scope["type"] in ("http", "websocket")
|
||||
self.scope = scope
|
||||
|
||||
def __getitem__(self, key: str) -> typing.Any:
|
||||
return self.scope[key]
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
return iter(self.scope)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.scope)
|
||||
|
||||
# Don't use the `abc.Mapping.__eq__` implementation.
|
||||
# Connection instances should never be considered equal
|
||||
# unless `self is other`.
|
||||
__eq__ = object.__eq__
|
||||
__hash__ = object.__hash__
|
||||
|
||||
@property
|
||||
def app(self) -> typing.Any:
|
||||
return self.scope["app"]
|
||||
|
||||
@property
|
||||
def url(self) -> URL:
|
||||
if not hasattr(self, "_url"): # pragma: no branch
|
||||
self._url = URL(scope=self.scope)
|
||||
return self._url
|
||||
|
||||
@property
|
||||
def base_url(self) -> URL:
|
||||
if not hasattr(self, "_base_url"):
|
||||
base_url_scope = dict(self.scope)
|
||||
# This is used by request.url_for, it might be used inside a Mount which
|
||||
# would have its own child scope with its own root_path, but the base URL
|
||||
# for url_for should still be the top level app root path.
|
||||
app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
|
||||
path = app_root_path
|
||||
if not path.endswith("/"):
|
||||
path += "/"
|
||||
base_url_scope["path"] = path
|
||||
base_url_scope["query_string"] = b""
|
||||
base_url_scope["root_path"] = app_root_path
|
||||
self._base_url = URL(scope=base_url_scope)
|
||||
return self._base_url
|
||||
|
||||
@property
|
||||
def headers(self) -> Headers:
|
||||
if not hasattr(self, "_headers"):
|
||||
self._headers = Headers(scope=self.scope)
|
||||
return self._headers
|
||||
|
||||
@property
|
||||
def query_params(self) -> QueryParams:
|
||||
if not hasattr(self, "_query_params"): # pragma: no branch
|
||||
self._query_params = QueryParams(self.scope["query_string"])
|
||||
return self._query_params
|
||||
|
||||
@property
|
||||
def path_params(self) -> dict[str, typing.Any]:
|
||||
return self.scope.get("path_params", {})
|
||||
|
||||
@property
|
||||
def cookies(self) -> dict[str, str]:
|
||||
if not hasattr(self, "_cookies"):
|
||||
cookies: dict[str, str] = {}
|
||||
cookie_header = self.headers.get("cookie")
|
||||
|
||||
if cookie_header:
|
||||
cookies = cookie_parser(cookie_header)
|
||||
self._cookies = cookies
|
||||
return self._cookies
|
||||
|
||||
@property
|
||||
def client(self) -> Address | None:
|
||||
# client is a 2 item tuple of (host, port), None if missing
|
||||
host_port = self.scope.get("client")
|
||||
if host_port is not None:
|
||||
return Address(*host_port)
|
||||
return None
|
||||
|
||||
@property
|
||||
def session(self) -> dict[str, typing.Any]:
|
||||
assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
|
||||
return self.scope["session"] # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def auth(self) -> typing.Any:
|
||||
assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
|
||||
return self.scope["auth"]
|
||||
|
||||
@property
|
||||
def user(self) -> typing.Any:
|
||||
assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
|
||||
return self.scope["user"]
|
||||
|
||||
@property
|
||||
def state(self) -> State:
|
||||
if not hasattr(self, "_state"):
|
||||
# Ensure 'state' has an empty dict if it's not already populated.
|
||||
self.scope.setdefault("state", {})
|
||||
# Create a state instance with a reference to the dict in which it should
|
||||
# store info
|
||||
self._state = State(self.scope["state"])
|
||||
return self._state
|
||||
|
||||
def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
|
||||
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
|
||||
if url_path_provider is None:
|
||||
raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
|
||||
url_path = url_path_provider.url_path_for(name, **path_params)
|
||||
return url_path.make_absolute_url(base_url=self.base_url)
|
||||
|
||||
|
||||
async def empty_receive() -> typing.NoReturn:
|
||||
raise RuntimeError("Receive channel has not been made available")
|
||||
|
||||
|
||||
async def empty_send(message: Message) -> typing.NoReturn:
|
||||
raise RuntimeError("Send channel has not been made available")
|
||||
|
||||
|
||||
class Request(HTTPConnection):
|
||||
_form: FormData | None
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
|
||||
super().__init__(scope)
|
||||
assert scope["type"] == "http"
|
||||
self._receive = receive
|
||||
self._send = send
|
||||
self._stream_consumed = False
|
||||
self._is_disconnected = False
|
||||
self._form = None
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return typing.cast(str, self.scope["method"])
|
||||
|
||||
@property
|
||||
def receive(self) -> Receive:
|
||||
return self._receive
|
||||
|
||||
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
if hasattr(self, "_body"):
|
||||
yield self._body
|
||||
yield b""
|
||||
return
|
||||
if self._stream_consumed:
|
||||
raise RuntimeError("Stream consumed")
|
||||
while not self._stream_consumed:
|
||||
message = await self._receive()
|
||||
if message["type"] == "http.request":
|
||||
body = message.get("body", b"")
|
||||
if not message.get("more_body", False):
|
||||
self._stream_consumed = True
|
||||
if body:
|
||||
yield body
|
||||
elif message["type"] == "http.disconnect": # pragma: no branch
|
||||
self._is_disconnected = True
|
||||
raise ClientDisconnect()
|
||||
yield b""
|
||||
|
||||
async def body(self) -> bytes:
|
||||
if not hasattr(self, "_body"):
|
||||
chunks: list[bytes] = []
|
||||
async for chunk in self.stream():
|
||||
chunks.append(chunk)
|
||||
self._body = b"".join(chunks)
|
||||
return self._body
|
||||
|
||||
async def json(self) -> typing.Any:
|
||||
if not hasattr(self, "_json"): # pragma: no branch
|
||||
body = await self.body()
|
||||
self._json = json.loads(body)
|
||||
return self._json
|
||||
|
||||
async def _get_form(
|
||||
self,
|
||||
*,
|
||||
max_files: int | float = 1000,
|
||||
max_fields: int | float = 1000,
|
||||
max_part_size: int = 1024 * 1024,
|
||||
) -> FormData:
|
||||
if self._form is None: # pragma: no branch
|
||||
assert parse_options_header is not None, (
|
||||
"The `python-multipart` library must be installed to use form parsing."
|
||||
)
|
||||
content_type_header = self.headers.get("Content-Type")
|
||||
content_type: bytes
|
||||
content_type, _ = parse_options_header(content_type_header)
|
||||
if content_type == b"multipart/form-data":
|
||||
try:
|
||||
multipart_parser = MultiPartParser(
|
||||
self.headers,
|
||||
self.stream(),
|
||||
max_files=max_files,
|
||||
max_fields=max_fields,
|
||||
max_part_size=max_part_size,
|
||||
)
|
||||
self._form = await multipart_parser.parse()
|
||||
except MultiPartException as exc:
|
||||
if "app" in self.scope:
|
||||
raise HTTPException(status_code=400, detail=exc.message)
|
||||
raise exc
|
||||
elif content_type == b"application/x-www-form-urlencoded":
|
||||
form_parser = FormParser(self.headers, self.stream())
|
||||
self._form = await form_parser.parse()
|
||||
else:
|
||||
self._form = FormData()
|
||||
return self._form
|
||||
|
||||
def form(
|
||||
self,
|
||||
*,
|
||||
max_files: int | float = 1000,
|
||||
max_fields: int | float = 1000,
|
||||
max_part_size: int = 1024 * 1024,
|
||||
) -> AwaitableOrContextManager[FormData]:
|
||||
return AwaitableOrContextManagerWrapper(
|
||||
self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._form is not None: # pragma: no branch
|
||||
await self._form.close()
|
||||
|
||||
async def is_disconnected(self) -> bool:
|
||||
if not self._is_disconnected:
|
||||
message: Message = {}
|
||||
|
||||
# If message isn't immediately available, move on
|
||||
with anyio.CancelScope() as cs:
|
||||
cs.cancel()
|
||||
message = await self._receive()
|
||||
|
||||
if message.get("type") == "http.disconnect":
|
||||
self._is_disconnected = True
|
||||
|
||||
return self._is_disconnected
|
||||
|
||||
async def send_push_promise(self, path: str) -> None:
|
||||
if "http.response.push" in self.scope.get("extensions", {}):
|
||||
raw_headers: list[tuple[bytes, bytes]] = []
|
||||
for name in SERVER_PUSH_HEADERS_TO_COPY:
|
||||
for value in self.headers.getlist(name):
|
||||
raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
|
||||
await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})
|
536
venv/lib/python3.11/site-packages/starlette/responses.py
Normal file
536
venv/lib/python3.11/site-packages/starlette/responses.py
Normal file
@ -0,0 +1,536 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import http.cookies
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
import typing
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from email.utils import format_datetime, formatdate
|
||||
from functools import partial
|
||||
from mimetypes import guess_type
|
||||
from secrets import token_hex
|
||||
from urllib.parse import quote
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
|
||||
from starlette._utils import collapse_excgroups
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders
|
||||
from starlette.requests import ClientDisconnect
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
|
||||
class Response:
|
||||
media_type = None
|
||||
charset = "utf-8"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: typing.Any = None,
|
||||
status_code: int = 200,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
self.media_type = media_type
|
||||
self.background = background
|
||||
self.body = self.render(content)
|
||||
self.init_headers(headers)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes | memoryview:
|
||||
if content is None:
|
||||
return b""
|
||||
if isinstance(content, (bytes, memoryview)):
|
||||
return content
|
||||
return content.encode(self.charset) # type: ignore
|
||||
|
||||
def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None:
|
||||
if headers is None:
|
||||
raw_headers: list[tuple[bytes, bytes]] = []
|
||||
populate_content_length = True
|
||||
populate_content_type = True
|
||||
else:
|
||||
raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
|
||||
keys = [h[0] for h in raw_headers]
|
||||
populate_content_length = b"content-length" not in keys
|
||||
populate_content_type = b"content-type" not in keys
|
||||
|
||||
body = getattr(self, "body", None)
|
||||
if (
|
||||
body is not None
|
||||
and populate_content_length
|
||||
and not (self.status_code < 200 or self.status_code in (204, 304))
|
||||
):
|
||||
content_length = str(len(body))
|
||||
raw_headers.append((b"content-length", content_length.encode("latin-1")))
|
||||
|
||||
content_type = self.media_type
|
||||
if content_type is not None and populate_content_type:
|
||||
if content_type.startswith("text/") and "charset=" not in content_type.lower():
|
||||
content_type += "; charset=" + self.charset
|
||||
raw_headers.append((b"content-type", content_type.encode("latin-1")))
|
||||
|
||||
self.raw_headers = raw_headers
|
||||
|
||||
@property
|
||||
def headers(self) -> MutableHeaders:
|
||||
if not hasattr(self, "_headers"):
|
||||
self._headers = MutableHeaders(raw=self.raw_headers)
|
||||
return self._headers
|
||||
|
||||
def set_cookie(
|
||||
self,
|
||||
key: str,
|
||||
value: str = "",
|
||||
max_age: int | None = None,
|
||||
expires: datetime | str | int | None = None,
|
||||
path: str | None = "/",
|
||||
domain: str | None = None,
|
||||
secure: bool = False,
|
||||
httponly: bool = False,
|
||||
samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
|
||||
) -> None:
|
||||
cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie()
|
||||
cookie[key] = value
|
||||
if max_age is not None:
|
||||
cookie[key]["max-age"] = max_age
|
||||
if expires is not None:
|
||||
if isinstance(expires, datetime):
|
||||
cookie[key]["expires"] = format_datetime(expires, usegmt=True)
|
||||
else:
|
||||
cookie[key]["expires"] = expires
|
||||
if path is not None:
|
||||
cookie[key]["path"] = path
|
||||
if domain is not None:
|
||||
cookie[key]["domain"] = domain
|
||||
if secure:
|
||||
cookie[key]["secure"] = True
|
||||
if httponly:
|
||||
cookie[key]["httponly"] = True
|
||||
if samesite is not None:
|
||||
assert samesite.lower() in [
|
||||
"strict",
|
||||
"lax",
|
||||
"none",
|
||||
], "samesite must be either 'strict', 'lax' or 'none'"
|
||||
cookie[key]["samesite"] = samesite
|
||||
cookie_val = cookie.output(header="").strip()
|
||||
self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
|
||||
|
||||
def delete_cookie(
|
||||
self,
|
||||
key: str,
|
||||
path: str = "/",
|
||||
domain: str | None = None,
|
||||
secure: bool = False,
|
||||
httponly: bool = False,
|
||||
samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
|
||||
) -> None:
|
||||
self.set_cookie(
|
||||
key,
|
||||
max_age=0,
|
||||
expires=0,
|
||||
path=path,
|
||||
domain=domain,
|
||||
secure=secure,
|
||||
httponly=httponly,
|
||||
samesite=samesite,
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
prefix = "websocket." if scope["type"] == "websocket" else ""
|
||||
await send(
|
||||
{
|
||||
"type": prefix + "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
await send({"type": prefix + "http.response.body", "body": self.body})
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
|
||||
|
||||
class HTMLResponse(Response):
|
||||
media_type = "text/html"
|
||||
|
||||
|
||||
class PlainTextResponse(Response):
|
||||
media_type = "text/plain"
|
||||
|
||||
|
||||
class JSONResponse(Response):
|
||||
media_type = "application/json"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: typing.Any,
|
||||
status_code: int = 200,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
super().__init__(content, status_code, headers, media_type, background)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
return json.dumps(
|
||||
content,
|
||||
ensure_ascii=False,
|
||||
allow_nan=False,
|
||||
indent=None,
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
|
||||
|
||||
class RedirectResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
url: str | URL,
|
||||
status_code: int = 307,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
|
||||
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
|
||||
|
||||
|
||||
Content = typing.Union[str, bytes, memoryview]
|
||||
SyncContentStream = typing.Iterable[Content]
|
||||
AsyncContentStream = typing.AsyncIterable[Content]
|
||||
ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
|
||||
|
||||
|
||||
class StreamingResponse(Response):
|
||||
body_iterator: AsyncContentStream
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: ContentStream,
|
||||
status_code: int = 200,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
if isinstance(content, typing.AsyncIterable):
|
||||
self.body_iterator = content
|
||||
else:
|
||||
self.body_iterator = iterate_in_threadpool(content)
|
||||
self.status_code = status_code
|
||||
self.media_type = self.media_type if media_type is None else media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
|
||||
async def listen_for_disconnect(self, receive: Receive) -> None:
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
break
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
async for chunk in self.body_iterator:
|
||||
if not isinstance(chunk, (bytes, memoryview)):
|
||||
chunk = chunk.encode(self.charset)
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
|
||||
|
||||
if spec_version >= (2, 4):
|
||||
try:
|
||||
await self.stream_response(send)
|
||||
except OSError:
|
||||
raise ClientDisconnect()
|
||||
else:
|
||||
with collapse_excgroups():
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
|
||||
await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
|
||||
task_group.start_soon(wrap, partial(self.stream_response, send))
|
||||
await wrap(partial(self.listen_for_disconnect, receive))
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
|
||||
|
||||
class MalformedRangeHeader(Exception):
|
||||
def __init__(self, content: str = "Malformed range header.") -> None:
|
||||
self.content = content
|
||||
|
||||
|
||||
class RangeNotSatisfiable(Exception):
|
||||
def __init__(self, max_size: int) -> None:
|
||||
self.max_size = max_size
|
||||
|
||||
|
||||
_RANGE_PATTERN = re.compile(r"(\d*)-(\d*)")
|
||||
|
||||
|
||||
class FileResponse(Response):
|
||||
chunk_size = 64 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | os.PathLike[str],
|
||||
status_code: int = 200,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
filename: str | None = None,
|
||||
stat_result: os.stat_result | None = None,
|
||||
method: str | None = None,
|
||||
content_disposition_type: str = "attachment",
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.status_code = status_code
|
||||
self.filename = filename
|
||||
if method is not None:
|
||||
warnings.warn(
|
||||
"The 'method' parameter is not used, and it will be removed.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if media_type is None:
|
||||
media_type = guess_type(filename or path)[0] or "text/plain"
|
||||
self.media_type = media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
self.headers.setdefault("accept-ranges", "bytes")
|
||||
if self.filename is not None:
|
||||
content_disposition_filename = quote(self.filename)
|
||||
if content_disposition_filename != self.filename:
|
||||
content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
|
||||
else:
|
||||
content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
|
||||
self.headers.setdefault("content-disposition", content_disposition)
|
||||
self.stat_result = stat_result
|
||||
if stat_result is not None:
|
||||
self.set_stat_headers(stat_result)
|
||||
|
||||
def set_stat_headers(self, stat_result: os.stat_result) -> None:
|
||||
content_length = str(stat_result.st_size)
|
||||
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
|
||||
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
|
||||
etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"'
|
||||
|
||||
self.headers.setdefault("content-length", content_length)
|
||||
self.headers.setdefault("last-modified", last_modified)
|
||||
self.headers.setdefault("etag", etag)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
send_header_only: bool = scope["method"].upper() == "HEAD"
|
||||
if self.stat_result is None:
|
||||
try:
|
||||
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
|
||||
self.set_stat_headers(stat_result)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"File at path {self.path} does not exist.")
|
||||
else:
|
||||
mode = stat_result.st_mode
|
||||
if not stat.S_ISREG(mode):
|
||||
raise RuntimeError(f"File at path {self.path} is not a file.")
|
||||
else:
|
||||
stat_result = self.stat_result
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
http_range = headers.get("range")
|
||||
http_if_range = headers.get("if-range")
|
||||
|
||||
if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
|
||||
await self._handle_simple(send, send_header_only)
|
||||
else:
|
||||
try:
|
||||
ranges = self._parse_range_header(http_range, stat_result.st_size)
|
||||
except MalformedRangeHeader as exc:
|
||||
return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send)
|
||||
except RangeNotSatisfiable as exc:
|
||||
response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"})
|
||||
return await response(scope, receive, send)
|
||||
|
||||
if len(ranges) == 1:
|
||||
start, end = ranges[0]
|
||||
await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only)
|
||||
else:
|
||||
await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only)
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
|
||||
async def _handle_simple(self, send: Send, send_header_only: bool) -> None:
|
||||
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
|
||||
if send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
else:
|
||||
async with await anyio.open_file(self.path, mode="rb") as file:
|
||||
more_body = True
|
||||
while more_body:
|
||||
chunk = await file.read(self.chunk_size)
|
||||
more_body = len(chunk) == self.chunk_size
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
|
||||
|
||||
async def _handle_single_range(
|
||||
self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
|
||||
) -> None:
|
||||
self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}"
|
||||
self.headers["content-length"] = str(end - start)
|
||||
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
|
||||
if send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
else:
|
||||
async with await anyio.open_file(self.path, mode="rb") as file:
|
||||
await file.seek(start)
|
||||
more_body = True
|
||||
while more_body:
|
||||
chunk = await file.read(min(self.chunk_size, end - start))
|
||||
start += len(chunk)
|
||||
more_body = len(chunk) == self.chunk_size and start < end
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
|
||||
|
||||
async def _handle_multiple_ranges(
|
||||
self,
|
||||
send: Send,
|
||||
ranges: list[tuple[int, int]],
|
||||
file_size: int,
|
||||
send_header_only: bool,
|
||||
) -> None:
|
||||
# In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes).
|
||||
boundary = token_hex(13)
|
||||
content_length, header_generator = self.generate_multipart(
|
||||
ranges, boundary, file_size, self.headers["content-type"]
|
||||
)
|
||||
self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}"
|
||||
self.headers["content-length"] = str(content_length)
|
||||
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
|
||||
if send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
else:
|
||||
async with await anyio.open_file(self.path, mode="rb") as file:
|
||||
for start, end in ranges:
|
||||
await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
|
||||
await file.seek(start)
|
||||
while start < end:
|
||||
chunk = await file.read(min(self.chunk_size, end - start))
|
||||
start += len(chunk)
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
await send({"type": "http.response.body", "body": b"\n", "more_body": True})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": f"\n--{boundary}--\n".encode("latin-1"),
|
||||
"more_body": False,
|
||||
}
|
||||
)
|
||||
|
||||
def _should_use_range(self, http_if_range: str) -> bool:
|
||||
return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"]
|
||||
|
||||
@staticmethod
|
||||
def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]:
|
||||
ranges: list[tuple[int, int]] = []
|
||||
try:
|
||||
units, range_ = http_range.split("=", 1)
|
||||
except ValueError:
|
||||
raise MalformedRangeHeader()
|
||||
|
||||
units = units.strip().lower()
|
||||
|
||||
if units != "bytes":
|
||||
raise MalformedRangeHeader("Only support bytes range")
|
||||
|
||||
ranges = [
|
||||
(
|
||||
int(_[0]) if _[0] else file_size - int(_[1]),
|
||||
int(_[1]) + 1 if _[0] and _[1] and int(_[1]) < file_size else file_size,
|
||||
)
|
||||
for _ in _RANGE_PATTERN.findall(range_)
|
||||
if _ != ("", "")
|
||||
]
|
||||
|
||||
if len(ranges) == 0:
|
||||
raise MalformedRangeHeader("Range header: range must be requested")
|
||||
|
||||
if any(not (0 <= start < file_size) for start, _ in ranges):
|
||||
raise RangeNotSatisfiable(file_size)
|
||||
|
||||
if any(start > end for start, end in ranges):
|
||||
raise MalformedRangeHeader("Range header: start must be less than end")
|
||||
|
||||
if len(ranges) == 1:
|
||||
return ranges
|
||||
|
||||
# Merge ranges
|
||||
result: list[tuple[int, int]] = []
|
||||
for start, end in ranges:
|
||||
for p in range(len(result)):
|
||||
p_start, p_end = result[p]
|
||||
if start > p_end:
|
||||
continue
|
||||
elif end < p_start:
|
||||
result.insert(p, (start, end)) # THIS IS NOT REACHED!
|
||||
break
|
||||
else:
|
||||
result[p] = (min(start, p_start), max(end, p_end))
|
||||
break
|
||||
else:
|
||||
result.append((start, end))
|
||||
|
||||
return result
|
||||
|
||||
def generate_multipart(
|
||||
self,
|
||||
ranges: typing.Sequence[tuple[int, int]],
|
||||
boundary: str,
|
||||
max_size: int,
|
||||
content_type: str,
|
||||
) -> tuple[int, typing.Callable[[int, int], bytes]]:
|
||||
r"""
|
||||
Multipart response headers generator.
|
||||
|
||||
```
|
||||
--{boundary}\n
|
||||
Content-Type: {content_type}\n
|
||||
Content-Range: bytes {start}-{end-1}/{max_size}\n
|
||||
\n
|
||||
..........content...........\n
|
||||
--{boundary}\n
|
||||
Content-Type: {content_type}\n
|
||||
Content-Range: bytes {start}-{end-1}/{max_size}\n
|
||||
\n
|
||||
..........content...........\n
|
||||
--{boundary}--\n
|
||||
```
|
||||
"""
|
||||
boundary_len = len(boundary)
|
||||
static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size))
|
||||
content_length = sum(
|
||||
(len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers
|
||||
+ (end - start) # Content
|
||||
for start, end in ranges
|
||||
) + (
|
||||
5 + boundary_len # --boundary--\n
|
||||
)
|
||||
return (
|
||||
content_length,
|
||||
lambda start, end: (
|
||||
f"--{boundary}\nContent-Type: {content_type}\nContent-Range: bytes {start}-{end - 1}/{max_size}\n\n"
|
||||
).encode("latin-1"),
|
||||
)
|
874
venv/lib/python3.11/site-packages/starlette/routing.py
Normal file
874
venv/lib/python3.11/site-packages/starlette/routing.py
Normal file
@ -0,0 +1,874 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import traceback
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
|
||||
from starlette._exception_handler import wrap_app_handling_exceptions
|
||||
from starlette._utils import get_route_path, is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.convertors import CONVERTOR_TYPES, Convertor
|
||||
from starlette.datastructures import URL, Headers, URLPath
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket, WebSocketClose
|
||||
|
||||
|
||||
class NoMatchFound(Exception):
|
||||
"""
|
||||
Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
|
||||
if no matching route exists.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None:
|
||||
params = ", ".join(list(path_params.keys()))
|
||||
super().__init__(f'No route exists for name "{name}" and params "{params}".')
|
||||
|
||||
|
||||
class Match(Enum):
|
||||
NONE = 0
|
||||
PARTIAL = 1
|
||||
FULL = 2
|
||||
|
||||
|
||||
def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
|
||||
"""
|
||||
Correctly determines if an object is a coroutine function,
|
||||
including those wrapped in functools.partial objects.
|
||||
"""
|
||||
warnings.warn(
|
||||
"iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
while isinstance(obj, functools.partial):
|
||||
obj = obj.func
|
||||
return inspect.iscoroutinefunction(obj)
|
||||
|
||||
|
||||
def request_response(
|
||||
func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a function or coroutine `func(request) -> response`,
|
||||
and returns an ASGI application.
|
||||
"""
|
||||
f: typing.Callable[[Request], typing.Awaitable[Response]] = (
|
||||
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
|
||||
)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = Request(scope, receive, send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
response = await f(request)
|
||||
await response(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def websocket_session(
|
||||
func: typing.Callable[[WebSocket], typing.Awaitable[None]],
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a coroutine `func(session)`, and returns an ASGI application.
|
||||
"""
|
||||
# assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
session = WebSocket(scope, receive=receive, send=send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await func(session)
|
||||
|
||||
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_name(endpoint: typing.Callable[..., typing.Any]) -> str:
|
||||
return getattr(endpoint, "__name__", endpoint.__class__.__name__)
|
||||
|
||||
|
||||
def replace_params(
|
||||
path: str,
|
||||
param_convertors: dict[str, Convertor[typing.Any]],
|
||||
path_params: dict[str, str],
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
for key, value in list(path_params.items()):
|
||||
if "{" + key + "}" in path:
|
||||
convertor = param_convertors[key]
|
||||
value = convertor.to_string(value)
|
||||
path = path.replace("{" + key + "}", value)
|
||||
path_params.pop(key)
|
||||
return path, path_params
|
||||
|
||||
|
||||
# Match parameters in URL paths, eg. '{param}', and '{param:int}'
|
||||
PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
|
||||
|
||||
|
||||
def compile_path(
|
||||
path: str,
|
||||
) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]:
|
||||
"""
|
||||
Given a path string, like: "/{username:str}",
|
||||
or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
|
||||
of (regex, format, {param_name:convertor}).
|
||||
|
||||
regex: "/(?P<username>[^/]+)"
|
||||
format: "/{username}"
|
||||
convertors: {"username": StringConvertor()}
|
||||
"""
|
||||
is_host = not path.startswith("/")
|
||||
|
||||
path_regex = "^"
|
||||
path_format = ""
|
||||
duplicated_params = set()
|
||||
|
||||
idx = 0
|
||||
param_convertors = {}
|
||||
for match in PARAM_REGEX.finditer(path):
|
||||
param_name, convertor_type = match.groups("str")
|
||||
convertor_type = convertor_type.lstrip(":")
|
||||
assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
|
||||
convertor = CONVERTOR_TYPES[convertor_type]
|
||||
|
||||
path_regex += re.escape(path[idx : match.start()])
|
||||
path_regex += f"(?P<{param_name}>{convertor.regex})"
|
||||
|
||||
path_format += path[idx : match.start()]
|
||||
path_format += "{%s}" % param_name
|
||||
|
||||
if param_name in param_convertors:
|
||||
duplicated_params.add(param_name)
|
||||
|
||||
param_convertors[param_name] = convertor
|
||||
|
||||
idx = match.end()
|
||||
|
||||
if duplicated_params:
|
||||
names = ", ".join(sorted(duplicated_params))
|
||||
ending = "s" if len(duplicated_params) > 1 else ""
|
||||
raise ValueError(f"Duplicated param name{ending} {names} at path {path}")
|
||||
|
||||
if is_host:
|
||||
# Align with `Host.matches()` behavior, which ignores port.
|
||||
hostname = path[idx:].split(":")[0]
|
||||
path_regex += re.escape(hostname) + "$"
|
||||
else:
|
||||
path_regex += re.escape(path[idx:]) + "$"
|
||||
|
||||
path_format += path[idx:]
|
||||
|
||||
return re.compile(path_regex), path_format, param_convertors
|
||||
|
||||
|
||||
class BaseRoute:
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""
|
||||
A route may be used in isolation as a stand-alone ASGI app.
|
||||
This is a somewhat contrived case, as they'll almost always be used
|
||||
within a Router, but could be useful for some tooling and minimal apps.
|
||||
"""
|
||||
match, child_scope = self.matches(scope)
|
||||
if match == Match.NONE:
|
||||
if scope["type"] == "http":
|
||||
response = PlainTextResponse("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
elif scope["type"] == "websocket": # pragma: no branch
|
||||
websocket_close = WebSocketClose()
|
||||
await websocket_close(scope, receive, send)
|
||||
return
|
||||
|
||||
scope.update(child_scope)
|
||||
await self.handle(scope, receive, send)
|
||||
|
||||
|
||||
class Route(BaseRoute):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable[..., typing.Any],
|
||||
*,
|
||||
methods: list[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
middleware: typing.Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path.startswith("/"), "Routed paths must start with '/'"
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
self.include_in_schema = include_in_schema
|
||||
|
||||
endpoint_handler = endpoint
|
||||
while isinstance(endpoint_handler, functools.partial):
|
||||
endpoint_handler = endpoint_handler.func
|
||||
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
|
||||
# Endpoint is function or method. Treat it as `func(request) -> response`.
|
||||
self.app = request_response(endpoint)
|
||||
if methods is None:
|
||||
methods = ["GET"]
|
||||
else:
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
|
||||
if methods is None:
|
||||
self.methods = None
|
||||
else:
|
||||
self.methods = {method.upper() for method in methods}
|
||||
if "GET" in self.methods:
|
||||
self.methods.add("HEAD")
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, typing.Any]
|
||||
if scope["type"] == "http":
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {"endpoint": self.endpoint, "path_params": path_params}
|
||||
if self.methods and scope["method"] not in self.methods:
|
||||
return Match.PARTIAL, child_scope
|
||||
else:
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
|
||||
seen_params = set(path_params.keys())
|
||||
expected_params = set(self.param_convertors.keys())
|
||||
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="http")
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.methods and scope["method"] not in self.methods:
|
||||
headers = {"Allow": ", ".join(self.methods)}
|
||||
if "app" in scope:
|
||||
raise HTTPException(status_code=405, headers=headers)
|
||||
else:
|
||||
response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Route)
|
||||
and self.path == other.path
|
||||
and self.endpoint == other.endpoint
|
||||
and self.methods == other.methods
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
methods = sorted(self.methods or [])
|
||||
path, name = self.path, self.name
|
||||
return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})"
|
||||
|
||||
|
||||
class WebSocketRoute(BaseRoute):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable[..., typing.Any],
|
||||
*,
|
||||
name: str | None = None,
|
||||
middleware: typing.Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path.startswith("/"), "Routed paths must start with '/'"
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
|
||||
endpoint_handler = endpoint
|
||||
while isinstance(endpoint_handler, functools.partial):
|
||||
endpoint_handler = endpoint_handler.func
|
||||
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
|
||||
# Endpoint is function or method. Treat it as `func(websocket)`.
|
||||
self.app = websocket_session(endpoint)
|
||||
else:
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, typing.Any]
|
||||
if scope["type"] == "websocket":
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {"endpoint": self.endpoint, "path_params": path_params}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
|
||||
seen_params = set(path_params.keys())
|
||||
expected_params = set(self.param_convertors.keys())
|
||||
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="websocket")
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
|
||||
|
||||
|
||||
class Mount(BaseRoute):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
app: ASGIApp | None = None,
|
||||
routes: typing.Sequence[BaseRoute] | None = None,
|
||||
name: str | None = None,
|
||||
*,
|
||||
middleware: typing.Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
|
||||
assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
|
||||
self.path = path.rstrip("/")
|
||||
if app is not None:
|
||||
self._base_app: ASGIApp = app
|
||||
else:
|
||||
self._base_app = Router(routes=routes)
|
||||
self.app = self._base_app
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
self.name = name
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
|
||||
|
||||
@property
|
||||
def routes(self) -> list[BaseRoute]:
|
||||
return getattr(self._base_app, "routes", [])
|
||||
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, typing.Any]
|
||||
if scope["type"] in ("http", "websocket"): # pragma: no branch
|
||||
root_path = scope.get("root_path", "")
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
remaining_path = "/" + matched_params.pop("path")
|
||||
matched_path = route_path[: -len(remaining_path)]
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {
|
||||
"path_params": path_params,
|
||||
# app_root_path will only be set at the top level scope,
|
||||
# initialized with the (optional) value of a root_path
|
||||
# set above/before Starlette. And even though any
|
||||
# mount will have its own child scope with its own respective
|
||||
# root_path, the app_root_path will always be available in all
|
||||
# the child scopes with the same top level value because it's
|
||||
# set only once here with a default, any other child scope will
|
||||
# just inherit that app_root_path default value stored in the
|
||||
# scope. All this is needed to support Request.url_for(), as it
|
||||
# uses the app_root_path to build the URL path.
|
||||
"app_root_path": scope.get("app_root_path", root_path),
|
||||
"root_path": root_path + matched_path,
|
||||
"endpoint": self.app,
|
||||
}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path_params["path"] = path_params["path"].lstrip("/")
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path)
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
if self.name is None:
|
||||
# No mount name.
|
||||
remaining_name = name
|
||||
else:
|
||||
# 'name' matches "<mount_name>:<child_name>".
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
path_kwarg = path_params.get("path")
|
||||
path_params["path"] = ""
|
||||
path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
if path_kwarg is not None:
|
||||
remaining_params["path"] = path_kwarg
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, Mount) and self.path == other.path and self.app == other.app
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
name = self.name or ""
|
||||
return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})"
|
||||
|
||||
|
||||
class Host(BaseRoute):
|
||||
def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None:
|
||||
assert not host.startswith("/"), "Host must not start with '/'"
|
||||
self.host = host
|
||||
self.app = app
|
||||
self.name = name
|
||||
self.host_regex, self.host_format, self.param_convertors = compile_path(host)
|
||||
|
||||
@property
|
||||
def routes(self) -> list[BaseRoute]:
|
||||
return getattr(self.app, "routes", [])
|
||||
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"): # pragma:no branch
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
match = self.host_regex.match(host)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {"path_params": path_params, "endpoint": self.app}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path = path_params.pop("path")
|
||||
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path, host=host)
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
if self.name is None:
|
||||
# No mount name.
|
||||
remaining_name = name
|
||||
else:
|
||||
# 'name' matches "<mount_name>:<child_name>".
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
return URLPath(path=str(url), protocol=url.protocol, host=host)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, Host) and self.host == other.host and self.app == other.app
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
name = self.name or ""
|
||||
return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
|
||||
|
||||
|
||||
_T = typing.TypeVar("_T")
|
||||
|
||||
|
||||
class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
|
||||
def __init__(self, cm: typing.ContextManager[_T]):
|
||||
self._cm = cm
|
||||
|
||||
async def __aenter__(self) -> _T:
|
||||
return self._cm.__enter__()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
return self._cm.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
def _wrap_gen_lifespan_context(
|
||||
lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]],
|
||||
) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
|
||||
cmgr = contextlib.contextmanager(lifespan_context)
|
||||
|
||||
@functools.wraps(cmgr)
|
||||
def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]:
|
||||
return _AsyncLiftContextManager(cmgr(app))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class _DefaultLifespan:
|
||||
def __init__(self, router: Router):
|
||||
self._router = router
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self._router.startup()
|
||||
|
||||
async def __aexit__(self, *exc_info: object) -> None:
|
||||
await self._router.shutdown()
|
||||
|
||||
def __call__(self: _T, app: object) -> _T:
|
||||
return self
|
||||
|
||||
|
||||
class Router:
|
||||
def __init__(
|
||||
self,
|
||||
routes: typing.Sequence[BaseRoute] | None = None,
|
||||
redirect_slashes: bool = True,
|
||||
default: ASGIApp | None = None,
|
||||
on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
|
||||
on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
|
||||
# the generic to Lifespan[AppType] is the type of the top level application
|
||||
# which the router cannot know statically, so we use typing.Any
|
||||
lifespan: Lifespan[typing.Any] | None = None,
|
||||
*,
|
||||
middleware: typing.Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
self.routes = [] if routes is None else list(routes)
|
||||
self.redirect_slashes = redirect_slashes
|
||||
self.default = self.not_found if default is None else default
|
||||
self.on_startup = [] if on_startup is None else list(on_startup)
|
||||
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
|
||||
|
||||
if on_startup or on_shutdown:
|
||||
warnings.warn(
|
||||
"The on_startup and on_shutdown parameters are deprecated, and they "
|
||||
"will be removed on version 1.0. Use the lifespan parameter instead. "
|
||||
"See more about it on https://www.starlette.io/lifespan/.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if lifespan:
|
||||
warnings.warn(
|
||||
"The `lifespan` parameter cannot be used with `on_startup` or "
|
||||
"`on_shutdown`. Both `on_startup` and `on_shutdown` will be "
|
||||
"ignored."
|
||||
)
|
||||
|
||||
if lifespan is None:
|
||||
self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self)
|
||||
|
||||
elif inspect.isasyncgenfunction(lifespan):
|
||||
warnings.warn(
|
||||
"async generator function lifespans are deprecated, "
|
||||
"use an @contextlib.asynccontextmanager function instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.lifespan_context = asynccontextmanager(
|
||||
lifespan,
|
||||
)
|
||||
elif inspect.isgeneratorfunction(lifespan):
|
||||
warnings.warn(
|
||||
"generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.lifespan_context = _wrap_gen_lifespan_context(
|
||||
lifespan,
|
||||
)
|
||||
else:
|
||||
self.lifespan_context = lifespan
|
||||
|
||||
self.middleware_stack = self.app
|
||||
if middleware:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
|
||||
|
||||
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "websocket":
|
||||
websocket_close = WebSocketClose()
|
||||
await websocket_close(scope, receive, send)
|
||||
return
|
||||
|
||||
# If we're running inside a starlette application then raise an
|
||||
# exception, so that the configurable exception handler can deal with
|
||||
# returning the response. For plain ASGI apps, just return the response.
|
||||
if "app" in scope:
|
||||
raise HTTPException(status_code=404)
|
||||
else:
|
||||
response = PlainTextResponse("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
|
||||
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
|
||||
for route in self.routes:
|
||||
try:
|
||||
return route.url_path_for(name, **path_params)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
async def startup(self) -> None:
|
||||
"""
|
||||
Run any `.on_startup` event handlers.
|
||||
"""
|
||||
for handler in self.on_startup:
|
||||
if is_async_callable(handler):
|
||||
await handler()
|
||||
else:
|
||||
handler()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Run any `.on_shutdown` event handlers.
|
||||
"""
|
||||
for handler in self.on_shutdown:
|
||||
if is_async_callable(handler):
|
||||
await handler()
|
||||
else:
|
||||
handler()
|
||||
|
||||
async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""
|
||||
Handle ASGI lifespan messages, which allows us to manage application
|
||||
startup and shutdown events.
|
||||
"""
|
||||
started = False
|
||||
app: typing.Any = scope.get("app")
|
||||
await receive()
|
||||
try:
|
||||
async with self.lifespan_context(app) as maybe_state:
|
||||
if maybe_state is not None:
|
||||
if "state" not in scope:
|
||||
raise RuntimeError('The server does not support "state" in the lifespan scope.')
|
||||
scope["state"].update(maybe_state)
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
started = True
|
||||
await receive()
|
||||
except BaseException:
|
||||
exc_text = traceback.format_exc()
|
||||
if started:
|
||||
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
|
||||
else:
|
||||
await send({"type": "lifespan.startup.failed", "message": exc_text})
|
||||
raise
|
||||
else:
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""
|
||||
The main entry point to the Router class.
|
||||
"""
|
||||
await self.middleware_stack(scope, receive, send)
|
||||
|
||||
async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] in ("http", "websocket", "lifespan")
|
||||
|
||||
if "router" not in scope:
|
||||
scope["router"] = self
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
await self.lifespan(scope, receive, send)
|
||||
return
|
||||
|
||||
partial = None
|
||||
|
||||
for route in self.routes:
|
||||
# Determine if any route matches the incoming scope,
|
||||
# and hand over to the matching route if found.
|
||||
match, child_scope = route.matches(scope)
|
||||
if match == Match.FULL:
|
||||
scope.update(child_scope)
|
||||
await route.handle(scope, receive, send)
|
||||
return
|
||||
elif match == Match.PARTIAL and partial is None:
|
||||
partial = route
|
||||
partial_scope = child_scope
|
||||
|
||||
if partial is not None:
|
||||
# Handle partial matches. These are cases where an endpoint is
|
||||
# able to handle the request, but is not a preferred option.
|
||||
# We use this in particular to deal with "405 Method Not Allowed".
|
||||
scope.update(partial_scope)
|
||||
await partial.handle(scope, receive, send)
|
||||
return
|
||||
|
||||
route_path = get_route_path(scope)
|
||||
if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
|
||||
redirect_scope = dict(scope)
|
||||
if route_path.endswith("/"):
|
||||
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
|
||||
else:
|
||||
redirect_scope["path"] = redirect_scope["path"] + "/"
|
||||
|
||||
for route in self.routes:
|
||||
match, child_scope = route.matches(redirect_scope)
|
||||
if match != Match.NONE:
|
||||
redirect_url = URL(scope=redirect_scope)
|
||||
response = RedirectResponse(url=str(redirect_url))
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.default(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, Router) and self.routes == other.routes
|
||||
|
||||
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
|
||||
route = Mount(path, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
|
||||
route = Host(host, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response],
|
||||
methods: list[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None: # pragma: no cover
|
||||
route = Route(
|
||||
path,
|
||||
endpoint=endpoint,
|
||||
methods=methods,
|
||||
name=name,
|
||||
include_in_schema=include_in_schema,
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
def add_websocket_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]],
|
||||
name: str | None = None,
|
||||
) -> None: # pragma: no cover
|
||||
route = WebSocketRoute(path, endpoint=endpoint, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def route(
|
||||
self,
|
||||
path: str,
|
||||
methods: list[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> typing.Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
|
||||
>>> routes = [Route(path, endpoint=...), ...]
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `route` decorator is deprecated, and will be removed in version 1.0.0."
|
||||
"Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_route(
|
||||
path,
|
||||
func,
|
||||
methods=methods,
|
||||
name=name,
|
||||
include_in_schema=include_in_schema,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
|
||||
>>> routes = [WebSocketRoute(path, endpoint=...), ...]
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
|
||||
"https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover
|
||||
assert event_type in ("startup", "shutdown")
|
||||
|
||||
if event_type == "startup":
|
||||
self.on_startup.append(func)
|
||||
else:
|
||||
self.on_shutdown.append(func)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_event_handler(event_type, func)
|
||||
return func
|
||||
|
||||
return decorator
|
147
venv/lib/python3.11/site-packages/starlette/schemas.py
Normal file
147
venv/lib/python3.11/site-packages/starlette/schemas.py
Normal file
@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import typing
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import BaseRoute, Host, Mount, Route
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
yaml = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class OpenAPIResponse(Response):
|
||||
media_type = "application/vnd.oai.openapi"
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
|
||||
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
|
||||
return yaml.dump(content, default_flow_style=False).encode("utf-8")
|
||||
|
||||
|
||||
class EndpointInfo(typing.NamedTuple):
|
||||
path: str
|
||||
http_method: str
|
||||
func: typing.Callable[..., typing.Any]
|
||||
|
||||
|
||||
_remove_converter_pattern = re.compile(r":\w+}")
|
||||
|
||||
|
||||
class BaseSchemaGenerator:
|
||||
def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
|
||||
"""
|
||||
Given the routes, yields the following information:
|
||||
|
||||
- path
|
||||
eg: /users/
|
||||
- http_method
|
||||
one of 'get', 'post', 'put', 'patch', 'delete', 'options'
|
||||
- func
|
||||
method ready to extract the docstring
|
||||
"""
|
||||
endpoints_info: list[EndpointInfo] = []
|
||||
|
||||
for route in routes:
|
||||
if isinstance(route, (Mount, Host)):
|
||||
routes = route.routes or []
|
||||
if isinstance(route, Mount):
|
||||
path = self._remove_converter(route.path)
|
||||
else:
|
||||
path = ""
|
||||
sub_endpoints = [
|
||||
EndpointInfo(
|
||||
path="".join((path, sub_endpoint.path)),
|
||||
http_method=sub_endpoint.http_method,
|
||||
func=sub_endpoint.func,
|
||||
)
|
||||
for sub_endpoint in self.get_endpoints(routes)
|
||||
]
|
||||
endpoints_info.extend(sub_endpoints)
|
||||
|
||||
elif not isinstance(route, Route) or not route.include_in_schema:
|
||||
continue
|
||||
|
||||
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
|
||||
path = self._remove_converter(route.path)
|
||||
for method in route.methods or ["GET"]:
|
||||
if method == "HEAD":
|
||||
continue
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
|
||||
else:
|
||||
path = self._remove_converter(route.path)
|
||||
for method in ["get", "post", "put", "patch", "delete", "options"]:
|
||||
if not hasattr(route.endpoint, method):
|
||||
continue
|
||||
func = getattr(route.endpoint, method)
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), func))
|
||||
|
||||
return endpoints_info
|
||||
|
||||
def _remove_converter(self, path: str) -> str:
|
||||
"""
|
||||
Remove the converter from the path.
|
||||
For example, a route like this:
|
||||
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
|
||||
Should be represented as `/users/{id}` in the OpenAPI schema.
|
||||
"""
|
||||
return _remove_converter_pattern.sub("}", path)
|
||||
|
||||
def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
|
||||
"""
|
||||
Given a function, parse the docstring as YAML and return a dictionary of info.
|
||||
"""
|
||||
docstring = func_or_method.__doc__
|
||||
if not docstring:
|
||||
return {}
|
||||
|
||||
assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
|
||||
|
||||
# We support having regular docstrings before the schema
|
||||
# definition. Here we return just the schema part from
|
||||
# the docstring.
|
||||
docstring = docstring.split("---")[-1]
|
||||
|
||||
parsed = yaml.safe_load(docstring)
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
# A regular docstring (not yaml formatted) can return
|
||||
# a simple string here, which wouldn't follow the schema.
|
||||
return {}
|
||||
|
||||
return parsed
|
||||
|
||||
def OpenAPIResponse(self, request: Request) -> Response:
|
||||
routes = request.app.routes
|
||||
schema = self.get_schema(routes=routes)
|
||||
return OpenAPIResponse(schema)
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
def __init__(self, base_schema: dict[str, typing.Any]) -> None:
|
||||
self.base_schema = base_schema
|
||||
|
||||
def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault("paths", {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
|
||||
for endpoint in endpoints_info:
|
||||
parsed = self.parse_docstring(endpoint.func)
|
||||
|
||||
if not parsed:
|
||||
continue
|
||||
|
||||
if endpoint.path not in schema["paths"]:
|
||||
schema["paths"][endpoint.path] = {}
|
||||
|
||||
schema["paths"][endpoint.path][endpoint.http_method] = parsed
|
||||
|
||||
return schema
|
220
venv/lib/python3.11/site-packages/starlette/staticfiles.py
Normal file
220
venv/lib/python3.11/site-packages/starlette/staticfiles.py
Normal file
@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import importlib.util
|
||||
import os
|
||||
import stat
|
||||
import typing
|
||||
from email.utils import parsedate
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
|
||||
from starlette._utils import get_route_path
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.responses import FileResponse, RedirectResponse, Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
PathLike = typing.Union[str, "os.PathLike[str]"]
|
||||
|
||||
|
||||
class NotModifiedResponse(Response):
|
||||
NOT_MODIFIED_HEADERS = (
|
||||
"cache-control",
|
||||
"content-location",
|
||||
"date",
|
||||
"etag",
|
||||
"expires",
|
||||
"vary",
|
||||
)
|
||||
|
||||
def __init__(self, headers: Headers):
|
||||
super().__init__(
|
||||
status_code=304,
|
||||
headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
|
||||
)
|
||||
|
||||
|
||||
class StaticFiles:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
directory: PathLike | None = None,
|
||||
packages: list[str | tuple[str, str]] | None = None,
|
||||
html: bool = False,
|
||||
check_dir: bool = True,
|
||||
follow_symlink: bool = False,
|
||||
) -> None:
|
||||
self.directory = directory
|
||||
self.packages = packages
|
||||
self.all_directories = self.get_directories(directory, packages)
|
||||
self.html = html
|
||||
self.config_checked = False
|
||||
self.follow_symlink = follow_symlink
|
||||
if check_dir and directory is not None and not os.path.isdir(directory):
|
||||
raise RuntimeError(f"Directory '{directory}' does not exist")
|
||||
|
||||
def get_directories(
|
||||
self,
|
||||
directory: PathLike | None = None,
|
||||
packages: list[str | tuple[str, str]] | None = None,
|
||||
) -> list[PathLike]:
|
||||
"""
|
||||
Given `directory` and `packages` arguments, return a list of all the
|
||||
directories that should be used for serving static files from.
|
||||
"""
|
||||
directories = []
|
||||
if directory is not None:
|
||||
directories.append(directory)
|
||||
|
||||
for package in packages or []:
|
||||
if isinstance(package, tuple):
|
||||
package, statics_dir = package
|
||||
else:
|
||||
statics_dir = "statics"
|
||||
spec = importlib.util.find_spec(package)
|
||||
assert spec is not None, f"Package {package!r} could not be found."
|
||||
assert spec.origin is not None, f"Package {package!r} could not be found."
|
||||
package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
|
||||
assert os.path.isdir(package_directory), (
|
||||
f"Directory '{statics_dir!r}' in package {package!r} could not be found."
|
||||
)
|
||||
directories.append(package_directory)
|
||||
|
||||
return directories
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""
|
||||
The ASGI entry point.
|
||||
"""
|
||||
assert scope["type"] == "http"
|
||||
|
||||
if not self.config_checked:
|
||||
await self.check_config()
|
||||
self.config_checked = True
|
||||
|
||||
path = self.get_path(scope)
|
||||
response = await self.get_response(path, scope)
|
||||
await response(scope, receive, send)
|
||||
|
||||
def get_path(self, scope: Scope) -> str:
|
||||
"""
|
||||
Given the ASGI scope, return the `path` string to serve up,
|
||||
with OS specific path separators, and any '..', '.' components removed.
|
||||
"""
|
||||
route_path = get_route_path(scope)
|
||||
return os.path.normpath(os.path.join(*route_path.split("/")))
|
||||
|
||||
async def get_response(self, path: str, scope: Scope) -> Response:
|
||||
"""
|
||||
Returns an HTTP response, given the incoming path, method and request headers.
|
||||
"""
|
||||
if scope["method"] not in ("GET", "HEAD"):
|
||||
raise HTTPException(status_code=405)
|
||||
|
||||
try:
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
|
||||
except PermissionError:
|
||||
raise HTTPException(status_code=401)
|
||||
except OSError as exc:
|
||||
# Filename is too long, so it can't be a valid static file.
|
||||
if exc.errno == errno.ENAMETOOLONG:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
raise exc
|
||||
|
||||
if stat_result and stat.S_ISREG(stat_result.st_mode):
|
||||
# We have a static file to serve.
|
||||
return self.file_response(full_path, stat_result, scope)
|
||||
|
||||
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
|
||||
# We're in HTML mode, and have got a directory URL.
|
||||
# Check if we have 'index.html' file to serve.
|
||||
index_path = os.path.join(path, "index.html")
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
|
||||
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
|
||||
if not scope["path"].endswith("/"):
|
||||
# Directory URLs should redirect to always end in "/".
|
||||
url = URL(scope=scope)
|
||||
url = url.replace(path=url.path + "/")
|
||||
return RedirectResponse(url=url)
|
||||
return self.file_response(full_path, stat_result, scope)
|
||||
|
||||
if self.html:
|
||||
# Check for '404.html' if we're in HTML mode.
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
|
||||
if stat_result and stat.S_ISREG(stat_result.st_mode):
|
||||
return FileResponse(full_path, stat_result=stat_result, status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
|
||||
for directory in self.all_directories:
|
||||
joined_path = os.path.join(directory, path)
|
||||
if self.follow_symlink:
|
||||
full_path = os.path.abspath(joined_path)
|
||||
directory = os.path.abspath(directory)
|
||||
else:
|
||||
full_path = os.path.realpath(joined_path)
|
||||
directory = os.path.realpath(directory)
|
||||
if os.path.commonpath([full_path, directory]) != str(directory):
|
||||
# Don't allow misbehaving clients to break out of the static files directory.
|
||||
continue
|
||||
try:
|
||||
return full_path, os.stat(full_path)
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
continue
|
||||
return "", None
|
||||
|
||||
def file_response(
|
||||
self,
|
||||
full_path: PathLike,
|
||||
stat_result: os.stat_result,
|
||||
scope: Scope,
|
||||
status_code: int = 200,
|
||||
) -> Response:
|
||||
request_headers = Headers(scope=scope)
|
||||
|
||||
response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
|
||||
if self.is_not_modified(response.headers, request_headers):
|
||||
return NotModifiedResponse(response.headers)
|
||||
return response
|
||||
|
||||
async def check_config(self) -> None:
|
||||
"""
|
||||
Perform a one-off configuration check that StaticFiles is actually
|
||||
pointed at a directory, so that we can raise loud errors rather than
|
||||
just returning 404 responses.
|
||||
"""
|
||||
if self.directory is None:
|
||||
return
|
||||
|
||||
try:
|
||||
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
|
||||
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
|
||||
raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
|
||||
|
||||
def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
|
||||
"""
|
||||
Given the request and response headers, return `True` if an HTTP
|
||||
"Not Modified" response could be returned instead.
|
||||
"""
|
||||
try:
|
||||
if_none_match = request_headers["if-none-match"]
|
||||
etag = response_headers["etag"]
|
||||
if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
if_modified_since = parsedate(request_headers["if-modified-since"])
|
||||
last_modified = parsedate(response_headers["last-modified"])
|
||||
if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return False
|
95
venv/lib/python3.11/site-packages/starlette/status.py
Normal file
95
venv/lib/python3.11/site-packages/starlette/status.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""
|
||||
HTTP codes
|
||||
See HTTP Status Code Registry:
|
||||
https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
|
||||
|
||||
And RFC 2324 - https://tools.ietf.org/html/rfc2324
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
HTTP_100_CONTINUE = 100
|
||||
HTTP_101_SWITCHING_PROTOCOLS = 101
|
||||
HTTP_102_PROCESSING = 102
|
||||
HTTP_103_EARLY_HINTS = 103
|
||||
HTTP_200_OK = 200
|
||||
HTTP_201_CREATED = 201
|
||||
HTTP_202_ACCEPTED = 202
|
||||
HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203
|
||||
HTTP_204_NO_CONTENT = 204
|
||||
HTTP_205_RESET_CONTENT = 205
|
||||
HTTP_206_PARTIAL_CONTENT = 206
|
||||
HTTP_207_MULTI_STATUS = 207
|
||||
HTTP_208_ALREADY_REPORTED = 208
|
||||
HTTP_226_IM_USED = 226
|
||||
HTTP_300_MULTIPLE_CHOICES = 300
|
||||
HTTP_301_MOVED_PERMANENTLY = 301
|
||||
HTTP_302_FOUND = 302
|
||||
HTTP_303_SEE_OTHER = 303
|
||||
HTTP_304_NOT_MODIFIED = 304
|
||||
HTTP_305_USE_PROXY = 305
|
||||
HTTP_306_RESERVED = 306
|
||||
HTTP_307_TEMPORARY_REDIRECT = 307
|
||||
HTTP_308_PERMANENT_REDIRECT = 308
|
||||
HTTP_400_BAD_REQUEST = 400
|
||||
HTTP_401_UNAUTHORIZED = 401
|
||||
HTTP_402_PAYMENT_REQUIRED = 402
|
||||
HTTP_403_FORBIDDEN = 403
|
||||
HTTP_404_NOT_FOUND = 404
|
||||
HTTP_405_METHOD_NOT_ALLOWED = 405
|
||||
HTTP_406_NOT_ACCEPTABLE = 406
|
||||
HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407
|
||||
HTTP_408_REQUEST_TIMEOUT = 408
|
||||
HTTP_409_CONFLICT = 409
|
||||
HTTP_410_GONE = 410
|
||||
HTTP_411_LENGTH_REQUIRED = 411
|
||||
HTTP_412_PRECONDITION_FAILED = 412
|
||||
HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413
|
||||
HTTP_414_REQUEST_URI_TOO_LONG = 414
|
||||
HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415
|
||||
HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416
|
||||
HTTP_417_EXPECTATION_FAILED = 417
|
||||
HTTP_418_IM_A_TEAPOT = 418
|
||||
HTTP_421_MISDIRECTED_REQUEST = 421
|
||||
HTTP_422_UNPROCESSABLE_ENTITY = 422
|
||||
HTTP_423_LOCKED = 423
|
||||
HTTP_424_FAILED_DEPENDENCY = 424
|
||||
HTTP_425_TOO_EARLY = 425
|
||||
HTTP_426_UPGRADE_REQUIRED = 426
|
||||
HTTP_428_PRECONDITION_REQUIRED = 428
|
||||
HTTP_429_TOO_MANY_REQUESTS = 429
|
||||
HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431
|
||||
HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = 500
|
||||
HTTP_501_NOT_IMPLEMENTED = 501
|
||||
HTTP_502_BAD_GATEWAY = 502
|
||||
HTTP_503_SERVICE_UNAVAILABLE = 503
|
||||
HTTP_504_GATEWAY_TIMEOUT = 504
|
||||
HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505
|
||||
HTTP_506_VARIANT_ALSO_NEGOTIATES = 506
|
||||
HTTP_507_INSUFFICIENT_STORAGE = 507
|
||||
HTTP_508_LOOP_DETECTED = 508
|
||||
HTTP_510_NOT_EXTENDED = 510
|
||||
HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511
|
||||
|
||||
|
||||
"""
|
||||
WebSocket codes
|
||||
https://www.iana.org/assignments/websocket/websocket.xml#close-code-number
|
||||
https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent
|
||||
"""
|
||||
WS_1000_NORMAL_CLOSURE = 1000
|
||||
WS_1001_GOING_AWAY = 1001
|
||||
WS_1002_PROTOCOL_ERROR = 1002
|
||||
WS_1003_UNSUPPORTED_DATA = 1003
|
||||
WS_1005_NO_STATUS_RCVD = 1005
|
||||
WS_1006_ABNORMAL_CLOSURE = 1006
|
||||
WS_1007_INVALID_FRAME_PAYLOAD_DATA = 1007
|
||||
WS_1008_POLICY_VIOLATION = 1008
|
||||
WS_1009_MESSAGE_TOO_BIG = 1009
|
||||
WS_1010_MANDATORY_EXT = 1010
|
||||
WS_1011_INTERNAL_ERROR = 1011
|
||||
WS_1012_SERVICE_RESTART = 1012
|
||||
WS_1013_TRY_AGAIN_LATER = 1013
|
||||
WS_1014_BAD_GATEWAY = 1014
|
||||
WS_1015_TLS_HANDSHAKE = 1015
|
216
venv/lib/python3.11/site-packages/starlette/templating.py
Normal file
216
venv/lib/python3.11/site-packages/starlette/templating.py
Normal file
@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import warnings
|
||||
from os import PathLike
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.datastructures import URL
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
|
||||
# @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1
|
||||
# hence we try to get pass_context (most installs will be >=3.1)
|
||||
# and fall back to contextfunction,
|
||||
# adding a type ignore for mypy to let us access an attribute that may not exist
|
||||
if hasattr(jinja2, "pass_context"):
|
||||
pass_context = jinja2.pass_context
|
||||
else: # pragma: no cover
|
||||
pass_context = jinja2.contextfunction # type: ignore[attr-defined]
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
jinja2 = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class _TemplateResponse(HTMLResponse):
|
||||
def __init__(
|
||||
self,
|
||||
template: typing.Any,
|
||||
context: dict[str, typing.Any],
|
||||
status_code: int = 200,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
):
|
||||
self.template = template
|
||||
self.context = context
|
||||
content = template.render(context)
|
||||
super().__init__(content, status_code, headers, media_type, background)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = self.context.get("request", {})
|
||||
extensions = request.get("extensions", {})
|
||||
if "http.response.debug" in extensions: # pragma: no branch
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.debug",
|
||||
"info": {
|
||||
"template": self.template,
|
||||
"context": self.context,
|
||||
},
|
||||
}
|
||||
)
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
|
||||
class Jinja2Templates:
|
||||
"""
|
||||
templates = Jinja2Templates("templates")
|
||||
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
"""
|
||||
|
||||
@typing.overload
|
||||
def __init__(
|
||||
self,
|
||||
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
|
||||
*,
|
||||
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
|
||||
**env_options: typing.Any,
|
||||
) -> None: ...
|
||||
|
||||
@typing.overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env: jinja2.Environment,
|
||||
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None,
|
||||
*,
|
||||
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
|
||||
env: jinja2.Environment | None = None,
|
||||
**env_options: typing.Any,
|
||||
) -> None:
|
||||
if env_options:
|
||||
warnings.warn(
|
||||
"Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
|
||||
assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
|
||||
self.context_processors = context_processors or []
|
||||
if directory is not None:
|
||||
self.env = self._create_env(directory, **env_options)
|
||||
elif env is not None: # pragma: no branch
|
||||
self.env = env
|
||||
|
||||
self._setup_env_defaults(self.env)
|
||||
|
||||
def _create_env(
|
||||
self,
|
||||
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
|
||||
**env_options: typing.Any,
|
||||
) -> jinja2.Environment:
|
||||
loader = jinja2.FileSystemLoader(directory)
|
||||
env_options.setdefault("loader", loader)
|
||||
env_options.setdefault("autoescape", True)
|
||||
|
||||
return jinja2.Environment(**env_options)
|
||||
|
||||
def _setup_env_defaults(self, env: jinja2.Environment) -> None:
|
||||
@pass_context
|
||||
def url_for(
|
||||
context: dict[str, typing.Any],
|
||||
name: str,
|
||||
/,
|
||||
**path_params: typing.Any,
|
||||
) -> URL:
|
||||
request: Request = context["request"]
|
||||
return request.url_for(name, **path_params)
|
||||
|
||||
env.globals.setdefault("url_for", url_for)
|
||||
|
||||
def get_template(self, name: str) -> jinja2.Template:
|
||||
return self.env.get_template(name)
|
||||
|
||||
@typing.overload
|
||||
def TemplateResponse(
|
||||
self,
|
||||
request: Request,
|
||||
name: str,
|
||||
context: dict[str, typing.Any] | None = None,
|
||||
status_code: int = 200,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> _TemplateResponse: ...
|
||||
|
||||
@typing.overload
|
||||
def TemplateResponse(
|
||||
self,
|
||||
name: str,
|
||||
context: dict[str, typing.Any] | None = None,
|
||||
status_code: int = 200,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> _TemplateResponse:
|
||||
# Deprecated usage
|
||||
...
|
||||
|
||||
def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse:
|
||||
if args:
|
||||
if isinstance(args[0], str): # the first argument is template name (old style)
|
||||
warnings.warn(
|
||||
"The `name` is not the first parameter anymore. "
|
||||
"The first parameter should be the `Request` instance.\n"
|
||||
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
name = args[0]
|
||||
context = args[1] if len(args) > 1 else kwargs.get("context", {})
|
||||
status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
|
||||
headers = args[2] if len(args) > 2 else kwargs.get("headers")
|
||||
media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
|
||||
background = args[4] if len(args) > 4 else kwargs.get("background")
|
||||
|
||||
if "request" not in context:
|
||||
raise ValueError('context must include a "request" key')
|
||||
request = context["request"]
|
||||
else: # the first argument is a request instance (new style)
|
||||
request = args[0]
|
||||
name = args[1] if len(args) > 1 else kwargs["name"]
|
||||
context = args[2] if len(args) > 2 else kwargs.get("context", {})
|
||||
status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
|
||||
headers = args[4] if len(args) > 4 else kwargs.get("headers")
|
||||
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
|
||||
background = args[6] if len(args) > 6 else kwargs.get("background")
|
||||
else: # all arguments are kwargs
|
||||
if "request" not in kwargs:
|
||||
warnings.warn(
|
||||
"The `TemplateResponse` now requires the `request` argument.\n"
|
||||
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
if "request" not in kwargs.get("context", {}):
|
||||
raise ValueError('context must include a "request" key')
|
||||
|
||||
context = kwargs.get("context", {})
|
||||
request = kwargs.get("request", context.get("request"))
|
||||
name = typing.cast(str, kwargs["name"])
|
||||
status_code = kwargs.get("status_code", 200)
|
||||
headers = kwargs.get("headers")
|
||||
media_type = kwargs.get("media_type")
|
||||
background = kwargs.get("background")
|
||||
|
||||
context.setdefault("request", request)
|
||||
for context_processor in self.context_processors:
|
||||
context.update(context_processor(request))
|
||||
|
||||
template = self.get_template(name)
|
||||
return _TemplateResponse(
|
||||
template,
|
||||
context,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background,
|
||||
)
|
731
venv/lib/python3.11/site-packages/starlette/testclient.py
Normal file
731
venv/lib/python3.11/site-packages/starlette/testclient.py
Normal file
@ -0,0 +1,731 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from types import GeneratorType
|
||||
from urllib.parse import unquote, urljoin
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
import anyio.from_thread
|
||||
from anyio.streams.stapled import StapledObjectStream
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import TypeGuard
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
try:
|
||||
import httpx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"The starlette.testclient module requires the httpx package to be installed.\n"
|
||||
"You can install this with:\n"
|
||||
" $ pip install httpx\n"
|
||||
)
|
||||
_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
|
||||
|
||||
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
|
||||
ASGI2App = typing.Callable[[Scope], ASGIInstance]
|
||||
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
||||
|
||||
|
||||
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
|
||||
|
||||
|
||||
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
|
||||
if inspect.isclass(app):
|
||||
return hasattr(app, "__await__")
|
||||
return is_async_callable(app)
|
||||
|
||||
|
||||
class _WrapASGI2:
|
||||
"""
|
||||
Provide an ASGI3 interface onto an ASGI2 app.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGI2App) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
instance = self.app(scope)
|
||||
await instance(receive, send)
|
||||
|
||||
|
||||
class _AsyncBackend(typing.TypedDict):
|
||||
backend: str
|
||||
backend_options: dict[str, typing.Any]
|
||||
|
||||
|
||||
class _Upgrade(Exception):
|
||||
def __init__(self, session: WebSocketTestSession) -> None:
|
||||
self.session = session
|
||||
|
||||
|
||||
class WebSocketDenialResponse( # type: ignore[misc]
|
||||
httpx.Response,
|
||||
WebSocketDisconnect,
|
||||
):
|
||||
"""
|
||||
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
|
||||
`WebSocket` is closed before being accepted with a `send_denial_response()`.
|
||||
"""
|
||||
|
||||
|
||||
class WebSocketTestSession:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGI3App,
|
||||
scope: Scope,
|
||||
portal_factory: _PortalFactoryType,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.scope = scope
|
||||
self.accepted_subprotocol = None
|
||||
self.portal_factory = portal_factory
|
||||
self.extra_headers = None
|
||||
|
||||
def __enter__(self) -> WebSocketTestSession:
|
||||
with contextlib.ExitStack() as stack:
|
||||
self.portal = portal = stack.enter_context(self.portal_factory())
|
||||
fut, cs = portal.start_task(self._run)
|
||||
stack.callback(fut.result)
|
||||
stack.callback(portal.call, cs.cancel)
|
||||
self.send({"type": "websocket.connect"})
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
self.extra_headers = message.get("headers", None)
|
||||
stack.callback(self.close, 1000)
|
||||
self.exit_stack = stack.pop_all()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> bool | None:
|
||||
return self.exit_stack.__exit__(*args)
|
||||
|
||||
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
|
||||
"""
|
||||
The sub-thread in which the websocket session runs.
|
||||
"""
|
||||
send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
|
||||
send_tx, send_rx = send
|
||||
receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
|
||||
receive_tx, receive_rx = receive
|
||||
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
|
||||
self._receive_tx = receive_tx
|
||||
self._send_rx = send_rx
|
||||
task_status.started(cs)
|
||||
await self.app(self.scope, receive_rx.receive, send_tx.send)
|
||||
|
||||
# wait for cs.cancel to be called before closing streams
|
||||
await anyio.sleep_forever()
|
||||
|
||||
def _raise_on_close(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.close":
|
||||
raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
|
||||
elif message["type"] == "websocket.http.response.start":
|
||||
status_code: int = message["status"]
|
||||
headers: list[tuple[bytes, bytes]] = message["headers"]
|
||||
body: list[bytes] = []
|
||||
while True:
|
||||
message = self.receive()
|
||||
assert message["type"] == "websocket.http.response.body"
|
||||
body.append(message["body"])
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
|
||||
|
||||
def send(self, message: Message) -> None:
|
||||
self.portal.call(self._receive_tx.send, message)
|
||||
|
||||
def send_text(self, data: str) -> None:
|
||||
self.send({"type": "websocket.receive", "text": data})
|
||||
|
||||
def send_bytes(self, data: bytes) -> None:
|
||||
self.send({"type": "websocket.receive", "bytes": data})
|
||||
|
||||
def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
|
||||
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
if mode == "text":
|
||||
self.send({"type": "websocket.receive", "text": text})
|
||||
else:
|
||||
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
|
||||
|
||||
def close(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
def receive(self) -> Message:
|
||||
return self.portal.call(self._send_rx.receive)
|
||||
|
||||
def receive_text(self) -> str:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return typing.cast(str, message["text"])
|
||||
|
||||
def receive_bytes(self) -> bytes:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return typing.cast(bytes, message["bytes"])
|
||||
|
||||
def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
if mode == "text":
|
||||
text = message["text"]
|
||||
else:
|
||||
text = message["bytes"].decode("utf-8")
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
class _TestClientTransport(httpx.BaseTransport):
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGI3App,
|
||||
portal_factory: _PortalFactoryType,
|
||||
raise_server_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
*,
|
||||
client: tuple[str, int],
|
||||
app_state: dict[str, typing.Any],
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_server_exceptions = raise_server_exceptions
|
||||
self.root_path = root_path
|
||||
self.portal_factory = portal_factory
|
||||
self.app_state = app_state
|
||||
self.client = client
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
scheme = request.url.scheme
|
||||
netloc = request.url.netloc.decode(encoding="ascii")
|
||||
path = request.url.path
|
||||
raw_path = request.url.raw_path
|
||||
query = request.url.query.decode(encoding="ascii")
|
||||
|
||||
default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
|
||||
|
||||
if ":" in netloc:
|
||||
host, port_string = netloc.split(":", 1)
|
||||
port = int(port_string)
|
||||
else:
|
||||
host = netloc
|
||||
port = default_port
|
||||
|
||||
# Include the 'host' header.
|
||||
if "host" in request.headers:
|
||||
headers: list[tuple[bytes, bytes]] = []
|
||||
elif port == default_port: # pragma: no cover
|
||||
headers = [(b"host", host.encode())]
|
||||
else: # pragma: no cover
|
||||
headers = [(b"host", (f"{host}:{port}").encode())]
|
||||
|
||||
# Include other request headers.
|
||||
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
|
||||
|
||||
scope: dict[str, typing.Any]
|
||||
|
||||
if scheme in {"ws", "wss"}:
|
||||
subprotocol = request.headers.get("sec-websocket-protocol", None)
|
||||
if subprotocol is None:
|
||||
subprotocols: typing.Sequence[str] = []
|
||||
else:
|
||||
subprotocols = [value.strip() for value in subprotocol.split(",")]
|
||||
scope = {
|
||||
"type": "websocket",
|
||||
"path": unquote(path),
|
||||
"raw_path": raw_path.split(b"?", 1)[0],
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": self.client,
|
||||
"server": [host, port],
|
||||
"subprotocols": subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
session = WebSocketTestSession(self.app, scope, self.portal_factory)
|
||||
raise _Upgrade(session)
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"path": unquote(path),
|
||||
"raw_path": raw_path.split(b"?", 1)[0],
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": self.client,
|
||||
"server": [host, port],
|
||||
"extensions": {"http.response.debug": {}},
|
||||
"state": self.app_state.copy(),
|
||||
}
|
||||
|
||||
request_complete = False
|
||||
response_started = False
|
||||
response_complete: anyio.Event
|
||||
raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
|
||||
template = None
|
||||
context = None
|
||||
|
||||
async def receive() -> Message:
|
||||
nonlocal request_complete
|
||||
|
||||
if request_complete:
|
||||
if not response_complete.is_set():
|
||||
await response_complete.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
body = request.read()
|
||||
if isinstance(body, str):
|
||||
body_bytes: bytes = body.encode("utf-8") # pragma: no cover
|
||||
elif body is None:
|
||||
body_bytes = b"" # pragma: no cover
|
||||
elif isinstance(body, GeneratorType):
|
||||
try: # pragma: no cover
|
||||
chunk = body.send(None)
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.encode("utf-8")
|
||||
return {"type": "http.request", "body": chunk, "more_body": True}
|
||||
except StopIteration: # pragma: no cover
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": b""}
|
||||
else:
|
||||
body_bytes = body
|
||||
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body_bytes}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
nonlocal raw_kwargs, response_started, template, context
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
assert not response_started, 'Received multiple "http.response.start" messages.'
|
||||
raw_kwargs["status_code"] = message["status"]
|
||||
raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
|
||||
response_started = True
|
||||
elif message["type"] == "http.response.body":
|
||||
assert response_started, 'Received "http.response.body" without "http.response.start".'
|
||||
assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if request.method != "HEAD":
|
||||
raw_kwargs["stream"].write(body)
|
||||
if not more_body:
|
||||
raw_kwargs["stream"].seek(0)
|
||||
response_complete.set()
|
||||
elif message["type"] == "http.response.debug":
|
||||
template = message["info"]["template"]
|
||||
context = message["info"]["context"]
|
||||
|
||||
try:
|
||||
with self.portal_factory() as portal:
|
||||
response_complete = portal.call(anyio.Event)
|
||||
portal.call(self.app, scope, receive, send)
|
||||
except BaseException as exc:
|
||||
if self.raise_server_exceptions:
|
||||
raise exc
|
||||
|
||||
if self.raise_server_exceptions:
|
||||
assert response_started, "TestClient did not receive any response."
|
||||
elif not response_started:
|
||||
raw_kwargs = {
|
||||
"status_code": 500,
|
||||
"headers": [],
|
||||
"stream": io.BytesIO(),
|
||||
}
|
||||
|
||||
raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
|
||||
|
||||
response = httpx.Response(**raw_kwargs, request=request)
|
||||
if template is not None:
|
||||
response.template = template # type: ignore[attr-defined]
|
||||
response.context = context # type: ignore[attr-defined]
|
||||
return response
|
||||
|
||||
|
||||
class TestClient(httpx.Client):
|
||||
__test__ = False
|
||||
task: Future[None]
|
||||
portal: anyio.abc.BlockingPortal | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
base_url: str = "http://testserver",
|
||||
raise_server_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
backend: typing.Literal["asyncio", "trio"] = "asyncio",
|
||||
backend_options: dict[str, typing.Any] | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
follow_redirects: bool = True,
|
||||
client: tuple[str, int] = ("testclient", 50000),
|
||||
) -> None:
|
||||
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
|
||||
if _is_asgi3(app):
|
||||
asgi_app = app
|
||||
else:
|
||||
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
|
||||
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
|
||||
self.app = asgi_app
|
||||
self.app_state: dict[str, typing.Any] = {}
|
||||
transport = _TestClientTransport(
|
||||
self.app,
|
||||
portal_factory=self._portal_factory,
|
||||
raise_server_exceptions=raise_server_exceptions,
|
||||
root_path=root_path,
|
||||
app_state=self.app_state,
|
||||
client=client,
|
||||
)
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.setdefault("user-agent", "testclient")
|
||||
super().__init__(
|
||||
base_url=base_url,
|
||||
headers=headers,
|
||||
transport=transport,
|
||||
follow_redirects=follow_redirects,
|
||||
cookies=cookies,
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
|
||||
if self.portal is not None:
|
||||
yield self.portal
|
||||
else:
|
||||
with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
|
||||
yield portal
|
||||
|
||||
def request( # type: ignore[override]
|
||||
self,
|
||||
method: str,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: typing.Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
if timeout is not httpx.USE_CLIENT_DEFAULT:
|
||||
warnings.warn(
|
||||
"You should not use the 'timeout' argument with the TestClient. "
|
||||
"See https://github.com/encode/starlette/issues/1108 for more information.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
url = self._merge_url(url)
|
||||
return super().request(
|
||||
method,
|
||||
url,
|
||||
content=content,
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def get( # type: ignore[override]
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
return super().get(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def options( # type: ignore[override]
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
return super().options(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def head( # type: ignore[override]
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
return super().head(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def post( # type: ignore[override]
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: typing.Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
return super().post(
|
||||
url,
|
||||
content=content,
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def put( # type: ignore[override]
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: typing.Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
return super().put(
|
||||
url,
|
||||
content=content,
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def patch( # type: ignore[override]
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
content: httpx._types.RequestContent | None = None,
|
||||
data: _RequestData | None = None,
|
||||
files: httpx._types.RequestFiles | None = None,
|
||||
json: typing.Any = None,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
return super().patch(
|
||||
url,
|
||||
content=content,
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def delete( # type: ignore[override]
|
||||
self,
|
||||
url: httpx._types.URLTypes,
|
||||
*,
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
return super().delete(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
def websocket_connect(
|
||||
self,
|
||||
url: str,
|
||||
subprotocols: typing.Sequence[str] | None = None,
|
||||
**kwargs: typing.Any,
|
||||
) -> WebSocketTestSession:
|
||||
url = urljoin("ws://testserver", url)
|
||||
headers = kwargs.get("headers", {})
|
||||
headers.setdefault("connection", "upgrade")
|
||||
headers.setdefault("sec-websocket-key", "testserver==")
|
||||
headers.setdefault("sec-websocket-version", "13")
|
||||
if subprotocols is not None:
|
||||
headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
|
||||
kwargs["headers"] = headers
|
||||
try:
|
||||
super().request("GET", url, **kwargs)
|
||||
except _Upgrade as exc:
|
||||
session = exc.session
|
||||
else:
|
||||
raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
|
||||
|
||||
return session
|
||||
|
||||
def __enter__(self) -> TestClient:
|
||||
with contextlib.ExitStack() as stack:
|
||||
self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
|
||||
|
||||
@stack.callback
|
||||
def reset_portal() -> None:
|
||||
self.portal = None
|
||||
|
||||
send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = (
|
||||
anyio.create_memory_object_stream(math.inf)
|
||||
)
|
||||
receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = (
|
||||
anyio.create_memory_object_stream(math.inf)
|
||||
)
|
||||
for channel in (*send, *receive):
|
||||
stack.callback(channel.close)
|
||||
self.stream_send = StapledObjectStream(*send)
|
||||
self.stream_receive = StapledObjectStream(*receive)
|
||||
self.task = portal.start_task_soon(self.lifespan)
|
||||
portal.call(self.wait_startup)
|
||||
|
||||
@stack.callback
|
||||
def wait_shutdown() -> None:
|
||||
portal.call(self.wait_shutdown)
|
||||
|
||||
self.exit_stack = stack.pop_all()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
self.exit_stack.close()
|
||||
|
||||
async def lifespan(self) -> None:
|
||||
scope = {"type": "lifespan", "state": self.app_state}
|
||||
try:
|
||||
await self.app(scope, self.stream_receive.receive, self.stream_send.send)
|
||||
finally:
|
||||
await self.stream_send.send(None)
|
||||
|
||||
async def wait_startup(self) -> None:
|
||||
await self.stream_receive.send({"type": "lifespan.startup"})
|
||||
|
||||
async def receive() -> typing.Any:
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
return message
|
||||
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.startup.complete",
|
||||
"lifespan.startup.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.startup.failed":
|
||||
await receive()
|
||||
|
||||
async def wait_shutdown(self) -> None:
|
||||
async def receive() -> typing.Any:
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
return message
|
||||
|
||||
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.shutdown.failed":
|
||||
await receive()
|
24
venv/lib/python3.11/site-packages/starlette/types.py
Normal file
24
venv/lib/python3.11/site-packages/starlette/types.py
Normal file
@ -0,0 +1,24 @@
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
AppType = typing.TypeVar("AppType")
|
||||
|
||||
Scope = typing.MutableMapping[str, typing.Any]
|
||||
Message = typing.MutableMapping[str, typing.Any]
|
||||
|
||||
Receive = typing.Callable[[], typing.Awaitable[Message]]
|
||||
Send = typing.Callable[[Message], typing.Awaitable[None]]
|
||||
|
||||
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
||||
|
||||
StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
|
||||
StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]]
|
||||
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
|
||||
|
||||
HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"]
|
||||
WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]]
|
||||
ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]
|
195
venv/lib/python3.11/site-packages/starlette/websockets.py
Normal file
195
venv/lib/python3.11/site-packages/starlette/websockets.py
Normal file
@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import json
|
||||
import typing
|
||||
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class WebSocketState(enum.Enum):
|
||||
CONNECTING = 0
|
||||
CONNECTED = 1
|
||||
DISCONNECTED = 2
|
||||
RESPONSE = 3
|
||||
|
||||
|
||||
class WebSocketDisconnect(Exception):
|
||||
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
self.code = code
|
||||
self.reason = reason or ""
|
||||
|
||||
|
||||
class WebSocket(HTTPConnection):
|
||||
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
super().__init__(scope)
|
||||
assert scope["type"] == "websocket"
|
||||
self._receive = receive
|
||||
self._send = send
|
||||
self.client_state = WebSocketState.CONNECTING
|
||||
self.application_state = WebSocketState.CONNECTING
|
||||
|
||||
async def receive(self) -> Message:
|
||||
"""
|
||||
Receive ASGI websocket messages, ensuring valid state transitions.
|
||||
"""
|
||||
if self.client_state == WebSocketState.CONNECTING:
|
||||
message = await self._receive()
|
||||
message_type = message["type"]
|
||||
if message_type != "websocket.connect":
|
||||
raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
|
||||
self.client_state = WebSocketState.CONNECTED
|
||||
return message
|
||||
elif self.client_state == WebSocketState.CONNECTED:
|
||||
message = await self._receive()
|
||||
message_type = message["type"]
|
||||
if message_type not in {"websocket.receive", "websocket.disconnect"}:
|
||||
raise RuntimeError(
|
||||
f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
|
||||
)
|
||||
if message_type == "websocket.disconnect":
|
||||
self.client_state = WebSocketState.DISCONNECTED
|
||||
return message
|
||||
else:
|
||||
raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
|
||||
|
||||
async def send(self, message: Message) -> None:
|
||||
"""
|
||||
Send ASGI websocket messages, ensuring valid state transitions.
|
||||
"""
|
||||
if self.application_state == WebSocketState.CONNECTING:
|
||||
message_type = message["type"]
|
||||
if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
|
||||
raise RuntimeError(
|
||||
'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
|
||||
f"but got {message_type!r}"
|
||||
)
|
||||
if message_type == "websocket.close":
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
elif message_type == "websocket.http.response.start":
|
||||
self.application_state = WebSocketState.RESPONSE
|
||||
else:
|
||||
self.application_state = WebSocketState.CONNECTED
|
||||
await self._send(message)
|
||||
elif self.application_state == WebSocketState.CONNECTED:
|
||||
message_type = message["type"]
|
||||
if message_type not in {"websocket.send", "websocket.close"}:
|
||||
raise RuntimeError(
|
||||
f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
|
||||
)
|
||||
if message_type == "websocket.close":
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
try:
|
||||
await self._send(message)
|
||||
except OSError:
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
raise WebSocketDisconnect(code=1006)
|
||||
elif self.application_state == WebSocketState.RESPONSE:
|
||||
message_type = message["type"]
|
||||
if message_type != "websocket.http.response.body":
|
||||
raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
|
||||
if not message.get("more_body", False):
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
await self._send(message)
|
||||
else:
|
||||
raise RuntimeError('Cannot call "send" once a close message has been sent.')
|
||||
|
||||
async def accept(
|
||||
self,
|
||||
subprotocol: str | None = None,
|
||||
headers: typing.Iterable[tuple[bytes, bytes]] | None = None,
|
||||
) -> None:
|
||||
headers = headers or []
|
||||
|
||||
if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
|
||||
# If we haven't yet seen the 'connect' message, then wait for it first.
|
||||
await self.receive()
|
||||
await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
|
||||
|
||||
def _raise_on_disconnect(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.disconnect":
|
||||
raise WebSocketDisconnect(message["code"], message.get("reason"))
|
||||
|
||||
async def receive_text(self) -> str:
|
||||
if self.application_state != WebSocketState.CONNECTED:
|
||||
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
return typing.cast(str, message["text"])
|
||||
|
||||
async def receive_bytes(self) -> bytes:
|
||||
if self.application_state != WebSocketState.CONNECTED:
|
||||
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
return typing.cast(bytes, message["bytes"])
|
||||
|
||||
async def receive_json(self, mode: str = "text") -> typing.Any:
|
||||
if mode not in {"text", "binary"}:
|
||||
raise RuntimeError('The "mode" argument should be "text" or "binary".')
|
||||
if self.application_state != WebSocketState.CONNECTED:
|
||||
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
|
||||
if mode == "text":
|
||||
text = message["text"]
|
||||
else:
|
||||
text = message["bytes"].decode("utf-8")
|
||||
return json.loads(text)
|
||||
|
||||
async def iter_text(self) -> typing.AsyncIterator[str]:
|
||||
try:
|
||||
while True:
|
||||
yield await self.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
|
||||
try:
|
||||
while True:
|
||||
yield await self.receive_bytes()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
|
||||
try:
|
||||
while True:
|
||||
yield await self.receive_json()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
async def send_text(self, data: str) -> None:
|
||||
await self.send({"type": "websocket.send", "text": data})
|
||||
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.send({"type": "websocket.send", "bytes": data})
|
||||
|
||||
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
|
||||
if mode not in {"text", "binary"}:
|
||||
raise RuntimeError('The "mode" argument should be "text" or "binary".')
|
||||
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
if mode == "text":
|
||||
await self.send({"type": "websocket.send", "text": text})
|
||||
else:
|
||||
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
|
||||
|
||||
async def send_denial_response(self, response: Response) -> None:
|
||||
if "websocket.http.response" in self.scope.get("extensions", {}):
|
||||
await response(self.scope, self.receive, self.send)
|
||||
else:
|
||||
raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
|
||||
|
||||
|
||||
class WebSocketClose:
|
||||
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
self.code = code
|
||||
self.reason = reason or ""
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await send({"type": "websocket.close", "code": self.code, "reason": self.reason})
|
Reference in New Issue
Block a user