Update 2025-04-13_16:26:34

This commit is contained in:
root
2025-04-13 16:26:35 +02:00
commit 0e49903693
2239 changed files with 407432 additions and 0 deletions

View File

@ -0,0 +1 @@
__version__ = "0.46.1"

View File

@ -0,0 +1,65 @@
from __future__ import annotations
import typing
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
ExceptionHandlers = dict[typing.Any, ExceptionHandler]
StatusHandlers = dict[int, ExceptionHandler]
def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
return None
def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
exception_handlers: ExceptionHandlers
status_handlers: StatusHandlers
try:
exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
response_started = False
async def sender(message: Message) -> None:
nonlocal response_started
if message["type"] == "http.response.start":
response_started = True
await send(message)
try:
await app(scope, receive, sender)
except Exception as exc:
handler = None
if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)
if handler is None:
handler = _lookup_exception_handler(exception_handlers, exc)
if handler is None:
raise exc
if response_started:
raise RuntimeError("Caught handled exception, but response already started.") from exc
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc) # type: ignore
if response is not None:
await response(scope, receive, sender)
return wrapped_app

View File

@ -0,0 +1,100 @@
from __future__ import annotations
import functools
import inspect
import sys
import typing
from contextlib import contextmanager
from starlette.types import Scope
if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
else: # pragma: no cover
from typing_extensions import TypeGuard
has_exceptiongroups = True
if sys.version_info < (3, 11): # pragma: no cover
try:
from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
except ImportError:
has_exceptiongroups = False
T = typing.TypeVar("T")
AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
@typing.overload
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
@typing.overload
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ...
def is_async_callable(obj: typing.Any) -> typing.Any:
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__))
T_co = typing.TypeVar("T_co", covariant=True)
class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
class SupportsAsyncClose(typing.Protocol):
async def close(self) -> None: ... # pragma: no cover
SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
__slots__ = ("aw", "entered")
def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
self.aw = aw
def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
return self.aw.__await__()
async def __aenter__(self) -> SupportsAsyncCloseType:
self.entered = await self.aw
return self.entered
async def __aexit__(self, *args: typing.Any) -> None | bool:
await self.entered.close()
return None
@contextmanager
def collapse_excgroups() -> typing.Generator[None, None, None]:
try:
yield
except BaseException as exc:
if has_exceptiongroups: # pragma: no cover
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]
raise exc
def get_route_path(scope: Scope) -> str:
path: str = scope["path"]
root_path = scope.get("root_path", "")
if not root_path:
return path
if not path.startswith(root_path):
return path
if path == root_path:
return ""
if path[len(root_path)] == "/":
return path[len(root_path) :]
return path

View File

@ -0,0 +1,249 @@
from __future__ import annotations
import sys
import typing
import warnings
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket
AppType = typing.TypeVar("AppType", bound="Starlette")
P = ParamSpec("P")
class Starlette:
"""Creates an Starlette application."""
def __init__(
self: AppType,
debug: bool = False,
routes: typing.Sequence[BaseRoute] | None = None,
middleware: typing.Sequence[Middleware] | None = None,
exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None,
on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
lifespan: Lifespan[AppType] | None = None,
) -> None:
"""Initializes the application.
Parameters:
debug: Boolean indicating if debug tracebacks should be returned on errors.
routes: A list of routes to serve incoming HTTP and WebSocket requests.
middleware: A list of middleware to run for every request. A starlette
application will always automatically include two middleware classes.
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
any uncaught errors occurring anywhere in the entire stack.
`ExceptionMiddleware` is added as the very innermost middleware, to deal
with handled exception cases occurring in the routing or endpoints.
exception_handlers: A mapping of either integer status codes,
or exception class types onto callables which handle the exceptions.
Exception handler callables should be of the form
`handler(request, exc) -> response` and may be either standard functions, or
async functions.
on_startup: A list of callables to run on application startup.
Startup handler callables do not take any arguments, and may be either
standard functions, or async functions.
on_shutdown: A list of callables to run on application shutdown.
Shutdown handler callables do not take any arguments, and may be either
standard functions, or async functions.
lifespan: A lifespan context function, which can be used to perform
startup and shutdown tasks. This is a newer style that replaces the
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
"""
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
assert lifespan is None or (on_startup is None and on_shutdown is None), (
"Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
)
self.debug = debug
self.state = State()
self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: ASGIApp | None = None
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
error_handler = value
else:
exception_handlers[key] = value
middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
+ [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
)
app = self.router
for cls, args, kwargs in reversed(middleware):
app = cls(app, *args, **kwargs)
return app
@property
def routes(self) -> list[BaseRoute]:
return self.router.routes
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
return self.router.url_path_for(name, **path_params)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
if self.middleware_stack is None:
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
return self.router.on_event(event_type) # pragma: no cover
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
self.router.mount(path, app=app, name=name) # pragma: no cover
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
self.router.host(host, app=app, name=name) # pragma: no cover
def add_middleware(
self,
middleware_class: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))
def add_exception_handler(
self,
exc_class_or_status_code: int | type[Exception],
handler: ExceptionHandler,
) -> None: # pragma: no cover
self.exception_handlers[exc_class_or_status_code] = handler
def add_event_handler(
self,
event_type: str,
func: typing.Callable, # type: ignore[type-arg]
) -> None: # pragma: no cover
self.router.add_event_handler(event_type, func)
def add_route(
self,
path: str,
route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: no cover
self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
def add_websocket_route(
self,
path: str,
route: typing.Callable[[WebSocket], typing.Awaitable[None]],
name: str | None = None,
) -> None: # pragma: no cover
self.router.add_websocket_route(path, route, name=name)
def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_exception_handler(exc_class_or_status_code, func)
return func
return decorator
def route(
self,
path: str,
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
>>> routes = [Route(path, endpoint=...), ...]
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `route` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/routing/ for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.router.add_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
)
return func
return decorator
def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
>>> routes = [WebSocketRoute(path, endpoint=...), ...]
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.router.add_websocket_route(path, func, name=name)
return func
return decorator
def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
>>> middleware = [Middleware(...), ...]
>>> app = Starlette(middleware=middleware)
"""
warnings.warn(
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.",
DeprecationWarning,
)
assert middleware_type == "http", 'Currently only middleware("http") is supported.'
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
return func
return decorator

View File

@ -0,0 +1,147 @@
from __future__ import annotations
import functools
import inspect
import sys
import typing
from urllib.parse import urlencode
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse
from starlette.websockets import WebSocket
_P = ParamSpec("_P")
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
for scope in scopes:
if scope not in conn.auth.scopes:
return False
return True
def requires(
scopes: str | typing.Sequence[str],
status_code: int = 403,
redirect: str | None = None,
) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(
func: typing.Callable[_P, typing.Any],
) -> typing.Callable[_P, typing.Any]:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
type_ = parameter.name
break
else:
raise Exception(f'No "request" or "websocket" argument on function "{func}"')
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
await websocket.close()
else:
await func(*args, **kwargs)
return websocket_wrapper
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return await func(*args, **kwargs)
return async_wrapper
else:
# Handle sync request/response functions.
@functools.wraps(func)
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return func(*args, **kwargs)
return sync_wrapper
return decorator
class AuthenticationError(Exception):
pass
class AuthenticationBackend:
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
raise NotImplementedError() # pragma: no cover
class AuthCredentials:
def __init__(self, scopes: typing.Sequence[str] | None = None):
self.scopes = [] if scopes is None else list(scopes)
class BaseUser:
@property
def is_authenticated(self) -> bool:
raise NotImplementedError() # pragma: no cover
@property
def display_name(self) -> str:
raise NotImplementedError() # pragma: no cover
@property
def identity(self) -> str:
raise NotImplementedError() # pragma: no cover
class SimpleUser(BaseUser):
def __init__(self, username: str) -> None:
self.username = username
@property
def is_authenticated(self) -> bool:
return True
@property
def display_name(self) -> str:
return self.username
class UnauthenticatedUser(BaseUser):
@property
def is_authenticated(self) -> bool:
return False
@property
def display_name(self) -> str:
return ""

View File

@ -0,0 +1,41 @@
from __future__ import annotations
import sys
import typing
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
P = ParamSpec("P")
class BackgroundTask:
def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
self.is_async = is_async_callable(func)
async def __call__(self) -> None:
if self.is_async:
await self.func(*self.args, **self.kwargs)
else:
await run_in_threadpool(self.func, *self.args, **self.kwargs)
class BackgroundTasks(BackgroundTask):
def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
self.tasks = list(tasks) if tasks else []
def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)
async def __call__(self) -> None:
for task in self.tasks:
await task()

View File

@ -0,0 +1,62 @@
from __future__ import annotations
import functools
import sys
import typing
import warnings
import anyio.to_thread
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
P = ParamSpec("P")
T = typing.TypeVar("T")
async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg]
warnings.warn(
"run_until_first_complete is deprecated and will be removed in a future version.",
DeprecationWarning,
)
async with anyio.create_task_group() as task_group:
async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg]
await func()
task_group.cancel_scope.cancel()
for func, kwargs in args:
task_group.start_soon(run, functools.partial(func, **kwargs))
async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func = functools.partial(func, *args, **kwargs)
return await anyio.to_thread.run_sync(func)
class _StopIteration(Exception):
pass
def _next(iterator: typing.Iterator[T]) -> T:
# We can't raise `StopIteration` from within the threadpool iterator
# and catch it outside that context, so we coerce them into a different
# exception type.
try:
return next(iterator)
except StopIteration:
raise _StopIteration
async def iterate_in_threadpool(
iterator: typing.Iterable[T],
) -> typing.AsyncIterator[T]:
as_iterator = iter(iterator)
while True:
try:
yield await anyio.to_thread.run_sync(_next, as_iterator)
except _StopIteration:
break

View File

@ -0,0 +1,138 @@
from __future__ import annotations
import os
import typing
import warnings
from pathlib import Path
class undefined:
pass
class EnvironError(Exception):
pass
class Environ(typing.MutableMapping[str, str]):
def __init__(self, environ: typing.MutableMapping[str, str] = os.environ):
self._environ = environ
self._has_been_read: set[str] = set()
def __getitem__(self, key: str) -> str:
self._has_been_read.add(key)
return self._environ.__getitem__(key)
def __setitem__(self, key: str, value: str) -> None:
if key in self._has_been_read:
raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
self._environ.__setitem__(key, value)
def __delitem__(self, key: str) -> None:
if key in self._has_been_read:
raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
self._environ.__delitem__(key)
def __iter__(self) -> typing.Iterator[str]:
return iter(self._environ)
def __len__(self) -> int:
return len(self._environ)
environ = Environ()
T = typing.TypeVar("T")
class Config:
def __init__(
self,
env_file: str | Path | None = None,
environ: typing.Mapping[str, str] = environ,
env_prefix: str = "",
) -> None:
self.environ = environ
self.env_prefix = env_prefix
self.file_values: dict[str, str] = {}
if env_file is not None:
if not os.path.isfile(env_file):
warnings.warn(f"Config file '{env_file}' not found.")
else:
self.file_values = self._read_file(env_file)
@typing.overload
def __call__(self, key: str, *, default: None) -> str | None: ...
@typing.overload
def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
@typing.overload
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
@typing.overload
def __call__(
self,
key: str,
cast: typing.Callable[[typing.Any], T] = ...,
default: typing.Any = ...,
) -> T: ...
@typing.overload
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
def __call__(
self,
key: str,
cast: typing.Callable[[typing.Any], typing.Any] | None = None,
default: typing.Any = undefined,
) -> typing.Any:
return self.get(key, cast, default)
def get(
self,
key: str,
cast: typing.Callable[[typing.Any], typing.Any] | None = None,
default: typing.Any = undefined,
) -> typing.Any:
key = self.env_prefix + key
if key in self.environ:
value = self.environ[key]
return self._perform_cast(key, value, cast)
if key in self.file_values:
value = self.file_values[key]
return self._perform_cast(key, value, cast)
if default is not undefined:
return self._perform_cast(key, default, cast)
raise KeyError(f"Config '{key}' is missing, and has no default.")
def _read_file(self, file_name: str | Path) -> dict[str, str]:
file_values: dict[str, str] = {}
with open(file_name) as input_file:
for line in input_file.readlines():
line = line.strip()
if "=" in line and not line.startswith("#"):
key, value = line.split("=", 1)
key = key.strip()
value = value.strip().strip("\"'")
file_values[key] = value
return file_values
def _perform_cast(
self,
key: str,
value: typing.Any,
cast: typing.Callable[[typing.Any], typing.Any] | None = None,
) -> typing.Any:
if cast is None or value is None:
return value
elif cast is bool and isinstance(value, str):
mapping = {"true": True, "1": True, "false": False, "0": False}
value = value.lower()
if value not in mapping:
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
return mapping[value]
try:
return cast(value)
except (TypeError, ValueError):
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")

View File

@ -0,0 +1,89 @@
from __future__ import annotations
import math
import typing
import uuid
T = typing.TypeVar("T")
class Convertor(typing.Generic[T]):
regex: typing.ClassVar[str] = ""
def convert(self, value: str) -> T:
raise NotImplementedError() # pragma: no cover
def to_string(self, value: T) -> str:
raise NotImplementedError() # pragma: no cover
class StringConvertor(Convertor[str]):
regex = "[^/]+"
def convert(self, value: str) -> str:
return value
def to_string(self, value: str) -> str:
value = str(value)
assert "/" not in value, "May not contain path separators"
assert value, "Must not be empty"
return value
class PathConvertor(Convertor[str]):
regex = ".*"
def convert(self, value: str) -> str:
return str(value)
def to_string(self, value: str) -> str:
return str(value)
class IntegerConvertor(Convertor[int]):
regex = "[0-9]+"
def convert(self, value: str) -> int:
return int(value)
def to_string(self, value: int) -> str:
value = int(value)
assert value >= 0, "Negative integers are not supported"
return str(value)
class FloatConvertor(Convertor[float]):
regex = r"[0-9]+(\.[0-9]+)?"
def convert(self, value: str) -> float:
return float(value)
def to_string(self, value: float) -> str:
value = float(value)
assert value >= 0.0, "Negative floats are not supported"
assert not math.isnan(value), "NaN values are not supported"
assert not math.isinf(value), "Infinite values are not supported"
return ("%0.20f" % value).rstrip("0").rstrip(".")
class UUIDConvertor(Convertor[uuid.UUID]):
regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}"
def convert(self, value: str) -> uuid.UUID:
return uuid.UUID(value)
def to_string(self, value: uuid.UUID) -> str:
return str(value)
CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = {
"str": StringConvertor(),
"path": PathConvertor(),
"int": IntegerConvertor(),
"float": FloatConvertor(),
"uuid": UUIDConvertor(),
}
def register_url_convertor(key: str, convertor: Convertor[typing.Any]) -> None:
CONVERTOR_TYPES[key] = convertor

View File

@ -0,0 +1,674 @@
from __future__ import annotations
import typing
from shlex import shlex
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
from starlette.concurrency import run_in_threadpool
from starlette.types import Scope
class Address(typing.NamedTuple):
host: str
port: int
_KeyType = typing.TypeVar("_KeyType")
# Mapping keys are invariant but their values are covariant since
# you can only read them
# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
class URL:
def __init__(
self,
url: str = "",
scope: Scope | None = None,
**components: typing.Any,
) -> None:
if scope is not None:
assert not url, 'Cannot set both "url" and "scope".'
assert not components, 'Cannot set both "scope" and "**components".'
scheme = scope.get("scheme", "http")
server = scope.get("server", None)
path = scope["path"]
query_string = scope.get("query_string", b"")
host_header = None
for key, value in scope["headers"]:
if key == b"host":
host_header = value.decode("latin-1")
break
if host_header is not None:
url = f"{scheme}://{host_header}{path}"
elif server is None:
url = path
else:
host, port = server
default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
if port == default_port:
url = f"{scheme}://{host}{path}"
else:
url = f"{scheme}://{host}:{port}{path}"
if query_string:
url += "?" + query_string.decode()
elif components:
assert not url, 'Cannot set both "url" and "**components".'
url = URL("").replace(**components).components.geturl()
self._url = url
@property
def components(self) -> SplitResult:
if not hasattr(self, "_components"):
self._components = urlsplit(self._url)
return self._components
@property
def scheme(self) -> str:
return self.components.scheme
@property
def netloc(self) -> str:
return self.components.netloc
@property
def path(self) -> str:
return self.components.path
@property
def query(self) -> str:
return self.components.query
@property
def fragment(self) -> str:
return self.components.fragment
@property
def username(self) -> None | str:
return self.components.username
@property
def password(self) -> None | str:
return self.components.password
@property
def hostname(self) -> None | str:
return self.components.hostname
@property
def port(self) -> int | None:
return self.components.port
@property
def is_secure(self) -> bool:
return self.scheme in ("https", "wss")
def replace(self, **kwargs: typing.Any) -> URL:
if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
hostname = kwargs.pop("hostname", None)
port = kwargs.pop("port", self.port)
username = kwargs.pop("username", self.username)
password = kwargs.pop("password", self.password)
if hostname is None:
netloc = self.netloc
_, _, hostname = netloc.rpartition("@")
if hostname[-1] != "]":
hostname = hostname.rsplit(":", 1)[0]
netloc = hostname
if port is not None:
netloc += f":{port}"
if username is not None:
userpass = username
if password is not None:
userpass += f":{password}"
netloc = f"{userpass}@{netloc}"
kwargs["netloc"] = netloc
components = self.components._replace(**kwargs)
return self.__class__(components.geturl())
def include_query_params(self, **kwargs: typing.Any) -> URL:
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
params.update({str(key): str(value) for key, value in kwargs.items()})
query = urlencode(params.multi_items())
return self.replace(query=query)
def replace_query_params(self, **kwargs: typing.Any) -> URL:
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
return self.replace(query=query)
def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
if isinstance(keys, str):
keys = [keys]
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
for key in keys:
params.pop(key, None)
query = urlencode(params.multi_items())
return self.replace(query=query)
def __eq__(self, other: typing.Any) -> bool:
return str(self) == str(other)
def __str__(self) -> str:
return self._url
def __repr__(self) -> str:
url = str(self)
if self.password:
url = str(self.replace(password="********"))
return f"{self.__class__.__name__}({repr(url)})"
class URLPath(str):
"""
A URL path string that may also hold an associated protocol and/or host.
Used by the routing to return `url_path_for` matches.
"""
def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
assert protocol in ("http", "websocket", "")
return str.__new__(cls, path)
def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
self.protocol = protocol
self.host = host
def make_absolute_url(self, base_url: str | URL) -> URL:
if isinstance(base_url, str):
base_url = URL(base_url)
if self.protocol:
scheme = {
"http": {True: "https", False: "http"},
"websocket": {True: "wss", False: "ws"},
}[self.protocol][base_url.is_secure]
else:
scheme = base_url.scheme
netloc = self.host or base_url.netloc
path = base_url.path.rstrip("/") + str(self)
return URL(scheme=scheme, netloc=netloc, path=path)
class Secret:
"""
Holds a string value that should not be revealed in tracebacks etc.
You should cast the value to `str` at the point it is required.
"""
def __init__(self, value: str):
self._value = value
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}('**********')"
def __str__(self) -> str:
return self._value
def __bool__(self) -> bool:
return bool(self._value)
class CommaSeparatedStrings(typing.Sequence[str]):
def __init__(self, value: str | typing.Sequence[str]):
if isinstance(value, str):
splitter = shlex(value, posix=True)
splitter.whitespace = ","
splitter.whitespace_split = True
self._items = [item.strip() for item in splitter]
else:
self._items = list(value)
def __len__(self) -> int:
return len(self._items)
def __getitem__(self, index: int | slice) -> typing.Any:
return self._items[index]
def __iter__(self) -> typing.Iterator[str]:
return iter(self._items)
def __repr__(self) -> str:
class_name = self.__class__.__name__
items = [item for item in self]
return f"{class_name}({items!r})"
def __str__(self) -> str:
return ", ".join(repr(item) for item in self)
class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
_dict: dict[_KeyType, _CovariantValueType]
def __init__(
self,
*args: ImmutableMultiDict[_KeyType, _CovariantValueType]
| typing.Mapping[_KeyType, _CovariantValueType]
| typing.Iterable[tuple[_KeyType, _CovariantValueType]],
**kwargs: typing.Any,
) -> None:
assert len(args) < 2, "Too many arguments."
value: typing.Any = args[0] if args else []
if kwargs:
value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
if not value:
_items: list[tuple[typing.Any, typing.Any]] = []
elif hasattr(value, "multi_items"):
value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
_items = list(value.items())
else:
value = typing.cast("list[tuple[typing.Any, typing.Any]]", value)
_items = list(value)
self._dict = {k: v for k, v in _items}
self._list = _items
def getlist(self, key: typing.Any) -> list[_CovariantValueType]:
return [item_value for item_key, item_value in self._list if item_key == key]
def keys(self) -> typing.KeysView[_KeyType]:
return self._dict.keys()
def values(self) -> typing.ValuesView[_CovariantValueType]:
return self._dict.values()
def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
return self._dict.items()
def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
return list(self._list)
def __getitem__(self, key: _KeyType) -> _CovariantValueType:
return self._dict[key]
def __contains__(self, key: typing.Any) -> bool:
return key in self._dict
def __iter__(self) -> typing.Iterator[_KeyType]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._dict)
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, self.__class__):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
items = self.multi_items()
return f"{class_name}({items!r})"
class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
self.setlist(key, [value])
def __delitem__(self, key: typing.Any) -> None:
self._list = [(k, v) for k, v in self._list if k != key]
del self._dict[key]
def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)
def popitem(self) -> tuple[typing.Any, typing.Any]:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value
def poplist(self, key: typing.Any) -> list[typing.Any]:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
def clear(self) -> None:
self._dict.clear()
self._list.clear()
def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
if key not in self:
self._dict[key] = default
self._list.append((key, default))
return self[key]
def setlist(self, key: typing.Any, values: list[typing.Any]) -> None:
if not values:
self.pop(key, None)
else:
existing_items = [(k, v) for (k, v) in self._list if k != key]
self._list = existing_items + [(key, value) for value in values]
self._dict[key] = values[-1]
def append(self, key: typing.Any, value: typing.Any) -> None:
self._list.append((key, value))
self._dict[key] = value
def update(
self,
*args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
**kwargs: typing.Any,
) -> None:
value = MultiDict(*args, **kwargs)
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
self._list = existing_items + value.multi_items()
self._dict.update(value)
class QueryParams(ImmutableMultiDict[str, str]):
"""
An immutable multidict.
"""
def __init__(
self,
*args: ImmutableMultiDict[typing.Any, typing.Any]
| typing.Mapping[typing.Any, typing.Any]
| list[tuple[typing.Any, typing.Any]]
| str
| bytes,
**kwargs: typing.Any,
) -> None:
assert len(args) < 2, "Too many arguments."
value = args[0] if args else []
if isinstance(value, str):
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
elif isinstance(value, bytes):
super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
else:
super().__init__(*args, **kwargs) # type: ignore[arg-type]
self._list = [(str(k), str(v)) for k, v in self._list]
self._dict = {str(k): str(v) for k, v in self._dict.items()}
def __str__(self) -> str:
return urlencode(self._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
query_string = str(self)
return f"{class_name}({query_string!r})"
class UploadFile:
"""
An uploaded file included as part of the request data.
"""
def __init__(
self,
file: typing.BinaryIO,
*,
size: int | None = None,
filename: str | None = None,
headers: Headers | None = None,
) -> None:
self.filename = filename
self.file = file
self.size = size
self.headers = headers or Headers()
@property
def content_type(self) -> str | None:
return self.headers.get("content-type", None)
@property
def _in_memory(self) -> bool:
# check for SpooledTemporaryFile._rolled
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk
async def write(self, data: bytes) -> None:
if self.size is not None:
self.size += len(data)
if self._in_memory:
self.file.write(data)
else:
await run_in_threadpool(self.file.write, data)
async def read(self, size: int = -1) -> bytes:
if self._in_memory:
return self.file.read(size)
return await run_in_threadpool(self.file.read, size)
async def seek(self, offset: int) -> None:
if self._in_memory:
self.file.seek(offset)
else:
await run_in_threadpool(self.file.seek, offset)
async def close(self) -> None:
if self._in_memory:
self.file.close()
else:
await run_in_threadpool(self.file.close)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
"""
An immutable multidict, containing both file uploads and text input.
"""
def __init__(
self,
*args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
**kwargs: str | UploadFile,
) -> None:
super().__init__(*args, **kwargs)
async def close(self) -> None:
for key, value in self.multi_items():
if isinstance(value, UploadFile):
await value.close()
class Headers(typing.Mapping[str, str]):
"""
An immutable, case-insensitive multidict.
"""
def __init__(
self,
headers: typing.Mapping[str, str] | None = None,
raw: list[tuple[bytes, bytes]] | None = None,
scope: typing.MutableMapping[str, typing.Any] | None = None,
) -> None:
self._list: list[tuple[bytes, bytes]] = []
if headers is not None:
assert raw is None, 'Cannot set both "headers" and "raw".'
assert scope is None, 'Cannot set both "headers" and "scope".'
self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
elif raw is not None:
assert scope is None, 'Cannot set both "raw" and "scope".'
self._list = raw
elif scope is not None:
# scope["headers"] isn't necessarily a list
# it might be a tuple or other iterable
self._list = scope["headers"] = list(scope["headers"])
@property
def raw(self) -> list[tuple[bytes, bytes]]:
return list(self._list)
def keys(self) -> list[str]: # type: ignore[override]
return [key.decode("latin-1") for key, value in self._list]
def values(self) -> list[str]: # type: ignore[override]
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
def getlist(self, key: str) -> list[str]:
get_header_key = key.lower().encode("latin-1")
return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
def mutablecopy(self) -> MutableHeaders:
return MutableHeaders(raw=self._list[:])
def __getitem__(self, key: str) -> str:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return header_value.decode("latin-1")
raise KeyError(key)
def __contains__(self, key: typing.Any) -> bool:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, Headers):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
as_dict = dict(self.items())
if len(as_dict) == len(self):
return f"{class_name}({as_dict!r})"
return f"{class_name}(raw={self.raw!r})"
class MutableHeaders(Headers):
def __setitem__(self, key: str, value: str) -> None:
"""
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
found_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)
for idx in reversed(found_indexes[1:]):
del self._list[idx]
if found_indexes:
idx = found_indexes[0]
self._list[idx] = (set_key, set_value)
else:
self._list.append((set_key, set_value))
def __delitem__(self, key: str) -> None:
"""
Remove the header `key`.
"""
del_key = key.lower().encode("latin-1")
pop_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
for idx in reversed(pop_indexes):
del self._list[idx]
def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
if not isinstance(other, typing.Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
self.update(other)
return self
def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
if not isinstance(other, typing.Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
new = self.mutablecopy()
new.update(other)
return new
@property
def raw(self) -> list[tuple[bytes, bytes]]:
return self._list
def setdefault(self, key: str, value: str) -> str:
"""
If the header `key` does not exist, then set it to `value`.
Returns the header value.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
return item_value.decode("latin-1")
self._list.append((set_key, set_value))
return value
def update(self, other: typing.Mapping[str, str]) -> None:
for key, val in other.items():
self[key] = val
def append(self, key: str, value: str) -> None:
"""
Append a header, preserving any duplicate entries.
"""
append_key = key.lower().encode("latin-1")
append_value = value.encode("latin-1")
self._list.append((append_key, append_value))
def add_vary_header(self, vary: str) -> None:
existing = self.get("vary")
if existing is not None:
vary = ", ".join([existing, vary])
self["vary"] = vary
class State:
"""
An object that can be used to store arbitrary state.
Used for `request.state` and `app.state`.
"""
_state: dict[str, typing.Any]
def __init__(self, state: dict[str, typing.Any] | None = None):
if state is None:
state = {}
super().__setattr__("_state", state)
def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
self._state[key] = value
def __getattr__(self, key: typing.Any) -> typing.Any:
try:
return self._state[key]
except KeyError:
message = "'{}' object has no attribute '{}'"
raise AttributeError(message.format(self.__class__.__name__, key))
def __delattr__(self, key: typing.Any) -> None:
del self._state[key]

View File

@ -0,0 +1,122 @@
from __future__ import annotations
import json
import typing
from starlette import status
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocket
class HTTPEndpoint:
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "http"
self.scope = scope
self.receive = receive
self.send = send
self._allowed_methods = [
method
for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
if getattr(self, method.lower(), None) is not None
]
def __await__(self) -> typing.Generator[typing.Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
request = Request(self.scope, receive=self.receive)
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
else:
response = await run_in_threadpool(handler, request)
await response(self.scope, self.receive, self.send)
async def method_not_allowed(self, request: Request) -> Response:
# If we're running inside a starlette application then raise an
# exception, so that the configurable exception handler can deal with
# returning the response. For plain ASGI apps, just return the response.
headers = {"Allow": ", ".join(self._allowed_methods)}
if "app" in self.scope:
raise HTTPException(status_code=405, headers=headers)
return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
class WebSocketEndpoint:
encoding: str | None = None # May be "text", "bytes", or "json".
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "websocket"
self.scope = scope
self.receive = receive
self.send = send
def __await__(self) -> typing.Generator[typing.Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
await self.on_connect(websocket)
close_code = status.WS_1000_NORMAL_CLOSURE
try:
while True:
message = await websocket.receive()
if message["type"] == "websocket.receive":
data = await self.decode(websocket, message)
await self.on_receive(websocket, data)
elif message["type"] == "websocket.disconnect": # pragma: no branch
close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
break
except Exception as exc:
close_code = status.WS_1011_INTERNAL_ERROR
raise exc
finally:
await self.on_disconnect(websocket, close_code)
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
if self.encoding == "text":
if "text" not in message:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Expected text websocket messages, but got bytes")
return message["text"]
elif self.encoding == "bytes":
if "bytes" not in message:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Expected bytes websocket messages, but got text")
return message["bytes"]
elif self.encoding == "json":
if message.get("text") is not None:
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
try:
return json.loads(text)
except json.decoder.JSONDecodeError:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Malformed JSON data received.")
assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
return message["text"] if message.get("text") else message["bytes"]
async def on_connect(self, websocket: WebSocket) -> None:
"""Override to handle an incoming websocket connection"""
await websocket.accept()
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
"""Override to handle an incoming websocket message"""
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
"""Override to handle a disconnecting websocket"""

View File

@ -0,0 +1,33 @@
from __future__ import annotations
import http
from collections.abc import Mapping
class HTTPException(Exception):
def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None:
if detail is None:
detail = http.HTTPStatus(status_code).phrase
self.status_code = status_code
self.detail = detail
self.headers = headers
def __str__(self) -> str:
return f"{self.status_code}: {self.detail}"
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
class WebSocketException(Exception):
def __init__(self, code: int, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
def __str__(self) -> str:
return f"{self.code}: {self.reason}"
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"

View File

@ -0,0 +1,275 @@
from __future__ import annotations
import typing
from dataclasses import dataclass, field
from enum import Enum
from tempfile import SpooledTemporaryFile
from urllib.parse import unquote_plus
from starlette.datastructures import FormData, Headers, UploadFile
if typing.TYPE_CHECKING:
import python_multipart as multipart
from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
else:
try:
try:
import python_multipart as multipart
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
import multipart
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
multipart = None
parse_options_header = None
class FormMessage(Enum):
FIELD_START = 1
FIELD_NAME = 2
FIELD_DATA = 3
FIELD_END = 4
END = 5
@dataclass
class MultipartPart:
content_disposition: bytes | None = None
field_name: str = ""
data: bytearray = field(default_factory=bytearray)
file: UploadFile | None = None
item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
try:
return src.decode(codec)
except (UnicodeDecodeError, LookupError):
return src.decode("latin-1")
class MultiPartException(Exception):
def __init__(self, message: str) -> None:
self.message = message
class FormParser:
def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.messages: list[tuple[FormMessage, bytes]] = []
def on_field_start(self) -> None:
message = (FormMessage.FIELD_START, b"")
self.messages.append(message)
def on_field_name(self, data: bytes, start: int, end: int) -> None:
message = (FormMessage.FIELD_NAME, data[start:end])
self.messages.append(message)
def on_field_data(self, data: bytes, start: int, end: int) -> None:
message = (FormMessage.FIELD_DATA, data[start:end])
self.messages.append(message)
def on_field_end(self) -> None:
message = (FormMessage.FIELD_END, b"")
self.messages.append(message)
def on_end(self) -> None:
message = (FormMessage.END, b"")
self.messages.append(message)
async def parse(self) -> FormData:
# Callbacks dictionary.
callbacks: QuerystringCallbacks = {
"on_field_start": self.on_field_start,
"on_field_name": self.on_field_name,
"on_field_data": self.on_field_data,
"on_field_end": self.on_field_end,
"on_end": self.on_end,
}
# Create the parser.
parser = multipart.QuerystringParser(callbacks)
field_name = b""
field_value = b""
items: list[tuple[str, str | UploadFile]] = []
# Feed the parser with data from the request.
async for chunk in self.stream:
if chunk:
parser.write(chunk)
else:
parser.finalize()
messages = list(self.messages)
self.messages.clear()
for message_type, message_bytes in messages:
if message_type == FormMessage.FIELD_START:
field_name = b""
field_value = b""
elif message_type == FormMessage.FIELD_NAME:
field_name += message_bytes
elif message_type == FormMessage.FIELD_DATA:
field_value += message_bytes
elif message_type == FormMessage.FIELD_END:
name = unquote_plus(field_name.decode("latin-1"))
value = unquote_plus(field_value.decode("latin-1"))
items.append((name, value))
return FormData(items)
class MultiPartParser:
spool_max_size = 1024 * 1024 # 1MB
"""The maximum size of the spooled temporary file used to store file data."""
max_part_size = 1024 * 1024 # 1MB
"""The maximum size of a part in the multipart request."""
def __init__(
self,
headers: Headers,
stream: typing.AsyncGenerator[bytes, None],
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024, # 1MB
) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
self.max_fields = max_fields
self.items: list[tuple[str, str | UploadFile]] = []
self._current_files = 0
self._current_fields = 0
self._current_partial_header_name: bytes = b""
self._current_partial_header_value: bytes = b""
self._current_part = MultipartPart()
self._charset = ""
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: list[MultipartPart] = []
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
self.max_part_size = max_part_size
def on_part_begin(self) -> None:
self._current_part = MultipartPart()
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self._current_part.file is None:
if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
self._current_part.data.extend(message_bytes)
else:
self._file_parts_to_write.append((self._current_part, message_bytes))
def on_part_end(self) -> None:
if self._current_part.file is None:
self.items.append(
(
self._current_part.field_name,
_user_safe_decode(self._current_part.data, self._charset),
)
)
else:
self._file_parts_to_finish.append(self._current_part)
# The file can be added to the items right now even though it's not
# finished yet, because it will be finished in the `parse()` method, before
# self.items is used in the return value.
self.items.append((self._current_part.field_name, self._current_part.file))
def on_header_field(self, data: bytes, start: int, end: int) -> None:
self._current_partial_header_name += data[start:end]
def on_header_value(self, data: bytes, start: int, end: int) -> None:
self._current_partial_header_value += data[start:end]
def on_header_end(self) -> None:
field = self._current_partial_header_name.lower()
if field == b"content-disposition":
self._current_part.content_disposition = self._current_partial_header_value
self._current_part.item_headers.append((field, self._current_partial_header_value))
self._current_partial_header_name = b""
self._current_partial_header_value = b""
def on_headers_finished(self) -> None:
disposition, options = parse_options_header(self._current_part.content_disposition)
try:
self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
except KeyError:
raise MultiPartException('The Content-Disposition header field "name" must be provided.')
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = UploadFile(
file=tempfile, # type: ignore[arg-type]
size=0,
filename=filename,
headers=Headers(raw=self._current_part.item_headers),
)
else:
self._current_fields += 1
if self._current_fields > self.max_fields:
raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
self._current_part.file = None
def on_end(self) -> None:
pass
async def parse(self) -> FormData:
# Parse the Content-Type header to get the multipart boundary.
_, params = parse_options_header(self.headers["Content-Type"])
charset = params.get(b"charset", "utf-8")
if isinstance(charset, bytes):
charset = charset.decode("latin-1")
self._charset = charset
try:
boundary = params[b"boundary"]
except KeyError:
raise MultiPartException("Missing boundary in multipart.")
# Callbacks dictionary.
callbacks: MultipartCallbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,
"on_header_field": self.on_header_field,
"on_header_value": self.on_header_value,
"on_header_end": self.on_header_end,
"on_headers_finished": self.on_headers_finished,
"on_end": self.on_end,
}
# Create the parser.
parser = multipart.MultipartParser(boundary, callbacks)
try:
# Feed the parser with data from the request.
async for chunk in self.stream:
parser.write(chunk)
# Write file data, it needs to use await with the UploadFile methods
# that call the corresponding file methods *in a threadpool*,
# otherwise, if they were called directly in the callback methods above
# (regular, non-async functions), that would block the event loop in
# the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
await part.file.write(data)
for part in self._file_parts_to_finish:
assert part.file # for type checkers
await part.file.seek(0)
self._file_parts_to_write.clear()
self._file_parts_to_finish.clear()
except MultiPartException as exc:
# Close all the files if there was an error.
for file in self._files_to_close_on_error:
file.close()
raise exc
parser.finalize()
return FormData(self.items)

View File

@ -0,0 +1,42 @@
from __future__ import annotations
import sys
from collections.abc import Iterator
from typing import Any, Protocol
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette.types import ASGIApp
P = ParamSpec("P")
class _MiddlewareFactory(Protocol[P]):
def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover
class Middleware:
def __init__(
self,
cls: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
self.cls = cls
self.args = args
self.kwargs = kwargs
def __iter__(self) -> Iterator[Any]:
as_tuple = (self.cls, self.args, self.kwargs)
return iter(as_tuple)
def __repr__(self) -> str:
class_name = self.__class__.__name__
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
name = getattr(self.cls, "__name__", "")
args_repr = ", ".join([name] + args_strings + option_strings)
return f"{class_name}({args_repr})"

View File

@ -0,0 +1,52 @@
from __future__ import annotations
import typing
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
AuthenticationError,
UnauthenticatedUser,
)
from starlette.requests import HTTPConnection
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
class AuthenticationMiddleware:
def __init__(
self,
app: ASGIApp,
backend: AuthenticationBackend,
on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
) -> None:
self.app = app
self.backend = backend
self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
on_error if on_error is not None else self.default_on_error
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ["http", "websocket"]:
await self.app(scope, receive, send)
return
conn = HTTPConnection(scope)
try:
auth_result = await self.backend.authenticate(conn)
except AuthenticationError as exc:
response = self.on_error(conn, exc)
if scope["type"] == "websocket":
await send({"type": "websocket.close", "code": 1000})
else:
await response(scope, receive, send)
return
if auth_result is None:
auth_result = AuthCredentials(), UnauthenticatedUser()
scope["auth"], scope["user"] = auth_result
await self.app(scope, receive, send)
@staticmethod
def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
return PlainTextResponse(str(exc), status_code=400)

View File

@ -0,0 +1,220 @@
from __future__ import annotations
import typing
import anyio
from starlette._utils import collapse_excgroups
from starlette.requests import ClientDisconnect, Request
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
T = typing.TypeVar("T")
class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
we cache the entire request body in memory and pass that to downstream middlewares,
but if they call Request.stream() then all we do is send an
empty body so that downstream things don't hang forever.
"""
def __init__(self, scope: Scope, receive: Receive):
super().__init__(scope, receive)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False
self._wrapped_rc_stream = self.stream()
async def wrapped_receive(self) -> Message:
# wrapped_rcv state 1: disconnected
if self._wrapped_rcv_disconnected:
# we've already sent a disconnect to the downstream app
# we don't need to wait to get another one
# (although most ASGI servers will just keep sending it)
return {"type": "http.disconnect"}
# wrapped_rcv state 1: consumed but not yet disconnected
if self._wrapped_rcv_consumed:
# since the downstream app has consumed us all that is left
# is to send it a disconnect
if self._is_disconnected:
# the middleware has already seen the disconnect
# since we know the client is disconnected no need to wait
# for the message
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# we don't know yet if the client is disconnected or not
# so we'll wait until we get that message
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
self._wrapped_rcv_disconnected = True
return msg
# wrapped_rcv state 3: not yet consumed
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion
# return an empty body so that downstream apps don't hang
# waiting for a disconnect
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
try:
stream = self.stream()
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": not self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
class BaseHTTPMiddleware:
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = _CachedRequest(scope, receive)
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()
app_exc: Exception | None = None
async def call_next(request: Request) -> Response:
async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}
async with anyio.create_task_group() as task_group:
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result
task_group.start_soon(wrap, response_sent.wait)
message = await wrap(wrapped_receive)
if response_sent.is_set():
return {"type": "http.disconnect"}
return message
async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. response_sent has been set.
return
async def coro() -> None:
nonlocal app_exc
with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc
task_group.start_soon(coro)
try:
message = await recv_stream.receive()
info = message.get("info", None)
if message["type"] == "http.response.debug" and info is not None:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
response.raw_headers = message["headers"]
return response
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
send_stream, recv_stream = streams
with recv_stream, send_stream, collapse_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
recv_stream.close()
if app_exc is not None:
raise app_exc
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
class _StreamingResponse(Response):
def __init__(
self,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)
self.background = None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
async for chunk in self.body_iterator:
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
if self.background:
await self.background()

View File

@ -0,0 +1,172 @@
from __future__ import annotations
import functools
import re
import typing
from starlette.datastructures import Headers, MutableHeaders
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
class CORSMiddleware:
def __init__(
self,
app: ASGIApp,
allow_origins: typing.Sequence[str] = (),
allow_methods: typing.Sequence[str] = ("GET",),
allow_headers: typing.Sequence[str] = (),
allow_credentials: bool = False,
allow_origin_regex: str | None = None,
expose_headers: typing.Sequence[str] = (),
max_age: int = 600,
) -> None:
if "*" in allow_methods:
allow_methods = ALL_METHODS
compiled_allow_origin_regex = None
if allow_origin_regex is not None:
compiled_allow_origin_regex = re.compile(allow_origin_regex)
allow_all_origins = "*" in allow_origins
allow_all_headers = "*" in allow_headers
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
simple_headers = {}
if allow_all_origins:
simple_headers["Access-Control-Allow-Origin"] = "*"
if allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true"
if expose_headers:
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
preflight_headers = {}
if preflight_explicit_allow_origin:
# The origin value will be set in preflight_response() if it is allowed.
preflight_headers["Vary"] = "Origin"
else:
preflight_headers["Access-Control-Allow-Origin"] = "*"
preflight_headers.update(
{
"Access-Control-Allow-Methods": ", ".join(allow_methods),
"Access-Control-Max-Age": str(max_age),
}
)
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
if allow_headers and not allow_all_headers:
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
if allow_credentials:
preflight_headers["Access-Control-Allow-Credentials"] = "true"
self.app = app
self.allow_origins = allow_origins
self.allow_methods = allow_methods
self.allow_headers = [h.lower() for h in allow_headers]
self.allow_all_origins = allow_all_origins
self.allow_all_headers = allow_all_headers
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
self.allow_origin_regex = compiled_allow_origin_regex
self.simple_headers = simple_headers
self.preflight_headers = preflight_headers
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
await self.app(scope, receive, send)
return
method = scope["method"]
headers = Headers(scope=scope)
origin = headers.get("origin")
if origin is None:
await self.app(scope, receive, send)
return
if method == "OPTIONS" and "access-control-request-method" in headers:
response = self.preflight_response(request_headers=headers)
await response(scope, receive, send)
return
await self.simple_response(scope, receive, send, request_headers=headers)
def is_allowed_origin(self, origin: str) -> bool:
if self.allow_all_origins:
return True
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
return True
return origin in self.allow_origins
def preflight_response(self, request_headers: Headers) -> Response:
requested_origin = request_headers["origin"]
requested_method = request_headers["access-control-request-method"]
requested_headers = request_headers.get("access-control-request-headers")
headers = dict(self.preflight_headers)
failures = []
if self.is_allowed_origin(origin=requested_origin):
if self.preflight_explicit_allow_origin:
# The "else" case is already accounted for in self.preflight_headers
# and the value would be "*".
headers["Access-Control-Allow-Origin"] = requested_origin
else:
failures.append("origin")
if requested_method not in self.allow_methods:
failures.append("method")
# If we allow all headers, then we have to mirror back any requested
# headers in the response.
if self.allow_all_headers and requested_headers is not None:
headers["Access-Control-Allow-Headers"] = requested_headers
elif requested_headers is not None:
for header in [h.lower() for h in requested_headers.split(",")]:
if header.strip() not in self.allow_headers:
failures.append("headers")
break
# We don't strictly need to use 400 responses here, since its up to
# the browser to enforce the CORS policy, but its more informative
# if we do.
if failures:
failure_text = "Disallowed CORS " + ", ".join(failures)
return PlainTextResponse(failure_text, status_code=400, headers=headers)
return PlainTextResponse("OK", status_code=200, headers=headers)
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
send = functools.partial(self.send, send=send, request_headers=request_headers)
await self.app(scope, receive, send)
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
if message["type"] != "http.response.start":
await send(message)
return
message.setdefault("headers", [])
headers = MutableHeaders(scope=message)
headers.update(self.simple_headers)
origin = request_headers["Origin"]
has_cookie = "cookie" in request_headers
# If request includes any cookie headers, then we must respond
# with the specific origin instead of '*'.
if self.allow_all_origins and has_cookie:
self.allow_explicit_origin(headers, origin)
# If we only allow specific origins, then we have to mirror back
# the Origin header in the response.
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
self.allow_explicit_origin(headers, origin)
await send(message)
@staticmethod
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
headers["Access-Control-Allow-Origin"] = origin
headers.add_vary_header("Origin")

View File

@ -0,0 +1,260 @@
from __future__ import annotations
import html
import inspect
import sys
import traceback
import typing
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
STYLES = """
p {
color: #211c1c;
}
.traceback-container {
border: 1px solid #038BB8;
}
.traceback-title {
background-color: #038BB8;
color: lemonchiffon;
padding: 12px;
font-size: 20px;
margin-top: 0px;
}
.frame-line {
padding-left: 10px;
font-family: monospace;
}
.frame-filename {
font-family: monospace;
}
.center-line {
background-color: #038BB8;
color: #f9f6e1;
padding: 5px 0px 5px 5px;
}
.lineno {
margin-right: 5px;
}
.frame-title {
font-weight: unset;
padding: 10px 10px 10px 10px;
background-color: #E4F4FD;
margin-right: 10px;
color: #191f21;
font-size: 17px;
border: 1px solid #c7dce8;
}
.collapse-btn {
float: right;
padding: 0px 5px 1px 5px;
border: solid 1px #96aebb;
cursor: pointer;
}
.collapsed {
display: none;
}
.source-code {
font-family: courier;
font-size: small;
padding-bottom: 10px;
}
"""
JS = """
<script type="text/javascript">
function collapse(element){
const frameId = element.getAttribute("data-frame-id");
const frame = document.getElementById(frameId);
if (frame.classList.contains("collapsed")){
element.innerHTML = "&#8210;";
frame.classList.remove("collapsed");
} else {
element.innerHTML = "+";
frame.classList.add("collapsed");
}
}
</script>
"""
TEMPLATE = """
<html>
<head>
<style type='text/css'>
{styles}
</style>
<title>Starlette Debugger</title>
</head>
<body>
<h1>500 Server Error</h1>
<h2>{error}</h2>
<div class="traceback-container">
<p class="traceback-title">Traceback</p>
<div>{exc_html}</div>
</div>
{js}
</body>
</html>
"""
FRAME_TEMPLATE = """
<div>
<p class="frame-title">File <span class="frame-filename">{frame_filename}</span>,
line <i>{frame_lineno}</i>,
in <b>{frame_name}</b>
<span class="collapse-btn" data-frame-id="{frame_filename}-{frame_lineno}" onclick="collapse(this)">{collapse_button}</span>
</p>
<div id="{frame_filename}-{frame_lineno}" class="source-code {collapsed}">{code_context}</div>
</div>
""" # noqa: E501
LINE = """
<p><span class="frame-line">
<span class="lineno">{lineno}.</span> {line}</span></p>
"""
CENTER_LINE = """
<p class="center-line"><span class="frame-line center-line">
<span class="lineno">{lineno}.</span> {line}</span></p>
"""
class ServerErrorMiddleware:
"""
Handles returning 500 responses when a server error occurs.
If 'debug' is set, then traceback responses will be returned,
otherwise the designated 'handler' will be called.
This middleware class should generally be used to wrap *everything*
else up, so that unhandled exceptions anywhere in the stack
always result in an appropriate 500 response.
"""
def __init__(
self,
app: ASGIApp,
handler: typing.Callable[[Request, Exception], typing.Any] | None = None,
debug: bool = False,
) -> None:
self.app = app
self.handler = handler
self.debug = debug
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
response_started = False
async def _send(message: Message) -> None:
nonlocal response_started, send
if message["type"] == "http.response.start":
response_started = True
await send(message)
try:
await self.app(scope, receive, _send)
except Exception as exc:
request = Request(scope)
if self.debug:
# In debug mode, return traceback responses.
response = self.debug_response(request, exc)
elif self.handler is None:
# Use our default 500 error handler.
response = self.error_response(request, exc)
else:
# Use an installed 500 error handler.
if is_async_callable(self.handler):
response = await self.handler(request, exc)
else:
response = await run_in_threadpool(self.handler, request, exc)
if not response_started:
await response(scope, receive, send)
# We always continue to raise the exception.
# This allows servers to log the error, or allows test clients
# to optionally raise the error within the test case.
raise exc
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
values = {
# HTML escape - line could contain < or >
"line": html.escape(line).replace(" ", "&nbsp"),
"lineno": (frame_lineno - frame_index) + index,
}
if index != frame_index:
return LINE.format(**values)
return CENTER_LINE.format(**values)
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
code_context = "".join(
self.format_line(
index,
line,
frame.lineno,
frame.index, # type: ignore[arg-type]
)
for index, line in enumerate(frame.code_context or [])
)
values = {
# HTML escape - filename could contain < or >, especially if it's a virtual
# file e.g. <stdin> in the REPL
"frame_filename": html.escape(frame.filename),
"frame_lineno": frame.lineno,
# HTML escape - if you try very hard it's possible to name a function with <
# or >
"frame_name": html.escape(frame.function),
"code_context": code_context,
"collapsed": "collapsed" if is_collapsed else "",
"collapse_button": "+" if is_collapsed else "&#8210;",
}
return FRAME_TEMPLATE.format(**values)
def generate_html(self, exc: Exception, limit: int = 7) -> str:
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
exc_html = ""
is_collapsed = False
exc_traceback = exc.__traceback__
if exc_traceback is not None:
frames = inspect.getinnerframes(exc_traceback, limit)
for frame in reversed(frames):
exc_html += self.generate_frame_html(frame, is_collapsed)
is_collapsed = True
if sys.version_info >= (3, 13): # pragma: no cover
exc_type_str = traceback_obj.exc_type_str
else: # pragma: no cover
exc_type_str = traceback_obj.exc_type.__name__
# escape error class and text
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
def generate_plain_text(self, exc: Exception) -> str:
return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
def debug_response(self, request: Request, exc: Exception) -> Response:
accept = request.headers.get("accept", "")
if "text/html" in accept:
content = self.generate_html(exc)
return HTMLResponse(content, status_code=500)
content = self.generate_plain_text(exc)
return PlainTextResponse(content, status_code=500)
def error_response(self, request: Request, exc: Exception) -> Response:
return PlainTextResponse("Internal Server Error", status_code=500)

View File

@ -0,0 +1,72 @@
from __future__ import annotations
import typing
from starlette._exception_handler import (
ExceptionHandlers,
StatusHandlers,
wrap_app_handling_exceptions,
)
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket
class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
debug: bool = False,
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExceptionHandlers = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
}
if handlers is not None: # pragma: no branch
for key, value in handlers.items():
self.add_exception_handler(key, value)
def add_exception_handler(
self,
exc_class_or_status_code: int | type[Exception],
handler: typing.Callable[[Request, Exception], Response],
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
else:
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
scope["starlette.exception_handlers"] = (
self._exception_handlers,
self._status_handlers,
)
conn: Request | WebSocket
if scope["type"] == "http":
conn = Request(scope, receive, send)
else:
conn = WebSocket(scope, receive, send)
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
def http_exception(self, request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover

View File

@ -0,0 +1,141 @@
import gzip
import io
import typing
from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
class GZipMiddleware:
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
self.app = app
self.minimum_size = minimum_size
self.compresslevel = compresslevel
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
responder: ASGIApp
if "gzip" in headers.get("Accept-Encoding", ""):
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
else:
responder = IdentityResponder(self.app, self.minimum_size)
await responder(scope, receive, send)
class IdentityResponder:
content_encoding: str
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
self.app = app
self.minimum_size = minimum_size
self.send: Send = unattached_send
self.initial_message: Message = {}
self.started = False
self.content_encoding_set = False
self.content_type_is_excluded = False
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_compression)
async def send_with_compression(self, message: Message) -> None:
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
# modify the outgoing headers correctly.
self.initial_message = message
headers = Headers(raw=self.initial_message["headers"])
self.content_encoding_set = "content-encoding" in headers
self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
if not self.started:
self.started = True
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body" and not self.started:
self.started = True
body = message.get("body", b"")
more_body = message.get("more_body", False)
if len(body) < self.minimum_size and not more_body:
# Don't apply compression to small outgoing responses.
await self.send(self.initial_message)
await self.send(message)
elif not more_body:
# Standard response.
body = self.apply_compression(body, more_body=False)
headers = MutableHeaders(raw=self.initial_message["headers"])
headers.add_vary_header("Accept-Encoding")
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
headers["Content-Length"] = str(len(body))
message["body"] = body
await self.send(self.initial_message)
await self.send(message)
else:
# Initial body in streaming response.
body = self.apply_compression(body, more_body=True)
headers = MutableHeaders(raw=self.initial_message["headers"])
headers.add_vary_header("Accept-Encoding")
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
del headers["Content-Length"]
message["body"] = body
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body": # pragma: no branch
# Remaining body in streaming response.
body = message.get("body", b"")
more_body = message.get("more_body", False)
message["body"] = self.apply_compression(body, more_body=more_body)
await self.send(message)
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
"""Apply compression on the response body.
If more_body is False, any compression file should be closed. If it
isn't, it won't be closed automatically until all background tasks
complete.
"""
return body
class GZipResponder(IdentityResponder):
content_encoding = "gzip"
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
super().__init__(app, minimum_size)
self.gzip_buffer = io.BytesIO()
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
with self.gzip_buffer, self.gzip_file:
await super().__call__(scope, receive, send)
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
self.gzip_file.write(body)
if not more_body:
self.gzip_file.close()
body = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
return body
async def unattached_send(message: Message) -> typing.NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover

View File

@ -0,0 +1,19 @@
from starlette.datastructures import URL
from starlette.responses import RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
class HTTPSRedirectMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"):
url = URL(scope=scope)
redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme]
netloc = url.hostname if url.port in (80, 443) else url.netloc
url = url.replace(scheme=redirect_scheme, netloc=netloc)
response = RedirectResponse(url, status_code=307)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)

View File

@ -0,0 +1,85 @@
from __future__ import annotations
import json
import typing
from base64 import b64decode, b64encode
import itsdangerous
from itsdangerous.exc import BadSignature
from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Message, Receive, Scope, Send
class SessionMiddleware:
def __init__(
self,
app: ASGIApp,
secret_key: str | Secret,
session_cookie: str = "session",
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
path: str = "/",
same_site: typing.Literal["lax", "strict", "none"] = "lax",
https_only: bool = False,
domain: str | None = None,
) -> None:
self.app = app
self.signer = itsdangerous.TimestampSigner(str(secret_key))
self.session_cookie = session_cookie
self.max_age = max_age
self.path = path
self.security_flags = "httponly; samesite=" + same_site
if https_only: # Secure flag can be used with HTTPS only
self.security_flags += "; secure"
if domain is not None:
self.security_flags += f"; domain={domain}"
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
return
connection = HTTPConnection(scope)
initial_session_was_empty = True
if self.session_cookie in connection.cookies:
data = connection.cookies[self.session_cookie].encode("utf-8")
try:
data = self.signer.unsign(data, max_age=self.max_age)
scope["session"] = json.loads(b64decode(data))
initial_session_was_empty = False
except BadSignature:
scope["session"] = {}
else:
scope["session"] = {}
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
if scope["session"]:
# We have session data to persist.
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
elif not initial_session_was_empty:
# The session has been cleared.
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
session_cookie=self.session_cookie,
data="null",
path=self.path,
expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
await send(message)
await self.app(scope, receive, send_wrapper)

View File

@ -0,0 +1,60 @@
from __future__ import annotations
import typing
from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
class TrustedHostMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_hosts: typing.Sequence[str] | None = None,
www_redirect: bool = True,
) -> None:
if allowed_hosts is None:
allowed_hosts = ["*"]
for pattern in allowed_hosts:
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
if pattern.startswith("*") and pattern != "*":
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
self.app = app
self.allowed_hosts = list(allowed_hosts)
self.allow_any = "*" in allowed_hosts
self.www_redirect = www_redirect
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.allow_any or scope["type"] not in (
"http",
"websocket",
): # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
is_valid_host = True
break
elif "www." + host == pattern:
found_www_redirect = True
if is_valid_host:
await self.app(scope, receive, send)
else:
response: Response
if found_www_redirect and self.www_redirect:
url = URL(scope=scope)
redirect_url = url.replace(netloc="www." + url.netloc)
response = RedirectResponse(url=str(redirect_url))
else:
response = PlainTextResponse("Invalid host header", status_code=400)
await response(scope, receive, send)

View File

@ -0,0 +1,152 @@
from __future__ import annotations
import io
import math
import sys
import typing
import warnings
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from starlette.types import Receive, Scope, Send
warnings.warn(
"starlette.middleware.wsgi is deprecated and will be removed in a future release. "
"Please refer to https://github.com/abersheeran/a2wsgi as a replacement.",
DeprecationWarning,
)
def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]:
"""
Builds a scope and request body into a WSGI environ object.
"""
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
path_info = scope["path"].encode("utf8").decode("latin1")
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": script_name,
"PATH_INFO": path_info,
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": io.BytesIO(body),
"wsgi.errors": sys.stdout,
"wsgi.multithread": True,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
# Get server name and port - required in WSGI, not in ASGI
server = scope.get("server") or ("localhost", 80)
environ["SERVER_NAME"] = server[0]
environ["SERVER_PORT"] = server[1]
# Get client IP address
if scope.get("client"):
environ["REMOTE_ADDR"] = scope["client"][0]
# Go through headers and make them into environ entries
for name, value in scope.get("headers", []):
name = name.decode("latin1")
if name == "content-length":
corrected_name = "CONTENT_LENGTH"
elif name == "content-type":
corrected_name = "CONTENT_TYPE"
else:
corrected_name = f"HTTP_{name}".upper().replace("-", "_")
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in
# case
value = value.decode("latin1")
if corrected_name in environ:
value = environ[corrected_name] + "," + value
environ[corrected_name] = value
return environ
class WSGIMiddleware:
def __init__(self, app: typing.Callable[..., typing.Any]) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "http"
responder = WSGIResponder(self.app, scope)
await responder(receive, send)
class WSGIResponder:
stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None:
self.app = app
self.scope = scope
self.status = None
self.response_headers = None
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
self.response_started = False
self.exc_info: typing.Any = None
async def __call__(self, receive: Receive, send: Send) -> None:
body = b""
more_body = True
while more_body:
message = await receive()
body += message.get("body", b"")
more_body = message.get("more_body", False)
environ = build_environ(self.scope, body)
async with anyio.create_task_group() as task_group:
task_group.start_soon(self.sender, send)
async with self.stream_send:
await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
if self.exc_info is not None:
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
async def sender(self, send: Send) -> None:
async with self.stream_receive:
async for message in self.stream_receive:
await send(message)
def start_response(
self,
status: str,
response_headers: list[tuple[str, str]],
exc_info: typing.Any = None,
) -> None:
self.exc_info = exc_info
if not self.response_started: # pragma: no branch
self.response_started = True
status_code_string, _ = status.split(" ", 1)
status_code = int(status_code_string)
headers = [
(name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
for name, value in response_headers
]
anyio.from_thread.run(
self.stream_send.send,
{
"type": "http.response.start",
"status": status_code,
"headers": headers,
},
)
def wsgi(
self,
environ: dict[str, typing.Any],
start_response: typing.Callable[..., typing.Any],
) -> None:
for chunk in self.app(environ, start_response):
anyio.from_thread.run(
self.stream_send.send,
{"type": "http.response.body", "body": chunk, "more_body": True},
)
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})

View File

@ -0,0 +1,322 @@
from __future__ import annotations
import json
import typing
from http import cookies as http_cookies
import anyio
from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
from starlette.exceptions import HTTPException
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
from starlette.types import Message, Receive, Scope, Send
if typing.TYPE_CHECKING:
from python_multipart.multipart import parse_options_header
from starlette.applications import Starlette
from starlette.routing import Router
else:
try:
try:
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
parse_options_header = None
SERVER_PUSH_HEADERS_TO_COPY = {
"accept",
"accept-encoding",
"accept-language",
"cache-control",
"user-agent",
}
def cookie_parser(cookie_string: str) -> dict[str, str]:
"""
This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
It attempts to mimic browser cookie parsing behavior: browsers and web servers
frequently disregard the spec (RFC 6265) when setting and reading cookies,
so we attempt to suit the common scenarios here.
This function has been adapted from Django 3.1.0.
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
on an outdated spec and will fail on lots of input we want to support
"""
cookie_dict: dict[str, str] = {}
for chunk in cookie_string.split(";"):
if "=" in chunk:
key, val = chunk.split("=", 1)
else:
# Assume an empty name per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
key, val = "", chunk
key, val = key.strip(), val.strip()
if key or val:
# unquote using Python's algorithm.
cookie_dict[key] = http_cookies._unquote(val)
return cookie_dict
class ClientDisconnect(Exception):
pass
class HTTPConnection(typing.Mapping[str, typing.Any]):
"""
A base class for incoming HTTP connections, that is used to provide
any functionality that is common to both `Request` and `WebSocket`.
"""
def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
assert scope["type"] in ("http", "websocket")
self.scope = scope
def __getitem__(self, key: str) -> typing.Any:
return self.scope[key]
def __iter__(self) -> typing.Iterator[str]:
return iter(self.scope)
def __len__(self) -> int:
return len(self.scope)
# Don't use the `abc.Mapping.__eq__` implementation.
# Connection instances should never be considered equal
# unless `self is other`.
__eq__ = object.__eq__
__hash__ = object.__hash__
@property
def app(self) -> typing.Any:
return self.scope["app"]
@property
def url(self) -> URL:
if not hasattr(self, "_url"): # pragma: no branch
self._url = URL(scope=self.scope)
return self._url
@property
def base_url(self) -> URL:
if not hasattr(self, "_base_url"):
base_url_scope = dict(self.scope)
# This is used by request.url_for, it might be used inside a Mount which
# would have its own child scope with its own root_path, but the base URL
# for url_for should still be the top level app root path.
app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
path = app_root_path
if not path.endswith("/"):
path += "/"
base_url_scope["path"] = path
base_url_scope["query_string"] = b""
base_url_scope["root_path"] = app_root_path
self._base_url = URL(scope=base_url_scope)
return self._base_url
@property
def headers(self) -> Headers:
if not hasattr(self, "_headers"):
self._headers = Headers(scope=self.scope)
return self._headers
@property
def query_params(self) -> QueryParams:
if not hasattr(self, "_query_params"): # pragma: no branch
self._query_params = QueryParams(self.scope["query_string"])
return self._query_params
@property
def path_params(self) -> dict[str, typing.Any]:
return self.scope.get("path_params", {})
@property
def cookies(self) -> dict[str, str]:
if not hasattr(self, "_cookies"):
cookies: dict[str, str] = {}
cookie_header = self.headers.get("cookie")
if cookie_header:
cookies = cookie_parser(cookie_header)
self._cookies = cookies
return self._cookies
@property
def client(self) -> Address | None:
# client is a 2 item tuple of (host, port), None if missing
host_port = self.scope.get("client")
if host_port is not None:
return Address(*host_port)
return None
@property
def session(self) -> dict[str, typing.Any]:
assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
return self.scope["session"] # type: ignore[no-any-return]
@property
def auth(self) -> typing.Any:
assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
return self.scope["auth"]
@property
def user(self) -> typing.Any:
assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
return self.scope["user"]
@property
def state(self) -> State:
if not hasattr(self, "_state"):
# Ensure 'state' has an empty dict if it's not already populated.
self.scope.setdefault("state", {})
# Create a state instance with a reference to the dict in which it should
# store info
self._state = State(self.scope["state"])
return self._state
def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
if url_path_provider is None:
raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
url_path = url_path_provider.url_path_for(name, **path_params)
return url_path.make_absolute_url(base_url=self.base_url)
async def empty_receive() -> typing.NoReturn:
raise RuntimeError("Receive channel has not been made available")
async def empty_send(message: Message) -> typing.NoReturn:
raise RuntimeError("Send channel has not been made available")
class Request(HTTPConnection):
_form: FormData | None
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
super().__init__(scope)
assert scope["type"] == "http"
self._receive = receive
self._send = send
self._stream_consumed = False
self._is_disconnected = False
self._form = None
@property
def method(self) -> str:
return typing.cast(str, self.scope["method"])
@property
def receive(self) -> Receive:
return self._receive
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):
yield self._body
yield b""
return
if self._stream_consumed:
raise RuntimeError("Stream consumed")
while not self._stream_consumed:
message = await self._receive()
if message["type"] == "http.request":
body = message.get("body", b"")
if not message.get("more_body", False):
self._stream_consumed = True
if body:
yield body
elif message["type"] == "http.disconnect": # pragma: no branch
self._is_disconnected = True
raise ClientDisconnect()
yield b""
async def body(self) -> bytes:
if not hasattr(self, "_body"):
chunks: list[bytes] = []
async for chunk in self.stream():
chunks.append(chunk)
self._body = b"".join(chunks)
return self._body
async def json(self) -> typing.Any:
if not hasattr(self, "_json"): # pragma: no branch
body = await self.body()
self._json = json.loads(body)
return self._json
async def _get_form(
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> FormData:
if self._form is None: # pragma: no branch
assert parse_options_header is not None, (
"The `python-multipart` library must be installed to use form parsing."
)
content_type_header = self.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
if content_type == b"multipart/form-data":
try:
multipart_parser = MultiPartParser(
self.headers,
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_part_size=max_part_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
if "app" in self.scope:
raise HTTPException(status_code=400, detail=exc.message)
raise exc
elif content_type == b"application/x-www-form-urlencoded":
form_parser = FormParser(self.headers, self.stream())
self._form = await form_parser.parse()
else:
self._form = FormData()
return self._form
def form(
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(
self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
)
async def close(self) -> None:
if self._form is not None: # pragma: no branch
await self._form.close()
async def is_disconnected(self) -> bool:
if not self._is_disconnected:
message: Message = {}
# If message isn't immediately available, move on
with anyio.CancelScope() as cs:
cs.cancel()
message = await self._receive()
if message.get("type") == "http.disconnect":
self._is_disconnected = True
return self._is_disconnected
async def send_push_promise(self, path: str) -> None:
if "http.response.push" in self.scope.get("extensions", {}):
raw_headers: list[tuple[bytes, bytes]] = []
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})

View File

@ -0,0 +1,536 @@
from __future__ import annotations
import hashlib
import http.cookies
import json
import os
import re
import stat
import typing
import warnings
from datetime import datetime
from email.utils import format_datetime, formatdate
from functools import partial
from mimetypes import guess_type
from secrets import token_hex
from urllib.parse import quote
import anyio
import anyio.to_thread
from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders
from starlette.requests import ClientDisconnect
from starlette.types import Receive, Scope, Send
class Response:
media_type = None
charset = "utf-8"
def __init__(
self,
content: typing.Any = None,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.background = background
self.body = self.render(content)
self.init_headers(headers)
def render(self, content: typing.Any) -> bytes | memoryview:
if content is None:
return b""
if isinstance(content, (bytes, memoryview)):
return content
return content.encode(self.charset) # type: ignore
def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None:
if headers is None:
raw_headers: list[tuple[bytes, bytes]] = []
populate_content_length = True
populate_content_type = True
else:
raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
keys = [h[0] for h in raw_headers]
populate_content_length = b"content-length" not in keys
populate_content_type = b"content-type" not in keys
body = getattr(self, "body", None)
if (
body is not None
and populate_content_length
and not (self.status_code < 200 or self.status_code in (204, 304))
):
content_length = str(len(body))
raw_headers.append((b"content-length", content_length.encode("latin-1")))
content_type = self.media_type
if content_type is not None and populate_content_type:
if content_type.startswith("text/") and "charset=" not in content_type.lower():
content_type += "; charset=" + self.charset
raw_headers.append((b"content-type", content_type.encode("latin-1")))
self.raw_headers = raw_headers
@property
def headers(self) -> MutableHeaders:
if not hasattr(self, "_headers"):
self._headers = MutableHeaders(raw=self.raw_headers)
return self._headers
def set_cookie(
self,
key: str,
value: str = "",
max_age: int | None = None,
expires: datetime | str | int | None = None,
path: str | None = "/",
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
) -> None:
cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie()
cookie[key] = value
if max_age is not None:
cookie[key]["max-age"] = max_age
if expires is not None:
if isinstance(expires, datetime):
cookie[key]["expires"] = format_datetime(expires, usegmt=True)
else:
cookie[key]["expires"] = expires
if path is not None:
cookie[key]["path"] = path
if domain is not None:
cookie[key]["domain"] = domain
if secure:
cookie[key]["secure"] = True
if httponly:
cookie[key]["httponly"] = True
if samesite is not None:
assert samesite.lower() in [
"strict",
"lax",
"none",
], "samesite must be either 'strict', 'lax' or 'none'"
cookie[key]["samesite"] = samesite
cookie_val = cookie.output(header="").strip()
self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
def delete_cookie(
self,
key: str,
path: str = "/",
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
) -> None:
self.set_cookie(
key,
max_age=0,
expires=0,
path=path,
domain=domain,
secure=secure,
httponly=httponly,
samesite=samesite,
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
"type": prefix + "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
await send({"type": prefix + "http.response.body", "body": self.body})
if self.background is not None:
await self.background()
class HTMLResponse(Response):
media_type = "text/html"
class PlainTextResponse(Response):
media_type = "text/plain"
class JSONResponse(Response):
media_type = "application/json"
def __init__(
self,
content: typing.Any,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
super().__init__(content, status_code, headers, media_type, background)
def render(self, content: typing.Any) -> bytes:
return json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
).encode("utf-8")
class RedirectResponse(Response):
def __init__(
self,
url: str | URL,
status_code: int = 307,
headers: typing.Mapping[str, str] | None = None,
background: BackgroundTask | None = None,
) -> None:
super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
Content = typing.Union[str, bytes, memoryview]
SyncContentStream = typing.Iterable[Content]
AsyncContentStream = typing.AsyncIterable[Content]
ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
class StreamingResponse(Response):
body_iterator: AsyncContentStream
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background
self.init_headers(headers)
async def listen_for_disconnect(self, receive: Receive) -> None:
while True:
message = await receive()
if message["type"] == "http.disconnect":
break
async def stream_response(self, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
async for chunk in self.body_iterator:
if not isinstance(chunk, (bytes, memoryview)):
chunk = chunk.encode(self.charset)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
if spec_version >= (2, 4):
try:
await self.stream_response(send)
except OSError:
raise ClientDisconnect()
else:
with collapse_excgroups():
async with anyio.create_task_group() as task_group:
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()
class MalformedRangeHeader(Exception):
def __init__(self, content: str = "Malformed range header.") -> None:
self.content = content
class RangeNotSatisfiable(Exception):
def __init__(self, max_size: int) -> None:
self.max_size = max_size
_RANGE_PATTERN = re.compile(r"(\d*)-(\d*)")
class FileResponse(Response):
chunk_size = 64 * 1024
def __init__(
self,
path: str | os.PathLike[str],
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
filename: str | None = None,
stat_result: os.stat_result | None = None,
method: str | None = None,
content_disposition_type: str = "attachment",
) -> None:
self.path = path
self.status_code = status_code
self.filename = filename
if method is not None:
warnings.warn(
"The 'method' parameter is not used, and it will be removed.",
DeprecationWarning,
)
if media_type is None:
media_type = guess_type(filename or path)[0] or "text/plain"
self.media_type = media_type
self.background = background
self.init_headers(headers)
self.headers.setdefault("accept-ranges", "bytes")
if self.filename is not None:
content_disposition_filename = quote(self.filename)
if content_disposition_filename != self.filename:
content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
else:
content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
self.headers.setdefault("content-disposition", content_disposition)
self.stat_result = stat_result
if stat_result is not None:
self.set_stat_headers(stat_result)
def set_stat_headers(self, stat_result: os.stat_result) -> None:
content_length = str(stat_result.st_size)
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"'
self.headers.setdefault("content-length", content_length)
self.headers.setdefault("last-modified", last_modified)
self.headers.setdefault("etag", etag)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
send_header_only: bool = scope["method"].upper() == "HEAD"
if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
self.set_stat_headers(stat_result)
except FileNotFoundError:
raise RuntimeError(f"File at path {self.path} does not exist.")
else:
mode = stat_result.st_mode
if not stat.S_ISREG(mode):
raise RuntimeError(f"File at path {self.path} is not a file.")
else:
stat_result = self.stat_result
headers = Headers(scope=scope)
http_range = headers.get("range")
http_if_range = headers.get("if-range")
if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
await self._handle_simple(send, send_header_only)
else:
try:
ranges = self._parse_range_header(http_range, stat_result.st_size)
except MalformedRangeHeader as exc:
return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send)
except RangeNotSatisfiable as exc:
response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"})
return await response(scope, receive, send)
if len(ranges) == 1:
start, end = ranges[0]
await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only)
else:
await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only)
if self.background is not None:
await self.background()
async def _handle_simple(self, send: Send, send_header_only: bool) -> None:
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
more_body = len(chunk) == self.chunk_size
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
async def _handle_single_range(
self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
) -> None:
self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}"
self.headers["content-length"] = str(end - start)
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
await file.seek(start)
more_body = True
while more_body:
chunk = await file.read(min(self.chunk_size, end - start))
start += len(chunk)
more_body = len(chunk) == self.chunk_size and start < end
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
async def _handle_multiple_ranges(
self,
send: Send,
ranges: list[tuple[int, int]],
file_size: int,
send_header_only: bool,
) -> None:
# In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes).
boundary = token_hex(13)
content_length, header_generator = self.generate_multipart(
ranges, boundary, file_size, self.headers["content-type"]
)
self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}"
self.headers["content-length"] = str(content_length)
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
for start, end in ranges:
await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
await file.seek(start)
while start < end:
chunk = await file.read(min(self.chunk_size, end - start))
start += len(chunk)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"\n", "more_body": True})
await send(
{
"type": "http.response.body",
"body": f"\n--{boundary}--\n".encode("latin-1"),
"more_body": False,
}
)
def _should_use_range(self, http_if_range: str) -> bool:
return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"]
@staticmethod
def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]:
ranges: list[tuple[int, int]] = []
try:
units, range_ = http_range.split("=", 1)
except ValueError:
raise MalformedRangeHeader()
units = units.strip().lower()
if units != "bytes":
raise MalformedRangeHeader("Only support bytes range")
ranges = [
(
int(_[0]) if _[0] else file_size - int(_[1]),
int(_[1]) + 1 if _[0] and _[1] and int(_[1]) < file_size else file_size,
)
for _ in _RANGE_PATTERN.findall(range_)
if _ != ("", "")
]
if len(ranges) == 0:
raise MalformedRangeHeader("Range header: range must be requested")
if any(not (0 <= start < file_size) for start, _ in ranges):
raise RangeNotSatisfiable(file_size)
if any(start > end for start, end in ranges):
raise MalformedRangeHeader("Range header: start must be less than end")
if len(ranges) == 1:
return ranges
# Merge ranges
result: list[tuple[int, int]] = []
for start, end in ranges:
for p in range(len(result)):
p_start, p_end = result[p]
if start > p_end:
continue
elif end < p_start:
result.insert(p, (start, end)) # THIS IS NOT REACHED!
break
else:
result[p] = (min(start, p_start), max(end, p_end))
break
else:
result.append((start, end))
return result
def generate_multipart(
self,
ranges: typing.Sequence[tuple[int, int]],
boundary: str,
max_size: int,
content_type: str,
) -> tuple[int, typing.Callable[[int, int], bytes]]:
r"""
Multipart response headers generator.
```
--{boundary}\n
Content-Type: {content_type}\n
Content-Range: bytes {start}-{end-1}/{max_size}\n
\n
..........content...........\n
--{boundary}\n
Content-Type: {content_type}\n
Content-Range: bytes {start}-{end-1}/{max_size}\n
\n
..........content...........\n
--{boundary}--\n
```
"""
boundary_len = len(boundary)
static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size))
content_length = sum(
(len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers
+ (end - start) # Content
for start, end in ranges
) + (
5 + boundary_len # --boundary--\n
)
return (
content_length,
lambda start, end: (
f"--{boundary}\nContent-Type: {content_type}\nContent-Range: bytes {start}-{end - 1}/{max_size}\n\n"
).encode("latin-1"),
)

View File

@ -0,0 +1,874 @@
from __future__ import annotations
import contextlib
import functools
import inspect
import re
import traceback
import types
import typing
import warnings
from contextlib import asynccontextmanager
from enum import Enum
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import get_route_path, is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose
class NoMatchFound(Exception):
"""
Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
if no matching route exists.
"""
def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None:
params = ", ".join(list(path_params.keys()))
super().__init__(f'No route exists for name "{name}" and params "{params}".')
class Match(Enum):
NONE = 0
PARTIAL = 1
FULL = 2
def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
"""
Correctly determines if an object is a coroutine function,
including those wrapped in functools.partial objects.
"""
warnings.warn(
"iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
DeprecationWarning,
)
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.iscoroutinefunction(obj)
def request_response(
func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
f: typing.Callable[[Request], typing.Awaitable[Response]] = (
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive, send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = await f(request)
await response(scope, receive, send)
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
return app
def websocket_session(
func: typing.Callable[[WebSocket], typing.Awaitable[None]],
) -> ASGIApp:
"""
Takes a coroutine `func(session)`, and returns an ASGI application.
"""
# assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
await func(session)
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
return app
def get_name(endpoint: typing.Callable[..., typing.Any]) -> str:
return getattr(endpoint, "__name__", endpoint.__class__.__name__)
def replace_params(
path: str,
param_convertors: dict[str, Convertor[typing.Any]],
path_params: dict[str, str],
) -> tuple[str, dict[str, str]]:
for key, value in list(path_params.items()):
if "{" + key + "}" in path:
convertor = param_convertors[key]
value = convertor.to_string(value)
path = path.replace("{" + key + "}", value)
path_params.pop(key)
return path, path_params
# Match parameters in URL paths, eg. '{param}', and '{param:int}'
PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
def compile_path(
path: str,
) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]:
"""
Given a path string, like: "/{username:str}",
or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
of (regex, format, {param_name:convertor}).
regex: "/(?P<username>[^/]+)"
format: "/{username}"
convertors: {"username": StringConvertor()}
"""
is_host = not path.startswith("/")
path_regex = "^"
path_format = ""
duplicated_params = set()
idx = 0
param_convertors = {}
for match in PARAM_REGEX.finditer(path):
param_name, convertor_type = match.groups("str")
convertor_type = convertor_type.lstrip(":")
assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
convertor = CONVERTOR_TYPES[convertor_type]
path_regex += re.escape(path[idx : match.start()])
path_regex += f"(?P<{param_name}>{convertor.regex})"
path_format += path[idx : match.start()]
path_format += "{%s}" % param_name
if param_name in param_convertors:
duplicated_params.add(param_name)
param_convertors[param_name] = convertor
idx = match.end()
if duplicated_params:
names = ", ".join(sorted(duplicated_params))
ending = "s" if len(duplicated_params) > 1 else ""
raise ValueError(f"Duplicated param name{ending} {names} at path {path}")
if is_host:
# Align with `Host.matches()` behavior, which ignores port.
hostname = path[idx:].split(":")[0]
path_regex += re.escape(hostname) + "$"
else:
path_regex += re.escape(path[idx:]) + "$"
path_format += path[idx:]
return re.compile(path_regex), path_format, param_convertors
class BaseRoute:
def matches(self, scope: Scope) -> tuple[Match, Scope]:
raise NotImplementedError() # pragma: no cover
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
raise NotImplementedError() # pragma: no cover
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
raise NotImplementedError() # pragma: no cover
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
A route may be used in isolation as a stand-alone ASGI app.
This is a somewhat contrived case, as they'll almost always be used
within a Router, but could be useful for some tooling and minimal apps.
"""
match, child_scope = self.matches(scope)
if match == Match.NONE:
if scope["type"] == "http":
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)
elif scope["type"] == "websocket": # pragma: no branch
websocket_close = WebSocketClose()
await websocket_close(scope, receive, send)
return
scope.update(child_scope)
await self.handle(scope, receive, send)
class Route(BaseRoute):
def __init__(
self,
path: str,
endpoint: typing.Callable[..., typing.Any],
*,
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
middleware: typing.Sequence[Middleware] | None = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema
endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
endpoint_handler = endpoint_handler.func
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
# Endpoint is function or method. Treat it as `func(request) -> response`.
self.app = request_response(endpoint)
if methods is None:
methods = ["GET"]
else:
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(self.app, *args, **kwargs)
if methods is None:
self.methods = None
else:
self.methods = {method.upper() for method in methods}
if "GET" in self.methods:
self.methods.add("HEAD")
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, typing.Any]
if scope["type"] == "http":
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"endpoint": self.endpoint, "path_params": path_params}
if self.methods and scope["method"] not in self.methods:
return Match.PARTIAL, child_scope
else:
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="http")
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.methods and scope["method"] not in self.methods:
headers = {"Allow": ", ".join(self.methods)}
if "app" in scope:
raise HTTPException(status_code=405, headers=headers)
else:
response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, Route)
and self.path == other.path
and self.endpoint == other.endpoint
and self.methods == other.methods
)
def __repr__(self) -> str:
class_name = self.__class__.__name__
methods = sorted(self.methods or [])
path, name = self.path, self.name
return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})"
class WebSocketRoute(BaseRoute):
def __init__(
self,
path: str,
endpoint: typing.Callable[..., typing.Any],
*,
name: str | None = None,
middleware: typing.Sequence[Middleware] | None = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
endpoint_handler = endpoint_handler.func
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
# Endpoint is function or method. Treat it as `func(websocket)`.
self.app = websocket_session(endpoint)
else:
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(self.app, *args, **kwargs)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, typing.Any]
if scope["type"] == "websocket":
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"endpoint": self.endpoint, "path_params": path_params}
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="websocket")
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
class Mount(BaseRoute):
def __init__(
self,
path: str,
app: ASGIApp | None = None,
routes: typing.Sequence[BaseRoute] | None = None,
name: str | None = None,
*,
middleware: typing.Sequence[Middleware] | None = None,
) -> None:
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if app is not None:
self._base_app: ASGIApp = app
else:
self._base_app = Router(routes=routes)
self.app = self._base_app
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
@property
def routes(self) -> list[BaseRoute]:
return getattr(self._base_app, "routes", [])
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, typing.Any]
if scope["type"] in ("http", "websocket"): # pragma: no branch
root_path = scope.get("root_path", "")
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = route_path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {
"path_params": path_params,
# app_root_path will only be set at the top level scope,
# initialized with the (optional) value of a root_path
# set above/before Starlette. And even though any
# mount will have its own child scope with its own respective
# root_path, the app_root_path will always be available in all
# the child scopes with the same top level value because it's
# set only once here with a default, any other child scope will
# just inherit that app_root_path default value stored in the
# scope. All this is needed to support Request.url_for(), as it
# uses the app_root_path to build the URL path.
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path_params["path"] = path_params["path"].lstrip("/")
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path)
elif self.name is None or name.startswith(self.name + ":"):
if self.name is None:
# No mount name.
remaining_name = name
else:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = name[len(self.name) + 1 :]
path_kwarg = path_params.get("path")
path_params["path"] = ""
path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if path_kwarg is not None:
remaining_params["path"] = path_kwarg
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
except NoMatchFound:
pass
raise NoMatchFound(name, path_params)
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, Mount) and self.path == other.path and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
name = self.name or ""
return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})"
class Host(BaseRoute):
def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None:
assert not host.startswith("/"), "Host must not start with '/'"
self.host = host
self.app = app
self.name = name
self.host_regex, self.host_format, self.param_convertors = compile_path(host)
@property
def routes(self) -> list[BaseRoute]:
return getattr(self.app, "routes", [])
def matches(self, scope: Scope) -> tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"): # pragma:no branch
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"path_params": path_params, "endpoint": self.app}
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path = path_params.pop("path")
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path, host=host)
elif self.name is None or name.startswith(self.name + ":"):
if self.name is None:
# No mount name.
remaining_name = name
else:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = name[len(self.name) + 1 :]
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
return URLPath(path=str(url), protocol=url.protocol, host=host)
except NoMatchFound:
pass
raise NoMatchFound(name, path_params)
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, Host) and self.host == other.host and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
name = self.name or ""
return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
_T = typing.TypeVar("_T")
class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
def __init__(self, cm: typing.ContextManager[_T]):
self._cm = cm
async def __aenter__(self) -> _T:
return self._cm.__enter__()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
return self._cm.__exit__(exc_type, exc_value, traceback)
def _wrap_gen_lifespan_context(
lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]],
) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
cmgr = contextlib.contextmanager(lifespan_context)
@functools.wraps(cmgr)
def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]:
return _AsyncLiftContextManager(cmgr(app))
return wrapper
class _DefaultLifespan:
def __init__(self, router: Router):
self._router = router
async def __aenter__(self) -> None:
await self._router.startup()
async def __aexit__(self, *exc_info: object) -> None:
await self._router.shutdown()
def __call__(self: _T, app: object) -> _T:
return self
class Router:
def __init__(
self,
routes: typing.Sequence[BaseRoute] | None = None,
redirect_slashes: bool = True,
default: ASGIApp | None = None,
on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
# the generic to Lifespan[AppType] is the type of the top level application
# which the router cannot know statically, so we use typing.Any
lifespan: Lifespan[typing.Any] | None = None,
*,
middleware: typing.Sequence[Middleware] | None = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.default = self.not_found if default is None else default
self.on_startup = [] if on_startup is None else list(on_startup)
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
if on_startup or on_shutdown:
warnings.warn(
"The on_startup and on_shutdown parameters are deprecated, and they "
"will be removed on version 1.0. Use the lifespan parameter instead. "
"See more about it on https://www.starlette.io/lifespan/.",
DeprecationWarning,
)
if lifespan:
warnings.warn(
"The `lifespan` parameter cannot be used with `on_startup` or "
"`on_shutdown`. Both `on_startup` and `on_shutdown` will be "
"ignored."
)
if lifespan is None:
self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self)
elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
"async generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = asynccontextmanager(
lifespan,
)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
"generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = _wrap_gen_lifespan_context(
lifespan,
)
else:
self.lifespan_context = lifespan
self.middleware_stack = self.app
if middleware:
for cls, args, kwargs in reversed(middleware):
self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
websocket_close = WebSocketClose()
await websocket_close(scope, receive, send)
return
# If we're running inside a starlette application then raise an
# exception, so that the configurable exception handler can deal with
# returning the response. For plain ASGI apps, just return the response.
if "app" in scope:
raise HTTPException(status_code=404)
else:
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
for route in self.routes:
try:
return route.url_path_for(name, **path_params)
except NoMatchFound:
pass
raise NoMatchFound(name, path_params)
async def startup(self) -> None:
"""
Run any `.on_startup` event handlers.
"""
for handler in self.on_startup:
if is_async_callable(handler):
await handler()
else:
handler()
async def shutdown(self) -> None:
"""
Run any `.on_shutdown` event handlers.
"""
for handler in self.on_shutdown:
if is_async_callable(handler):
await handler()
else:
handler()
async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
started = False
app: typing.Any = scope.get("app")
await receive()
try:
async with self.lifespan_context(app) as maybe_state:
if maybe_state is not None:
if "state" not in scope:
raise RuntimeError('The server does not support "state" in the lifespan scope.')
scope["state"].update(maybe_state)
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
except BaseException:
exc_text = traceback.format_exc()
if started:
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:
await send({"type": "lifespan.shutdown.complete"})
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
The main entry point to the Router class.
"""
await self.middleware_stack(scope, receive, send)
async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] in ("http", "websocket", "lifespan")
if "router" not in scope:
scope["router"] = self
if scope["type"] == "lifespan":
await self.lifespan(scope, receive, send)
return
partial = None
for route in self.routes:
# Determine if any route matches the incoming scope,
# and hand over to the matching route if found.
match, child_scope = route.matches(scope)
if match == Match.FULL:
scope.update(child_scope)
await route.handle(scope, receive, send)
return
elif match == Match.PARTIAL and partial is None:
partial = route
partial_scope = child_scope
if partial is not None:
#  Handle partial matches. These are cases where an endpoint is
# able to handle the request, but is not a preferred option.
# We use this in particular to deal with "405 Method Not Allowed".
scope.update(partial_scope)
await partial.handle(scope, receive, send)
return
route_path = get_route_path(scope)
if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
redirect_scope = dict(scope)
if route_path.endswith("/"):
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
else:
redirect_scope["path"] = redirect_scope["path"] + "/"
for route in self.routes:
match, child_scope = route.matches(redirect_scope)
if match != Match.NONE:
redirect_url = URL(scope=redirect_scope)
response = RedirectResponse(url=str(redirect_url))
await response(scope, receive, send)
return
await self.default(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Mount(path, app=app, name=name)
self.routes.append(route)
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Host(host, app=app, name=name)
self.routes.append(route)
def add_route(
self,
path: str,
endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response],
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: no cover
route = Route(
path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
)
self.routes.append(route)
def add_websocket_route(
self,
path: str,
endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]],
name: str | None = None,
) -> None: # pragma: no cover
route = WebSocketRoute(path, endpoint=endpoint, name=name)
self.routes.append(route)
def route(
self,
path: str,
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
>>> routes = [Route(path, endpoint=...), ...]
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `route` decorator is deprecated, and will be removed in version 1.0.0."
"Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
)
return func
return decorator
def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
>>> routes = [WebSocketRoute(path, endpoint=...), ...]
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
"https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_websocket_route(path, func, name=name)
return func
return decorator
def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover
assert event_type in ("startup", "shutdown")
if event_type == "startup":
self.on_startup.append(func)
else:
self.on_shutdown.append(func)
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_event_handler(event_type, func)
return func
return decorator

View File

@ -0,0 +1,147 @@
from __future__ import annotations
import inspect
import re
import typing
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Host, Mount, Route
try:
import yaml
except ModuleNotFoundError: # pragma: no cover
yaml = None # type: ignore[assignment]
class OpenAPIResponse(Response):
media_type = "application/vnd.oai.openapi"
def render(self, content: typing.Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
class EndpointInfo(typing.NamedTuple):
path: str
http_method: str
func: typing.Callable[..., typing.Any]
_remove_converter_pattern = re.compile(r":\w+}")
class BaseSchemaGenerator:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
raise NotImplementedError() # pragma: no cover
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
"""
Given the routes, yields the following information:
- path
eg: /users/
- http_method
one of 'get', 'post', 'put', 'patch', 'delete', 'options'
- func
method ready to extract the docstring
"""
endpoints_info: list[EndpointInfo] = []
for route in routes:
if isinstance(route, (Mount, Host)):
routes = route.routes or []
if isinstance(route, Mount):
path = self._remove_converter(route.path)
else:
path = ""
sub_endpoints = [
EndpointInfo(
path="".join((path, sub_endpoint.path)),
http_method=sub_endpoint.http_method,
func=sub_endpoint.func,
)
for sub_endpoint in self.get_endpoints(routes)
]
endpoints_info.extend(sub_endpoints)
elif not isinstance(route, Route) or not route.include_in_schema:
continue
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
path = self._remove_converter(route.path)
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
if not hasattr(route.endpoint, method):
continue
func = getattr(route.endpoint, method)
endpoints_info.append(EndpointInfo(path, method.lower(), func))
return endpoints_info
def _remove_converter(self, path: str) -> str:
"""
Remove the converter from the path.
For example, a route like this:
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
Should be represented as `/users/{id}` in the OpenAPI schema.
"""
return _remove_converter_pattern.sub("}", path)
def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
docstring = func_or_method.__doc__
if not docstring:
return {}
assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
# We support having regular docstrings before the schema
# definition. Here we return just the schema part from
# the docstring.
docstring = docstring.split("---")[-1]
parsed = yaml.safe_load(docstring)
if not isinstance(parsed, dict):
# A regular docstring (not yaml formatted) can return
# a simple string here, which wouldn't follow the schema.
return {}
return parsed
def OpenAPIResponse(self, request: Request) -> Response:
routes = request.app.routes
schema = self.get_schema(routes=routes)
return OpenAPIResponse(schema)
class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, base_schema: dict[str, typing.Any]) -> None:
self.base_schema = base_schema
def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
schema = dict(self.base_schema)
schema.setdefault("paths", {})
endpoints_info = self.get_endpoints(routes)
for endpoint in endpoints_info:
parsed = self.parse_docstring(endpoint.func)
if not parsed:
continue
if endpoint.path not in schema["paths"]:
schema["paths"][endpoint.path] = {}
schema["paths"][endpoint.path][endpoint.http_method] = parsed
return schema

View File

@ -0,0 +1,220 @@
from __future__ import annotations
import errno
import importlib.util
import os
import stat
import typing
from email.utils import parsedate
import anyio
import anyio.to_thread
from starlette._utils import get_route_path
from starlette.datastructures import URL, Headers
from starlette.exceptions import HTTPException
from starlette.responses import FileResponse, RedirectResponse, Response
from starlette.types import Receive, Scope, Send
PathLike = typing.Union[str, "os.PathLike[str]"]
class NotModifiedResponse(Response):
NOT_MODIFIED_HEADERS = (
"cache-control",
"content-location",
"date",
"etag",
"expires",
"vary",
)
def __init__(self, headers: Headers):
super().__init__(
status_code=304,
headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
)
class StaticFiles:
def __init__(
self,
*,
directory: PathLike | None = None,
packages: list[str | tuple[str, str]] | None = None,
html: bool = False,
check_dir: bool = True,
follow_symlink: bool = False,
) -> None:
self.directory = directory
self.packages = packages
self.all_directories = self.get_directories(directory, packages)
self.html = html
self.config_checked = False
self.follow_symlink = follow_symlink
if check_dir and directory is not None and not os.path.isdir(directory):
raise RuntimeError(f"Directory '{directory}' does not exist")
def get_directories(
self,
directory: PathLike | None = None,
packages: list[str | tuple[str, str]] | None = None,
) -> list[PathLike]:
"""
Given `directory` and `packages` arguments, return a list of all the
directories that should be used for serving static files from.
"""
directories = []
if directory is not None:
directories.append(directory)
for package in packages or []:
if isinstance(package, tuple):
package, statics_dir = package
else:
statics_dir = "statics"
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
assert os.path.isdir(package_directory), (
f"Directory '{statics_dir!r}' in package {package!r} could not be found."
)
directories.append(package_directory)
return directories
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
The ASGI entry point.
"""
assert scope["type"] == "http"
if not self.config_checked:
await self.check_config()
self.config_checked = True
path = self.get_path(scope)
response = await self.get_response(path, scope)
await response(scope, receive, send)
def get_path(self, scope: Scope) -> str:
"""
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
route_path = get_route_path(scope)
return os.path.normpath(os.path.join(*route_path.split("/")))
async def get_response(self, path: str, scope: Scope) -> Response:
"""
Returns an HTTP response, given the incoming path, method and request headers.
"""
if scope["method"] not in ("GET", "HEAD"):
raise HTTPException(status_code=405)
try:
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
except PermissionError:
raise HTTPException(status_code=401)
except OSError as exc:
# Filename is too long, so it can't be a valid static file.
if exc.errno == errno.ENAMETOOLONG:
raise HTTPException(status_code=404)
raise exc
if stat_result and stat.S_ISREG(stat_result.st_mode):
# We have a static file to serve.
return self.file_response(full_path, stat_result, scope)
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = os.path.join(path, "index.html")
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
if not scope["path"].endswith("/"):
# Directory URLs should redirect to always end in "/".
url = URL(scope=scope)
url = url.replace(path=url.path + "/")
return RedirectResponse(url=url)
return self.file_response(full_path, stat_result, scope)
if self.html:
# Check for '404.html' if we're in HTML mode.
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
if stat_result and stat.S_ISREG(stat_result.st_mode):
return FileResponse(full_path, stat_result=stat_result, status_code=404)
raise HTTPException(status_code=404)
def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
for directory in self.all_directories:
joined_path = os.path.join(directory, path)
if self.follow_symlink:
full_path = os.path.abspath(joined_path)
directory = os.path.abspath(directory)
else:
full_path = os.path.realpath(joined_path)
directory = os.path.realpath(directory)
if os.path.commonpath([full_path, directory]) != str(directory):
# Don't allow misbehaving clients to break out of the static files directory.
continue
try:
return full_path, os.stat(full_path)
except (FileNotFoundError, NotADirectoryError):
continue
return "", None
def file_response(
self,
full_path: PathLike,
stat_result: os.stat_result,
scope: Scope,
status_code: int = 200,
) -> Response:
request_headers = Headers(scope=scope)
response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
if self.is_not_modified(response.headers, request_headers):
return NotModifiedResponse(response.headers)
return response
async def check_config(self) -> None:
"""
Perform a one-off configuration check that StaticFiles is actually
pointed at a directory, so that we can raise loud errors rather than
just returning 404 responses.
"""
if self.directory is None:
return
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
except FileNotFoundError:
raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
"""
Given the request and response headers, return `True` if an HTTP
"Not Modified" response could be returned instead.
"""
try:
if_none_match = request_headers["if-none-match"]
etag = response_headers["etag"]
if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]:
return True
except KeyError:
pass
try:
if_modified_since = parsedate(request_headers["if-modified-since"])
last_modified = parsedate(response_headers["last-modified"])
if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
return True
except KeyError:
pass
return False

View File

@ -0,0 +1,95 @@
"""
HTTP codes
See HTTP Status Code Registry:
https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
And RFC 2324 - https://tools.ietf.org/html/rfc2324
"""
from __future__ import annotations
HTTP_100_CONTINUE = 100
HTTP_101_SWITCHING_PROTOCOLS = 101
HTTP_102_PROCESSING = 102
HTTP_103_EARLY_HINTS = 103
HTTP_200_OK = 200
HTTP_201_CREATED = 201
HTTP_202_ACCEPTED = 202
HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203
HTTP_204_NO_CONTENT = 204
HTTP_205_RESET_CONTENT = 205
HTTP_206_PARTIAL_CONTENT = 206
HTTP_207_MULTI_STATUS = 207
HTTP_208_ALREADY_REPORTED = 208
HTTP_226_IM_USED = 226
HTTP_300_MULTIPLE_CHOICES = 300
HTTP_301_MOVED_PERMANENTLY = 301
HTTP_302_FOUND = 302
HTTP_303_SEE_OTHER = 303
HTTP_304_NOT_MODIFIED = 304
HTTP_305_USE_PROXY = 305
HTTP_306_RESERVED = 306
HTTP_307_TEMPORARY_REDIRECT = 307
HTTP_308_PERMANENT_REDIRECT = 308
HTTP_400_BAD_REQUEST = 400
HTTP_401_UNAUTHORIZED = 401
HTTP_402_PAYMENT_REQUIRED = 402
HTTP_403_FORBIDDEN = 403
HTTP_404_NOT_FOUND = 404
HTTP_405_METHOD_NOT_ALLOWED = 405
HTTP_406_NOT_ACCEPTABLE = 406
HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407
HTTP_408_REQUEST_TIMEOUT = 408
HTTP_409_CONFLICT = 409
HTTP_410_GONE = 410
HTTP_411_LENGTH_REQUIRED = 411
HTTP_412_PRECONDITION_FAILED = 412
HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413
HTTP_414_REQUEST_URI_TOO_LONG = 414
HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415
HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416
HTTP_417_EXPECTATION_FAILED = 417
HTTP_418_IM_A_TEAPOT = 418
HTTP_421_MISDIRECTED_REQUEST = 421
HTTP_422_UNPROCESSABLE_ENTITY = 422
HTTP_423_LOCKED = 423
HTTP_424_FAILED_DEPENDENCY = 424
HTTP_425_TOO_EARLY = 425
HTTP_426_UPGRADE_REQUIRED = 426
HTTP_428_PRECONDITION_REQUIRED = 428
HTTP_429_TOO_MANY_REQUESTS = 429
HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431
HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451
HTTP_500_INTERNAL_SERVER_ERROR = 500
HTTP_501_NOT_IMPLEMENTED = 501
HTTP_502_BAD_GATEWAY = 502
HTTP_503_SERVICE_UNAVAILABLE = 503
HTTP_504_GATEWAY_TIMEOUT = 504
HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505
HTTP_506_VARIANT_ALSO_NEGOTIATES = 506
HTTP_507_INSUFFICIENT_STORAGE = 507
HTTP_508_LOOP_DETECTED = 508
HTTP_510_NOT_EXTENDED = 510
HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511
"""
WebSocket codes
https://www.iana.org/assignments/websocket/websocket.xml#close-code-number
https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent
"""
WS_1000_NORMAL_CLOSURE = 1000
WS_1001_GOING_AWAY = 1001
WS_1002_PROTOCOL_ERROR = 1002
WS_1003_UNSUPPORTED_DATA = 1003
WS_1005_NO_STATUS_RCVD = 1005
WS_1006_ABNORMAL_CLOSURE = 1006
WS_1007_INVALID_FRAME_PAYLOAD_DATA = 1007
WS_1008_POLICY_VIOLATION = 1008
WS_1009_MESSAGE_TOO_BIG = 1009
WS_1010_MANDATORY_EXT = 1010
WS_1011_INTERNAL_ERROR = 1011
WS_1012_SERVICE_RESTART = 1012
WS_1013_TRY_AGAIN_LATER = 1013
WS_1014_BAD_GATEWAY = 1014
WS_1015_TLS_HANDSHAKE = 1015

View File

@ -0,0 +1,216 @@
from __future__ import annotations
import typing
import warnings
from os import PathLike
from starlette.background import BackgroundTask
from starlette.datastructures import URL
from starlette.requests import Request
from starlette.responses import HTMLResponse
from starlette.types import Receive, Scope, Send
try:
import jinja2
# @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1
# hence we try to get pass_context (most installs will be >=3.1)
# and fall back to contextfunction,
# adding a type ignore for mypy to let us access an attribute that may not exist
if hasattr(jinja2, "pass_context"):
pass_context = jinja2.pass_context
else: # pragma: no cover
pass_context = jinja2.contextfunction # type: ignore[attr-defined]
except ModuleNotFoundError: # pragma: no cover
jinja2 = None # type: ignore[assignment]
class _TemplateResponse(HTMLResponse):
def __init__(
self,
template: typing.Any,
context: dict[str, typing.Any],
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
):
self.template = template
self.context = context
content = template.render(context)
super().__init__(content, status_code, headers, media_type, background)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = self.context.get("request", {})
extensions = request.get("extensions", {})
if "http.response.debug" in extensions: # pragma: no branch
await send(
{
"type": "http.response.debug",
"info": {
"template": self.template,
"context": self.context,
},
}
)
await super().__call__(scope, receive, send)
class Jinja2Templates:
"""
templates = Jinja2Templates("templates")
return templates.TemplateResponse("index.html", {"request": request})
"""
@typing.overload
def __init__(
self,
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
*,
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
**env_options: typing.Any,
) -> None: ...
@typing.overload
def __init__(
self,
*,
env: jinja2.Environment,
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
) -> None: ...
def __init__(
self,
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None,
*,
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
env: jinja2.Environment | None = None,
**env_options: typing.Any,
) -> None:
if env_options:
warnings.warn(
"Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
DeprecationWarning,
)
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
self.context_processors = context_processors or []
if directory is not None:
self.env = self._create_env(directory, **env_options)
elif env is not None: # pragma: no branch
self.env = env
self._setup_env_defaults(self.env)
def _create_env(
self,
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
**env_options: typing.Any,
) -> jinja2.Environment:
loader = jinja2.FileSystemLoader(directory)
env_options.setdefault("loader", loader)
env_options.setdefault("autoescape", True)
return jinja2.Environment(**env_options)
def _setup_env_defaults(self, env: jinja2.Environment) -> None:
@pass_context
def url_for(
context: dict[str, typing.Any],
name: str,
/,
**path_params: typing.Any,
) -> URL:
request: Request = context["request"]
return request.url_for(name, **path_params)
env.globals.setdefault("url_for", url_for)
def get_template(self, name: str) -> jinja2.Template:
return self.env.get_template(name)
@typing.overload
def TemplateResponse(
self,
request: Request,
name: str,
context: dict[str, typing.Any] | None = None,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse: ...
@typing.overload
def TemplateResponse(
self,
name: str,
context: dict[str, typing.Any] | None = None,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse:
# Deprecated usage
...
def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse:
if args:
if isinstance(args[0], str): # the first argument is template name (old style)
warnings.warn(
"The `name` is not the first parameter anymore. "
"The first parameter should be the `Request` instance.\n"
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
name = args[0]
context = args[1] if len(args) > 1 else kwargs.get("context", {})
status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
headers = args[2] if len(args) > 2 else kwargs.get("headers")
media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
background = args[4] if len(args) > 4 else kwargs.get("background")
if "request" not in context:
raise ValueError('context must include a "request" key')
request = context["request"]
else: # the first argument is a request instance (new style)
request = args[0]
name = args[1] if len(args) > 1 else kwargs["name"]
context = args[2] if len(args) > 2 else kwargs.get("context", {})
status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
headers = args[4] if len(args) > 4 else kwargs.get("headers")
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
background = args[6] if len(args) > 6 else kwargs.get("background")
else: # all arguments are kwargs
if "request" not in kwargs:
warnings.warn(
"The `TemplateResponse` now requires the `request` argument.\n"
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
if "request" not in kwargs.get("context", {}):
raise ValueError('context must include a "request" key')
context = kwargs.get("context", {})
request = kwargs.get("request", context.get("request"))
name = typing.cast(str, kwargs["name"])
status_code = kwargs.get("status_code", 200)
headers = kwargs.get("headers")
media_type = kwargs.get("media_type")
background = kwargs.get("background")
context.setdefault("request", request)
for context_processor in self.context_processors:
context.update(context_processor(request))
template = self.get_template(name)
return _TemplateResponse(
template,
context,
status_code=status_code,
headers=headers,
media_type=media_type,
background=background,
)

View File

@ -0,0 +1,731 @@
from __future__ import annotations
import contextlib
import inspect
import io
import json
import math
import sys
import typing
import warnings
from concurrent.futures import Future
from types import GeneratorType
from urllib.parse import unquote, urljoin
import anyio
import anyio.abc
import anyio.from_thread
from anyio.streams.stapled import StapledObjectStream
from starlette._utils import is_async_callable
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect
if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
else: # pragma: no cover
from typing_extensions import TypeGuard
try:
import httpx
except ModuleNotFoundError: # pragma: no cover
raise RuntimeError(
"The starlette.testclient module requires the httpx package to be installed.\n"
"You can install this with:\n"
" $ pip install httpx\n"
)
_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
ASGI2App = typing.Callable[[Scope], ASGIInstance]
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
if inspect.isclass(app):
return hasattr(app, "__await__")
return is_async_callable(app)
class _WrapASGI2:
"""
Provide an ASGI3 interface onto an ASGI2 app.
"""
def __init__(self, app: ASGI2App) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
instance = self.app(scope)
await instance(receive, send)
class _AsyncBackend(typing.TypedDict):
backend: str
backend_options: dict[str, typing.Any]
class _Upgrade(Exception):
def __init__(self, session: WebSocketTestSession) -> None:
self.session = session
class WebSocketDenialResponse( # type: ignore[misc]
httpx.Response,
WebSocketDisconnect,
):
"""
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
`WebSocket` is closed before being accepted with a `send_denial_response()`.
"""
class WebSocketTestSession:
def __init__(
self,
app: ASGI3App,
scope: Scope,
portal_factory: _PortalFactoryType,
) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self.extra_headers = None
def __enter__(self) -> WebSocketTestSession:
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())
fut, cs = portal.start_task(self._run)
stack.callback(fut.result)
stack.callback(portal.call, cs.cancel)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self
def __exit__(self, *args: typing.Any) -> bool | None:
return self.exit_stack.__exit__(*args)
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""
send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
send_tx, send_rx = send
receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
receive_tx, receive_rx = receive
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
self._receive_tx = receive_tx
self._send_rx = send_rx
task_status.started(cs)
await self.app(self.scope, receive_rx.receive, send_tx.send)
# wait for cs.cancel to be called before closing streams
await anyio.sleep_forever()
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
elif message["type"] == "websocket.http.response.start":
status_code: int = message["status"]
headers: list[tuple[bytes, bytes]] = message["headers"]
body: list[bytes] = []
while True:
message = self.receive()
assert message["type"] == "websocket.http.response.body"
body.append(message["body"])
if not message.get("more_body", False):
break
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
def send(self, message: Message) -> None:
self.portal.call(self._receive_tx.send, message)
def send_text(self, data: str) -> None:
self.send({"type": "websocket.receive", "text": data})
def send_bytes(self, data: bytes) -> None:
self.send({"type": "websocket.receive", "bytes": data})
def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
self.send({"type": "websocket.receive", "text": text})
else:
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
def close(self, code: int = 1000, reason: str | None = None) -> None:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
def receive(self) -> Message:
return self.portal.call(self._send_rx.receive)
def receive_text(self) -> str:
message = self.receive()
self._raise_on_close(message)
return typing.cast(str, message["text"])
def receive_bytes(self) -> bytes:
message = self.receive()
self._raise_on_close(message)
return typing.cast(bytes, message["bytes"])
def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
message = self.receive()
self._raise_on_close(message)
if mode == "text":
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
return json.loads(text)
class _TestClientTransport(httpx.BaseTransport):
def __init__(
self,
app: ASGI3App,
portal_factory: _PortalFactoryType,
raise_server_exceptions: bool = True,
root_path: str = "",
*,
client: tuple[str, int],
app_state: dict[str, typing.Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state
self.client = client
def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
netloc = request.url.netloc.decode(encoding="ascii")
path = request.url.path
raw_path = request.url.raw_path
query = request.url.query.decode(encoding="ascii")
default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
if ":" in netloc:
host, port_string = netloc.split(":", 1)
port = int(port_string)
else:
host = netloc
port = default_port
# Include the 'host' header.
if "host" in request.headers:
headers: list[tuple[bytes, bytes]] = []
elif port == default_port: # pragma: no cover
headers = [(b"host", host.encode())]
else: # pragma: no cover
headers = [(b"host", (f"{host}:{port}").encode())]
# Include other request headers.
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
scope: dict[str, typing.Any]
if scheme in {"ws", "wss"}:
subprotocol = request.headers.get("sec-websocket-protocol", None)
if subprotocol is None:
subprotocols: typing.Sequence[str] = []
else:
subprotocols = [value.strip() for value in subprotocol.split(",")]
scope = {
"type": "websocket",
"path": unquote(path),
"raw_path": raw_path.split(b"?", 1)[0],
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": self.client,
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
scope = {
"type": "http",
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
"raw_path": raw_path.split(b"?", 1)[0],
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": self.client,
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
}
request_complete = False
response_started = False
response_complete: anyio.Event
raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
template = None
context = None
async def receive() -> Message:
nonlocal request_complete
if request_complete:
if not response_complete.is_set():
await response_complete.wait()
return {"type": "http.disconnect"}
body = request.read()
if isinstance(body, str):
body_bytes: bytes = body.encode("utf-8") # pragma: no cover
elif body is None:
body_bytes = b"" # pragma: no cover
elif isinstance(body, GeneratorType):
try: # pragma: no cover
chunk = body.send(None)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
return {"type": "http.request", "body": chunk, "more_body": True}
except StopIteration: # pragma: no cover
request_complete = True
return {"type": "http.request", "body": b""}
else:
body_bytes = body
request_complete = True
return {"type": "http.request", "body": body_bytes}
async def send(message: Message) -> None:
nonlocal raw_kwargs, response_started, template, context
if message["type"] == "http.response.start":
assert not response_started, 'Received multiple "http.response.start" messages.'
raw_kwargs["status_code"] = message["status"]
raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
response_started = True
elif message["type"] == "http.response.body":
assert response_started, 'Received "http.response.body" without "http.response.start".'
assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
raw_kwargs["stream"].write(body)
if not more_body:
raw_kwargs["stream"].seek(0)
response_complete.set()
elif message["type"] == "http.response.debug":
template = message["info"]["template"]
context = message["info"]["context"]
try:
with self.portal_factory() as portal:
response_complete = portal.call(anyio.Event)
portal.call(self.app, scope, receive, send)
except BaseException as exc:
if self.raise_server_exceptions:
raise exc
if self.raise_server_exceptions:
assert response_started, "TestClient did not receive any response."
elif not response_started:
raw_kwargs = {
"status_code": 500,
"headers": [],
"stream": io.BytesIO(),
}
raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
response = httpx.Response(**raw_kwargs, request=request)
if template is not None:
response.template = template # type: ignore[attr-defined]
response.context = context # type: ignore[attr-defined]
return response
class TestClient(httpx.Client):
__test__ = False
task: Future[None]
portal: anyio.abc.BlockingPortal | None = None
def __init__(
self,
app: ASGIApp,
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
backend: typing.Literal["asyncio", "trio"] = "asyncio",
backend_options: dict[str, typing.Any] | None = None,
cookies: httpx._types.CookieTypes | None = None,
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
client: tuple[str, int] = ("testclient", 50000),
) -> None:
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
if _is_asgi3(app):
asgi_app = app
else:
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
self.app = asgi_app
self.app_state: dict[str, typing.Any] = {}
transport = _TestClientTransport(
self.app,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
client=client,
)
if headers is None:
headers = {}
headers.setdefault("user-agent", "testclient")
super().__init__(
base_url=base_url,
headers=headers,
transport=transport,
follow_redirects=follow_redirects,
cookies=cookies,
)
@contextlib.contextmanager
def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
if self.portal is not None:
yield self.portal
else:
with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
yield portal
def request( # type: ignore[override]
self,
method: str,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: typing.Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
if timeout is not httpx.USE_CLIENT_DEFAULT:
warnings.warn(
"You should not use the 'timeout' argument with the TestClient. "
"See https://github.com/encode/starlette/issues/1108 for more information.",
DeprecationWarning,
)
url = self._merge_url(url)
return super().request(
method,
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def get( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
return super().get(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def options( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
return super().options(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def head( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
return super().head(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def post( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: typing.Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
return super().post(
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def put( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: typing.Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
return super().put(
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def patch( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: typing.Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
return super().patch(
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def delete( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
return super().delete(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def websocket_connect(
self,
url: str,
subprotocols: typing.Sequence[str] | None = None,
**kwargs: typing.Any,
) -> WebSocketTestSession:
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
headers.setdefault("sec-websocket-key", "testserver==")
headers.setdefault("sec-websocket-version", "13")
if subprotocols is not None:
headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
kwargs["headers"] = headers
try:
super().request("GET", url, **kwargs)
except _Upgrade as exc:
session = exc.session
else:
raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
return session
def __enter__(self) -> TestClient:
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
@stack.callback
def reset_portal() -> None:
self.portal = None
send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = (
anyio.create_memory_object_stream(math.inf)
)
receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = (
anyio.create_memory_object_stream(math.inf)
)
for channel in (*send, *receive):
stack.callback(channel.close)
self.stream_send = StapledObjectStream(*send)
self.stream_receive = StapledObjectStream(*receive)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)
@stack.callback
def wait_shutdown() -> None:
portal.call(self.wait_shutdown)
self.exit_stack = stack.pop_all()
return self
def __exit__(self, *args: typing.Any) -> None:
self.exit_stack.close()
async def lifespan(self) -> None:
scope = {"type": "lifespan", "state": self.app_state}
try:
await self.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
await self.stream_send.send(None)
async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message
message = await receive()
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
await receive()
async def wait_shutdown(self) -> None:
async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()

View File

@ -0,0 +1,24 @@
import typing
if typing.TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response
from starlette.websockets import WebSocket
AppType = typing.TypeVar("AppType")
Scope = typing.MutableMapping[str, typing.Any]
Message = typing.MutableMapping[str, typing.Any]
Receive = typing.Callable[[], typing.Awaitable[Message]]
Send = typing.Callable[[Message], typing.Awaitable[None]]
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]]
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"]
WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]]
ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]

View File

@ -0,0 +1,195 @@
from __future__ import annotations
import enum
import json
import typing
from starlette.requests import HTTPConnection
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send
class WebSocketState(enum.Enum):
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
RESPONSE = 3
class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
class WebSocket(HTTPConnection):
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
super().__init__(scope)
assert scope["type"] == "websocket"
self._receive = receive
self._send = send
self.client_state = WebSocketState.CONNECTING
self.application_state = WebSocketState.CONNECTING
async def receive(self) -> Message:
"""
Receive ASGI websocket messages, ensuring valid state transitions.
"""
if self.client_state == WebSocketState.CONNECTING:
message = await self._receive()
message_type = message["type"]
if message_type != "websocket.connect":
raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
message = await self._receive()
message_type = message["type"]
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
else:
raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
async def send(self, message: Message) -> None:
"""
Send ASGI websocket messages, ensuring valid state transitions.
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
raise RuntimeError(
'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
elif message_type == "websocket.http.response.start":
self.application_state = WebSocketState.RESPONSE
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
elif self.application_state == WebSocketState.CONNECTED:
message_type = message["type"]
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
try:
await self._send(message)
except OSError:
self.application_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(code=1006)
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
async def accept(
self,
subprotocol: str | None = None,
headers: typing.Iterable[tuple[bytes, bytes]] | None = None,
) -> None:
headers = headers or []
if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
raise WebSocketDisconnect(message["code"], message.get("reason"))
async def receive_text(self) -> str:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return typing.cast(str, message["text"])
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return typing.cast(bytes, message["bytes"])
async def receive_json(self, mode: str = "text") -> typing.Any:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
if mode == "text":
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
return json.loads(text)
async def iter_text(self) -> typing.AsyncIterator[str]:
try:
while True:
yield await self.receive_text()
except WebSocketDisconnect:
pass
async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
try:
while True:
yield await self.receive_bytes()
except WebSocketDisconnect:
pass
async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
try:
while True:
yield await self.receive_json()
except WebSocketDisconnect:
pass
async def send_text(self, data: str) -> None:
await self.send({"type": "websocket.send", "text": data})
async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
await self.send({"type": "websocket.send", "text": text})
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
async def close(self, code: int = 1000, reason: str | None = None) -> None:
await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
async def send_denial_response(self, response: Response) -> None:
if "websocket.http.response" in self.scope.get("extensions", {}):
await response(self.scope, self.receive, self.send)
else:
raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
class WebSocketClose:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "websocket.close", "code": self.code, "reason": self.reason})