Update 2025-04-13_16:25:39
This commit is contained in:
5
venv/lib/python3.11/site-packages/uvicorn/__init__.py
Normal file
5
venv/lib/python3.11/site-packages/uvicorn/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.main import Server, main, run
|
||||
|
||||
__version__ = "0.34.0"
|
||||
__all__ = ["main", "run", "Config", "Server"]
|
4
venv/lib/python3.11/site-packages/uvicorn/__main__.py
Normal file
4
venv/lib/python3.11/site-packages/uvicorn/__main__.py
Normal file
@ -0,0 +1,4 @@
|
||||
import uvicorn
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.main()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
84
venv/lib/python3.11/site-packages/uvicorn/_subprocess.py
Normal file
84
venv/lib/python3.11/site-packages/uvicorn/_subprocess.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""
|
||||
Some light wrappers around Python's multiprocessing, to deal with cleanly
|
||||
starting child processes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
from multiprocessing.context import SpawnProcess
|
||||
from socket import socket
|
||||
from typing import Callable
|
||||
|
||||
from uvicorn.config import Config
|
||||
|
||||
multiprocessing.allow_connection_pickling()
|
||||
spawn = multiprocessing.get_context("spawn")
|
||||
|
||||
|
||||
def get_subprocess(
|
||||
config: Config,
|
||||
target: Callable[..., None],
|
||||
sockets: list[socket],
|
||||
) -> SpawnProcess:
|
||||
"""
|
||||
Called in the parent process, to instantiate a new child process instance.
|
||||
The child is not yet started at this point.
|
||||
|
||||
* config - The Uvicorn configuration instance.
|
||||
* target - A callable that accepts a list of sockets. In practice this will
|
||||
be the `Server.run()` method.
|
||||
* sockets - A list of sockets to pass to the server. Sockets are bound once
|
||||
by the parent process, and then passed to the child processes.
|
||||
"""
|
||||
# We pass across the stdin fileno, and reopen it in the child process.
|
||||
# This is required for some debugging environments.
|
||||
try:
|
||||
stdin_fileno = sys.stdin.fileno()
|
||||
# The `sys.stdin` can be `None`, see https://docs.python.org/3/library/sys.html#sys.__stdin__.
|
||||
except (AttributeError, OSError):
|
||||
stdin_fileno = None
|
||||
|
||||
kwargs = {
|
||||
"config": config,
|
||||
"target": target,
|
||||
"sockets": sockets,
|
||||
"stdin_fileno": stdin_fileno,
|
||||
}
|
||||
|
||||
return spawn.Process(target=subprocess_started, kwargs=kwargs)
|
||||
|
||||
|
||||
def subprocess_started(
|
||||
config: Config,
|
||||
target: Callable[..., None],
|
||||
sockets: list[socket],
|
||||
stdin_fileno: int | None,
|
||||
) -> None:
|
||||
"""
|
||||
Called when the child process starts.
|
||||
|
||||
* config - The Uvicorn configuration instance.
|
||||
* target - A callable that accepts a list of sockets. In practice this will
|
||||
be the `Server.run()` method.
|
||||
* sockets - A list of sockets to pass to the server. Sockets are bound once
|
||||
by the parent process, and then passed to the child processes.
|
||||
* stdin_fileno - The file number of sys.stdin, so that it can be reattached
|
||||
to the child process.
|
||||
"""
|
||||
# Re-open stdin.
|
||||
if stdin_fileno is not None:
|
||||
sys.stdin = os.fdopen(stdin_fileno) # pragma: full coverage
|
||||
|
||||
# Logging needs to be setup again for each child.
|
||||
config.configure_logging()
|
||||
|
||||
try:
|
||||
# Now we can call into `Server.run(sockets=sockets)`
|
||||
target(sockets=sockets)
|
||||
except KeyboardInterrupt: # pragma: no cover
|
||||
# supress the exception to avoid a traceback from subprocess.Popen
|
||||
# the parent already expects us to end, so no vital information is lost
|
||||
pass
|
281
venv/lib/python3.11/site-packages/uvicorn/_types.py
Normal file
281
venv/lib/python3.11/site-packages/uvicorn/_types.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""
|
||||
Copyright (c) Django Software Foundation and individual contributors.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of Django nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Awaitable, Iterable, MutableMapping
|
||||
from typing import Any, Callable, Literal, Optional, Protocol, TypedDict, Union
|
||||
|
||||
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
||||
from typing import NotRequired
|
||||
else: # pragma: py-gte-311
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
# WSGI
|
||||
Environ = MutableMapping[str, Any]
|
||||
ExcInfo = tuple[type[BaseException], BaseException, Optional[types.TracebackType]]
|
||||
StartResponse = Callable[[str, Iterable[tuple[str, str]], Optional[ExcInfo]], None]
|
||||
WSGIApp = Callable[[Environ, StartResponse], Union[Iterable[bytes], BaseException]]
|
||||
|
||||
|
||||
# ASGI
|
||||
class ASGIVersions(TypedDict):
|
||||
spec_version: str
|
||||
version: Literal["2.0"] | Literal["3.0"]
|
||||
|
||||
|
||||
class HTTPScope(TypedDict):
|
||||
type: Literal["http"]
|
||||
asgi: ASGIVersions
|
||||
http_version: str
|
||||
method: str
|
||||
scheme: str
|
||||
path: str
|
||||
raw_path: bytes
|
||||
query_string: bytes
|
||||
root_path: str
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
client: tuple[str, int] | None
|
||||
server: tuple[str, int | None] | None
|
||||
state: NotRequired[dict[str, Any]]
|
||||
extensions: NotRequired[dict[str, dict[object, object]]]
|
||||
|
||||
|
||||
class WebSocketScope(TypedDict):
|
||||
type: Literal["websocket"]
|
||||
asgi: ASGIVersions
|
||||
http_version: str
|
||||
scheme: str
|
||||
path: str
|
||||
raw_path: bytes
|
||||
query_string: bytes
|
||||
root_path: str
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
client: tuple[str, int] | None
|
||||
server: tuple[str, int | None] | None
|
||||
subprotocols: Iterable[str]
|
||||
state: NotRequired[dict[str, Any]]
|
||||
extensions: NotRequired[dict[str, dict[object, object]]]
|
||||
|
||||
|
||||
class LifespanScope(TypedDict):
|
||||
type: Literal["lifespan"]
|
||||
asgi: ASGIVersions
|
||||
state: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
WWWScope = Union[HTTPScope, WebSocketScope]
|
||||
Scope = Union[HTTPScope, WebSocketScope, LifespanScope]
|
||||
|
||||
|
||||
class HTTPRequestEvent(TypedDict):
|
||||
type: Literal["http.request"]
|
||||
body: bytes
|
||||
more_body: bool
|
||||
|
||||
|
||||
class HTTPResponseDebugEvent(TypedDict):
|
||||
type: Literal["http.response.debug"]
|
||||
info: dict[str, object]
|
||||
|
||||
|
||||
class HTTPResponseStartEvent(TypedDict):
|
||||
type: Literal["http.response.start"]
|
||||
status: int
|
||||
headers: NotRequired[Iterable[tuple[bytes, bytes]]]
|
||||
trailers: NotRequired[bool]
|
||||
|
||||
|
||||
class HTTPResponseBodyEvent(TypedDict):
|
||||
type: Literal["http.response.body"]
|
||||
body: bytes
|
||||
more_body: NotRequired[bool]
|
||||
|
||||
|
||||
class HTTPResponseTrailersEvent(TypedDict):
|
||||
type: Literal["http.response.trailers"]
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
more_trailers: bool
|
||||
|
||||
|
||||
class HTTPServerPushEvent(TypedDict):
|
||||
type: Literal["http.response.push"]
|
||||
path: str
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class HTTPDisconnectEvent(TypedDict):
|
||||
type: Literal["http.disconnect"]
|
||||
|
||||
|
||||
class WebSocketConnectEvent(TypedDict):
|
||||
type: Literal["websocket.connect"]
|
||||
|
||||
|
||||
class WebSocketAcceptEvent(TypedDict):
|
||||
type: Literal["websocket.accept"]
|
||||
subprotocol: NotRequired[str | None]
|
||||
headers: NotRequired[Iterable[tuple[bytes, bytes]]]
|
||||
|
||||
|
||||
class _WebSocketReceiveEventBytes(TypedDict):
|
||||
type: Literal["websocket.receive"]
|
||||
bytes: bytes
|
||||
text: NotRequired[None]
|
||||
|
||||
|
||||
class _WebSocketReceiveEventText(TypedDict):
|
||||
type: Literal["websocket.receive"]
|
||||
bytes: NotRequired[None]
|
||||
text: str
|
||||
|
||||
|
||||
WebSocketReceiveEvent = Union[_WebSocketReceiveEventBytes, _WebSocketReceiveEventText]
|
||||
|
||||
|
||||
class _WebSocketSendEventBytes(TypedDict):
|
||||
type: Literal["websocket.send"]
|
||||
bytes: bytes
|
||||
text: NotRequired[None]
|
||||
|
||||
|
||||
class _WebSocketSendEventText(TypedDict):
|
||||
type: Literal["websocket.send"]
|
||||
bytes: NotRequired[None]
|
||||
text: str
|
||||
|
||||
|
||||
WebSocketSendEvent = Union[_WebSocketSendEventBytes, _WebSocketSendEventText]
|
||||
|
||||
|
||||
class WebSocketResponseStartEvent(TypedDict):
|
||||
type: Literal["websocket.http.response.start"]
|
||||
status: int
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class WebSocketResponseBodyEvent(TypedDict):
|
||||
type: Literal["websocket.http.response.body"]
|
||||
body: bytes
|
||||
more_body: NotRequired[bool]
|
||||
|
||||
|
||||
class WebSocketDisconnectEvent(TypedDict):
|
||||
type: Literal["websocket.disconnect"]
|
||||
code: int
|
||||
reason: NotRequired[str | None]
|
||||
|
||||
|
||||
class WebSocketCloseEvent(TypedDict):
|
||||
type: Literal["websocket.close"]
|
||||
code: NotRequired[int]
|
||||
reason: NotRequired[str | None]
|
||||
|
||||
|
||||
class LifespanStartupEvent(TypedDict):
|
||||
type: Literal["lifespan.startup"]
|
||||
|
||||
|
||||
class LifespanShutdownEvent(TypedDict):
|
||||
type: Literal["lifespan.shutdown"]
|
||||
|
||||
|
||||
class LifespanStartupCompleteEvent(TypedDict):
|
||||
type: Literal["lifespan.startup.complete"]
|
||||
|
||||
|
||||
class LifespanStartupFailedEvent(TypedDict):
|
||||
type: Literal["lifespan.startup.failed"]
|
||||
message: str
|
||||
|
||||
|
||||
class LifespanShutdownCompleteEvent(TypedDict):
|
||||
type: Literal["lifespan.shutdown.complete"]
|
||||
|
||||
|
||||
class LifespanShutdownFailedEvent(TypedDict):
|
||||
type: Literal["lifespan.shutdown.failed"]
|
||||
message: str
|
||||
|
||||
|
||||
WebSocketEvent = Union[WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent]
|
||||
|
||||
|
||||
ASGIReceiveEvent = Union[
|
||||
HTTPRequestEvent,
|
||||
HTTPDisconnectEvent,
|
||||
WebSocketConnectEvent,
|
||||
WebSocketReceiveEvent,
|
||||
WebSocketDisconnectEvent,
|
||||
LifespanStartupEvent,
|
||||
LifespanShutdownEvent,
|
||||
]
|
||||
|
||||
|
||||
ASGISendEvent = Union[
|
||||
HTTPResponseStartEvent,
|
||||
HTTPResponseBodyEvent,
|
||||
HTTPResponseTrailersEvent,
|
||||
HTTPServerPushEvent,
|
||||
HTTPDisconnectEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketSendEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketCloseEvent,
|
||||
LifespanStartupCompleteEvent,
|
||||
LifespanStartupFailedEvent,
|
||||
LifespanShutdownCompleteEvent,
|
||||
LifespanShutdownFailedEvent,
|
||||
]
|
||||
|
||||
|
||||
ASGIReceiveCallable = Callable[[], Awaitable[ASGIReceiveEvent]]
|
||||
ASGISendCallable = Callable[[ASGISendEvent], Awaitable[None]]
|
||||
|
||||
|
||||
class ASGI2Protocol(Protocol):
|
||||
def __init__(self, scope: Scope) -> None: ... # pragma: no cover
|
||||
|
||||
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... # pragma: no cover
|
||||
|
||||
|
||||
ASGI2Application = type[ASGI2Protocol]
|
||||
ASGI3Application = Callable[
|
||||
[
|
||||
Scope,
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
],
|
||||
Awaitable[None],
|
||||
]
|
||||
ASGIApplication = Union[ASGI2Application, ASGI3Application]
|
530
venv/lib/python3.11/site-packages/uvicorn/config.py
Normal file
530
venv/lib/python3.11/site-packages/uvicorn/config.py
Normal file
@ -0,0 +1,530 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
from collections.abc import Awaitable
|
||||
from configparser import RawConfigParser
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, Callable, Literal
|
||||
|
||||
import click
|
||||
|
||||
from uvicorn._types import ASGIApplication
|
||||
from uvicorn.importer import ImportFromStringError, import_from_string
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.middleware.asgi2 import ASGI2Middleware
|
||||
from uvicorn.middleware.message_logger import MessageLoggerMiddleware
|
||||
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||
from uvicorn.middleware.wsgi import WSGIMiddleware
|
||||
|
||||
HTTPProtocolType = Literal["auto", "h11", "httptools"]
|
||||
WSProtocolType = Literal["auto", "none", "websockets", "wsproto"]
|
||||
LifespanType = Literal["auto", "on", "off"]
|
||||
LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"]
|
||||
InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"]
|
||||
|
||||
LOG_LEVELS: dict[str, int] = {
|
||||
"critical": logging.CRITICAL,
|
||||
"error": logging.ERROR,
|
||||
"warning": logging.WARNING,
|
||||
"info": logging.INFO,
|
||||
"debug": logging.DEBUG,
|
||||
"trace": TRACE_LOG_LEVEL,
|
||||
}
|
||||
HTTP_PROTOCOLS: dict[HTTPProtocolType, str] = {
|
||||
"auto": "uvicorn.protocols.http.auto:AutoHTTPProtocol",
|
||||
"h11": "uvicorn.protocols.http.h11_impl:H11Protocol",
|
||||
"httptools": "uvicorn.protocols.http.httptools_impl:HttpToolsProtocol",
|
||||
}
|
||||
WS_PROTOCOLS: dict[WSProtocolType, str | None] = {
|
||||
"auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol",
|
||||
"none": None,
|
||||
"websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
|
||||
"wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
|
||||
}
|
||||
LIFESPAN: dict[LifespanType, str] = {
|
||||
"auto": "uvicorn.lifespan.on:LifespanOn",
|
||||
"on": "uvicorn.lifespan.on:LifespanOn",
|
||||
"off": "uvicorn.lifespan.off:LifespanOff",
|
||||
}
|
||||
LOOP_SETUPS: dict[LoopSetupType, str | None] = {
|
||||
"none": None,
|
||||
"auto": "uvicorn.loops.auto:auto_loop_setup",
|
||||
"asyncio": "uvicorn.loops.asyncio:asyncio_setup",
|
||||
"uvloop": "uvicorn.loops.uvloop:uvloop_setup",
|
||||
}
|
||||
INTERFACES: list[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]
|
||||
|
||||
SSL_PROTOCOL_VERSION: int = ssl.PROTOCOL_TLS_SERVER
|
||||
|
||||
LOGGING_CONFIG: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"()": "uvicorn.logging.DefaultFormatter",
|
||||
"fmt": "%(levelprefix)s %(message)s",
|
||||
"use_colors": None,
|
||||
},
|
||||
"access": {
|
||||
"()": "uvicorn.logging.AccessFormatter",
|
||||
"fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', # noqa: E501
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"default": {
|
||||
"formatter": "default",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stderr",
|
||||
},
|
||||
"access": {
|
||||
"formatter": "access",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
|
||||
"uvicorn.error": {"level": "INFO"},
|
||||
"uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": False},
|
||||
},
|
||||
}
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
def create_ssl_context(
|
||||
certfile: str | os.PathLike[str],
|
||||
keyfile: str | os.PathLike[str] | None,
|
||||
password: str | None,
|
||||
ssl_version: int,
|
||||
cert_reqs: int,
|
||||
ca_certs: str | os.PathLike[str] | None,
|
||||
ciphers: str | None,
|
||||
) -> ssl.SSLContext:
|
||||
ctx = ssl.SSLContext(ssl_version)
|
||||
get_password = (lambda: password) if password else None
|
||||
ctx.load_cert_chain(certfile, keyfile, get_password)
|
||||
ctx.verify_mode = ssl.VerifyMode(cert_reqs)
|
||||
if ca_certs:
|
||||
ctx.load_verify_locations(ca_certs)
|
||||
if ciphers:
|
||||
ctx.set_ciphers(ciphers)
|
||||
return ctx
|
||||
|
||||
|
||||
def is_dir(path: Path) -> bool:
|
||||
try:
|
||||
if not path.is_absolute():
|
||||
path = path.resolve()
|
||||
return path.is_dir()
|
||||
except OSError: # pragma: full coverage
|
||||
return False
|
||||
|
||||
|
||||
def resolve_reload_patterns(patterns_list: list[str], directories_list: list[str]) -> tuple[list[str], list[Path]]:
|
||||
directories: list[Path] = list(set(map(Path, directories_list.copy())))
|
||||
patterns: list[str] = patterns_list.copy()
|
||||
|
||||
current_working_directory = Path.cwd()
|
||||
for pattern in patterns_list:
|
||||
# Special case for the .* pattern, otherwise this would only match
|
||||
# hidden directories which is probably undesired
|
||||
if pattern == ".*":
|
||||
continue # pragma: py-darwin
|
||||
patterns.append(pattern)
|
||||
if is_dir(Path(pattern)):
|
||||
directories.append(Path(pattern))
|
||||
else:
|
||||
for match in current_working_directory.glob(pattern):
|
||||
if is_dir(match):
|
||||
directories.append(match)
|
||||
|
||||
directories = list(set(directories))
|
||||
directories = list(map(Path, directories))
|
||||
directories = list(map(lambda x: x.resolve(), directories))
|
||||
directories = list({reload_path for reload_path in directories if is_dir(reload_path)})
|
||||
|
||||
children = []
|
||||
for j in range(len(directories)):
|
||||
for k in range(j + 1, len(directories)): # pragma: full coverage
|
||||
if directories[j] in directories[k].parents:
|
||||
children.append(directories[k])
|
||||
elif directories[k] in directories[j].parents:
|
||||
children.append(directories[j])
|
||||
|
||||
directories = list(set(directories).difference(set(children)))
|
||||
|
||||
return list(set(patterns)), directories
|
||||
|
||||
|
||||
def _normalize_dirs(dirs: list[str] | str | None) -> list[str]:
|
||||
if dirs is None:
|
||||
return []
|
||||
if isinstance(dirs, str):
|
||||
return [dirs]
|
||||
return list(set(dirs))
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApplication | Callable[..., Any] | str,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8000,
|
||||
uds: str | None = None,
|
||||
fd: int | None = None,
|
||||
loop: LoopSetupType = "auto",
|
||||
http: type[asyncio.Protocol] | HTTPProtocolType = "auto",
|
||||
ws: type[asyncio.Protocol] | WSProtocolType = "auto",
|
||||
ws_max_size: int = 16 * 1024 * 1024,
|
||||
ws_max_queue: int = 32,
|
||||
ws_ping_interval: float | None = 20.0,
|
||||
ws_ping_timeout: float | None = 20.0,
|
||||
ws_per_message_deflate: bool = True,
|
||||
lifespan: LifespanType = "auto",
|
||||
env_file: str | os.PathLike[str] | None = None,
|
||||
log_config: dict[str, Any] | str | RawConfigParser | IO[Any] | None = LOGGING_CONFIG,
|
||||
log_level: str | int | None = None,
|
||||
access_log: bool = True,
|
||||
use_colors: bool | None = None,
|
||||
interface: InterfaceType = "auto",
|
||||
reload: bool = False,
|
||||
reload_dirs: list[str] | str | None = None,
|
||||
reload_delay: float = 0.25,
|
||||
reload_includes: list[str] | str | None = None,
|
||||
reload_excludes: list[str] | str | None = None,
|
||||
workers: int | None = None,
|
||||
proxy_headers: bool = True,
|
||||
server_header: bool = True,
|
||||
date_header: bool = True,
|
||||
forwarded_allow_ips: list[str] | str | None = None,
|
||||
root_path: str = "",
|
||||
limit_concurrency: int | None = None,
|
||||
limit_max_requests: int | None = None,
|
||||
backlog: int = 2048,
|
||||
timeout_keep_alive: int = 5,
|
||||
timeout_notify: int = 30,
|
||||
timeout_graceful_shutdown: int | None = None,
|
||||
callback_notify: Callable[..., Awaitable[None]] | None = None,
|
||||
ssl_keyfile: str | os.PathLike[str] | None = None,
|
||||
ssl_certfile: str | os.PathLike[str] | None = None,
|
||||
ssl_keyfile_password: str | None = None,
|
||||
ssl_version: int = SSL_PROTOCOL_VERSION,
|
||||
ssl_cert_reqs: int = ssl.CERT_NONE,
|
||||
ssl_ca_certs: str | None = None,
|
||||
ssl_ciphers: str = "TLSv1",
|
||||
headers: list[tuple[str, str]] | None = None,
|
||||
factory: bool = False,
|
||||
h11_max_incomplete_event_size: int | None = None,
|
||||
):
|
||||
self.app = app
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.uds = uds
|
||||
self.fd = fd
|
||||
self.loop = loop
|
||||
self.http = http
|
||||
self.ws = ws
|
||||
self.ws_max_size = ws_max_size
|
||||
self.ws_max_queue = ws_max_queue
|
||||
self.ws_ping_interval = ws_ping_interval
|
||||
self.ws_ping_timeout = ws_ping_timeout
|
||||
self.ws_per_message_deflate = ws_per_message_deflate
|
||||
self.lifespan = lifespan
|
||||
self.log_config = log_config
|
||||
self.log_level = log_level
|
||||
self.access_log = access_log
|
||||
self.use_colors = use_colors
|
||||
self.interface = interface
|
||||
self.reload = reload
|
||||
self.reload_delay = reload_delay
|
||||
self.workers = workers or 1
|
||||
self.proxy_headers = proxy_headers
|
||||
self.server_header = server_header
|
||||
self.date_header = date_header
|
||||
self.root_path = root_path
|
||||
self.limit_concurrency = limit_concurrency
|
||||
self.limit_max_requests = limit_max_requests
|
||||
self.backlog = backlog
|
||||
self.timeout_keep_alive = timeout_keep_alive
|
||||
self.timeout_notify = timeout_notify
|
||||
self.timeout_graceful_shutdown = timeout_graceful_shutdown
|
||||
self.callback_notify = callback_notify
|
||||
self.ssl_keyfile = ssl_keyfile
|
||||
self.ssl_certfile = ssl_certfile
|
||||
self.ssl_keyfile_password = ssl_keyfile_password
|
||||
self.ssl_version = ssl_version
|
||||
self.ssl_cert_reqs = ssl_cert_reqs
|
||||
self.ssl_ca_certs = ssl_ca_certs
|
||||
self.ssl_ciphers = ssl_ciphers
|
||||
self.headers: list[tuple[str, str]] = headers or []
|
||||
self.encoded_headers: list[tuple[bytes, bytes]] = []
|
||||
self.factory = factory
|
||||
self.h11_max_incomplete_event_size = h11_max_incomplete_event_size
|
||||
|
||||
self.loaded = False
|
||||
self.configure_logging()
|
||||
|
||||
self.reload_dirs: list[Path] = []
|
||||
self.reload_dirs_excludes: list[Path] = []
|
||||
self.reload_includes: list[str] = []
|
||||
self.reload_excludes: list[str] = []
|
||||
|
||||
if (reload_dirs or reload_includes or reload_excludes) and not self.should_reload:
|
||||
logger.warning(
|
||||
"Current configuration will not reload as not all conditions are met, " "please refer to documentation."
|
||||
)
|
||||
|
||||
if self.should_reload:
|
||||
reload_dirs = _normalize_dirs(reload_dirs)
|
||||
reload_includes = _normalize_dirs(reload_includes)
|
||||
reload_excludes = _normalize_dirs(reload_excludes)
|
||||
|
||||
self.reload_includes, self.reload_dirs = resolve_reload_patterns(reload_includes, reload_dirs)
|
||||
|
||||
self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(reload_excludes, [])
|
||||
|
||||
reload_dirs_tmp = self.reload_dirs.copy()
|
||||
|
||||
for directory in self.reload_dirs_excludes:
|
||||
for reload_directory in reload_dirs_tmp:
|
||||
if directory == reload_directory or directory in reload_directory.parents:
|
||||
try:
|
||||
self.reload_dirs.remove(reload_directory)
|
||||
except ValueError: # pragma: full coverage
|
||||
pass
|
||||
|
||||
for pattern in self.reload_excludes:
|
||||
if pattern in self.reload_includes:
|
||||
self.reload_includes.remove(pattern) # pragma: full coverage
|
||||
|
||||
if not self.reload_dirs:
|
||||
if reload_dirs:
|
||||
logger.warning(
|
||||
"Provided reload directories %s did not contain valid "
|
||||
+ "directories, watching current working directory.",
|
||||
reload_dirs,
|
||||
)
|
||||
self.reload_dirs = [Path(os.getcwd())]
|
||||
|
||||
logger.info(
|
||||
"Will watch for changes in these directories: %s",
|
||||
sorted(list(map(str, self.reload_dirs))),
|
||||
)
|
||||
|
||||
if env_file is not None:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger.info("Loading environment from '%s'", env_file)
|
||||
load_dotenv(dotenv_path=env_file)
|
||||
|
||||
if workers is None and "WEB_CONCURRENCY" in os.environ:
|
||||
self.workers = int(os.environ["WEB_CONCURRENCY"])
|
||||
|
||||
self.forwarded_allow_ips: list[str] | str
|
||||
if forwarded_allow_ips is None:
|
||||
self.forwarded_allow_ips = os.environ.get("FORWARDED_ALLOW_IPS", "127.0.0.1")
|
||||
else:
|
||||
self.forwarded_allow_ips = forwarded_allow_ips # pragma: full coverage
|
||||
|
||||
if self.reload and self.workers > 1:
|
||||
logger.warning('"workers" flag is ignored when reloading is enabled.')
|
||||
|
||||
@property
|
||||
def asgi_version(self) -> Literal["2.0", "3.0"]:
|
||||
mapping: dict[str, Literal["2.0", "3.0"]] = {
|
||||
"asgi2": "2.0",
|
||||
"asgi3": "3.0",
|
||||
"wsgi": "3.0",
|
||||
}
|
||||
return mapping[self.interface]
|
||||
|
||||
@property
|
||||
def is_ssl(self) -> bool:
|
||||
return bool(self.ssl_keyfile or self.ssl_certfile)
|
||||
|
||||
@property
|
||||
def use_subprocess(self) -> bool:
|
||||
return bool(self.reload or self.workers > 1)
|
||||
|
||||
def configure_logging(self) -> None:
|
||||
logging.addLevelName(TRACE_LOG_LEVEL, "TRACE")
|
||||
|
||||
if self.log_config is not None:
|
||||
if isinstance(self.log_config, dict):
|
||||
if self.use_colors in (True, False):
|
||||
self.log_config["formatters"]["default"]["use_colors"] = self.use_colors
|
||||
self.log_config["formatters"]["access"]["use_colors"] = self.use_colors
|
||||
logging.config.dictConfig(self.log_config)
|
||||
elif isinstance(self.log_config, str) and self.log_config.endswith(".json"):
|
||||
with open(self.log_config) as file:
|
||||
loaded_config = json.load(file)
|
||||
logging.config.dictConfig(loaded_config)
|
||||
elif isinstance(self.log_config, str) and self.log_config.endswith((".yaml", ".yml")):
|
||||
# Install the PyYAML package or the uvicorn[standard] optional
|
||||
# dependencies to enable this functionality.
|
||||
import yaml
|
||||
|
||||
with open(self.log_config) as file:
|
||||
loaded_config = yaml.safe_load(file)
|
||||
logging.config.dictConfig(loaded_config)
|
||||
else:
|
||||
# See the note about fileConfig() here:
|
||||
# https://docs.python.org/3/library/logging.config.html#configuration-file-format
|
||||
logging.config.fileConfig(self.log_config, disable_existing_loggers=False)
|
||||
|
||||
if self.log_level is not None:
|
||||
if isinstance(self.log_level, str):
|
||||
log_level = LOG_LEVELS[self.log_level]
|
||||
else:
|
||||
log_level = self.log_level
|
||||
logging.getLogger("uvicorn.error").setLevel(log_level)
|
||||
logging.getLogger("uvicorn.access").setLevel(log_level)
|
||||
logging.getLogger("uvicorn.asgi").setLevel(log_level)
|
||||
if self.access_log is False:
|
||||
logging.getLogger("uvicorn.access").handlers = []
|
||||
logging.getLogger("uvicorn.access").propagate = False
|
||||
|
||||
def load(self) -> None:
|
||||
assert not self.loaded
|
||||
|
||||
if self.is_ssl:
|
||||
assert self.ssl_certfile
|
||||
self.ssl: ssl.SSLContext | None = create_ssl_context(
|
||||
keyfile=self.ssl_keyfile,
|
||||
certfile=self.ssl_certfile,
|
||||
password=self.ssl_keyfile_password,
|
||||
ssl_version=self.ssl_version,
|
||||
cert_reqs=self.ssl_cert_reqs,
|
||||
ca_certs=self.ssl_ca_certs,
|
||||
ciphers=self.ssl_ciphers,
|
||||
)
|
||||
else:
|
||||
self.ssl = None
|
||||
|
||||
encoded_headers = [(key.lower().encode("latin1"), value.encode("latin1")) for key, value in self.headers]
|
||||
self.encoded_headers = (
|
||||
[(b"server", b"uvicorn")] + encoded_headers
|
||||
if b"server" not in dict(encoded_headers) and self.server_header
|
||||
else encoded_headers
|
||||
)
|
||||
|
||||
if isinstance(self.http, str):
|
||||
http_protocol_class = import_from_string(HTTP_PROTOCOLS[self.http])
|
||||
self.http_protocol_class: type[asyncio.Protocol] = http_protocol_class
|
||||
else:
|
||||
self.http_protocol_class = self.http
|
||||
|
||||
if isinstance(self.ws, str):
|
||||
ws_protocol_class = import_from_string(WS_PROTOCOLS[self.ws])
|
||||
self.ws_protocol_class: type[asyncio.Protocol] | None = ws_protocol_class
|
||||
else:
|
||||
self.ws_protocol_class = self.ws
|
||||
|
||||
self.lifespan_class = import_from_string(LIFESPAN[self.lifespan])
|
||||
|
||||
try:
|
||||
self.loaded_app = import_from_string(self.app)
|
||||
except ImportFromStringError as exc:
|
||||
logger.error("Error loading ASGI app. %s" % exc)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
self.loaded_app = self.loaded_app()
|
||||
except TypeError as exc:
|
||||
if self.factory:
|
||||
logger.error("Error loading ASGI app factory: %s", exc)
|
||||
sys.exit(1)
|
||||
else:
|
||||
if not self.factory:
|
||||
logger.warning(
|
||||
"ASGI app factory detected. Using it, " "but please consider setting the --factory flag explicitly."
|
||||
)
|
||||
|
||||
if self.interface == "auto":
|
||||
if inspect.isclass(self.loaded_app):
|
||||
use_asgi_3 = hasattr(self.loaded_app, "__await__")
|
||||
elif inspect.isfunction(self.loaded_app):
|
||||
use_asgi_3 = asyncio.iscoroutinefunction(self.loaded_app)
|
||||
else:
|
||||
call = getattr(self.loaded_app, "__call__", None)
|
||||
use_asgi_3 = asyncio.iscoroutinefunction(call)
|
||||
self.interface = "asgi3" if use_asgi_3 else "asgi2"
|
||||
|
||||
if self.interface == "wsgi":
|
||||
self.loaded_app = WSGIMiddleware(self.loaded_app)
|
||||
self.ws_protocol_class = None
|
||||
elif self.interface == "asgi2":
|
||||
self.loaded_app = ASGI2Middleware(self.loaded_app)
|
||||
|
||||
if logger.getEffectiveLevel() <= TRACE_LOG_LEVEL:
|
||||
self.loaded_app = MessageLoggerMiddleware(self.loaded_app)
|
||||
if self.proxy_headers:
|
||||
self.loaded_app = ProxyHeadersMiddleware(self.loaded_app, trusted_hosts=self.forwarded_allow_ips)
|
||||
|
||||
self.loaded = True
|
||||
|
||||
def setup_event_loop(self) -> None:
|
||||
loop_setup: Callable | None = import_from_string(LOOP_SETUPS[self.loop])
|
||||
if loop_setup is not None:
|
||||
loop_setup(use_subprocess=self.use_subprocess)
|
||||
|
||||
def bind_socket(self) -> socket.socket:
|
||||
logger_args: list[str | int]
|
||||
if self.uds: # pragma: py-win32
|
||||
path = self.uds
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.bind(path)
|
||||
uds_perms = 0o666
|
||||
os.chmod(self.uds, uds_perms)
|
||||
except OSError as exc: # pragma: full coverage
|
||||
logger.error(exc)
|
||||
sys.exit(1)
|
||||
|
||||
message = "Uvicorn running on unix socket %s (Press CTRL+C to quit)"
|
||||
sock_name_format = "%s"
|
||||
color_message = "Uvicorn running on " + click.style(sock_name_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger_args = [self.uds]
|
||||
elif self.fd: # pragma: py-win32
|
||||
sock = socket.fromfd(self.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
message = "Uvicorn running on socket %s (Press CTRL+C to quit)"
|
||||
fd_name_format = "%s"
|
||||
color_message = "Uvicorn running on " + click.style(fd_name_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger_args = [sock.getsockname()]
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
addr_format = "%s://%s:%d"
|
||||
|
||||
if self.host and ":" in self.host: # pragma: full coverage
|
||||
# It's an IPv6 address.
|
||||
family = socket.AF_INET6
|
||||
addr_format = "%s://[%s]:%d"
|
||||
|
||||
sock = socket.socket(family=family)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
try:
|
||||
sock.bind((self.host, self.port))
|
||||
except OSError as exc: # pragma: full coverage
|
||||
logger.error(exc)
|
||||
sys.exit(1)
|
||||
|
||||
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
|
||||
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
protocol_name = "https" if self.is_ssl else "http"
|
||||
logger_args = [protocol_name, self.host, sock.getsockname()[1]]
|
||||
logger.info(message, *logger_args, extra={"color_message": color_message})
|
||||
sock.set_inheritable(True)
|
||||
return sock
|
||||
|
||||
@property
|
||||
def should_reload(self) -> bool:
|
||||
return isinstance(self.app, str) and self.reload
|
34
venv/lib/python3.11/site-packages/uvicorn/importer.py
Normal file
34
venv/lib/python3.11/site-packages/uvicorn/importer.py
Normal file
@ -0,0 +1,34 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ImportFromStringError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def import_from_string(import_str: Any) -> Any:
|
||||
if not isinstance(import_str, str):
|
||||
return import_str
|
||||
|
||||
module_str, _, attrs_str = import_str.partition(":")
|
||||
if not module_str or not attrs_str:
|
||||
message = 'Import string "{import_str}" must be in format "<module>:<attribute>".'
|
||||
raise ImportFromStringError(message.format(import_str=import_str))
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_str)
|
||||
except ModuleNotFoundError as exc:
|
||||
if exc.name != module_str:
|
||||
raise exc from None
|
||||
message = 'Could not import module "{module_str}".'
|
||||
raise ImportFromStringError(message.format(module_str=module_str))
|
||||
|
||||
instance = module
|
||||
try:
|
||||
for attr_str in attrs_str.split("."):
|
||||
instance = getattr(instance, attr_str)
|
||||
except AttributeError:
|
||||
message = 'Attribute "{attrs_str}" not found in module "{module_str}".'
|
||||
raise ImportFromStringError(message.format(attrs_str=attrs_str, module_str=module_str))
|
||||
|
||||
return instance
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
17
venv/lib/python3.11/site-packages/uvicorn/lifespan/off.py
Normal file
17
venv/lib/python3.11/site-packages/uvicorn/lifespan/off.py
Normal file
@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from uvicorn import Config
|
||||
|
||||
|
||||
class LifespanOff:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.should_exit = False
|
||||
self.state: dict[str, Any] = {}
|
||||
|
||||
async def startup(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
137
venv/lib/python3.11/site-packages/uvicorn/lifespan/on.py
Normal file
137
venv/lib/python3.11/site-packages/uvicorn/lifespan/on.py
Normal file
@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from asyncio import Queue
|
||||
from typing import Any, Union
|
||||
|
||||
from uvicorn import Config
|
||||
from uvicorn._types import (
|
||||
LifespanScope,
|
||||
LifespanShutdownCompleteEvent,
|
||||
LifespanShutdownEvent,
|
||||
LifespanShutdownFailedEvent,
|
||||
LifespanStartupCompleteEvent,
|
||||
LifespanStartupEvent,
|
||||
LifespanStartupFailedEvent,
|
||||
)
|
||||
|
||||
LifespanReceiveMessage = Union[LifespanStartupEvent, LifespanShutdownEvent]
|
||||
LifespanSendMessage = Union[
|
||||
LifespanStartupFailedEvent,
|
||||
LifespanShutdownFailedEvent,
|
||||
LifespanStartupCompleteEvent,
|
||||
LifespanShutdownCompleteEvent,
|
||||
]
|
||||
|
||||
|
||||
STATE_TRANSITION_ERROR = "Got invalid state transition on lifespan protocol."
|
||||
|
||||
|
||||
class LifespanOn:
|
||||
def __init__(self, config: Config) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.startup_event = asyncio.Event()
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.receive_queue: Queue[LifespanReceiveMessage] = asyncio.Queue()
|
||||
self.error_occured = False
|
||||
self.startup_failed = False
|
||||
self.shutdown_failed = False
|
||||
self.should_exit = False
|
||||
self.state: dict[str, Any] = {}
|
||||
|
||||
async def startup(self) -> None:
|
||||
self.logger.info("Waiting for application startup.")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
main_lifespan_task = loop.create_task(self.main()) # noqa: F841
|
||||
# Keep a hard reference to prevent garbage collection
|
||||
# See https://github.com/encode/uvicorn/pull/972
|
||||
startup_event: LifespanStartupEvent = {"type": "lifespan.startup"}
|
||||
await self.receive_queue.put(startup_event)
|
||||
await self.startup_event.wait()
|
||||
|
||||
if self.startup_failed or (self.error_occured and self.config.lifespan == "on"):
|
||||
self.logger.error("Application startup failed. Exiting.")
|
||||
self.should_exit = True
|
||||
else:
|
||||
self.logger.info("Application startup complete.")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.error_occured:
|
||||
return
|
||||
self.logger.info("Waiting for application shutdown.")
|
||||
shutdown_event: LifespanShutdownEvent = {"type": "lifespan.shutdown"}
|
||||
await self.receive_queue.put(shutdown_event)
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
if self.shutdown_failed or (self.error_occured and self.config.lifespan == "on"):
|
||||
self.logger.error("Application shutdown failed. Exiting.")
|
||||
self.should_exit = True
|
||||
else:
|
||||
self.logger.info("Application shutdown complete.")
|
||||
|
||||
async def main(self) -> None:
|
||||
try:
|
||||
app = self.config.loaded_app
|
||||
scope: LifespanScope = {
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.0"},
|
||||
"state": self.state,
|
||||
}
|
||||
await app(scope, self.receive, self.send)
|
||||
except BaseException as exc:
|
||||
self.asgi = None
|
||||
self.error_occured = True
|
||||
if self.startup_failed or self.shutdown_failed:
|
||||
return
|
||||
if self.config.lifespan == "auto":
|
||||
msg = "ASGI 'lifespan' protocol appears unsupported."
|
||||
self.logger.info(msg)
|
||||
else:
|
||||
msg = "Exception in 'lifespan' protocol\n"
|
||||
self.logger.error(msg, exc_info=exc)
|
||||
finally:
|
||||
self.startup_event.set()
|
||||
self.shutdown_event.set()
|
||||
|
||||
async def send(self, message: LifespanSendMessage) -> None:
|
||||
assert message["type"] in (
|
||||
"lifespan.startup.complete",
|
||||
"lifespan.startup.failed",
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
)
|
||||
|
||||
if message["type"] == "lifespan.startup.complete":
|
||||
assert not self.startup_event.is_set(), STATE_TRANSITION_ERROR
|
||||
assert not self.shutdown_event.is_set(), STATE_TRANSITION_ERROR
|
||||
self.startup_event.set()
|
||||
|
||||
elif message["type"] == "lifespan.startup.failed":
|
||||
assert not self.startup_event.is_set(), STATE_TRANSITION_ERROR
|
||||
assert not self.shutdown_event.is_set(), STATE_TRANSITION_ERROR
|
||||
self.startup_event.set()
|
||||
self.startup_failed = True
|
||||
if message.get("message"):
|
||||
self.logger.error(message["message"])
|
||||
|
||||
elif message["type"] == "lifespan.shutdown.complete":
|
||||
assert self.startup_event.is_set(), STATE_TRANSITION_ERROR
|
||||
assert not self.shutdown_event.is_set(), STATE_TRANSITION_ERROR
|
||||
self.shutdown_event.set()
|
||||
|
||||
elif message["type"] == "lifespan.shutdown.failed":
|
||||
assert self.startup_event.is_set(), STATE_TRANSITION_ERROR
|
||||
assert not self.shutdown_event.is_set(), STATE_TRANSITION_ERROR
|
||||
self.shutdown_event.set()
|
||||
self.shutdown_failed = True
|
||||
if message.get("message"):
|
||||
self.logger.error(message["message"])
|
||||
|
||||
async def receive(self) -> LifespanReceiveMessage:
|
||||
return await self.receive_queue.get()
|
117
venv/lib/python3.11/site-packages/uvicorn/logging.py
Normal file
117
venv/lib/python3.11/site-packages/uvicorn/logging.py
Normal file
@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import http
|
||||
import logging
|
||||
import sys
|
||||
from copy import copy
|
||||
from typing import Literal
|
||||
|
||||
import click
|
||||
|
||||
TRACE_LOG_LEVEL = 5
|
||||
|
||||
|
||||
class ColourizedFormatter(logging.Formatter):
|
||||
"""
|
||||
A custom log formatter class that:
|
||||
|
||||
* Outputs the LOG_LEVEL with an appropriate color.
|
||||
* If a log call includes an `extra={"color_message": ...}` it will be used
|
||||
for formatting the output, instead of the plain text message.
|
||||
"""
|
||||
|
||||
level_name_colors = {
|
||||
TRACE_LOG_LEVEL: lambda level_name: click.style(str(level_name), fg="blue"),
|
||||
logging.DEBUG: lambda level_name: click.style(str(level_name), fg="cyan"),
|
||||
logging.INFO: lambda level_name: click.style(str(level_name), fg="green"),
|
||||
logging.WARNING: lambda level_name: click.style(str(level_name), fg="yellow"),
|
||||
logging.ERROR: lambda level_name: click.style(str(level_name), fg="red"),
|
||||
logging.CRITICAL: lambda level_name: click.style(str(level_name), fg="bright_red"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fmt: str | None = None,
|
||||
datefmt: str | None = None,
|
||||
style: Literal["%", "{", "$"] = "%",
|
||||
use_colors: bool | None = None,
|
||||
):
|
||||
if use_colors in (True, False):
|
||||
self.use_colors = use_colors
|
||||
else:
|
||||
self.use_colors = sys.stdout.isatty()
|
||||
super().__init__(fmt=fmt, datefmt=datefmt, style=style)
|
||||
|
||||
def color_level_name(self, level_name: str, level_no: int) -> str:
|
||||
def default(level_name: str) -> str:
|
||||
return str(level_name) # pragma: no cover
|
||||
|
||||
func = self.level_name_colors.get(level_no, default)
|
||||
return func(level_name)
|
||||
|
||||
def should_use_colors(self) -> bool:
|
||||
return True # pragma: no cover
|
||||
|
||||
def formatMessage(self, record: logging.LogRecord) -> str:
|
||||
recordcopy = copy(record)
|
||||
levelname = recordcopy.levelname
|
||||
seperator = " " * (8 - len(recordcopy.levelname))
|
||||
if self.use_colors:
|
||||
levelname = self.color_level_name(levelname, recordcopy.levelno)
|
||||
if "color_message" in recordcopy.__dict__:
|
||||
recordcopy.msg = recordcopy.__dict__["color_message"]
|
||||
recordcopy.__dict__["message"] = recordcopy.getMessage()
|
||||
recordcopy.__dict__["levelprefix"] = levelname + ":" + seperator
|
||||
return super().formatMessage(recordcopy)
|
||||
|
||||
|
||||
class DefaultFormatter(ColourizedFormatter):
|
||||
def should_use_colors(self) -> bool:
|
||||
return sys.stderr.isatty() # pragma: no cover
|
||||
|
||||
|
||||
class AccessFormatter(ColourizedFormatter):
|
||||
status_code_colours = {
|
||||
1: lambda code: click.style(str(code), fg="bright_white"),
|
||||
2: lambda code: click.style(str(code), fg="green"),
|
||||
3: lambda code: click.style(str(code), fg="yellow"),
|
||||
4: lambda code: click.style(str(code), fg="red"),
|
||||
5: lambda code: click.style(str(code), fg="bright_red"),
|
||||
}
|
||||
|
||||
def get_status_code(self, status_code: int) -> str:
|
||||
try:
|
||||
status_phrase = http.HTTPStatus(status_code).phrase
|
||||
except ValueError:
|
||||
status_phrase = ""
|
||||
status_and_phrase = f"{status_code} {status_phrase}"
|
||||
if self.use_colors:
|
||||
|
||||
def default(code: int) -> str:
|
||||
return status_and_phrase # pragma: no cover
|
||||
|
||||
func = self.status_code_colours.get(status_code // 100, default)
|
||||
return func(status_and_phrase)
|
||||
return status_and_phrase
|
||||
|
||||
def formatMessage(self, record: logging.LogRecord) -> str:
|
||||
recordcopy = copy(record)
|
||||
(
|
||||
client_addr,
|
||||
method,
|
||||
full_path,
|
||||
http_version,
|
||||
status_code,
|
||||
) = recordcopy.args # type: ignore[misc]
|
||||
status_code = self.get_status_code(int(status_code)) # type: ignore[arg-type]
|
||||
request_line = f"{method} {full_path} HTTP/{http_version}"
|
||||
if self.use_colors:
|
||||
request_line = click.style(request_line, bold=True)
|
||||
recordcopy.__dict__.update(
|
||||
{
|
||||
"client_addr": client_addr,
|
||||
"request_line": request_line,
|
||||
"status_code": status_code,
|
||||
}
|
||||
)
|
||||
return super().formatMessage(recordcopy)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
10
venv/lib/python3.11/site-packages/uvicorn/loops/asyncio.py
Normal file
10
venv/lib/python3.11/site-packages/uvicorn/loops/asyncio.py
Normal file
@ -0,0 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
def asyncio_setup(use_subprocess: bool = False) -> None:
|
||||
if sys.platform == "win32" and use_subprocess:
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # pragma: full coverage
|
11
venv/lib/python3.11/site-packages/uvicorn/loops/auto.py
Normal file
11
venv/lib/python3.11/site-packages/uvicorn/loops/auto.py
Normal file
@ -0,0 +1,11 @@
|
||||
def auto_loop_setup(use_subprocess: bool = False) -> None:
|
||||
try:
|
||||
import uvloop # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
from uvicorn.loops.asyncio import asyncio_setup as loop_setup
|
||||
|
||||
loop_setup(use_subprocess=use_subprocess)
|
||||
else: # pragma: no cover
|
||||
from uvicorn.loops.uvloop import uvloop_setup
|
||||
|
||||
uvloop_setup(use_subprocess=use_subprocess)
|
@ -0,0 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
import uvloop
|
||||
|
||||
|
||||
def uvloop_setup(use_subprocess: bool = False) -> None:
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
591
venv/lib/python3.11/site-packages/uvicorn/main.py
Normal file
591
venv/lib/python3.11/site-packages/uvicorn/main.py
Normal file
@ -0,0 +1,591 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import ssl
|
||||
import sys
|
||||
from configparser import RawConfigParser
|
||||
from typing import IO, Any, Callable
|
||||
|
||||
import click
|
||||
|
||||
import uvicorn
|
||||
from uvicorn._types import ASGIApplication
|
||||
from uvicorn.config import (
|
||||
HTTP_PROTOCOLS,
|
||||
INTERFACES,
|
||||
LIFESPAN,
|
||||
LOG_LEVELS,
|
||||
LOGGING_CONFIG,
|
||||
LOOP_SETUPS,
|
||||
SSL_PROTOCOL_VERSION,
|
||||
WS_PROTOCOLS,
|
||||
Config,
|
||||
HTTPProtocolType,
|
||||
InterfaceType,
|
||||
LifespanType,
|
||||
LoopSetupType,
|
||||
WSProtocolType,
|
||||
)
|
||||
from uvicorn.server import Server, ServerState # noqa: F401 # Used to be defined here.
|
||||
from uvicorn.supervisors import ChangeReload, Multiprocess
|
||||
|
||||
LEVEL_CHOICES = click.Choice(list(LOG_LEVELS.keys()))
|
||||
HTTP_CHOICES = click.Choice(list(HTTP_PROTOCOLS.keys()))
|
||||
WS_CHOICES = click.Choice(list(WS_PROTOCOLS.keys()))
|
||||
LIFESPAN_CHOICES = click.Choice(list(LIFESPAN.keys()))
|
||||
LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"])
|
||||
INTERFACE_CHOICES = click.Choice(INTERFACES)
|
||||
|
||||
STARTUP_FAILURE = 3
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> None:
|
||||
if not value or ctx.resilient_parsing:
|
||||
return
|
||||
click.echo(
|
||||
"Running uvicorn {version} with {py_implementation} {py_version} on {system}".format( # noqa: UP032
|
||||
version=uvicorn.__version__,
|
||||
py_implementation=platform.python_implementation(),
|
||||
py_version=platform.python_version(),
|
||||
system=platform.system(),
|
||||
)
|
||||
)
|
||||
ctx.exit()
|
||||
|
||||
|
||||
@click.command(context_settings={"auto_envvar_prefix": "UVICORN"})
|
||||
@click.argument("app", envvar="UVICORN_APP")
|
||||
@click.option(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Bind socket to this host.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Bind socket to this port. If 0, an available port will be picked.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--uds", type=str, default=None, help="Bind to a UNIX domain socket.")
|
||||
@click.option("--fd", type=int, default=None, help="Bind to socket from this file descriptor.")
|
||||
@click.option("--reload", is_flag=True, default=False, help="Enable auto-reload.")
|
||||
@click.option(
|
||||
"--reload-dir",
|
||||
"reload_dirs",
|
||||
multiple=True,
|
||||
help="Set reload directories explicitly, instead of using the current working" " directory.",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.option(
|
||||
"--reload-include",
|
||||
"reload_includes",
|
||||
multiple=True,
|
||||
help="Set glob patterns to include while watching for files. Includes '*.py' "
|
||||
"by default; these defaults can be overridden with `--reload-exclude`. "
|
||||
"This option has no effect unless watchfiles is installed.",
|
||||
)
|
||||
@click.option(
|
||||
"--reload-exclude",
|
||||
"reload_excludes",
|
||||
multiple=True,
|
||||
help="Set glob patterns to exclude while watching for files. Includes "
|
||||
"'.*, .py[cod], .sw.*, ~*' by default; these defaults can be overridden "
|
||||
"with `--reload-include`. This option has no effect unless watchfiles is "
|
||||
"installed.",
|
||||
)
|
||||
@click.option(
|
||||
"--reload-delay",
|
||||
type=float,
|
||||
default=0.25,
|
||||
show_default=True,
|
||||
help="Delay between previous and next check if application needs to be." " Defaults to 0.25s.",
|
||||
)
|
||||
@click.option(
|
||||
"--workers",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Number of worker processes. Defaults to the $WEB_CONCURRENCY environment"
|
||||
" variable if available, or 1. Not valid with --reload.",
|
||||
)
|
||||
@click.option(
|
||||
"--loop",
|
||||
type=LOOP_CHOICES,
|
||||
default="auto",
|
||||
help="Event loop implementation.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--http",
|
||||
type=HTTP_CHOICES,
|
||||
default="auto",
|
||||
help="HTTP protocol implementation.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ws",
|
||||
type=WS_CHOICES,
|
||||
default="auto",
|
||||
help="WebSocket protocol implementation.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ws-max-size",
|
||||
type=int,
|
||||
default=16777216,
|
||||
help="WebSocket max size message in bytes",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ws-max-queue",
|
||||
type=int,
|
||||
default=32,
|
||||
help="The maximum length of the WebSocket message queue.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ws-ping-interval",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="WebSocket ping interval in seconds.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ws-ping-timeout",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="WebSocket ping timeout in seconds.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ws-per-message-deflate",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="WebSocket per-message-deflate compression",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--lifespan",
|
||||
type=LIFESPAN_CHOICES,
|
||||
default="auto",
|
||||
help="Lifespan implementation.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--interface",
|
||||
type=INTERFACE_CHOICES,
|
||||
default="auto",
|
||||
help="Select ASGI3, ASGI2, or WSGI as the application interface.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--env-file",
|
||||
type=click.Path(exists=True),
|
||||
default=None,
|
||||
help="Environment configuration file.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--log-config",
|
||||
type=click.Path(exists=True),
|
||||
default=None,
|
||||
help="Logging configuration file. Supported formats: .ini, .json, .yaml.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--log-level",
|
||||
type=LEVEL_CHOICES,
|
||||
default=None,
|
||||
help="Log level. [default: info]",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--access-log/--no-access-log",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help="Enable/Disable access log.",
|
||||
)
|
||||
@click.option(
|
||||
"--use-colors/--no-use-colors",
|
||||
is_flag=True,
|
||||
default=None,
|
||||
help="Enable/Disable colorized logging.",
|
||||
)
|
||||
@click.option(
|
||||
"--proxy-headers/--no-proxy-headers",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to " "populate remote address info.",
|
||||
)
|
||||
@click.option(
|
||||
"--server-header/--no-server-header",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help="Enable/Disable default Server header.",
|
||||
)
|
||||
@click.option(
|
||||
"--date-header/--no-date-header",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help="Enable/Disable default Date header.",
|
||||
)
|
||||
@click.option(
|
||||
"--forwarded-allow-ips",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma separated list of IP Addresses, IP Networks, or literals "
|
||||
"(e.g. UNIX Socket path) to trust with proxy headers. Defaults to the "
|
||||
"$FORWARDED_ALLOW_IPS environment variable if available, or '127.0.0.1'. "
|
||||
"The literal '*' means trust everything.",
|
||||
)
|
||||
@click.option(
|
||||
"--root-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Set the ASGI 'root_path' for applications submounted below a given URL path.",
|
||||
)
|
||||
@click.option(
|
||||
"--limit-concurrency",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of concurrent connections or tasks to allow, before issuing" " HTTP 503 responses.",
|
||||
)
|
||||
@click.option(
|
||||
"--backlog",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum number of connections to hold in backlog",
|
||||
)
|
||||
@click.option(
|
||||
"--limit-max-requests",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of requests to service before terminating the process.",
|
||||
)
|
||||
@click.option(
|
||||
"--timeout-keep-alive",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Close Keep-Alive connections if no new data is received within this timeout.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--timeout-graceful-shutdown",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of seconds to wait for graceful shutdown.",
|
||||
)
|
||||
@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True)
|
||||
@click.option(
|
||||
"--ssl-certfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="SSL certificate file",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ssl-keyfile-password",
|
||||
type=str,
|
||||
default=None,
|
||||
help="SSL keyfile password",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ssl-version",
|
||||
type=int,
|
||||
default=int(SSL_PROTOCOL_VERSION),
|
||||
help="SSL version to use (see stdlib ssl module's)",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ssl-ca-certs",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CA certificates file",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ssl-ciphers",
|
||||
type=str,
|
||||
default="TLSv1",
|
||||
help="Ciphers to use (see stdlib ssl module's)",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--header",
|
||||
"headers",
|
||||
multiple=True,
|
||||
help="Specify custom default HTTP response headers as a Name:Value pair",
|
||||
)
|
||||
@click.option(
|
||||
"--version",
|
||||
is_flag=True,
|
||||
callback=print_version,
|
||||
expose_value=False,
|
||||
is_eager=True,
|
||||
help="Display the uvicorn version and exit.",
|
||||
)
|
||||
@click.option(
|
||||
"--app-dir",
|
||||
default="",
|
||||
show_default=True,
|
||||
help="Look for APP in the specified directory, by adding this to the PYTHONPATH."
|
||||
" Defaults to the current working directory.",
|
||||
)
|
||||
@click.option(
|
||||
"--h11-max-incomplete-event-size",
|
||||
"h11_max_incomplete_event_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="For h11, the maximum number of bytes to buffer of an incomplete event.",
|
||||
)
|
||||
@click.option(
|
||||
"--factory",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Treat APP as an application factory, i.e. a () -> <ASGI app> callable.",
|
||||
show_default=True,
|
||||
)
|
||||
def main(
|
||||
app: str,
|
||||
host: str,
|
||||
port: int,
|
||||
uds: str,
|
||||
fd: int,
|
||||
loop: LoopSetupType,
|
||||
http: HTTPProtocolType,
|
||||
ws: WSProtocolType,
|
||||
ws_max_size: int,
|
||||
ws_max_queue: int,
|
||||
ws_ping_interval: float,
|
||||
ws_ping_timeout: float,
|
||||
ws_per_message_deflate: bool,
|
||||
lifespan: LifespanType,
|
||||
interface: InterfaceType,
|
||||
reload: bool,
|
||||
reload_dirs: list[str],
|
||||
reload_includes: list[str],
|
||||
reload_excludes: list[str],
|
||||
reload_delay: float,
|
||||
workers: int,
|
||||
env_file: str,
|
||||
log_config: str,
|
||||
log_level: str,
|
||||
access_log: bool,
|
||||
proxy_headers: bool,
|
||||
server_header: bool,
|
||||
date_header: bool,
|
||||
forwarded_allow_ips: str,
|
||||
root_path: str,
|
||||
limit_concurrency: int,
|
||||
backlog: int,
|
||||
limit_max_requests: int,
|
||||
timeout_keep_alive: int,
|
||||
timeout_graceful_shutdown: int | None,
|
||||
ssl_keyfile: str,
|
||||
ssl_certfile: str,
|
||||
ssl_keyfile_password: str,
|
||||
ssl_version: int,
|
||||
ssl_cert_reqs: int,
|
||||
ssl_ca_certs: str,
|
||||
ssl_ciphers: str,
|
||||
headers: list[str],
|
||||
use_colors: bool,
|
||||
app_dir: str,
|
||||
h11_max_incomplete_event_size: int | None,
|
||||
factory: bool,
|
||||
) -> None:
|
||||
run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
uds=uds,
|
||||
fd=fd,
|
||||
loop=loop,
|
||||
http=http,
|
||||
ws=ws,
|
||||
ws_max_size=ws_max_size,
|
||||
ws_max_queue=ws_max_queue,
|
||||
ws_ping_interval=ws_ping_interval,
|
||||
ws_ping_timeout=ws_ping_timeout,
|
||||
ws_per_message_deflate=ws_per_message_deflate,
|
||||
lifespan=lifespan,
|
||||
env_file=env_file,
|
||||
log_config=LOGGING_CONFIG if log_config is None else log_config,
|
||||
log_level=log_level,
|
||||
access_log=access_log,
|
||||
interface=interface,
|
||||
reload=reload,
|
||||
reload_dirs=reload_dirs or None,
|
||||
reload_includes=reload_includes or None,
|
||||
reload_excludes=reload_excludes or None,
|
||||
reload_delay=reload_delay,
|
||||
workers=workers,
|
||||
proxy_headers=proxy_headers,
|
||||
server_header=server_header,
|
||||
date_header=date_header,
|
||||
forwarded_allow_ips=forwarded_allow_ips,
|
||||
root_path=root_path,
|
||||
limit_concurrency=limit_concurrency,
|
||||
backlog=backlog,
|
||||
limit_max_requests=limit_max_requests,
|
||||
timeout_keep_alive=timeout_keep_alive,
|
||||
timeout_graceful_shutdown=timeout_graceful_shutdown,
|
||||
ssl_keyfile=ssl_keyfile,
|
||||
ssl_certfile=ssl_certfile,
|
||||
ssl_keyfile_password=ssl_keyfile_password,
|
||||
ssl_version=ssl_version,
|
||||
ssl_cert_reqs=ssl_cert_reqs,
|
||||
ssl_ca_certs=ssl_ca_certs,
|
||||
ssl_ciphers=ssl_ciphers,
|
||||
headers=[header.split(":", 1) for header in headers], # type: ignore[misc]
|
||||
use_colors=use_colors,
|
||||
factory=factory,
|
||||
app_dir=app_dir,
|
||||
h11_max_incomplete_event_size=h11_max_incomplete_event_size,
|
||||
)
|
||||
|
||||
|
||||
def run(
|
||||
app: ASGIApplication | Callable[..., Any] | str,
|
||||
*,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8000,
|
||||
uds: str | None = None,
|
||||
fd: int | None = None,
|
||||
loop: LoopSetupType = "auto",
|
||||
http: type[asyncio.Protocol] | HTTPProtocolType = "auto",
|
||||
ws: type[asyncio.Protocol] | WSProtocolType = "auto",
|
||||
ws_max_size: int = 16777216,
|
||||
ws_max_queue: int = 32,
|
||||
ws_ping_interval: float | None = 20.0,
|
||||
ws_ping_timeout: float | None = 20.0,
|
||||
ws_per_message_deflate: bool = True,
|
||||
lifespan: LifespanType = "auto",
|
||||
interface: InterfaceType = "auto",
|
||||
reload: bool = False,
|
||||
reload_dirs: list[str] | str | None = None,
|
||||
reload_includes: list[str] | str | None = None,
|
||||
reload_excludes: list[str] | str | None = None,
|
||||
reload_delay: float = 0.25,
|
||||
workers: int | None = None,
|
||||
env_file: str | os.PathLike[str] | None = None,
|
||||
log_config: dict[str, Any] | str | RawConfigParser | IO[Any] | None = LOGGING_CONFIG,
|
||||
log_level: str | int | None = None,
|
||||
access_log: bool = True,
|
||||
proxy_headers: bool = True,
|
||||
server_header: bool = True,
|
||||
date_header: bool = True,
|
||||
forwarded_allow_ips: list[str] | str | None = None,
|
||||
root_path: str = "",
|
||||
limit_concurrency: int | None = None,
|
||||
backlog: int = 2048,
|
||||
limit_max_requests: int | None = None,
|
||||
timeout_keep_alive: int = 5,
|
||||
timeout_graceful_shutdown: int | None = None,
|
||||
ssl_keyfile: str | os.PathLike[str] | None = None,
|
||||
ssl_certfile: str | os.PathLike[str] | None = None,
|
||||
ssl_keyfile_password: str | None = None,
|
||||
ssl_version: int = SSL_PROTOCOL_VERSION,
|
||||
ssl_cert_reqs: int = ssl.CERT_NONE,
|
||||
ssl_ca_certs: str | None = None,
|
||||
ssl_ciphers: str = "TLSv1",
|
||||
headers: list[tuple[str, str]] | None = None,
|
||||
use_colors: bool | None = None,
|
||||
app_dir: str | None = None,
|
||||
factory: bool = False,
|
||||
h11_max_incomplete_event_size: int | None = None,
|
||||
) -> None:
|
||||
if app_dir is not None:
|
||||
sys.path.insert(0, app_dir)
|
||||
|
||||
config = Config(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
uds=uds,
|
||||
fd=fd,
|
||||
loop=loop,
|
||||
http=http,
|
||||
ws=ws,
|
||||
ws_max_size=ws_max_size,
|
||||
ws_max_queue=ws_max_queue,
|
||||
ws_ping_interval=ws_ping_interval,
|
||||
ws_ping_timeout=ws_ping_timeout,
|
||||
ws_per_message_deflate=ws_per_message_deflate,
|
||||
lifespan=lifespan,
|
||||
interface=interface,
|
||||
reload=reload,
|
||||
reload_dirs=reload_dirs,
|
||||
reload_includes=reload_includes,
|
||||
reload_excludes=reload_excludes,
|
||||
reload_delay=reload_delay,
|
||||
workers=workers,
|
||||
env_file=env_file,
|
||||
log_config=log_config,
|
||||
log_level=log_level,
|
||||
access_log=access_log,
|
||||
proxy_headers=proxy_headers,
|
||||
server_header=server_header,
|
||||
date_header=date_header,
|
||||
forwarded_allow_ips=forwarded_allow_ips,
|
||||
root_path=root_path,
|
||||
limit_concurrency=limit_concurrency,
|
||||
backlog=backlog,
|
||||
limit_max_requests=limit_max_requests,
|
||||
timeout_keep_alive=timeout_keep_alive,
|
||||
timeout_graceful_shutdown=timeout_graceful_shutdown,
|
||||
ssl_keyfile=ssl_keyfile,
|
||||
ssl_certfile=ssl_certfile,
|
||||
ssl_keyfile_password=ssl_keyfile_password,
|
||||
ssl_version=ssl_version,
|
||||
ssl_cert_reqs=ssl_cert_reqs,
|
||||
ssl_ca_certs=ssl_ca_certs,
|
||||
ssl_ciphers=ssl_ciphers,
|
||||
headers=headers,
|
||||
use_colors=use_colors,
|
||||
factory=factory,
|
||||
h11_max_incomplete_event_size=h11_max_incomplete_event_size,
|
||||
)
|
||||
server = Server(config=config)
|
||||
|
||||
if (config.reload or config.workers > 1) and not isinstance(app, str):
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger.warning("You must pass the application as an import string to enable 'reload' or " "'workers'.")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
if config.should_reload:
|
||||
sock = config.bind_socket()
|
||||
ChangeReload(config, target=server.run, sockets=[sock]).run()
|
||||
elif config.workers > 1:
|
||||
sock = config.bind_socket()
|
||||
Multiprocess(config, target=server.run, sockets=[sock]).run()
|
||||
else:
|
||||
server.run()
|
||||
except KeyboardInterrupt:
|
||||
pass # pragma: full coverage
|
||||
finally:
|
||||
if config.uds and os.path.exists(config.uds):
|
||||
os.remove(config.uds) # pragma: py-win32
|
||||
|
||||
if not server.started and not config.should_reload and config.workers == 1:
|
||||
sys.exit(STARTUP_FAILURE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pragma: no cover
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,15 @@
|
||||
from uvicorn._types import (
|
||||
ASGI2Application,
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
Scope,
|
||||
)
|
||||
|
||||
|
||||
class ASGI2Middleware:
|
||||
def __init__(self, app: "ASGI2Application"):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None:
|
||||
instance = self.app(scope)
|
||||
await instance(receive, send)
|
@ -0,0 +1,87 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveCallable,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendCallable,
|
||||
ASGISendEvent,
|
||||
WWWScope,
|
||||
)
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
|
||||
PLACEHOLDER_FORMAT = {
|
||||
"body": "<{length} bytes>",
|
||||
"bytes": "<{length} bytes>",
|
||||
"text": "<{length} chars>",
|
||||
"headers": "<...>",
|
||||
}
|
||||
|
||||
|
||||
def message_with_placeholders(message: Any) -> Any:
|
||||
"""
|
||||
Return an ASGI message, with any body-type content omitted and replaced
|
||||
with a placeholder.
|
||||
"""
|
||||
new_message = message.copy()
|
||||
for attr in PLACEHOLDER_FORMAT.keys():
|
||||
if message.get(attr) is not None:
|
||||
content = message[attr]
|
||||
placeholder = PLACEHOLDER_FORMAT[attr].format(length=len(content))
|
||||
new_message[attr] = placeholder
|
||||
return new_message
|
||||
|
||||
|
||||
class MessageLoggerMiddleware:
|
||||
def __init__(self, app: "ASGI3Application"):
|
||||
self.task_counter = 0
|
||||
self.app = app
|
||||
self.logger = logging.getLogger("uvicorn.asgi")
|
||||
|
||||
def trace(message: Any, *args: Any, **kwargs: Any) -> None:
|
||||
self.logger.log(TRACE_LOG_LEVEL, message, *args, **kwargs)
|
||||
|
||||
self.logger.trace = trace # type: ignore
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
scope: "WWWScope",
|
||||
receive: "ASGIReceiveCallable",
|
||||
send: "ASGISendCallable",
|
||||
) -> None:
|
||||
self.task_counter += 1
|
||||
|
||||
task_counter = self.task_counter
|
||||
client = scope.get("client")
|
||||
prefix = "%s:%d - ASGI" % (client[0], client[1]) if client else "ASGI"
|
||||
|
||||
async def inner_receive() -> "ASGIReceiveEvent":
|
||||
message = await receive()
|
||||
logged_message = message_with_placeholders(message)
|
||||
log_text = "%s [%d] Receive %s"
|
||||
self.logger.trace( # type: ignore
|
||||
log_text, prefix, task_counter, logged_message
|
||||
)
|
||||
return message
|
||||
|
||||
async def inner_send(message: "ASGISendEvent") -> None:
|
||||
logged_message = message_with_placeholders(message)
|
||||
log_text = "%s [%d] Send %s"
|
||||
self.logger.trace( # type: ignore
|
||||
log_text, prefix, task_counter, logged_message
|
||||
)
|
||||
await send(message)
|
||||
|
||||
logged_scope = message_with_placeholders(scope)
|
||||
log_text = "%s [%d] Started scope=%s"
|
||||
self.logger.trace(log_text, prefix, task_counter, logged_scope) # type: ignore
|
||||
try:
|
||||
await self.app(scope, inner_receive, inner_send)
|
||||
except BaseException as exc:
|
||||
log_text = "%s [%d] Raised exception"
|
||||
self.logger.trace(log_text, prefix, task_counter) # type: ignore
|
||||
raise exc from None
|
||||
else:
|
||||
log_text = "%s [%d] Completed"
|
||||
self.logger.trace(log_text, prefix, task_counter) # type: ignore
|
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
|
||||
from uvicorn._types import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
|
||||
|
||||
class ProxyHeadersMiddleware:
|
||||
"""Middleware for handling known proxy headers
|
||||
|
||||
This middleware can be used when a known proxy is fronting the application,
|
||||
and is trusted to be properly setting the `X-Forwarded-Proto` and
|
||||
`X-Forwarded-For` headers with the connecting client information.
|
||||
|
||||
Modifies the `client` and `scheme` information so that they reference
|
||||
the connecting client, rather that the connecting proxy.
|
||||
|
||||
References:
|
||||
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies>
|
||||
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For>
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1") -> None:
|
||||
self.app = app
|
||||
self.trusted_hosts = _TrustedHosts(trusted_hosts)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
if scope["type"] == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
client_addr = scope.get("client")
|
||||
client_host = client_addr[0] if client_addr else None
|
||||
|
||||
if client_host in self.trusted_hosts:
|
||||
headers = dict(scope["headers"])
|
||||
|
||||
if b"x-forwarded-proto" in headers:
|
||||
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1").strip()
|
||||
|
||||
if x_forwarded_proto in {"http", "https", "ws", "wss"}:
|
||||
if scope["type"] == "websocket":
|
||||
scope["scheme"] = x_forwarded_proto.replace("http", "ws")
|
||||
else:
|
||||
scope["scheme"] = x_forwarded_proto
|
||||
|
||||
if b"x-forwarded-for" in headers:
|
||||
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
|
||||
host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)
|
||||
|
||||
if host:
|
||||
# If the x-forwarded-for header is empty then host is an empty string.
|
||||
# Only set the client if we actually got something usable.
|
||||
# See: https://github.com/encode/uvicorn/issues/1068
|
||||
|
||||
# We've lost the connecting client's port information by now,
|
||||
# so only include the host.
|
||||
port = 0
|
||||
scope["client"] = (host, port)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def _parse_raw_hosts(value: str) -> list[str]:
|
||||
return [item.strip() for item in value.split(",")]
|
||||
|
||||
|
||||
class _TrustedHosts:
|
||||
"""Container for trusted hosts and networks"""
|
||||
|
||||
def __init__(self, trusted_hosts: list[str] | str) -> None:
|
||||
self.always_trust: bool = trusted_hosts in ("*", ["*"])
|
||||
|
||||
self.trusted_literals: set[str] = set()
|
||||
self.trusted_hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
||||
self.trusted_networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()
|
||||
|
||||
# Notes:
|
||||
# - We separate hosts from literals as there are many ways to write
|
||||
# an IPv6 Address so we need to compare by object.
|
||||
# - We don't convert IP Address to single host networks (e.g. /32 / 128) as
|
||||
# it more efficient to do an address lookup in a set than check for
|
||||
# membership in each network.
|
||||
# - We still allow literals as it might be possible that we receive a
|
||||
# something that isn't an IP Address e.g. a unix socket.
|
||||
|
||||
if not self.always_trust:
|
||||
if isinstance(trusted_hosts, str):
|
||||
trusted_hosts = _parse_raw_hosts(trusted_hosts)
|
||||
|
||||
for host in trusted_hosts:
|
||||
# Note: because we always convert invalid IP types to literals it
|
||||
# is not possible for the user to know they provided a malformed IP
|
||||
# type - this may lead to unexpected / difficult to debug behaviour.
|
||||
|
||||
if "/" in host:
|
||||
# Looks like a network
|
||||
try:
|
||||
self.trusted_networks.add(ipaddress.ip_network(host))
|
||||
except ValueError:
|
||||
# Was not a valid IP Network
|
||||
self.trusted_literals.add(host)
|
||||
else:
|
||||
try:
|
||||
self.trusted_hosts.add(ipaddress.ip_address(host))
|
||||
except ValueError:
|
||||
# Was not a valid IP Address
|
||||
self.trusted_literals.add(host)
|
||||
|
||||
def __contains__(self, host: str | None) -> bool:
|
||||
if self.always_trust:
|
||||
return True
|
||||
|
||||
if not host:
|
||||
return False
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if ip in self.trusted_hosts:
|
||||
return True
|
||||
return any(ip in net for net in self.trusted_networks)
|
||||
|
||||
except ValueError:
|
||||
return host in self.trusted_literals
|
||||
|
||||
def get_trusted_client_host(self, x_forwarded_for: str) -> str:
|
||||
"""Extract the client host from x_forwarded_for header
|
||||
|
||||
In general this is the first "untrusted" host in the forwarded for list.
|
||||
"""
|
||||
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
|
||||
|
||||
if self.always_trust:
|
||||
return x_forwarded_for_hosts[0]
|
||||
|
||||
# Note: each proxy appends to the header list so check it in reverse order
|
||||
for host in reversed(x_forwarded_for_hosts):
|
||||
if host not in self:
|
||||
return host
|
||||
|
||||
# All hosts are trusted meaning that the client was also a trusted proxy
|
||||
# See https://github.com/encode/uvicorn/issues/1068#issuecomment-855371576
|
||||
return x_forwarded_for_hosts[0]
|
200
venv/lib/python3.11/site-packages/uvicorn/middleware/wsgi.py
Normal file
200
venv/lib/python3.11/site-packages/uvicorn/middleware/wsgi.py
Normal file
@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import io
|
||||
import sys
|
||||
import warnings
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGIReceiveCallable,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendCallable,
|
||||
ASGISendEvent,
|
||||
Environ,
|
||||
ExcInfo,
|
||||
HTTPRequestEvent,
|
||||
HTTPResponseBodyEvent,
|
||||
HTTPResponseStartEvent,
|
||||
HTTPScope,
|
||||
StartResponse,
|
||||
WSGIApp,
|
||||
)
|
||||
|
||||
|
||||
def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: io.BytesIO) -> Environ:
|
||||
"""
|
||||
Builds a scope and request message into a WSGI environ object.
|
||||
"""
|
||||
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
|
||||
path_info = scope["path"].encode("utf8").decode("latin1")
|
||||
if path_info.startswith(script_name):
|
||||
path_info = path_info[len(script_name) :]
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": script_name,
|
||||
"PATH_INFO": path_info,
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": scope.get("scheme", "http"),
|
||||
"wsgi.input": body,
|
||||
"wsgi.errors": sys.stdout,
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
}
|
||||
|
||||
# Get server name and port - required in WSGI, not in ASGI
|
||||
server = scope.get("server")
|
||||
if server is None:
|
||||
server = ("localhost", 80)
|
||||
environ["SERVER_NAME"] = server[0]
|
||||
environ["SERVER_PORT"] = server[1]
|
||||
|
||||
# Get client IP address
|
||||
client = scope.get("client")
|
||||
if client is not None:
|
||||
environ["REMOTE_ADDR"] = client[0]
|
||||
|
||||
# Go through headers and make them into environ entries
|
||||
for name, value in scope.get("headers", []):
|
||||
name_str: str = name.decode("latin1")
|
||||
if name_str == "content-length":
|
||||
corrected_name = "CONTENT_LENGTH"
|
||||
elif name_str == "content-type":
|
||||
corrected_name = "CONTENT_TYPE"
|
||||
else:
|
||||
corrected_name = "HTTP_%s" % name_str.upper().replace("-", "_")
|
||||
# HTTPbis say only ASCII chars are allowed in headers, but we latin1
|
||||
# just in case
|
||||
value_str: str = value.decode("latin1")
|
||||
if corrected_name in environ:
|
||||
corrected_name_environ = environ[corrected_name]
|
||||
assert isinstance(corrected_name_environ, str)
|
||||
value_str = corrected_name_environ + "," + value_str
|
||||
environ[corrected_name] = value_str
|
||||
return environ
|
||||
|
||||
|
||||
class _WSGIMiddleware:
|
||||
def __init__(self, app: WSGIApp, workers: int = 10):
|
||||
warnings.warn(
|
||||
"Uvicorn's native WSGI implementation is deprecated, you "
|
||||
"should switch to a2wsgi (`pip install a2wsgi`).",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.app = app
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
scope: HTTPScope,
|
||||
receive: ASGIReceiveCallable,
|
||||
send: ASGISendCallable,
|
||||
) -> None:
|
||||
assert scope["type"] == "http"
|
||||
instance = WSGIResponder(self.app, self.executor, scope)
|
||||
await instance(receive, send)
|
||||
|
||||
|
||||
class WSGIResponder:
|
||||
def __init__(
|
||||
self,
|
||||
app: WSGIApp,
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
scope: HTTPScope,
|
||||
):
|
||||
self.app = app
|
||||
self.executor = executor
|
||||
self.scope = scope
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.send_event = asyncio.Event()
|
||||
self.send_queue: deque[ASGISendEvent | None] = deque()
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
self.response_started = False
|
||||
self.exc_info: ExcInfo | None = None
|
||||
|
||||
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
|
||||
body = io.BytesIO(message.get("body", b""))
|
||||
more_body = message.get("more_body", False)
|
||||
if more_body:
|
||||
body.seek(0, io.SEEK_END)
|
||||
while more_body:
|
||||
body_message: HTTPRequestEvent = (
|
||||
await receive() # type: ignore[assignment]
|
||||
)
|
||||
body.write(body_message.get("body", b""))
|
||||
more_body = body_message.get("more_body", False)
|
||||
body.seek(0)
|
||||
environ = build_environ(self.scope, message, body)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
wsgi = self.loop.run_in_executor(self.executor, self.wsgi, environ, self.start_response)
|
||||
sender = self.loop.create_task(self.sender(send))
|
||||
try:
|
||||
await asyncio.wait_for(wsgi, None)
|
||||
finally:
|
||||
self.send_queue.append(None)
|
||||
self.send_event.set()
|
||||
await asyncio.wait_for(sender, None)
|
||||
if self.exc_info is not None:
|
||||
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
|
||||
|
||||
async def sender(self, send: ASGISendCallable) -> None:
|
||||
while True:
|
||||
if self.send_queue:
|
||||
message = self.send_queue.popleft()
|
||||
if message is None:
|
||||
return
|
||||
await send(message)
|
||||
else:
|
||||
await self.send_event.wait()
|
||||
self.send_event.clear()
|
||||
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: Iterable[tuple[str, str]],
|
||||
exc_info: ExcInfo | None = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
self.response_started = True
|
||||
status_code_str, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_str)
|
||||
headers = [(name.encode("ascii"), value.encode("ascii")) for name, value in response_headers]
|
||||
http_response_start_event: HTTPResponseStartEvent = {
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
"headers": headers,
|
||||
}
|
||||
self.send_queue.append(http_response_start_event)
|
||||
self.loop.call_soon_threadsafe(self.send_event.set)
|
||||
|
||||
def wsgi(self, environ: Environ, start_response: StartResponse) -> None:
|
||||
for chunk in self.app(environ, start_response): # type: ignore
|
||||
response_body: HTTPResponseBodyEvent = {
|
||||
"type": "http.response.body",
|
||||
"body": chunk,
|
||||
"more_body": True,
|
||||
}
|
||||
self.send_queue.append(response_body)
|
||||
self.loop.call_soon_threadsafe(self.send_event.set)
|
||||
|
||||
empty_body: HTTPResponseBodyEvent = {
|
||||
"type": "http.response.body",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
self.send_queue.append(empty_body)
|
||||
self.loop.call_soon_threadsafe(self.send_event.set)
|
||||
|
||||
|
||||
try:
|
||||
from a2wsgi import WSGIMiddleware
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
WSGIMiddleware = _WSGIMiddleware # type: ignore[misc, assignment]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
AutoHTTPProtocol: type[asyncio.Protocol]
|
||||
try:
|
||||
import httptools # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
|
||||
AutoHTTPProtocol = H11Protocol
|
||||
else: # pragma: no cover
|
||||
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
|
||||
|
||||
AutoHTTPProtocol = HttpToolsProtocol
|
@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
|
||||
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
|
||||
CLOSE_HEADER = (b"connection", b"close")
|
||||
|
||||
HIGH_WATER_LIMIT = 65536
|
||||
|
||||
|
||||
class FlowControl:
|
||||
def __init__(self, transport: asyncio.Transport) -> None:
|
||||
self._transport = transport
|
||||
self.read_paused = False
|
||||
self.write_paused = False
|
||||
self._is_writable_event = asyncio.Event()
|
||||
self._is_writable_event.set()
|
||||
|
||||
async def drain(self) -> None:
|
||||
await self._is_writable_event.wait() # pragma: full coverage
|
||||
|
||||
def pause_reading(self) -> None:
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self._transport.pause_reading()
|
||||
|
||||
def resume_reading(self) -> None:
|
||||
if self.read_paused:
|
||||
self.read_paused = False
|
||||
self._transport.resume_reading()
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
if not self.write_paused: # pragma: full coverage
|
||||
self.write_paused = True
|
||||
self._is_writable_event.clear()
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
if self.write_paused: # pragma: full coverage
|
||||
self.write_paused = False
|
||||
self._is_writable_event.set()
|
||||
|
||||
|
||||
async def service_unavailable(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 503,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"content-length", b"19"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"Service Unavailable", "more_body": False})
|
@ -0,0 +1,543 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
from typing import Any, Callable, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import h11
|
||||
from h11._connection import DEFAULT_MAX_INCOMPLETE_EVENT_SIZE
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
HTTPRequestEvent,
|
||||
HTTPResponseBodyEvent,
|
||||
HTTPResponseStartEvent,
|
||||
HTTPScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
def _get_status_phrase(status_code: int) -> bytes:
|
||||
try:
|
||||
return http.HTTPStatus(status_code).phrase.encode()
|
||||
except ValueError:
|
||||
return b""
|
||||
|
||||
|
||||
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class H11Protocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
self.conn = h11.Connection(
|
||||
h11.SERVER,
|
||||
config.h11_max_incomplete_event_size
|
||||
if config.h11_max_incomplete_event_size is not None
|
||||
else DEFAULT_MAX_INCOMPLETE_EVENT_SIZE,
|
||||
)
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Shared server state
|
||||
self.server_state = server_state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
|
||||
self.transport = transport
|
||||
self.flow = FlowControl(transport)
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "https" if is_ssl(transport) else "http"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.discard(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)
|
||||
|
||||
if self.cycle and not self.cycle.response_complete:
|
||||
self.cycle.disconnected = True
|
||||
if self.conn.our_state != h11.ERROR:
|
||||
event = h11.ConnectionClosed()
|
||||
try:
|
||||
self.conn.send(event)
|
||||
except h11.LocalProtocolError:
|
||||
# Premature client disconnect
|
||||
pass
|
||||
|
||||
if self.cycle is not None:
|
||||
self.cycle.message_event.set()
|
||||
if self.flow is not None:
|
||||
self.flow.resume_writing()
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def _unset_keepalive_if_required(self) -> None:
|
||||
if self.timeout_keep_alive_task is not None:
|
||||
self.timeout_keep_alive_task.cancel()
|
||||
self.timeout_keep_alive_task = None
|
||||
|
||||
def _get_upgrade(self) -> bytes | None:
|
||||
connection = []
|
||||
upgrade = None
|
||||
for name, value in self.headers:
|
||||
if name == b"connection":
|
||||
connection = [token.lower().strip() for token in value.split(b",")]
|
||||
if name == b"upgrade":
|
||||
upgrade = value.lower()
|
||||
if b"upgrade" in connection:
|
||||
return upgrade
|
||||
return None
|
||||
|
||||
def _should_upgrade_to_ws(self) -> bool:
|
||||
if self.ws_protocol_class is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _unsupported_upgrade_warning(self) -> None:
|
||||
msg = "Unsupported upgrade request."
|
||||
self.logger.warning(msg)
|
||||
if not self._should_upgrade_to_ws():
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
self.logger.warning(msg)
|
||||
|
||||
def _should_upgrade(self) -> bool:
|
||||
upgrade = self._get_upgrade()
|
||||
if upgrade == b"websocket" and self._should_upgrade_to_ws():
|
||||
return True
|
||||
if upgrade is not None:
|
||||
self._unsupported_upgrade_warning()
|
||||
return False
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.conn.receive_data(data)
|
||||
self.handle_events()
|
||||
|
||||
def handle_events(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
event = self.conn.next_event()
|
||||
except h11.RemoteProtocolError:
|
||||
msg = "Invalid HTTP request received."
|
||||
self.logger.warning(msg)
|
||||
self.send_400_response(msg)
|
||||
return
|
||||
|
||||
if event is h11.NEED_DATA:
|
||||
break
|
||||
|
||||
elif event is h11.PAUSED:
|
||||
# This case can occur in HTTP pipelining, so we need to
|
||||
# stop reading any more data, and ensure that at the end
|
||||
# of the active request/response cycle we handle any
|
||||
# events that have been buffered up.
|
||||
self.flow.pause_reading()
|
||||
break
|
||||
|
||||
elif isinstance(event, h11.Request):
|
||||
self.headers = [(key.lower(), value) for key, value in event.headers]
|
||||
raw_path, _, query_string = event.target.partition(b"?")
|
||||
path = unquote(raw_path.decode("ascii"))
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path
|
||||
self.scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"http_version": event.http_version.decode("ascii"),
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"scheme": self.scheme, # type: ignore[typeddict-item]
|
||||
"method": event.method.decode("ascii"),
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
}
|
||||
if self._should_upgrade():
|
||||
self.handle_websocket_upgrade(event)
|
||||
return
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
self.logger.warning(message)
|
||||
else:
|
||||
app = self.app
|
||||
|
||||
# When starting to process a request, disable the keep-alive
|
||||
# timeout. Normally we disable this when receiving data from
|
||||
# client and set back when finishing processing its request.
|
||||
# However, for pipelined requests processing finishes after
|
||||
# already receiving the next request and thus the timer may
|
||||
# be set here, which we don't want.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.cycle = RequestResponseCycle(
|
||||
scope=self.scope,
|
||||
conn=self.conn,
|
||||
transport=self.transport,
|
||||
flow=self.flow,
|
||||
logger=self.logger,
|
||||
access_logger=self.access_logger,
|
||||
access_log=self.access_log,
|
||||
default_headers=self.server_state.default_headers,
|
||||
message_event=asyncio.Event(),
|
||||
on_response=self.on_response_complete,
|
||||
)
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
|
||||
elif isinstance(event, h11.Data):
|
||||
if self.conn.our_state is h11.DONE:
|
||||
continue
|
||||
self.cycle.body += event.data
|
||||
if len(self.cycle.body) > HIGH_WATER_LIMIT:
|
||||
self.flow.pause_reading()
|
||||
self.cycle.message_event.set()
|
||||
|
||||
elif isinstance(event, h11.EndOfMessage):
|
||||
if self.conn.our_state is h11.DONE:
|
||||
self.transport.resume_reading()
|
||||
self.conn.start_next_cycle()
|
||||
continue
|
||||
self.cycle.more_body = False
|
||||
self.cycle.message_event.set()
|
||||
if self.conn.their_state == h11.MUST_CLOSE:
|
||||
break
|
||||
|
||||
def handle_websocket_upgrade(self, event: h11.Request) -> None:
|
||||
if self.logger.level <= TRACE_LOG_LEVEL: # pragma: full coverage
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
|
||||
|
||||
self.connections.discard(self)
|
||||
output = [event.method, b" ", event.target, b" HTTP/1.1\r\n"]
|
||||
for name, value in self.headers:
|
||||
output += [name, b": ", value, b"\r\n"]
|
||||
output.append(b"\r\n")
|
||||
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
|
||||
config=self.config,
|
||||
server_state=self.server_state,
|
||||
app_state=self.app_state,
|
||||
)
|
||||
protocol.connection_made(self.transport)
|
||||
protocol.data_received(b"".join(output))
|
||||
self.transport.set_protocol(protocol)
|
||||
|
||||
def send_400_response(self, msg: str) -> None:
|
||||
reason = STATUS_PHRASES[400]
|
||||
headers: list[tuple[bytes, bytes]] = [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
]
|
||||
event = h11.Response(status_code=400, headers=headers, reason=reason)
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
|
||||
output = self.conn.send(event=h11.Data(data=msg.encode("ascii")))
|
||||
self.transport.write(output)
|
||||
|
||||
output = self.conn.send(event=h11.EndOfMessage())
|
||||
self.transport.write(output)
|
||||
|
||||
self.transport.close()
|
||||
|
||||
def on_response_complete(self) -> None:
|
||||
self.server_state.total_requests += 1
|
||||
|
||||
if self.transport.is_closing():
|
||||
return
|
||||
|
||||
# Set a short Keep-Alive timeout.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler)
|
||||
|
||||
# Unpause data reads if needed.
|
||||
self.flow.resume_reading()
|
||||
|
||||
# Unblock any pipelined events.
|
||||
if self.conn.our_state is h11.DONE and self.conn.their_state is h11.DONE:
|
||||
self.conn.start_next_cycle()
|
||||
self.handle_events()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Called by the server to commence a graceful shutdown.
|
||||
"""
|
||||
if self.cycle is None or self.cycle.response_complete:
|
||||
event = h11.ConnectionClosed()
|
||||
self.conn.send(event)
|
||||
self.transport.close()
|
||||
else:
|
||||
self.cycle.keep_alive = False
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.flow.pause_writing() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.flow.resume_writing() # pragma: full coverage
|
||||
|
||||
def timeout_keep_alive_handler(self) -> None:
|
||||
"""
|
||||
Called on a keep-alive connection if no new data is received after a short
|
||||
delay.
|
||||
"""
|
||||
if not self.transport.is_closing():
|
||||
event = h11.ConnectionClosed()
|
||||
self.conn.send(event)
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: HTTPScope,
|
||||
conn: h11.Connection,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
logger: logging.Logger,
|
||||
access_logger: logging.Logger,
|
||||
access_log: bool,
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
on_response: Callable[..., None],
|
||||
) -> None:
|
||||
self.scope = scope
|
||||
self.conn = conn
|
||||
self.transport = transport
|
||||
self.flow = flow
|
||||
self.logger = logger
|
||||
self.access_logger = access_logger
|
||||
self.access_log = access_log
|
||||
self.default_headers = default_headers
|
||||
self.message_event = message_event
|
||||
self.on_response = on_response
|
||||
|
||||
# Connection state
|
||||
self.disconnected = False
|
||||
self.keep_alive = True
|
||||
self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue
|
||||
|
||||
# Request state
|
||||
self.body = b""
|
||||
self.more_body = True
|
||||
|
||||
# Response state
|
||||
self.response_started = False
|
||||
self.response_complete = False
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
try:
|
||||
result = await app( # type: ignore[func-returns-value]
|
||||
self.scope, self.receive, self.send
|
||||
)
|
||||
except BaseException as exc:
|
||||
msg = "Exception in ASGI application\n"
|
||||
self.logger.error(msg, exc_info=exc)
|
||||
if not self.response_started:
|
||||
await self.send_500_response()
|
||||
else:
|
||||
self.transport.close()
|
||||
else:
|
||||
if result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
self.transport.close()
|
||||
elif not self.response_started and not self.disconnected:
|
||||
msg = "ASGI callable returned without starting response."
|
||||
self.logger.error(msg)
|
||||
await self.send_500_response()
|
||||
elif not self.response_complete and not self.disconnected:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
self.transport.close()
|
||||
finally:
|
||||
self.on_response = lambda: None
|
||||
|
||||
async def send_500_response(self) -> None:
|
||||
response_start_event: HTTPResponseStartEvent = {
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
await self.send(response_start_event)
|
||||
response_body_event: HTTPResponseBodyEvent = {
|
||||
"type": "http.response.body",
|
||||
"body": b"Internal Server Error",
|
||||
"more_body": False,
|
||||
}
|
||||
await self.send(response_body_event)
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
if self.disconnected:
|
||||
return # pragma: full coverage
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
if message_type != "http.response.start":
|
||||
msg = "Expected ASGI message 'http.response.start', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
message = cast("HTTPResponseStartEvent", message)
|
||||
|
||||
self.response_started = True
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
status = message["status"]
|
||||
headers = self.default_headers + list(message.get("headers", []))
|
||||
|
||||
if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
|
||||
headers = headers + [CLOSE_HEADER]
|
||||
|
||||
if self.access_log:
|
||||
self.access_logger.info(
|
||||
'%s - "%s %s HTTP/%s" %d',
|
||||
get_client_addr(self.scope),
|
||||
self.scope["method"],
|
||||
get_path_with_query_string(self.scope),
|
||||
self.scope["http_version"],
|
||||
status,
|
||||
)
|
||||
|
||||
# Write response status line and headers
|
||||
reason = STATUS_PHRASES[status]
|
||||
response = h11.Response(status_code=status, headers=headers, reason=reason)
|
||||
output = self.conn.send(event=response)
|
||||
self.transport.write(output)
|
||||
|
||||
elif not self.response_complete:
|
||||
# Sending response body
|
||||
if message_type != "http.response.body":
|
||||
msg = "Expected ASGI message 'http.response.body', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
message = cast("HTTPResponseBodyEvent", message)
|
||||
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# Write response body
|
||||
data = b"" if self.scope["method"] == "HEAD" else body
|
||||
output = self.conn.send(event=h11.Data(data=data))
|
||||
self.transport.write(output)
|
||||
|
||||
# Handle response completion
|
||||
if not more_body:
|
||||
self.response_complete = True
|
||||
self.message_event.set()
|
||||
output = self.conn.send(event=h11.EndOfMessage())
|
||||
self.transport.write(output)
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
msg = "Unexpected ASGI message '%s' sent, after response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
if self.response_complete:
|
||||
if self.conn.our_state is h11.MUST_CLOSE or not self.keep_alive:
|
||||
self.conn.send(event=h11.ConnectionClosed())
|
||||
self.transport.close()
|
||||
self.on_response()
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
headers: list[tuple[str, str]] = []
|
||||
event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue")
|
||||
output = self.conn.send(event=event)
|
||||
self.transport.write(output)
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
if not self.disconnected and not self.response_complete:
|
||||
self.flow.resume_reading()
|
||||
await self.message_event.wait()
|
||||
self.message_event.clear()
|
||||
|
||||
if self.disconnected or self.response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
message: HTTPRequestEvent = {
|
||||
"type": "http.request",
|
||||
"body": self.body,
|
||||
"more_body": self.more_body,
|
||||
}
|
||||
self.body = b""
|
||||
return message
|
@ -0,0 +1,570 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
import re
|
||||
import urllib
|
||||
from asyncio.events import TimerHandle
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Literal, cast
|
||||
|
||||
import httptools
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
HTTPRequestEvent,
|
||||
HTTPResponseStartEvent,
|
||||
HTTPScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]')
|
||||
HEADER_VALUE_RE = re.compile(b"[\x00-\x08\x0a-\x1f\x7f]")
|
||||
|
||||
|
||||
def _get_status_line(status_code: int) -> bytes:
|
||||
try:
|
||||
phrase = http.HTTPStatus(status_code).phrase.encode()
|
||||
except ValueError:
|
||||
phrase = b""
|
||||
return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
|
||||
|
||||
|
||||
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class HttpToolsProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
self.parser = httptools.HttpRequestParser(self)
|
||||
|
||||
try:
|
||||
# Enable dangerous leniencies to allow server to a response on the first request from a pipelined request.
|
||||
self.parser.set_dangerous_leniencies(lenient_data_after_close=True)
|
||||
except AttributeError: # pragma: no cover
|
||||
# httptools < 0.6.3
|
||||
pass
|
||||
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Global state
|
||||
self.server_state = server_state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.expect_100_continue = False
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
|
||||
self.transport = transport
|
||||
self.flow = FlowControl(transport)
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "https" if is_ssl(transport) else "http"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.discard(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)
|
||||
|
||||
if self.cycle and not self.cycle.response_complete:
|
||||
self.cycle.disconnected = True
|
||||
if self.cycle is not None:
|
||||
self.cycle.message_event.set()
|
||||
if self.flow is not None:
|
||||
self.flow.resume_writing()
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.parser = None
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def _unset_keepalive_if_required(self) -> None:
|
||||
if self.timeout_keep_alive_task is not None:
|
||||
self.timeout_keep_alive_task.cancel()
|
||||
self.timeout_keep_alive_task = None
|
||||
|
||||
def _get_upgrade(self) -> bytes | None:
|
||||
connection = []
|
||||
upgrade = None
|
||||
for name, value in self.headers:
|
||||
if name == b"connection":
|
||||
connection = [token.lower().strip() for token in value.split(b",")]
|
||||
if name == b"upgrade":
|
||||
upgrade = value.lower()
|
||||
if b"upgrade" in connection:
|
||||
return upgrade
|
||||
return None # pragma: full coverage
|
||||
|
||||
def _should_upgrade_to_ws(self) -> bool:
|
||||
if self.ws_protocol_class is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _unsupported_upgrade_warning(self) -> None:
|
||||
self.logger.warning("Unsupported upgrade request.")
|
||||
if not self._should_upgrade_to_ws():
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
self.logger.warning(msg)
|
||||
|
||||
def _should_upgrade(self) -> bool:
|
||||
upgrade = self._get_upgrade()
|
||||
return upgrade == b"websocket" and self._should_upgrade_to_ws()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
try:
|
||||
self.parser.feed_data(data)
|
||||
except httptools.HttpParserError:
|
||||
msg = "Invalid HTTP request received."
|
||||
self.logger.warning(msg)
|
||||
self.send_400_response(msg)
|
||||
return
|
||||
except httptools.HttpParserUpgrade:
|
||||
if self._should_upgrade():
|
||||
self.handle_websocket_upgrade()
|
||||
else:
|
||||
self._unsupported_upgrade_warning()
|
||||
|
||||
def handle_websocket_upgrade(self) -> None:
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
|
||||
|
||||
self.connections.discard(self)
|
||||
method = self.scope["method"].encode()
|
||||
output = [method, b" ", self.url, b" HTTP/1.1\r\n"]
|
||||
for name, value in self.scope["headers"]:
|
||||
output += [name, b": ", value, b"\r\n"]
|
||||
output.append(b"\r\n")
|
||||
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
|
||||
config=self.config,
|
||||
server_state=self.server_state,
|
||||
app_state=self.app_state,
|
||||
)
|
||||
protocol.connection_made(self.transport)
|
||||
protocol.data_received(b"".join(output))
|
||||
self.transport.set_protocol(protocol)
|
||||
|
||||
def send_400_response(self, msg: str) -> None:
|
||||
content = [STATUS_LINE[400]]
|
||||
for name, value in self.server_state.default_headers:
|
||||
content.extend([name, b": ", value, b"\r\n"]) # pragma: full coverage
|
||||
content.extend(
|
||||
[
|
||||
b"content-type: text/plain; charset=utf-8\r\n",
|
||||
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
|
||||
b"connection: close\r\n",
|
||||
b"\r\n",
|
||||
msg.encode("ascii"),
|
||||
]
|
||||
)
|
||||
self.transport.write(b"".join(content))
|
||||
self.transport.close()
|
||||
|
||||
def on_message_begin(self) -> None:
|
||||
self.url = b""
|
||||
self.expect_100_continue = False
|
||||
self.headers = []
|
||||
self.scope = { # type: ignore[typeddict-item]
|
||||
"type": "http",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"http_version": "1.1",
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"scheme": self.scheme, # type: ignore[typeddict-item]
|
||||
"root_path": self.root_path,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
}
|
||||
|
||||
# Parser callbacks
|
||||
def on_url(self, url: bytes) -> None:
|
||||
self.url += url
|
||||
|
||||
def on_header(self, name: bytes, value: bytes) -> None:
|
||||
name = name.lower()
|
||||
if name == b"expect" and value.lower() == b"100-continue":
|
||||
self.expect_100_continue = True
|
||||
self.headers.append((name, value))
|
||||
|
||||
def on_headers_complete(self) -> None:
|
||||
http_version = self.parser.get_http_version()
|
||||
method = self.parser.get_method()
|
||||
self.scope["method"] = method.decode("ascii")
|
||||
if http_version != "1.1":
|
||||
self.scope["http_version"] = http_version
|
||||
if self.parser.should_upgrade() and self._should_upgrade():
|
||||
return
|
||||
parsed_url = httptools.parse_url(self.url)
|
||||
raw_path = parsed_url.path
|
||||
path = raw_path.decode("ascii")
|
||||
if "%" in path:
|
||||
path = urllib.parse.unquote(path)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path
|
||||
self.scope["path"] = full_path
|
||||
self.scope["raw_path"] = full_raw_path
|
||||
self.scope["query_string"] = parsed_url.query or b""
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
self.logger.warning(message)
|
||||
else:
|
||||
app = self.app
|
||||
|
||||
existing_cycle = self.cycle
|
||||
self.cycle = RequestResponseCycle(
|
||||
scope=self.scope,
|
||||
transport=self.transport,
|
||||
flow=self.flow,
|
||||
logger=self.logger,
|
||||
access_logger=self.access_logger,
|
||||
access_log=self.access_log,
|
||||
default_headers=self.server_state.default_headers,
|
||||
message_event=asyncio.Event(),
|
||||
expect_100_continue=self.expect_100_continue,
|
||||
keep_alive=http_version != "1.0",
|
||||
on_response=self.on_response_complete,
|
||||
)
|
||||
if existing_cycle is None or existing_cycle.response_complete:
|
||||
# Standard case - start processing the request.
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
else:
|
||||
# Pipelined HTTP requests need to be queued up.
|
||||
self.flow.pause_reading()
|
||||
self.pipeline.appendleft((self.cycle, app))
|
||||
|
||||
def on_body(self, body: bytes) -> None:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.body += body
|
||||
if len(self.cycle.body) > HIGH_WATER_LIMIT:
|
||||
self.flow.pause_reading()
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def on_message_complete(self) -> None:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.more_body = False
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def on_response_complete(self) -> None:
|
||||
# Callback for pipelined HTTP requests to be started.
|
||||
self.server_state.total_requests += 1
|
||||
|
||||
if self.transport.is_closing():
|
||||
return
|
||||
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
# Unpause data reads if needed.
|
||||
self.flow.resume_reading()
|
||||
|
||||
# Unblock any pipelined events. If there are none, arm the
|
||||
# Keep-Alive timeout instead.
|
||||
if self.pipeline:
|
||||
cycle, app = self.pipeline.pop()
|
||||
task = self.loop.create_task(cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
else:
|
||||
self.timeout_keep_alive_task = self.loop.call_later(
|
||||
self.timeout_keep_alive, self.timeout_keep_alive_handler
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Called by the server to commence a graceful shutdown.
|
||||
"""
|
||||
if self.cycle is None or self.cycle.response_complete:
|
||||
self.transport.close()
|
||||
else:
|
||||
self.cycle.keep_alive = False
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.flow.pause_writing() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.flow.resume_writing() # pragma: full coverage
|
||||
|
||||
def timeout_keep_alive_handler(self) -> None:
|
||||
"""
|
||||
Called on a keep-alive connection if no new data is received after a short
|
||||
delay.
|
||||
"""
|
||||
if not self.transport.is_closing():
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: HTTPScope,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
logger: logging.Logger,
|
||||
access_logger: logging.Logger,
|
||||
access_log: bool,
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
expect_100_continue: bool,
|
||||
keep_alive: bool,
|
||||
on_response: Callable[..., None],
|
||||
):
|
||||
self.scope = scope
|
||||
self.transport = transport
|
||||
self.flow = flow
|
||||
self.logger = logger
|
||||
self.access_logger = access_logger
|
||||
self.access_log = access_log
|
||||
self.default_headers = default_headers
|
||||
self.message_event = message_event
|
||||
self.on_response = on_response
|
||||
|
||||
# Connection state
|
||||
self.disconnected = False
|
||||
self.keep_alive = keep_alive
|
||||
self.waiting_for_100_continue = expect_100_continue
|
||||
|
||||
# Request state
|
||||
self.body = b""
|
||||
self.more_body = True
|
||||
|
||||
# Response state
|
||||
self.response_started = False
|
||||
self.response_complete = False
|
||||
self.chunked_encoding: bool | None = None
|
||||
self.expected_content_length = 0
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
try:
|
||||
result = await app( # type: ignore[func-returns-value]
|
||||
self.scope, self.receive, self.send
|
||||
)
|
||||
except BaseException as exc:
|
||||
msg = "Exception in ASGI application\n"
|
||||
self.logger.error(msg, exc_info=exc)
|
||||
if not self.response_started:
|
||||
await self.send_500_response()
|
||||
else:
|
||||
self.transport.close()
|
||||
else:
|
||||
if result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
self.transport.close()
|
||||
elif not self.response_started and not self.disconnected:
|
||||
msg = "ASGI callable returned without starting response."
|
||||
self.logger.error(msg)
|
||||
await self.send_500_response()
|
||||
elif not self.response_complete and not self.disconnected:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
self.transport.close()
|
||||
finally:
|
||||
self.on_response = lambda: None
|
||||
|
||||
async def send_500_response(self) -> None:
|
||||
await self.send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"content-length", b"21"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
)
|
||||
await self.send({"type": "http.response.body", "body": b"Internal Server Error", "more_body": False})
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
if self.disconnected:
|
||||
return # pragma: full coverage
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
if message_type != "http.response.start":
|
||||
msg = "Expected ASGI message 'http.response.start', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
message = cast("HTTPResponseStartEvent", message)
|
||||
|
||||
self.response_started = True
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
status_code = message["status"]
|
||||
headers = self.default_headers + list(message.get("headers", []))
|
||||
|
||||
if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
|
||||
headers = headers + [CLOSE_HEADER]
|
||||
|
||||
if self.access_log:
|
||||
self.access_logger.info(
|
||||
'%s - "%s %s HTTP/%s" %d',
|
||||
get_client_addr(self.scope),
|
||||
self.scope["method"],
|
||||
get_path_with_query_string(self.scope),
|
||||
self.scope["http_version"],
|
||||
status_code,
|
||||
)
|
||||
|
||||
# Write response status line and headers
|
||||
content = [STATUS_LINE[status_code]]
|
||||
|
||||
for name, value in headers:
|
||||
if HEADER_RE.search(name):
|
||||
raise RuntimeError("Invalid HTTP header name.") # pragma: full coverage
|
||||
if HEADER_VALUE_RE.search(value):
|
||||
raise RuntimeError("Invalid HTTP header value.")
|
||||
|
||||
name = name.lower()
|
||||
if name == b"content-length" and self.chunked_encoding is None:
|
||||
self.expected_content_length = int(value.decode())
|
||||
self.chunked_encoding = False
|
||||
elif name == b"transfer-encoding" and value.lower() == b"chunked":
|
||||
self.expected_content_length = 0
|
||||
self.chunked_encoding = True
|
||||
elif name == b"connection" and value.lower() == b"close":
|
||||
self.keep_alive = False
|
||||
content.extend([name, b": ", value, b"\r\n"])
|
||||
|
||||
if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304):
|
||||
# Neither content-length nor transfer-encoding specified
|
||||
self.chunked_encoding = True
|
||||
content.append(b"transfer-encoding: chunked\r\n")
|
||||
|
||||
content.append(b"\r\n")
|
||||
self.transport.write(b"".join(content))
|
||||
|
||||
elif not self.response_complete:
|
||||
# Sending response body
|
||||
if message_type != "http.response.body":
|
||||
msg = "Expected ASGI message 'http.response.body', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
body = cast(bytes, message.get("body", b""))
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# Write response body
|
||||
if self.scope["method"] == "HEAD":
|
||||
self.expected_content_length = 0
|
||||
elif self.chunked_encoding:
|
||||
if body:
|
||||
content = [b"%x\r\n" % len(body), body, b"\r\n"]
|
||||
else:
|
||||
content = []
|
||||
if not more_body:
|
||||
content.append(b"0\r\n\r\n")
|
||||
self.transport.write(b"".join(content))
|
||||
else:
|
||||
num_bytes = len(body)
|
||||
if num_bytes > self.expected_content_length:
|
||||
raise RuntimeError("Response content longer than Content-Length")
|
||||
else:
|
||||
self.expected_content_length -= num_bytes
|
||||
self.transport.write(body)
|
||||
|
||||
# Handle response completion
|
||||
if not more_body:
|
||||
if self.expected_content_length != 0:
|
||||
raise RuntimeError("Response content shorter than Content-Length")
|
||||
self.response_complete = True
|
||||
self.message_event.set()
|
||||
if not self.keep_alive:
|
||||
self.transport.close()
|
||||
self.on_response()
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
msg = "Unexpected ASGI message '%s' sent, after response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
if not self.disconnected and not self.response_complete:
|
||||
self.flow.resume_reading()
|
||||
await self.message_event.wait()
|
||||
self.message_event.clear()
|
||||
|
||||
if self.disconnected or self.response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
|
||||
self.body = b""
|
||||
return message
|
56
venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py
Normal file
56
venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import urllib.parse
|
||||
|
||||
from uvicorn._types import WWWScope
|
||||
|
||||
|
||||
class ClientDisconnected(OSError): ...
|
||||
|
||||
|
||||
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
socket_info = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
try:
|
||||
info = socket_info.getpeername()
|
||||
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
|
||||
except OSError: # pragma: no cover
|
||||
# This case appears to inconsistently occur with uvloop
|
||||
# bound to a unix domain socket.
|
||||
return None
|
||||
|
||||
info = transport.get_extra_info("peername")
|
||||
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
|
||||
return (str(info[0]), int(info[1]))
|
||||
return None
|
||||
|
||||
|
||||
def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
socket_info = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
info = socket_info.getsockname()
|
||||
|
||||
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
|
||||
info = transport.get_extra_info("sockname")
|
||||
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
|
||||
return (str(info[0]), int(info[1]))
|
||||
return None
|
||||
|
||||
|
||||
def is_ssl(transport: asyncio.Transport) -> bool:
|
||||
return bool(transport.get_extra_info("sslcontext"))
|
||||
|
||||
|
||||
def get_client_addr(scope: WWWScope) -> str:
|
||||
client = scope.get("client")
|
||||
if not client:
|
||||
return ""
|
||||
return "%s:%d" % client
|
||||
|
||||
|
||||
def get_path_with_query_string(scope: WWWScope) -> str:
|
||||
path_with_query_string = urllib.parse.quote(scope["path"])
|
||||
if scope["query_string"]:
|
||||
path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii"))
|
||||
return path_with_query_string
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
AutoWebSocketsProtocol: typing.Callable[..., asyncio.Protocol] | None
|
||||
try:
|
||||
import websockets # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
import wsproto # noqa
|
||||
except ImportError:
|
||||
AutoWebSocketsProtocol = None
|
||||
else:
|
||||
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
|
||||
|
||||
AutoWebSocketsProtocol = WSProtocol
|
||||
else:
|
||||
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
||||
|
||||
AutoWebSocketsProtocol = WebSocketProtocol
|
@ -0,0 +1,386 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import websockets
|
||||
import websockets.legacy.handshake
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.extensions.base import ServerExtensionFactory
|
||||
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
|
||||
from websockets.legacy.server import HTTPResponse
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketConnectEvent,
|
||||
WebSocketDisconnectEvent,
|
||||
WebSocketReceiveEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
class Server:
|
||||
closing = False
|
||||
|
||||
def register(self, ws: WebSocketServerProtocol) -> None:
|
||||
pass
|
||||
|
||||
def unregister(self, ws: WebSocketServerProtocol) -> None:
|
||||
pass
|
||||
|
||||
def is_serving(self) -> bool:
|
||||
return not self.closing
|
||||
|
||||
|
||||
class WebSocketProtocol(WebSocketServerProtocol):
|
||||
extra_headers: list[tuple[str, str]]
|
||||
logger: logging.Logger | logging.LoggerAdapter[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
):
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = cast(ASGI3Application, config.loaded_app)
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
|
||||
# Shared server state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# Connection events
|
||||
self.scope: WebSocketScope
|
||||
self.handshake_started_event = asyncio.Event()
|
||||
self.handshake_completed_event = asyncio.Event()
|
||||
self.closed_event = asyncio.Event()
|
||||
self.initial_response: HTTPResponse | None = None
|
||||
self.connect_sent = False
|
||||
self.lost_connection_before_handshake = False
|
||||
self.accepted_subprotocol: Subprotocol | None = None
|
||||
|
||||
self.ws_server: Server = Server() # type: ignore[assignment]
|
||||
|
||||
extensions: list[ServerExtensionFactory] = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions.append(ServerPerMessageDeflateFactory())
|
||||
|
||||
super().__init__(
|
||||
ws_handler=self.ws_handler,
|
||||
ws_server=self.ws_server, # type: ignore[arg-type]
|
||||
max_size=self.config.ws_max_size,
|
||||
max_queue=self.config.ws_max_queue,
|
||||
ping_interval=self.config.ws_ping_interval,
|
||||
ping_timeout=self.config.ws_ping_timeout,
|
||||
extensions=extensions,
|
||||
logger=logging.getLogger("uvicorn.error"),
|
||||
)
|
||||
self.server_header = None
|
||||
self.extra_headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
|
||||
]
|
||||
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "wss" if is_ssl(transport) else "ws"
|
||||
|
||||
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
super().connection_made(transport)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.lost_connection_before_handshake = not self.handshake_completed_event.is_set()
|
||||
self.handshake_completed_event.set()
|
||||
super().connection_lost(exc)
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.ws_server.closing = True
|
||||
if self.handshake_completed_event.is_set():
|
||||
self.fail_connection(1012)
|
||||
else:
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
|
||||
"""
|
||||
This hook is called to determine if the websocket should return
|
||||
an HTTP response and close.
|
||||
|
||||
Our behavior here is to start the ASGI application, and then wait
|
||||
for either `accept` or `close` in order to determine if we should
|
||||
close the connection.
|
||||
"""
|
||||
path_portion, _, query_string = path.partition("?")
|
||||
|
||||
websockets.legacy.handshake.check_request(request_headers)
|
||||
|
||||
subprotocols: list[str] = []
|
||||
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
|
||||
subprotocols.extend([token.strip() for token in header.split(",")])
|
||||
|
||||
asgi_headers = [
|
||||
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
|
||||
for name, value in request_headers.raw_items()
|
||||
]
|
||||
path = unquote(path_portion)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii")
|
||||
|
||||
self.scope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": asgi_headers,
|
||||
"subprotocols": subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
await self.handshake_started_event.wait()
|
||||
return self.initial_response
|
||||
|
||||
def process_subprotocol(
|
||||
self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
|
||||
) -> Subprotocol | None:
|
||||
"""
|
||||
We override the standard 'process_subprotocol' behavior here so that
|
||||
we return whatever subprotocol is sent in the 'accept' message.
|
||||
"""
|
||||
return self.accepted_subprotocol
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
msg = b"Internal Server Error"
|
||||
content = [
|
||||
b"HTTP/1.1 500 Internal Server Error\r\n" b"content-type: text/plain; charset=utf-8\r\n",
|
||||
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
|
||||
b"connection: close\r\n",
|
||||
b"\r\n",
|
||||
msg,
|
||||
]
|
||||
self.transport.write(b"".join(content))
|
||||
# Allow handler task to terminate cleanly, as websockets doesn't cancel it by
|
||||
# itself (see https://github.com/encode/uvicorn/issues/920)
|
||||
self.handshake_started_event.set()
|
||||
|
||||
async def ws_handler(self, protocol: WebSocketServerProtocol, path: str) -> Any: # type: ignore[override]
|
||||
"""
|
||||
This is the main handler function for the 'websockets' implementation
|
||||
to call into. We just wait for close then return, and instead allow
|
||||
'send' and 'receive' events to drive the flow.
|
||||
"""
|
||||
self.handshake_completed_event.set()
|
||||
await self.wait_closed()
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
"""
|
||||
Wrapper around the ASGI callable, handling exceptions and unexpected
|
||||
termination states.
|
||||
"""
|
||||
try:
|
||||
result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value]
|
||||
except ClientDisconnected: # pragma: full coverage
|
||||
self.closed_event.set()
|
||||
self.transport.close()
|
||||
except BaseException:
|
||||
self.closed_event.set()
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
if not self.handshake_started_event.is_set():
|
||||
self.send_500_response()
|
||||
else:
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
else:
|
||||
self.closed_event.set()
|
||||
if not self.handshake_started_event.is_set():
|
||||
self.logger.error("ASGI callable returned without sending handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
|
||||
async def asgi_send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_started_event.is_set():
|
||||
if message_type == "websocket.accept":
|
||||
message = cast("WebSocketAcceptEvent", message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.initial_response = None
|
||||
self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol"))
|
||||
if "headers" in message:
|
||||
self.extra_headers.extend(
|
||||
# ASGI spec requires bytes
|
||||
# But for compatibility we need to convert it to strings
|
||||
(name.decode("latin-1"), value.decode("latin-1"))
|
||||
for name, value in message["headers"]
|
||||
)
|
||||
self.handshake_started_event.set()
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast("WebSocketCloseEvent", message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"")
|
||||
self.handshake_started_event.set()
|
||||
self.closed_event.set()
|
||||
|
||||
elif message_type == "websocket.http.response.start":
|
||||
message = cast("WebSocketResponseStartEvent", message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
# websockets requires the status to be an enum. look it up.
|
||||
status = http.HTTPStatus(message["status"])
|
||||
headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", [])
|
||||
]
|
||||
self.initial_response = (status, headers, b"")
|
||||
self.handshake_started_event.set()
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close', "
|
||||
"or 'websocket.http.response.start' but got '%s'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.closed_event.is_set() and self.initial_response is None:
|
||||
await self.handshake_completed_event.wait()
|
||||
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = cast("WebSocketSendEvent", message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
await self.send(data) # type: ignore[arg-type]
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast("WebSocketCloseEvent", message)
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
await self.close(code, reason)
|
||||
self.closed_event.set()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except ConnectionClosed as exc:
|
||||
raise ClientDisconnected from exc
|
||||
|
||||
elif self.initial_response is not None:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = cast("WebSocketResponseBodyEvent", message)
|
||||
body = self.initial_response[2] + message["body"]
|
||||
self.initial_response = self.initial_response[:2] + (body,)
|
||||
if not message.get("more_body", False):
|
||||
self.closed_event.set()
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close' " "or response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
|
||||
if not self.connect_sent:
|
||||
self.connect_sent = True
|
||||
return {"type": "websocket.connect"}
|
||||
|
||||
await self.handshake_completed_event.wait()
|
||||
|
||||
if self.lost_connection_before_handshake:
|
||||
# If the handshake failed or the app closed before handshake completion,
|
||||
# use 1006 Abnormal Closure.
|
||||
return {"type": "websocket.disconnect", "code": 1006}
|
||||
|
||||
if self.closed_event.is_set():
|
||||
return {"type": "websocket.disconnect", "code": 1005}
|
||||
|
||||
try:
|
||||
data = await self.recv()
|
||||
except ConnectionClosed:
|
||||
self.closed_event.set()
|
||||
if self.ws_server.closing:
|
||||
return {"type": "websocket.disconnect", "code": 1012}
|
||||
return {"type": "websocket.disconnect", "code": self.close_code or 1005, "reason": self.close_reason}
|
||||
|
||||
if isinstance(data, str):
|
||||
return {"type": "websocket.receive", "text": data}
|
||||
return {"type": "websocket.receive", "bytes": data}
|
@ -0,0 +1,377 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from typing import Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import wsproto
|
||||
from wsproto import ConnectionType, events
|
||||
from wsproto.connection import ConnectionState
|
||||
from wsproto.extensions import Extension, PerMessageDeflate
|
||||
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
class WSProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, typing.Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load() # pragma: full coverage
|
||||
|
||||
self.config = config
|
||||
self.app = cast(ASGI3Application, config.loaded_app)
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
|
||||
# Shared server state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
self.default_headers = server_state.default_headers
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# WebSocket state
|
||||
self.queue: asyncio.Queue[WebSocketEvent] = asyncio.Queue()
|
||||
self.handshake_complete = False
|
||||
self.close_sent = False
|
||||
|
||||
# Rejection state
|
||||
self.response_started = False
|
||||
|
||||
self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER)
|
||||
|
||||
self.read_paused = False
|
||||
self.writable = asyncio.Event()
|
||||
self.writable.set()
|
||||
|
||||
# Buffers
|
||||
self.bytes = b""
|
||||
self.text = ""
|
||||
|
||||
# Protocol interface
|
||||
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "wss" if is_ssl(transport) else "ws"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
code = 1005 if self.handshake_complete else 1006
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.handshake_complete = True
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
try:
|
||||
self.conn.receive_data(data)
|
||||
except RemoteProtocolError as err:
|
||||
# TODO: Remove `type: ignore` when wsproto fixes the type annotation.
|
||||
self.transport.write(self.conn.send(err.event_hint)) # type: ignore[arg-type] # noqa: E501
|
||||
self.transport.close()
|
||||
else:
|
||||
self.handle_events()
|
||||
|
||||
def handle_events(self) -> None:
|
||||
for event in self.conn.events():
|
||||
if isinstance(event, events.Request):
|
||||
self.handle_connect(event)
|
||||
elif isinstance(event, events.TextMessage):
|
||||
self.handle_text(event)
|
||||
elif isinstance(event, events.BytesMessage):
|
||||
self.handle_bytes(event)
|
||||
elif isinstance(event, events.CloseConnection):
|
||||
self.handle_close(event)
|
||||
elif isinstance(event, events.Ping):
|
||||
self.handle_ping(event)
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.writable.clear() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.writable.set() # pragma: full coverage
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.handshake_complete:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
||||
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
|
||||
self.transport.write(output)
|
||||
else:
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
# Event handlers
|
||||
|
||||
def handle_connect(self, event: events.Request) -> None:
|
||||
headers = [(b"host", event.host.encode())]
|
||||
headers += [(key.lower(), value) for key, value in event.extra_headers]
|
||||
raw_path, _, query_string = event.target.partition("?")
|
||||
path = unquote(raw_path)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
|
||||
self.scope: WebSocketScope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": headers,
|
||||
"subprotocols": event.subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
self.queue.put_nowait({"type": "websocket.connect"})
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
|
||||
def handle_text(self, event: events.TextMessage) -> None:
|
||||
self.text += event.data
|
||||
if event.message_finished:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "text": self.text})
|
||||
self.text = ""
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
def handle_bytes(self, event: events.BytesMessage) -> None:
|
||||
self.bytes += event.data
|
||||
# todo: we may want to guard the size of self.bytes and self.text
|
||||
if event.message_finished:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
|
||||
self.bytes = b""
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
def handle_close(self, event: events.CloseConnection) -> None:
|
||||
if self.conn.state == ConnectionState.REMOTE_CLOSING:
|
||||
self.transport.write(self.conn.send(event.response()))
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code, "reason": event.reason})
|
||||
self.transport.close()
|
||||
|
||||
def handle_ping(self, event: events.Ping) -> None:
|
||||
self.transport.write(self.conn.send(event.response()))
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
if self.response_started or self.handshake_complete:
|
||||
return # we cannot send responses anymore
|
||||
headers: list[tuple[bytes, bytes]] = [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
(b"content-length", b"21"),
|
||||
]
|
||||
output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
|
||||
output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
|
||||
self.transport.write(output)
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
try:
|
||||
result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value]
|
||||
except ClientDisconnected:
|
||||
self.transport.close() # pragma: full coverage
|
||||
except BaseException:
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
else:
|
||||
if not self.handshake_complete:
|
||||
self.logger.error("ASGI callable returned without completing handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
self.transport.close()
|
||||
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_complete:
|
||||
if message_type == "websocket.accept":
|
||||
message = typing.cast(WebSocketAcceptEvent, message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
subprotocol = message.get("subprotocol")
|
||||
extra_headers = self.default_headers + list(message.get("headers", []))
|
||||
extensions: list[Extension] = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions.append(PerMessageDeflate())
|
||||
if not self.transport.is_closing():
|
||||
self.handshake_complete = True
|
||||
output = self.conn.send(
|
||||
wsproto.events.AcceptConnection(
|
||||
subprotocol=subprotocol,
|
||||
extensions=extensions,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
)
|
||||
self.transport.write(output)
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.handshake_complete = True
|
||||
self.close_sent = True
|
||||
event = events.RejectConnection(status_code=403, headers=[])
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
elif message_type == "websocket.http.response.start":
|
||||
message = typing.cast(WebSocketResponseStartEvent, message)
|
||||
# ensure status code is in the valid range
|
||||
if not (100 <= message["status"] < 600):
|
||||
msg = "Invalid HTTP status code '%d' in response."
|
||||
raise RuntimeError(msg % message["status"])
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
self.handshake_complete = True
|
||||
event = events.RejectConnection(
|
||||
status_code=message["status"],
|
||||
headers=list(message["headers"]),
|
||||
has_body=True,
|
||||
)
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
self.response_started = True
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close' "
|
||||
"or 'websocket.http.response.start' "
|
||||
"but got '%s'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.close_sent and not self.response_started:
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = typing.cast(WebSocketSendEvent, message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = typing.cast(WebSocketCloseEvent, message)
|
||||
self.close_sent = True
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except LocalProtocolError as exc:
|
||||
raise ClientDisconnected from exc
|
||||
elif self.response_started:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = typing.cast("WebSocketResponseBodyEvent", message)
|
||||
body_finished = not message.get("more_body", False)
|
||||
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
|
||||
output = self.conn.send(reject_data)
|
||||
self.transport.write(output)
|
||||
|
||||
if body_finished:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def receive(self) -> WebSocketEvent:
|
||||
message = await self.queue.get()
|
||||
if self.read_paused and self.queue.empty():
|
||||
self.read_paused = False
|
||||
self.transport.resume_reading()
|
||||
return message
|
1
venv/lib/python3.11/site-packages/uvicorn/py.typed
Normal file
1
venv/lib/python3.11/site-packages/uvicorn/py.typed
Normal file
@ -0,0 +1 @@
|
||||
|
337
venv/lib/python3.11/site-packages/uvicorn/server.py
Normal file
337
venv/lib/python3.11/site-packages/uvicorn/server.py
Normal file
@ -0,0 +1,337 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator, Sequence
|
||||
from email.utils import formatdate
|
||||
from types import FrameType
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import click
|
||||
|
||||
from uvicorn.config import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
|
||||
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
||||
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
|
||||
|
||||
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol]
|
||||
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||
)
|
||||
if sys.platform == "win32": # pragma: py-not-win32
|
||||
HANDLED_SIGNALS += (signal.SIGBREAK,) # Windows signal 21. Sent by Ctrl+Break.
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
class ServerState:
|
||||
"""
|
||||
Shared servers state that is available between all protocol instances.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.total_requests = 0
|
||||
self.connections: set[Protocols] = set()
|
||||
self.tasks: set[asyncio.Task[None]] = set()
|
||||
self.default_headers: list[tuple[bytes, bytes]] = []
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
self.server_state = ServerState()
|
||||
|
||||
self.started = False
|
||||
self.should_exit = False
|
||||
self.force_exit = False
|
||||
self.last_notified = 0.0
|
||||
|
||||
self._captured_signals: list[int] = []
|
||||
|
||||
def run(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
self.config.setup_event_loop()
|
||||
return asyncio.run(self.serve(sockets=sockets))
|
||||
|
||||
async def serve(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
with self.capture_signals():
|
||||
await self._serve(sockets)
|
||||
|
||||
async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
process_id = os.getpid()
|
||||
|
||||
config = self.config
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.lifespan = config.lifespan_class(config)
|
||||
|
||||
message = "Started server process [%d]"
|
||||
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
|
||||
logger.info(message, process_id, extra={"color_message": color_message})
|
||||
|
||||
await self.startup(sockets=sockets)
|
||||
if self.should_exit:
|
||||
return
|
||||
await self.main_loop()
|
||||
await self.shutdown(sockets=sockets)
|
||||
|
||||
message = "Finished server process [%d]"
|
||||
color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]"
|
||||
logger.info(message, process_id, extra={"color_message": color_message})
|
||||
|
||||
async def startup(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
await self.lifespan.startup()
|
||||
if self.lifespan.should_exit:
|
||||
self.should_exit = True
|
||||
return
|
||||
|
||||
config = self.config
|
||||
|
||||
def create_protocol(
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> asyncio.Protocol:
|
||||
return config.http_protocol_class( # type: ignore[call-arg]
|
||||
config=config,
|
||||
server_state=self.server_state,
|
||||
app_state=self.lifespan.state,
|
||||
_loop=_loop,
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
listeners: Sequence[socket.SocketType]
|
||||
if sockets is not None: # pragma: full coverage
|
||||
# Explicitly passed a list of open sockets.
|
||||
# We use this when the server is run from a Gunicorn worker.
|
||||
|
||||
def _share_socket(
|
||||
sock: socket.SocketType,
|
||||
) -> socket.SocketType: # pragma py-linux pragma: py-darwin
|
||||
# Windows requires the socket be explicitly shared across
|
||||
# multiple workers (processes).
|
||||
from socket import fromshare # type: ignore[attr-defined]
|
||||
|
||||
sock_data = sock.share(os.getpid()) # type: ignore[attr-defined]
|
||||
return fromshare(sock_data)
|
||||
|
||||
self.servers: list[asyncio.base_events.Server] = []
|
||||
for sock in sockets:
|
||||
is_windows = platform.system() == "Windows"
|
||||
if config.workers > 1 and is_windows: # pragma: py-not-win32
|
||||
sock = _share_socket(sock) # type: ignore[assignment]
|
||||
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
|
||||
self.servers.append(server)
|
||||
listeners = sockets
|
||||
|
||||
elif config.fd is not None: # pragma: py-win32
|
||||
# Use an existing socket, from a file descriptor.
|
||||
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
|
||||
assert server.sockets is not None # mypy
|
||||
listeners = server.sockets
|
||||
self.servers = [server]
|
||||
|
||||
elif config.uds is not None: # pragma: py-win32
|
||||
# Create a socket using UNIX domain socket.
|
||||
uds_perms = 0o666
|
||||
if os.path.exists(config.uds):
|
||||
uds_perms = os.stat(config.uds).st_mode # pragma: full coverage
|
||||
server = await loop.create_unix_server(
|
||||
create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
|
||||
)
|
||||
os.chmod(config.uds, uds_perms)
|
||||
assert server.sockets is not None # mypy
|
||||
listeners = server.sockets
|
||||
self.servers = [server]
|
||||
|
||||
else:
|
||||
# Standard case. Create a socket from a host/port pair.
|
||||
try:
|
||||
server = await loop.create_server(
|
||||
create_protocol,
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
ssl=config.ssl,
|
||||
backlog=config.backlog,
|
||||
)
|
||||
except OSError as exc:
|
||||
logger.error(exc)
|
||||
await self.lifespan.shutdown()
|
||||
sys.exit(1)
|
||||
|
||||
assert server.sockets is not None
|
||||
listeners = server.sockets
|
||||
self.servers = [server]
|
||||
|
||||
if sockets is None:
|
||||
self._log_started_message(listeners)
|
||||
else:
|
||||
# We're most likely running multiple workers, so a message has already been
|
||||
# logged by `config.bind_socket()`.
|
||||
pass # pragma: full coverage
|
||||
|
||||
self.started = True
|
||||
|
||||
def _log_started_message(self, listeners: Sequence[socket.SocketType]) -> None:
|
||||
config = self.config
|
||||
|
||||
if config.fd is not None: # pragma: py-win32
|
||||
sock = listeners[0]
|
||||
logger.info(
|
||||
"Uvicorn running on socket %s (Press CTRL+C to quit)",
|
||||
sock.getsockname(),
|
||||
)
|
||||
|
||||
elif config.uds is not None: # pragma: py-win32
|
||||
logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds)
|
||||
|
||||
else:
|
||||
addr_format = "%s://%s:%d"
|
||||
host = "0.0.0.0" if config.host is None else config.host
|
||||
if ":" in host:
|
||||
# It's an IPv6 address.
|
||||
addr_format = "%s://[%s]:%d"
|
||||
|
||||
port = config.port
|
||||
if port == 0:
|
||||
port = listeners[0].getsockname()[1]
|
||||
|
||||
protocol_name = "https" if config.ssl else "http"
|
||||
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
|
||||
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger.info(
|
||||
message,
|
||||
protocol_name,
|
||||
host,
|
||||
port,
|
||||
extra={"color_message": color_message},
|
||||
)
|
||||
|
||||
async def main_loop(self) -> None:
|
||||
counter = 0
|
||||
should_exit = await self.on_tick(counter)
|
||||
while not should_exit:
|
||||
counter += 1
|
||||
counter = counter % 864000
|
||||
await asyncio.sleep(0.1)
|
||||
should_exit = await self.on_tick(counter)
|
||||
|
||||
async def on_tick(self, counter: int) -> bool:
|
||||
# Update the default headers, once per second.
|
||||
if counter % 10 == 0:
|
||||
current_time = time.time()
|
||||
current_date = formatdate(current_time, usegmt=True).encode()
|
||||
|
||||
if self.config.date_header:
|
||||
date_header = [(b"date", current_date)]
|
||||
else:
|
||||
date_header = []
|
||||
|
||||
self.server_state.default_headers = date_header + self.config.encoded_headers
|
||||
|
||||
# Callback to `callback_notify` once every `timeout_notify` seconds.
|
||||
if self.config.callback_notify is not None:
|
||||
if current_time - self.last_notified > self.config.timeout_notify: # pragma: full coverage
|
||||
self.last_notified = current_time
|
||||
await self.config.callback_notify()
|
||||
|
||||
# Determine if we should exit.
|
||||
if self.should_exit:
|
||||
return True
|
||||
|
||||
max_requests = self.config.limit_max_requests
|
||||
if max_requests is not None and self.server_state.total_requests >= max_requests:
|
||||
logger.warning(f"Maximum request limit of {max_requests} exceeded. Terminating process.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def shutdown(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
logger.info("Shutting down")
|
||||
|
||||
# Stop accepting new connections.
|
||||
for server in self.servers:
|
||||
server.close()
|
||||
for sock in sockets or []:
|
||||
sock.close() # pragma: full coverage
|
||||
|
||||
# Request shutdown on all existing connections.
|
||||
for connection in list(self.server_state.connections):
|
||||
connection.shutdown()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# When 3.10 is not supported anymore, use `async with asyncio.timeout(...):`.
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._wait_tasks_to_complete(),
|
||||
timeout=self.config.timeout_graceful_shutdown,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"Cancel %s running task(s), timeout graceful shutdown exceeded",
|
||||
len(self.server_state.tasks),
|
||||
)
|
||||
for t in self.server_state.tasks:
|
||||
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
|
||||
|
||||
# Send the lifespan shutdown event, and wait for application shutdown.
|
||||
if not self.force_exit:
|
||||
await self.lifespan.shutdown()
|
||||
|
||||
async def _wait_tasks_to_complete(self) -> None:
|
||||
# Wait for existing connections to finish sending responses.
|
||||
if self.server_state.connections and not self.force_exit:
|
||||
msg = "Waiting for connections to close. (CTRL+C to force quit)"
|
||||
logger.info(msg)
|
||||
while self.server_state.connections and not self.force_exit:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Wait for existing tasks to complete.
|
||||
if self.server_state.tasks and not self.force_exit:
|
||||
msg = "Waiting for background tasks to complete. (CTRL+C to force quit)"
|
||||
logger.info(msg)
|
||||
while self.server_state.tasks and not self.force_exit:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
for server in self.servers:
|
||||
await server.wait_closed()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_signals(self) -> Generator[None, None, None]:
|
||||
# Signals can only be listened to from the main thread.
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
yield
|
||||
return
|
||||
# always use signal.signal, even if loop.add_signal_handler is available
|
||||
# this allows to restore previous signal handlers later on
|
||||
original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for sig, handler in original_handlers.items():
|
||||
signal.signal(sig, handler)
|
||||
# If we did gracefully shut down due to a signal, try to
|
||||
# trigger the expected behaviour now; multiple signals would be
|
||||
# done LIFO, see https://stackoverflow.com/questions/48434964
|
||||
for captured_signal in reversed(self._captured_signals):
|
||||
signal.raise_signal(captured_signal)
|
||||
|
||||
def handle_exit(self, sig: int, frame: FrameType | None) -> None:
|
||||
self._captured_signals.append(sig)
|
||||
if self.should_exit and sig == signal.SIGINT:
|
||||
self.force_exit = True # pragma: full coverage
|
||||
else:
|
||||
self.should_exit = True
|
@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from uvicorn.supervisors.basereload import BaseReload
|
||||
from uvicorn.supervisors.multiprocess import Multiprocess
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ChangeReload: type[BaseReload]
|
||||
else:
|
||||
try:
|
||||
from uvicorn.supervisors.watchfilesreload import WatchFilesReload as ChangeReload
|
||||
except ImportError: # pragma: no cover
|
||||
from uvicorn.supervisors.statreload import StatReload as ChangeReload
|
||||
|
||||
__all__ = ["Multiprocess", "ChangeReload"]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from socket import socket
|
||||
from types import FrameType
|
||||
from typing import Callable
|
||||
|
||||
import click
|
||||
|
||||
from uvicorn._subprocess import get_subprocess
|
||||
from uvicorn.config import Config
|
||||
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||
)
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
class BaseReload:
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.target = target
|
||||
self.sockets = sockets
|
||||
self.should_exit = threading.Event()
|
||||
self.pid = os.getpid()
|
||||
self.is_restarting = False
|
||||
self.reloader_name: str | None = None
|
||||
|
||||
def signal_handler(self, sig: int, frame: FrameType | None) -> None: # pragma: full coverage
|
||||
"""
|
||||
A signal handler that is registered with the parent process.
|
||||
"""
|
||||
if sys.platform == "win32" and self.is_restarting:
|
||||
self.is_restarting = False
|
||||
else:
|
||||
self.should_exit.set()
|
||||
|
||||
def run(self) -> None:
|
||||
self.startup()
|
||||
for changes in self:
|
||||
if changes:
|
||||
logger.warning(
|
||||
"%s detected changes in %s. Reloading...",
|
||||
self.reloader_name,
|
||||
", ".join(map(_display_path, changes)),
|
||||
)
|
||||
self.restart()
|
||||
|
||||
self.shutdown()
|
||||
|
||||
def pause(self) -> None:
|
||||
if self.should_exit.wait(self.config.reload_delay):
|
||||
raise StopIteration()
|
||||
|
||||
def __iter__(self) -> Iterator[list[Path] | None]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> list[Path] | None:
|
||||
return self.should_restart()
|
||||
|
||||
def startup(self) -> None:
|
||||
message = f"Started reloader process [{self.pid}] using {self.reloader_name}"
|
||||
color_message = "Started reloader process [{}] using {}".format(
|
||||
click.style(str(self.pid), fg="cyan", bold=True),
|
||||
click.style(str(self.reloader_name), fg="cyan", bold=True),
|
||||
)
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.signal_handler)
|
||||
|
||||
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
|
||||
self.process.start()
|
||||
|
||||
def restart(self) -> None:
|
||||
if sys.platform == "win32": # pragma: py-not-win32
|
||||
self.is_restarting = True
|
||||
assert self.process.pid is not None
|
||||
os.kill(self.process.pid, signal.CTRL_C_EVENT)
|
||||
else: # pragma: py-win32
|
||||
self.process.terminate()
|
||||
self.process.join()
|
||||
|
||||
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
|
||||
self.process.start()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if sys.platform == "win32":
|
||||
self.should_exit.set() # pragma: py-not-win32
|
||||
else:
|
||||
self.process.terminate() # pragma: py-win32
|
||||
self.process.join()
|
||||
|
||||
for sock in self.sockets:
|
||||
sock.close()
|
||||
|
||||
message = f"Stopping reloader process [{str(self.pid)}]"
|
||||
color_message = "Stopping reloader process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True))
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
def should_restart(self) -> list[Path] | None:
|
||||
raise NotImplementedError("Reload strategies should override should_restart()")
|
||||
|
||||
|
||||
def _display_path(path: Path) -> str:
|
||||
try:
|
||||
return f"'{path.relative_to(Path.cwd())}'"
|
||||
except ValueError:
|
||||
return f"'{path}'"
|
@ -0,0 +1,222 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from multiprocessing import Pipe
|
||||
from socket import socket
|
||||
from typing import Any, Callable
|
||||
|
||||
import click
|
||||
|
||||
from uvicorn._subprocess import get_subprocess
|
||||
from uvicorn.config import Config
|
||||
|
||||
SIGNALS = {
|
||||
getattr(signal, f"SIG{x}"): x
|
||||
for x in "INT TERM BREAK HUP QUIT TTIN TTOU USR1 USR2 WINCH".split()
|
||||
if hasattr(signal, f"SIG{x}")
|
||||
}
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
class Process:
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
self.real_target = target
|
||||
|
||||
self.parent_conn, self.child_conn = Pipe()
|
||||
self.process = get_subprocess(config, self.target, sockets)
|
||||
|
||||
def ping(self, timeout: float = 5) -> bool:
|
||||
self.parent_conn.send(b"ping")
|
||||
if self.parent_conn.poll(timeout):
|
||||
self.parent_conn.recv()
|
||||
return True
|
||||
return False
|
||||
|
||||
def pong(self) -> None:
|
||||
self.child_conn.recv()
|
||||
self.child_conn.send(b"pong")
|
||||
|
||||
def always_pong(self) -> None:
|
||||
while True:
|
||||
self.pong()
|
||||
|
||||
def target(self, sockets: list[socket] | None = None) -> Any: # pragma: no cover
|
||||
if os.name == "nt": # pragma: py-not-win32
|
||||
# Windows doesn't support SIGTERM, so we use SIGBREAK instead.
|
||||
# And then we raise SIGTERM when SIGBREAK is received.
|
||||
# https://learn.microsoft.com/zh-cn/cpp/c-runtime-library/reference/signal?view=msvc-170
|
||||
signal.signal(
|
||||
signal.SIGBREAK, # type: ignore[attr-defined]
|
||||
lambda sig, frame: signal.raise_signal(signal.SIGTERM),
|
||||
)
|
||||
|
||||
threading.Thread(target=self.always_pong, daemon=True).start()
|
||||
return self.real_target(sockets)
|
||||
|
||||
def is_alive(self, timeout: float = 5) -> bool:
|
||||
if not self.process.is_alive():
|
||||
return False # pragma: full coverage
|
||||
|
||||
return self.ping(timeout)
|
||||
|
||||
def start(self) -> None:
|
||||
self.process.start()
|
||||
|
||||
def terminate(self) -> None:
|
||||
if self.process.exitcode is None: # Process is still running
|
||||
assert self.process.pid is not None
|
||||
if os.name == "nt": # pragma: py-not-win32
|
||||
# Windows doesn't support SIGTERM.
|
||||
# So send SIGBREAK, and then in process raise SIGTERM.
|
||||
os.kill(self.process.pid, signal.CTRL_BREAK_EVENT) # type: ignore[attr-defined]
|
||||
else:
|
||||
os.kill(self.process.pid, signal.SIGTERM)
|
||||
logger.info(f"Terminated child process [{self.process.pid}]")
|
||||
|
||||
self.parent_conn.close()
|
||||
self.child_conn.close()
|
||||
|
||||
def kill(self) -> None:
|
||||
# In Windows, the method will call `TerminateProcess` to kill the process.
|
||||
# In Unix, the method will send SIGKILL to the process.
|
||||
self.process.kill()
|
||||
|
||||
def join(self) -> None:
|
||||
logger.info(f"Waiting for child process [{self.process.pid}]")
|
||||
self.process.join()
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
return self.process.pid
|
||||
|
||||
|
||||
class Multiprocess:
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.target = target
|
||||
self.sockets = sockets
|
||||
|
||||
self.processes_num = config.workers
|
||||
self.processes: list[Process] = []
|
||||
|
||||
self.should_exit = threading.Event()
|
||||
|
||||
self.signal_queue: list[int] = []
|
||||
for sig in SIGNALS:
|
||||
signal.signal(sig, lambda sig, frame: self.signal_queue.append(sig))
|
||||
|
||||
def init_processes(self) -> None:
|
||||
for _ in range(self.processes_num):
|
||||
process = Process(self.config, self.target, self.sockets)
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
|
||||
def terminate_all(self) -> None:
|
||||
for process in self.processes:
|
||||
process.terminate()
|
||||
|
||||
def join_all(self) -> None:
|
||||
for process in self.processes:
|
||||
process.join()
|
||||
|
||||
def restart_all(self) -> None:
|
||||
for idx, process in enumerate(self.processes):
|
||||
process.terminate()
|
||||
process.join()
|
||||
new_process = Process(self.config, self.target, self.sockets)
|
||||
new_process.start()
|
||||
self.processes[idx] = new_process
|
||||
|
||||
def run(self) -> None:
|
||||
message = f"Started parent process [{os.getpid()}]"
|
||||
color_message = "Started parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True))
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
self.init_processes()
|
||||
|
||||
while not self.should_exit.wait(0.5):
|
||||
self.handle_signals()
|
||||
self.keep_subprocess_alive()
|
||||
|
||||
self.terminate_all()
|
||||
self.join_all()
|
||||
|
||||
message = f"Stopping parent process [{os.getpid()}]"
|
||||
color_message = "Stopping parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True))
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
def keep_subprocess_alive(self) -> None:
|
||||
if self.should_exit.is_set():
|
||||
return # parent process is exiting, no need to keep subprocess alive
|
||||
|
||||
for idx, process in enumerate(self.processes):
|
||||
if process.is_alive():
|
||||
continue
|
||||
|
||||
process.kill() # process is hung, kill it
|
||||
process.join()
|
||||
|
||||
if self.should_exit.is_set():
|
||||
return # pragma: full coverage
|
||||
|
||||
logger.info(f"Child process [{process.pid}] died")
|
||||
process = Process(self.config, self.target, self.sockets)
|
||||
process.start()
|
||||
self.processes[idx] = process
|
||||
|
||||
def handle_signals(self) -> None:
|
||||
for sig in tuple(self.signal_queue):
|
||||
self.signal_queue.remove(sig)
|
||||
sig_name = SIGNALS[sig]
|
||||
sig_handler = getattr(self, f"handle_{sig_name.lower()}", None)
|
||||
if sig_handler is not None:
|
||||
sig_handler()
|
||||
else: # pragma: no cover
|
||||
logger.debug(f"Received signal {sig_name}, but no handler is defined for it.")
|
||||
|
||||
def handle_int(self) -> None:
|
||||
logger.info("Received SIGINT, exiting.")
|
||||
self.should_exit.set()
|
||||
|
||||
def handle_term(self) -> None:
|
||||
logger.info("Received SIGTERM, exiting.")
|
||||
self.should_exit.set()
|
||||
|
||||
def handle_break(self) -> None: # pragma: py-not-win32
|
||||
logger.info("Received SIGBREAK, exiting.")
|
||||
self.should_exit.set()
|
||||
|
||||
def handle_hup(self) -> None: # pragma: py-win32
|
||||
logger.info("Received SIGHUP, restarting processes.")
|
||||
self.restart_all()
|
||||
|
||||
def handle_ttin(self) -> None: # pragma: py-win32
|
||||
logger.info("Received SIGTTIN, increasing the number of processes.")
|
||||
self.processes_num += 1
|
||||
process = Process(self.config, self.target, self.sockets)
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
|
||||
def handle_ttou(self) -> None: # pragma: py-win32
|
||||
logger.info("Received SIGTTOU, decreasing number of processes.")
|
||||
if self.processes_num <= 1:
|
||||
logger.info("Already reached one process, cannot decrease the number of processes anymore.")
|
||||
return
|
||||
self.processes_num -= 1
|
||||
process = self.processes.pop()
|
||||
process.terminate()
|
||||
process.join()
|
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from socket import socket
|
||||
from typing import Callable
|
||||
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.supervisors.basereload import BaseReload
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
class StatReload(BaseReload):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
super().__init__(config, target, sockets)
|
||||
self.reloader_name = "StatReload"
|
||||
self.mtimes: dict[Path, float] = {}
|
||||
|
||||
if config.reload_excludes or config.reload_includes:
|
||||
logger.warning("--reload-include and --reload-exclude have no effect unless " "watchfiles is installed.")
|
||||
|
||||
def should_restart(self) -> list[Path] | None:
|
||||
self.pause()
|
||||
|
||||
for file in self.iter_py_files():
|
||||
try:
|
||||
mtime = file.stat().st_mtime
|
||||
except OSError: # pragma: nocover
|
||||
continue
|
||||
|
||||
old_time = self.mtimes.get(file)
|
||||
if old_time is None:
|
||||
self.mtimes[file] = mtime
|
||||
continue
|
||||
elif mtime > old_time:
|
||||
return [file]
|
||||
return None
|
||||
|
||||
def restart(self) -> None:
|
||||
self.mtimes = {}
|
||||
return super().restart()
|
||||
|
||||
def iter_py_files(self) -> Iterator[Path]:
|
||||
for reload_dir in self.config.reload_dirs:
|
||||
for path in list(reload_dir.rglob("*.py")):
|
||||
yield path.resolve()
|
@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from socket import socket
|
||||
from typing import Callable
|
||||
|
||||
from watchfiles import watch
|
||||
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.supervisors.basereload import BaseReload
|
||||
|
||||
|
||||
class FileFilter:
|
||||
def __init__(self, config: Config):
|
||||
default_includes = ["*.py"]
|
||||
self.includes = [default for default in default_includes if default not in config.reload_excludes]
|
||||
self.includes.extend(config.reload_includes)
|
||||
self.includes = list(set(self.includes))
|
||||
|
||||
default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
|
||||
self.excludes = [default for default in default_excludes if default not in config.reload_includes]
|
||||
self.exclude_dirs = []
|
||||
for e in config.reload_excludes:
|
||||
p = Path(e)
|
||||
try:
|
||||
is_dir = p.is_dir()
|
||||
except OSError: # pragma: no cover
|
||||
# gets raised on Windows for values like "*.py"
|
||||
is_dir = False
|
||||
|
||||
if is_dir:
|
||||
self.exclude_dirs.append(p)
|
||||
else:
|
||||
self.excludes.append(e) # pragma: full coverage
|
||||
self.excludes = list(set(self.excludes))
|
||||
|
||||
def __call__(self, path: Path) -> bool:
|
||||
for include_pattern in self.includes:
|
||||
if path.match(include_pattern):
|
||||
if str(path).endswith(include_pattern):
|
||||
return True # pragma: full coverage
|
||||
|
||||
for exclude_dir in self.exclude_dirs:
|
||||
if exclude_dir in path.parents:
|
||||
return False
|
||||
|
||||
for exclude_pattern in self.excludes:
|
||||
if path.match(exclude_pattern):
|
||||
return False # pragma: full coverage
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class WatchFilesReload(BaseReload):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
super().__init__(config, target, sockets)
|
||||
self.reloader_name = "WatchFiles"
|
||||
self.reload_dirs = []
|
||||
for directory in config.reload_dirs:
|
||||
if Path.cwd() not in directory.parents:
|
||||
self.reload_dirs.append(directory)
|
||||
if Path.cwd() not in self.reload_dirs:
|
||||
self.reload_dirs.append(Path.cwd())
|
||||
|
||||
self.watch_filter = FileFilter(config)
|
||||
self.watcher = watch(
|
||||
*self.reload_dirs,
|
||||
watch_filter=None,
|
||||
stop_event=self.should_exit,
|
||||
# using yield_on_timeout here mostly to make sure tests don't
|
||||
# hang forever, won't affect the class's behavior
|
||||
yield_on_timeout=True,
|
||||
)
|
||||
|
||||
def should_restart(self) -> list[Path] | None:
|
||||
self.pause()
|
||||
|
||||
changes = next(self.watcher)
|
||||
if changes:
|
||||
unique_paths = {Path(c[1]) for c in changes}
|
||||
return [p for p in unique_paths if self.watch_filter(p)]
|
||||
return None
|
114
venv/lib/python3.11/site-packages/uvicorn/workers.py
Normal file
114
venv/lib/python3.11/site-packages/uvicorn/workers.py
Normal file
@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from gunicorn.workers.base import Worker
|
||||
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.server import Server
|
||||
|
||||
warnings.warn(
|
||||
"The `uvicorn.workers` module is deprecated. Please use `uvicorn-worker` package instead.\n"
|
||||
"For more details, see https://github.com/Kludex/uvicorn-worker.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
class UvicornWorker(Worker):
|
||||
"""
|
||||
A worker class for Gunicorn that interfaces with an ASGI consumer callable,
|
||||
rather than a WSGI callable.
|
||||
"""
|
||||
|
||||
CONFIG_KWARGS: dict[str, Any] = {"loop": "auto", "http": "auto"}
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger.handlers = self.log.error_log.handlers
|
||||
logger.setLevel(self.log.error_log.level)
|
||||
logger.propagate = False
|
||||
|
||||
logger = logging.getLogger("uvicorn.access")
|
||||
logger.handlers = self.log.access_log.handlers
|
||||
logger.setLevel(self.log.access_log.level)
|
||||
logger.propagate = False
|
||||
|
||||
config_kwargs: dict = {
|
||||
"app": None,
|
||||
"log_config": None,
|
||||
"timeout_keep_alive": self.cfg.keepalive,
|
||||
"timeout_notify": self.timeout,
|
||||
"callback_notify": self.callback_notify,
|
||||
"limit_max_requests": self.max_requests,
|
||||
"forwarded_allow_ips": self.cfg.forwarded_allow_ips,
|
||||
}
|
||||
|
||||
if self.cfg.is_ssl:
|
||||
ssl_kwargs = {
|
||||
"ssl_keyfile": self.cfg.ssl_options.get("keyfile"),
|
||||
"ssl_certfile": self.cfg.ssl_options.get("certfile"),
|
||||
"ssl_keyfile_password": self.cfg.ssl_options.get("password"),
|
||||
"ssl_version": self.cfg.ssl_options.get("ssl_version"),
|
||||
"ssl_cert_reqs": self.cfg.ssl_options.get("cert_reqs"),
|
||||
"ssl_ca_certs": self.cfg.ssl_options.get("ca_certs"),
|
||||
"ssl_ciphers": self.cfg.ssl_options.get("ciphers"),
|
||||
}
|
||||
config_kwargs.update(ssl_kwargs)
|
||||
|
||||
if self.cfg.settings["backlog"].value:
|
||||
config_kwargs["backlog"] = self.cfg.settings["backlog"].value
|
||||
|
||||
config_kwargs.update(self.CONFIG_KWARGS)
|
||||
|
||||
self.config = Config(**config_kwargs)
|
||||
|
||||
def init_process(self) -> None:
|
||||
self.config.setup_event_loop()
|
||||
super().init_process()
|
||||
|
||||
def init_signals(self) -> None:
|
||||
# Reset signals so Gunicorn doesn't swallow subprocess return codes
|
||||
# other signals are set up by Server.install_signal_handlers()
|
||||
# See: https://github.com/encode/uvicorn/issues/894
|
||||
for s in self.SIGNALS:
|
||||
signal.signal(s, signal.SIG_DFL)
|
||||
|
||||
signal.signal(signal.SIGUSR1, self.handle_usr1)
|
||||
# Don't let SIGUSR1 disturb active requests by interrupting system calls
|
||||
signal.siginterrupt(signal.SIGUSR1, False)
|
||||
|
||||
def _install_sigquit_handler(self) -> None:
|
||||
"""Install a SIGQUIT handler on workers.
|
||||
|
||||
- https://github.com/encode/uvicorn/issues/1116
|
||||
- https://github.com/benoitc/gunicorn/issues/2604
|
||||
"""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.add_signal_handler(signal.SIGQUIT, self.handle_exit, signal.SIGQUIT, None)
|
||||
|
||||
async def _serve(self) -> None:
|
||||
self.config.app = self.wsgi
|
||||
server = Server(config=self.config)
|
||||
self._install_sigquit_handler()
|
||||
await server.serve(sockets=self.sockets)
|
||||
if not server.started:
|
||||
sys.exit(Arbiter.WORKER_BOOT_ERROR)
|
||||
|
||||
def run(self) -> None:
|
||||
return asyncio.run(self._serve())
|
||||
|
||||
async def callback_notify(self) -> None:
|
||||
self.notify()
|
||||
|
||||
|
||||
class UvicornH11Worker(UvicornWorker):
|
||||
CONFIG_KWARGS = {"loop": "asyncio", "http": "h11"}
|
Reference in New Issue
Block a user