Update 2025-04-13_16:26:34
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,15 @@
|
||||
from uvicorn._types import (
|
||||
ASGI2Application,
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
Scope,
|
||||
)
|
||||
|
||||
|
||||
class ASGI2Middleware:
|
||||
def __init__(self, app: "ASGI2Application"):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None:
|
||||
instance = self.app(scope)
|
||||
await instance(receive, send)
|
@ -0,0 +1,87 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveCallable,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendCallable,
|
||||
ASGISendEvent,
|
||||
WWWScope,
|
||||
)
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
|
||||
PLACEHOLDER_FORMAT = {
|
||||
"body": "<{length} bytes>",
|
||||
"bytes": "<{length} bytes>",
|
||||
"text": "<{length} chars>",
|
||||
"headers": "<...>",
|
||||
}
|
||||
|
||||
|
||||
def message_with_placeholders(message: Any) -> Any:
|
||||
"""
|
||||
Return an ASGI message, with any body-type content omitted and replaced
|
||||
with a placeholder.
|
||||
"""
|
||||
new_message = message.copy()
|
||||
for attr in PLACEHOLDER_FORMAT.keys():
|
||||
if message.get(attr) is not None:
|
||||
content = message[attr]
|
||||
placeholder = PLACEHOLDER_FORMAT[attr].format(length=len(content))
|
||||
new_message[attr] = placeholder
|
||||
return new_message
|
||||
|
||||
|
||||
class MessageLoggerMiddleware:
|
||||
def __init__(self, app: "ASGI3Application"):
|
||||
self.task_counter = 0
|
||||
self.app = app
|
||||
self.logger = logging.getLogger("uvicorn.asgi")
|
||||
|
||||
def trace(message: Any, *args: Any, **kwargs: Any) -> None:
|
||||
self.logger.log(TRACE_LOG_LEVEL, message, *args, **kwargs)
|
||||
|
||||
self.logger.trace = trace # type: ignore
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
scope: "WWWScope",
|
||||
receive: "ASGIReceiveCallable",
|
||||
send: "ASGISendCallable",
|
||||
) -> None:
|
||||
self.task_counter += 1
|
||||
|
||||
task_counter = self.task_counter
|
||||
client = scope.get("client")
|
||||
prefix = "%s:%d - ASGI" % (client[0], client[1]) if client else "ASGI"
|
||||
|
||||
async def inner_receive() -> "ASGIReceiveEvent":
|
||||
message = await receive()
|
||||
logged_message = message_with_placeholders(message)
|
||||
log_text = "%s [%d] Receive %s"
|
||||
self.logger.trace( # type: ignore
|
||||
log_text, prefix, task_counter, logged_message
|
||||
)
|
||||
return message
|
||||
|
||||
async def inner_send(message: "ASGISendEvent") -> None:
|
||||
logged_message = message_with_placeholders(message)
|
||||
log_text = "%s [%d] Send %s"
|
||||
self.logger.trace( # type: ignore
|
||||
log_text, prefix, task_counter, logged_message
|
||||
)
|
||||
await send(message)
|
||||
|
||||
logged_scope = message_with_placeholders(scope)
|
||||
log_text = "%s [%d] Started scope=%s"
|
||||
self.logger.trace(log_text, prefix, task_counter, logged_scope) # type: ignore
|
||||
try:
|
||||
await self.app(scope, inner_receive, inner_send)
|
||||
except BaseException as exc:
|
||||
log_text = "%s [%d] Raised exception"
|
||||
self.logger.trace(log_text, prefix, task_counter) # type: ignore
|
||||
raise exc from None
|
||||
else:
|
||||
log_text = "%s [%d] Completed"
|
||||
self.logger.trace(log_text, prefix, task_counter) # type: ignore
|
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
|
||||
from uvicorn._types import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
|
||||
|
||||
class ProxyHeadersMiddleware:
|
||||
"""Middleware for handling known proxy headers
|
||||
|
||||
This middleware can be used when a known proxy is fronting the application,
|
||||
and is trusted to be properly setting the `X-Forwarded-Proto` and
|
||||
`X-Forwarded-For` headers with the connecting client information.
|
||||
|
||||
Modifies the `client` and `scheme` information so that they reference
|
||||
the connecting client, rather that the connecting proxy.
|
||||
|
||||
References:
|
||||
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies>
|
||||
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For>
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1") -> None:
|
||||
self.app = app
|
||||
self.trusted_hosts = _TrustedHosts(trusted_hosts)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
if scope["type"] == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
client_addr = scope.get("client")
|
||||
client_host = client_addr[0] if client_addr else None
|
||||
|
||||
if client_host in self.trusted_hosts:
|
||||
headers = dict(scope["headers"])
|
||||
|
||||
if b"x-forwarded-proto" in headers:
|
||||
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1").strip()
|
||||
|
||||
if x_forwarded_proto in {"http", "https", "ws", "wss"}:
|
||||
if scope["type"] == "websocket":
|
||||
scope["scheme"] = x_forwarded_proto.replace("http", "ws")
|
||||
else:
|
||||
scope["scheme"] = x_forwarded_proto
|
||||
|
||||
if b"x-forwarded-for" in headers:
|
||||
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
|
||||
host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)
|
||||
|
||||
if host:
|
||||
# If the x-forwarded-for header is empty then host is an empty string.
|
||||
# Only set the client if we actually got something usable.
|
||||
# See: https://github.com/encode/uvicorn/issues/1068
|
||||
|
||||
# We've lost the connecting client's port information by now,
|
||||
# so only include the host.
|
||||
port = 0
|
||||
scope["client"] = (host, port)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def _parse_raw_hosts(value: str) -> list[str]:
|
||||
return [item.strip() for item in value.split(",")]
|
||||
|
||||
|
||||
class _TrustedHosts:
|
||||
"""Container for trusted hosts and networks"""
|
||||
|
||||
def __init__(self, trusted_hosts: list[str] | str) -> None:
|
||||
self.always_trust: bool = trusted_hosts in ("*", ["*"])
|
||||
|
||||
self.trusted_literals: set[str] = set()
|
||||
self.trusted_hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
||||
self.trusted_networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()
|
||||
|
||||
# Notes:
|
||||
# - We separate hosts from literals as there are many ways to write
|
||||
# an IPv6 Address so we need to compare by object.
|
||||
# - We don't convert IP Address to single host networks (e.g. /32 / 128) as
|
||||
# it more efficient to do an address lookup in a set than check for
|
||||
# membership in each network.
|
||||
# - We still allow literals as it might be possible that we receive a
|
||||
# something that isn't an IP Address e.g. a unix socket.
|
||||
|
||||
if not self.always_trust:
|
||||
if isinstance(trusted_hosts, str):
|
||||
trusted_hosts = _parse_raw_hosts(trusted_hosts)
|
||||
|
||||
for host in trusted_hosts:
|
||||
# Note: because we always convert invalid IP types to literals it
|
||||
# is not possible for the user to know they provided a malformed IP
|
||||
# type - this may lead to unexpected / difficult to debug behaviour.
|
||||
|
||||
if "/" in host:
|
||||
# Looks like a network
|
||||
try:
|
||||
self.trusted_networks.add(ipaddress.ip_network(host))
|
||||
except ValueError:
|
||||
# Was not a valid IP Network
|
||||
self.trusted_literals.add(host)
|
||||
else:
|
||||
try:
|
||||
self.trusted_hosts.add(ipaddress.ip_address(host))
|
||||
except ValueError:
|
||||
# Was not a valid IP Address
|
||||
self.trusted_literals.add(host)
|
||||
|
||||
def __contains__(self, host: str | None) -> bool:
|
||||
if self.always_trust:
|
||||
return True
|
||||
|
||||
if not host:
|
||||
return False
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if ip in self.trusted_hosts:
|
||||
return True
|
||||
return any(ip in net for net in self.trusted_networks)
|
||||
|
||||
except ValueError:
|
||||
return host in self.trusted_literals
|
||||
|
||||
def get_trusted_client_host(self, x_forwarded_for: str) -> str:
|
||||
"""Extract the client host from x_forwarded_for header
|
||||
|
||||
In general this is the first "untrusted" host in the forwarded for list.
|
||||
"""
|
||||
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
|
||||
|
||||
if self.always_trust:
|
||||
return x_forwarded_for_hosts[0]
|
||||
|
||||
# Note: each proxy appends to the header list so check it in reverse order
|
||||
for host in reversed(x_forwarded_for_hosts):
|
||||
if host not in self:
|
||||
return host
|
||||
|
||||
# All hosts are trusted meaning that the client was also a trusted proxy
|
||||
# See https://github.com/encode/uvicorn/issues/1068#issuecomment-855371576
|
||||
return x_forwarded_for_hosts[0]
|
200
venv/lib/python3.11/site-packages/uvicorn/middleware/wsgi.py
Normal file
200
venv/lib/python3.11/site-packages/uvicorn/middleware/wsgi.py
Normal file
@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import io
|
||||
import sys
|
||||
import warnings
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGIReceiveCallable,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendCallable,
|
||||
ASGISendEvent,
|
||||
Environ,
|
||||
ExcInfo,
|
||||
HTTPRequestEvent,
|
||||
HTTPResponseBodyEvent,
|
||||
HTTPResponseStartEvent,
|
||||
HTTPScope,
|
||||
StartResponse,
|
||||
WSGIApp,
|
||||
)
|
||||
|
||||
|
||||
def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: io.BytesIO) -> Environ:
|
||||
"""
|
||||
Builds a scope and request message into a WSGI environ object.
|
||||
"""
|
||||
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
|
||||
path_info = scope["path"].encode("utf8").decode("latin1")
|
||||
if path_info.startswith(script_name):
|
||||
path_info = path_info[len(script_name) :]
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": script_name,
|
||||
"PATH_INFO": path_info,
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": scope.get("scheme", "http"),
|
||||
"wsgi.input": body,
|
||||
"wsgi.errors": sys.stdout,
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
}
|
||||
|
||||
# Get server name and port - required in WSGI, not in ASGI
|
||||
server = scope.get("server")
|
||||
if server is None:
|
||||
server = ("localhost", 80)
|
||||
environ["SERVER_NAME"] = server[0]
|
||||
environ["SERVER_PORT"] = server[1]
|
||||
|
||||
# Get client IP address
|
||||
client = scope.get("client")
|
||||
if client is not None:
|
||||
environ["REMOTE_ADDR"] = client[0]
|
||||
|
||||
# Go through headers and make them into environ entries
|
||||
for name, value in scope.get("headers", []):
|
||||
name_str: str = name.decode("latin1")
|
||||
if name_str == "content-length":
|
||||
corrected_name = "CONTENT_LENGTH"
|
||||
elif name_str == "content-type":
|
||||
corrected_name = "CONTENT_TYPE"
|
||||
else:
|
||||
corrected_name = "HTTP_%s" % name_str.upper().replace("-", "_")
|
||||
# HTTPbis say only ASCII chars are allowed in headers, but we latin1
|
||||
# just in case
|
||||
value_str: str = value.decode("latin1")
|
||||
if corrected_name in environ:
|
||||
corrected_name_environ = environ[corrected_name]
|
||||
assert isinstance(corrected_name_environ, str)
|
||||
value_str = corrected_name_environ + "," + value_str
|
||||
environ[corrected_name] = value_str
|
||||
return environ
|
||||
|
||||
|
||||
class _WSGIMiddleware:
|
||||
def __init__(self, app: WSGIApp, workers: int = 10):
|
||||
warnings.warn(
|
||||
"Uvicorn's native WSGI implementation is deprecated, you "
|
||||
"should switch to a2wsgi (`pip install a2wsgi`).",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.app = app
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
scope: HTTPScope,
|
||||
receive: ASGIReceiveCallable,
|
||||
send: ASGISendCallable,
|
||||
) -> None:
|
||||
assert scope["type"] == "http"
|
||||
instance = WSGIResponder(self.app, self.executor, scope)
|
||||
await instance(receive, send)
|
||||
|
||||
|
||||
class WSGIResponder:
|
||||
def __init__(
|
||||
self,
|
||||
app: WSGIApp,
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
scope: HTTPScope,
|
||||
):
|
||||
self.app = app
|
||||
self.executor = executor
|
||||
self.scope = scope
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.send_event = asyncio.Event()
|
||||
self.send_queue: deque[ASGISendEvent | None] = deque()
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
self.response_started = False
|
||||
self.exc_info: ExcInfo | None = None
|
||||
|
||||
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
|
||||
body = io.BytesIO(message.get("body", b""))
|
||||
more_body = message.get("more_body", False)
|
||||
if more_body:
|
||||
body.seek(0, io.SEEK_END)
|
||||
while more_body:
|
||||
body_message: HTTPRequestEvent = (
|
||||
await receive() # type: ignore[assignment]
|
||||
)
|
||||
body.write(body_message.get("body", b""))
|
||||
more_body = body_message.get("more_body", False)
|
||||
body.seek(0)
|
||||
environ = build_environ(self.scope, message, body)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
wsgi = self.loop.run_in_executor(self.executor, self.wsgi, environ, self.start_response)
|
||||
sender = self.loop.create_task(self.sender(send))
|
||||
try:
|
||||
await asyncio.wait_for(wsgi, None)
|
||||
finally:
|
||||
self.send_queue.append(None)
|
||||
self.send_event.set()
|
||||
await asyncio.wait_for(sender, None)
|
||||
if self.exc_info is not None:
|
||||
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
|
||||
|
||||
async def sender(self, send: ASGISendCallable) -> None:
|
||||
while True:
|
||||
if self.send_queue:
|
||||
message = self.send_queue.popleft()
|
||||
if message is None:
|
||||
return
|
||||
await send(message)
|
||||
else:
|
||||
await self.send_event.wait()
|
||||
self.send_event.clear()
|
||||
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: Iterable[tuple[str, str]],
|
||||
exc_info: ExcInfo | None = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
self.response_started = True
|
||||
status_code_str, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_str)
|
||||
headers = [(name.encode("ascii"), value.encode("ascii")) for name, value in response_headers]
|
||||
http_response_start_event: HTTPResponseStartEvent = {
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
"headers": headers,
|
||||
}
|
||||
self.send_queue.append(http_response_start_event)
|
||||
self.loop.call_soon_threadsafe(self.send_event.set)
|
||||
|
||||
def wsgi(self, environ: Environ, start_response: StartResponse) -> None:
|
||||
for chunk in self.app(environ, start_response): # type: ignore
|
||||
response_body: HTTPResponseBodyEvent = {
|
||||
"type": "http.response.body",
|
||||
"body": chunk,
|
||||
"more_body": True,
|
||||
}
|
||||
self.send_queue.append(response_body)
|
||||
self.loop.call_soon_threadsafe(self.send_event.set)
|
||||
|
||||
empty_body: HTTPResponseBodyEvent = {
|
||||
"type": "http.response.body",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
self.send_queue.append(empty_body)
|
||||
self.loop.call_soon_threadsafe(self.send_event.set)
|
||||
|
||||
|
||||
try:
|
||||
from a2wsgi import WSGIMiddleware
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
WSGIMiddleware = _WSGIMiddleware # type: ignore[misc, assignment]
|
Reference in New Issue
Block a user