Update 2025-04-13_16:26:04
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
AutoHTTPProtocol: type[asyncio.Protocol]
|
||||
try:
|
||||
import httptools # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
|
||||
AutoHTTPProtocol = H11Protocol
|
||||
else: # pragma: no cover
|
||||
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
|
||||
|
||||
AutoHTTPProtocol = HttpToolsProtocol
|
@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
|
||||
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
|
||||
CLOSE_HEADER = (b"connection", b"close")
|
||||
|
||||
HIGH_WATER_LIMIT = 65536
|
||||
|
||||
|
||||
class FlowControl:
|
||||
def __init__(self, transport: asyncio.Transport) -> None:
|
||||
self._transport = transport
|
||||
self.read_paused = False
|
||||
self.write_paused = False
|
||||
self._is_writable_event = asyncio.Event()
|
||||
self._is_writable_event.set()
|
||||
|
||||
async def drain(self) -> None:
|
||||
await self._is_writable_event.wait() # pragma: full coverage
|
||||
|
||||
def pause_reading(self) -> None:
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self._transport.pause_reading()
|
||||
|
||||
def resume_reading(self) -> None:
|
||||
if self.read_paused:
|
||||
self.read_paused = False
|
||||
self._transport.resume_reading()
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
if not self.write_paused: # pragma: full coverage
|
||||
self.write_paused = True
|
||||
self._is_writable_event.clear()
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
if self.write_paused: # pragma: full coverage
|
||||
self.write_paused = False
|
||||
self._is_writable_event.set()
|
||||
|
||||
|
||||
async def service_unavailable(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 503,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"content-length", b"19"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"Service Unavailable", "more_body": False})
|
@ -0,0 +1,543 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
from typing import Any, Callable, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import h11
|
||||
from h11._connection import DEFAULT_MAX_INCOMPLETE_EVENT_SIZE
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
HTTPRequestEvent,
|
||||
HTTPResponseBodyEvent,
|
||||
HTTPResponseStartEvent,
|
||||
HTTPScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
def _get_status_phrase(status_code: int) -> bytes:
|
||||
try:
|
||||
return http.HTTPStatus(status_code).phrase.encode()
|
||||
except ValueError:
|
||||
return b""
|
||||
|
||||
|
||||
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class H11Protocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
self.conn = h11.Connection(
|
||||
h11.SERVER,
|
||||
config.h11_max_incomplete_event_size
|
||||
if config.h11_max_incomplete_event_size is not None
|
||||
else DEFAULT_MAX_INCOMPLETE_EVENT_SIZE,
|
||||
)
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Shared server state
|
||||
self.server_state = server_state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
|
||||
self.transport = transport
|
||||
self.flow = FlowControl(transport)
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "https" if is_ssl(transport) else "http"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.discard(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)
|
||||
|
||||
if self.cycle and not self.cycle.response_complete:
|
||||
self.cycle.disconnected = True
|
||||
if self.conn.our_state != h11.ERROR:
|
||||
event = h11.ConnectionClosed()
|
||||
try:
|
||||
self.conn.send(event)
|
||||
except h11.LocalProtocolError:
|
||||
# Premature client disconnect
|
||||
pass
|
||||
|
||||
if self.cycle is not None:
|
||||
self.cycle.message_event.set()
|
||||
if self.flow is not None:
|
||||
self.flow.resume_writing()
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def _unset_keepalive_if_required(self) -> None:
|
||||
if self.timeout_keep_alive_task is not None:
|
||||
self.timeout_keep_alive_task.cancel()
|
||||
self.timeout_keep_alive_task = None
|
||||
|
||||
def _get_upgrade(self) -> bytes | None:
|
||||
connection = []
|
||||
upgrade = None
|
||||
for name, value in self.headers:
|
||||
if name == b"connection":
|
||||
connection = [token.lower().strip() for token in value.split(b",")]
|
||||
if name == b"upgrade":
|
||||
upgrade = value.lower()
|
||||
if b"upgrade" in connection:
|
||||
return upgrade
|
||||
return None
|
||||
|
||||
def _should_upgrade_to_ws(self) -> bool:
|
||||
if self.ws_protocol_class is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _unsupported_upgrade_warning(self) -> None:
|
||||
msg = "Unsupported upgrade request."
|
||||
self.logger.warning(msg)
|
||||
if not self._should_upgrade_to_ws():
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
self.logger.warning(msg)
|
||||
|
||||
def _should_upgrade(self) -> bool:
|
||||
upgrade = self._get_upgrade()
|
||||
if upgrade == b"websocket" and self._should_upgrade_to_ws():
|
||||
return True
|
||||
if upgrade is not None:
|
||||
self._unsupported_upgrade_warning()
|
||||
return False
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.conn.receive_data(data)
|
||||
self.handle_events()
|
||||
|
||||
def handle_events(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
event = self.conn.next_event()
|
||||
except h11.RemoteProtocolError:
|
||||
msg = "Invalid HTTP request received."
|
||||
self.logger.warning(msg)
|
||||
self.send_400_response(msg)
|
||||
return
|
||||
|
||||
if event is h11.NEED_DATA:
|
||||
break
|
||||
|
||||
elif event is h11.PAUSED:
|
||||
# This case can occur in HTTP pipelining, so we need to
|
||||
# stop reading any more data, and ensure that at the end
|
||||
# of the active request/response cycle we handle any
|
||||
# events that have been buffered up.
|
||||
self.flow.pause_reading()
|
||||
break
|
||||
|
||||
elif isinstance(event, h11.Request):
|
||||
self.headers = [(key.lower(), value) for key, value in event.headers]
|
||||
raw_path, _, query_string = event.target.partition(b"?")
|
||||
path = unquote(raw_path.decode("ascii"))
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path
|
||||
self.scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"http_version": event.http_version.decode("ascii"),
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"scheme": self.scheme, # type: ignore[typeddict-item]
|
||||
"method": event.method.decode("ascii"),
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
}
|
||||
if self._should_upgrade():
|
||||
self.handle_websocket_upgrade(event)
|
||||
return
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
self.logger.warning(message)
|
||||
else:
|
||||
app = self.app
|
||||
|
||||
# When starting to process a request, disable the keep-alive
|
||||
# timeout. Normally we disable this when receiving data from
|
||||
# client and set back when finishing processing its request.
|
||||
# However, for pipelined requests processing finishes after
|
||||
# already receiving the next request and thus the timer may
|
||||
# be set here, which we don't want.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.cycle = RequestResponseCycle(
|
||||
scope=self.scope,
|
||||
conn=self.conn,
|
||||
transport=self.transport,
|
||||
flow=self.flow,
|
||||
logger=self.logger,
|
||||
access_logger=self.access_logger,
|
||||
access_log=self.access_log,
|
||||
default_headers=self.server_state.default_headers,
|
||||
message_event=asyncio.Event(),
|
||||
on_response=self.on_response_complete,
|
||||
)
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
|
||||
elif isinstance(event, h11.Data):
|
||||
if self.conn.our_state is h11.DONE:
|
||||
continue
|
||||
self.cycle.body += event.data
|
||||
if len(self.cycle.body) > HIGH_WATER_LIMIT:
|
||||
self.flow.pause_reading()
|
||||
self.cycle.message_event.set()
|
||||
|
||||
elif isinstance(event, h11.EndOfMessage):
|
||||
if self.conn.our_state is h11.DONE:
|
||||
self.transport.resume_reading()
|
||||
self.conn.start_next_cycle()
|
||||
continue
|
||||
self.cycle.more_body = False
|
||||
self.cycle.message_event.set()
|
||||
if self.conn.their_state == h11.MUST_CLOSE:
|
||||
break
|
||||
|
||||
def handle_websocket_upgrade(self, event: h11.Request) -> None:
|
||||
if self.logger.level <= TRACE_LOG_LEVEL: # pragma: full coverage
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
|
||||
|
||||
self.connections.discard(self)
|
||||
output = [event.method, b" ", event.target, b" HTTP/1.1\r\n"]
|
||||
for name, value in self.headers:
|
||||
output += [name, b": ", value, b"\r\n"]
|
||||
output.append(b"\r\n")
|
||||
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
|
||||
config=self.config,
|
||||
server_state=self.server_state,
|
||||
app_state=self.app_state,
|
||||
)
|
||||
protocol.connection_made(self.transport)
|
||||
protocol.data_received(b"".join(output))
|
||||
self.transport.set_protocol(protocol)
|
||||
|
||||
def send_400_response(self, msg: str) -> None:
|
||||
reason = STATUS_PHRASES[400]
|
||||
headers: list[tuple[bytes, bytes]] = [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
]
|
||||
event = h11.Response(status_code=400, headers=headers, reason=reason)
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
|
||||
output = self.conn.send(event=h11.Data(data=msg.encode("ascii")))
|
||||
self.transport.write(output)
|
||||
|
||||
output = self.conn.send(event=h11.EndOfMessage())
|
||||
self.transport.write(output)
|
||||
|
||||
self.transport.close()
|
||||
|
||||
def on_response_complete(self) -> None:
|
||||
self.server_state.total_requests += 1
|
||||
|
||||
if self.transport.is_closing():
|
||||
return
|
||||
|
||||
# Set a short Keep-Alive timeout.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler)
|
||||
|
||||
# Unpause data reads if needed.
|
||||
self.flow.resume_reading()
|
||||
|
||||
# Unblock any pipelined events.
|
||||
if self.conn.our_state is h11.DONE and self.conn.their_state is h11.DONE:
|
||||
self.conn.start_next_cycle()
|
||||
self.handle_events()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Called by the server to commence a graceful shutdown.
|
||||
"""
|
||||
if self.cycle is None or self.cycle.response_complete:
|
||||
event = h11.ConnectionClosed()
|
||||
self.conn.send(event)
|
||||
self.transport.close()
|
||||
else:
|
||||
self.cycle.keep_alive = False
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.flow.pause_writing() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.flow.resume_writing() # pragma: full coverage
|
||||
|
||||
def timeout_keep_alive_handler(self) -> None:
|
||||
"""
|
||||
Called on a keep-alive connection if no new data is received after a short
|
||||
delay.
|
||||
"""
|
||||
if not self.transport.is_closing():
|
||||
event = h11.ConnectionClosed()
|
||||
self.conn.send(event)
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: HTTPScope,
|
||||
conn: h11.Connection,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
logger: logging.Logger,
|
||||
access_logger: logging.Logger,
|
||||
access_log: bool,
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
on_response: Callable[..., None],
|
||||
) -> None:
|
||||
self.scope = scope
|
||||
self.conn = conn
|
||||
self.transport = transport
|
||||
self.flow = flow
|
||||
self.logger = logger
|
||||
self.access_logger = access_logger
|
||||
self.access_log = access_log
|
||||
self.default_headers = default_headers
|
||||
self.message_event = message_event
|
||||
self.on_response = on_response
|
||||
|
||||
# Connection state
|
||||
self.disconnected = False
|
||||
self.keep_alive = True
|
||||
self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue
|
||||
|
||||
# Request state
|
||||
self.body = b""
|
||||
self.more_body = True
|
||||
|
||||
# Response state
|
||||
self.response_started = False
|
||||
self.response_complete = False
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
try:
|
||||
result = await app( # type: ignore[func-returns-value]
|
||||
self.scope, self.receive, self.send
|
||||
)
|
||||
except BaseException as exc:
|
||||
msg = "Exception in ASGI application\n"
|
||||
self.logger.error(msg, exc_info=exc)
|
||||
if not self.response_started:
|
||||
await self.send_500_response()
|
||||
else:
|
||||
self.transport.close()
|
||||
else:
|
||||
if result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
self.transport.close()
|
||||
elif not self.response_started and not self.disconnected:
|
||||
msg = "ASGI callable returned without starting response."
|
||||
self.logger.error(msg)
|
||||
await self.send_500_response()
|
||||
elif not self.response_complete and not self.disconnected:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
self.transport.close()
|
||||
finally:
|
||||
self.on_response = lambda: None
|
||||
|
||||
async def send_500_response(self) -> None:
|
||||
response_start_event: HTTPResponseStartEvent = {
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
await self.send(response_start_event)
|
||||
response_body_event: HTTPResponseBodyEvent = {
|
||||
"type": "http.response.body",
|
||||
"body": b"Internal Server Error",
|
||||
"more_body": False,
|
||||
}
|
||||
await self.send(response_body_event)
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
if self.disconnected:
|
||||
return # pragma: full coverage
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
if message_type != "http.response.start":
|
||||
msg = "Expected ASGI message 'http.response.start', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
message = cast("HTTPResponseStartEvent", message)
|
||||
|
||||
self.response_started = True
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
status = message["status"]
|
||||
headers = self.default_headers + list(message.get("headers", []))
|
||||
|
||||
if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
|
||||
headers = headers + [CLOSE_HEADER]
|
||||
|
||||
if self.access_log:
|
||||
self.access_logger.info(
|
||||
'%s - "%s %s HTTP/%s" %d',
|
||||
get_client_addr(self.scope),
|
||||
self.scope["method"],
|
||||
get_path_with_query_string(self.scope),
|
||||
self.scope["http_version"],
|
||||
status,
|
||||
)
|
||||
|
||||
# Write response status line and headers
|
||||
reason = STATUS_PHRASES[status]
|
||||
response = h11.Response(status_code=status, headers=headers, reason=reason)
|
||||
output = self.conn.send(event=response)
|
||||
self.transport.write(output)
|
||||
|
||||
elif not self.response_complete:
|
||||
# Sending response body
|
||||
if message_type != "http.response.body":
|
||||
msg = "Expected ASGI message 'http.response.body', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
message = cast("HTTPResponseBodyEvent", message)
|
||||
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# Write response body
|
||||
data = b"" if self.scope["method"] == "HEAD" else body
|
||||
output = self.conn.send(event=h11.Data(data=data))
|
||||
self.transport.write(output)
|
||||
|
||||
# Handle response completion
|
||||
if not more_body:
|
||||
self.response_complete = True
|
||||
self.message_event.set()
|
||||
output = self.conn.send(event=h11.EndOfMessage())
|
||||
self.transport.write(output)
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
msg = "Unexpected ASGI message '%s' sent, after response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
if self.response_complete:
|
||||
if self.conn.our_state is h11.MUST_CLOSE or not self.keep_alive:
|
||||
self.conn.send(event=h11.ConnectionClosed())
|
||||
self.transport.close()
|
||||
self.on_response()
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
headers: list[tuple[str, str]] = []
|
||||
event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue")
|
||||
output = self.conn.send(event=event)
|
||||
self.transport.write(output)
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
if not self.disconnected and not self.response_complete:
|
||||
self.flow.resume_reading()
|
||||
await self.message_event.wait()
|
||||
self.message_event.clear()
|
||||
|
||||
if self.disconnected or self.response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
message: HTTPRequestEvent = {
|
||||
"type": "http.request",
|
||||
"body": self.body,
|
||||
"more_body": self.more_body,
|
||||
}
|
||||
self.body = b""
|
||||
return message
|
@ -0,0 +1,570 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
import re
|
||||
import urllib
|
||||
from asyncio.events import TimerHandle
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Literal, cast
|
||||
|
||||
import httptools
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
HTTPRequestEvent,
|
||||
HTTPResponseStartEvent,
|
||||
HTTPScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]')
|
||||
HEADER_VALUE_RE = re.compile(b"[\x00-\x08\x0a-\x1f\x7f]")
|
||||
|
||||
|
||||
def _get_status_line(status_code: int) -> bytes:
|
||||
try:
|
||||
phrase = http.HTTPStatus(status_code).phrase.encode()
|
||||
except ValueError:
|
||||
phrase = b""
|
||||
return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
|
||||
|
||||
|
||||
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class HttpToolsProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
self.parser = httptools.HttpRequestParser(self)
|
||||
|
||||
try:
|
||||
# Enable dangerous leniencies to allow server to a response on the first request from a pipelined request.
|
||||
self.parser.set_dangerous_leniencies(lenient_data_after_close=True)
|
||||
except AttributeError: # pragma: no cover
|
||||
# httptools < 0.6.3
|
||||
pass
|
||||
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Global state
|
||||
self.server_state = server_state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.expect_100_continue = False
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
|
||||
self.transport = transport
|
||||
self.flow = FlowControl(transport)
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "https" if is_ssl(transport) else "http"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.discard(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)
|
||||
|
||||
if self.cycle and not self.cycle.response_complete:
|
||||
self.cycle.disconnected = True
|
||||
if self.cycle is not None:
|
||||
self.cycle.message_event.set()
|
||||
if self.flow is not None:
|
||||
self.flow.resume_writing()
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.parser = None
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def _unset_keepalive_if_required(self) -> None:
|
||||
if self.timeout_keep_alive_task is not None:
|
||||
self.timeout_keep_alive_task.cancel()
|
||||
self.timeout_keep_alive_task = None
|
||||
|
||||
def _get_upgrade(self) -> bytes | None:
|
||||
connection = []
|
||||
upgrade = None
|
||||
for name, value in self.headers:
|
||||
if name == b"connection":
|
||||
connection = [token.lower().strip() for token in value.split(b",")]
|
||||
if name == b"upgrade":
|
||||
upgrade = value.lower()
|
||||
if b"upgrade" in connection:
|
||||
return upgrade
|
||||
return None # pragma: full coverage
|
||||
|
||||
def _should_upgrade_to_ws(self) -> bool:
|
||||
if self.ws_protocol_class is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _unsupported_upgrade_warning(self) -> None:
|
||||
self.logger.warning("Unsupported upgrade request.")
|
||||
if not self._should_upgrade_to_ws():
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
self.logger.warning(msg)
|
||||
|
||||
def _should_upgrade(self) -> bool:
|
||||
upgrade = self._get_upgrade()
|
||||
return upgrade == b"websocket" and self._should_upgrade_to_ws()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
try:
|
||||
self.parser.feed_data(data)
|
||||
except httptools.HttpParserError:
|
||||
msg = "Invalid HTTP request received."
|
||||
self.logger.warning(msg)
|
||||
self.send_400_response(msg)
|
||||
return
|
||||
except httptools.HttpParserUpgrade:
|
||||
if self._should_upgrade():
|
||||
self.handle_websocket_upgrade()
|
||||
else:
|
||||
self._unsupported_upgrade_warning()
|
||||
|
||||
def handle_websocket_upgrade(self) -> None:
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
|
||||
|
||||
self.connections.discard(self)
|
||||
method = self.scope["method"].encode()
|
||||
output = [method, b" ", self.url, b" HTTP/1.1\r\n"]
|
||||
for name, value in self.scope["headers"]:
|
||||
output += [name, b": ", value, b"\r\n"]
|
||||
output.append(b"\r\n")
|
||||
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
|
||||
config=self.config,
|
||||
server_state=self.server_state,
|
||||
app_state=self.app_state,
|
||||
)
|
||||
protocol.connection_made(self.transport)
|
||||
protocol.data_received(b"".join(output))
|
||||
self.transport.set_protocol(protocol)
|
||||
|
||||
def send_400_response(self, msg: str) -> None:
|
||||
content = [STATUS_LINE[400]]
|
||||
for name, value in self.server_state.default_headers:
|
||||
content.extend([name, b": ", value, b"\r\n"]) # pragma: full coverage
|
||||
content.extend(
|
||||
[
|
||||
b"content-type: text/plain; charset=utf-8\r\n",
|
||||
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
|
||||
b"connection: close\r\n",
|
||||
b"\r\n",
|
||||
msg.encode("ascii"),
|
||||
]
|
||||
)
|
||||
self.transport.write(b"".join(content))
|
||||
self.transport.close()
|
||||
|
||||
def on_message_begin(self) -> None:
|
||||
self.url = b""
|
||||
self.expect_100_continue = False
|
||||
self.headers = []
|
||||
self.scope = { # type: ignore[typeddict-item]
|
||||
"type": "http",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"http_version": "1.1",
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"scheme": self.scheme, # type: ignore[typeddict-item]
|
||||
"root_path": self.root_path,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
}
|
||||
|
||||
# Parser callbacks
|
||||
def on_url(self, url: bytes) -> None:
|
||||
self.url += url
|
||||
|
||||
def on_header(self, name: bytes, value: bytes) -> None:
|
||||
name = name.lower()
|
||||
if name == b"expect" and value.lower() == b"100-continue":
|
||||
self.expect_100_continue = True
|
||||
self.headers.append((name, value))
|
||||
|
||||
def on_headers_complete(self) -> None:
|
||||
http_version = self.parser.get_http_version()
|
||||
method = self.parser.get_method()
|
||||
self.scope["method"] = method.decode("ascii")
|
||||
if http_version != "1.1":
|
||||
self.scope["http_version"] = http_version
|
||||
if self.parser.should_upgrade() and self._should_upgrade():
|
||||
return
|
||||
parsed_url = httptools.parse_url(self.url)
|
||||
raw_path = parsed_url.path
|
||||
path = raw_path.decode("ascii")
|
||||
if "%" in path:
|
||||
path = urllib.parse.unquote(path)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path
|
||||
self.scope["path"] = full_path
|
||||
self.scope["raw_path"] = full_raw_path
|
||||
self.scope["query_string"] = parsed_url.query or b""
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
self.logger.warning(message)
|
||||
else:
|
||||
app = self.app
|
||||
|
||||
existing_cycle = self.cycle
|
||||
self.cycle = RequestResponseCycle(
|
||||
scope=self.scope,
|
||||
transport=self.transport,
|
||||
flow=self.flow,
|
||||
logger=self.logger,
|
||||
access_logger=self.access_logger,
|
||||
access_log=self.access_log,
|
||||
default_headers=self.server_state.default_headers,
|
||||
message_event=asyncio.Event(),
|
||||
expect_100_continue=self.expect_100_continue,
|
||||
keep_alive=http_version != "1.0",
|
||||
on_response=self.on_response_complete,
|
||||
)
|
||||
if existing_cycle is None or existing_cycle.response_complete:
|
||||
# Standard case - start processing the request.
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
else:
|
||||
# Pipelined HTTP requests need to be queued up.
|
||||
self.flow.pause_reading()
|
||||
self.pipeline.appendleft((self.cycle, app))
|
||||
|
||||
def on_body(self, body: bytes) -> None:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.body += body
|
||||
if len(self.cycle.body) > HIGH_WATER_LIMIT:
|
||||
self.flow.pause_reading()
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def on_message_complete(self) -> None:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.more_body = False
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def on_response_complete(self) -> None:
|
||||
# Callback for pipelined HTTP requests to be started.
|
||||
self.server_state.total_requests += 1
|
||||
|
||||
if self.transport.is_closing():
|
||||
return
|
||||
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
# Unpause data reads if needed.
|
||||
self.flow.resume_reading()
|
||||
|
||||
# Unblock any pipelined events. If there are none, arm the
|
||||
# Keep-Alive timeout instead.
|
||||
if self.pipeline:
|
||||
cycle, app = self.pipeline.pop()
|
||||
task = self.loop.create_task(cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
else:
|
||||
self.timeout_keep_alive_task = self.loop.call_later(
|
||||
self.timeout_keep_alive, self.timeout_keep_alive_handler
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Called by the server to commence a graceful shutdown.
|
||||
"""
|
||||
if self.cycle is None or self.cycle.response_complete:
|
||||
self.transport.close()
|
||||
else:
|
||||
self.cycle.keep_alive = False
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.flow.pause_writing() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.flow.resume_writing() # pragma: full coverage
|
||||
|
||||
def timeout_keep_alive_handler(self) -> None:
|
||||
"""
|
||||
Called on a keep-alive connection if no new data is received after a short
|
||||
delay.
|
||||
"""
|
||||
if not self.transport.is_closing():
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: HTTPScope,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
logger: logging.Logger,
|
||||
access_logger: logging.Logger,
|
||||
access_log: bool,
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
expect_100_continue: bool,
|
||||
keep_alive: bool,
|
||||
on_response: Callable[..., None],
|
||||
):
|
||||
self.scope = scope
|
||||
self.transport = transport
|
||||
self.flow = flow
|
||||
self.logger = logger
|
||||
self.access_logger = access_logger
|
||||
self.access_log = access_log
|
||||
self.default_headers = default_headers
|
||||
self.message_event = message_event
|
||||
self.on_response = on_response
|
||||
|
||||
# Connection state
|
||||
self.disconnected = False
|
||||
self.keep_alive = keep_alive
|
||||
self.waiting_for_100_continue = expect_100_continue
|
||||
|
||||
# Request state
|
||||
self.body = b""
|
||||
self.more_body = True
|
||||
|
||||
# Response state
|
||||
self.response_started = False
|
||||
self.response_complete = False
|
||||
self.chunked_encoding: bool | None = None
|
||||
self.expected_content_length = 0
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
try:
|
||||
result = await app( # type: ignore[func-returns-value]
|
||||
self.scope, self.receive, self.send
|
||||
)
|
||||
except BaseException as exc:
|
||||
msg = "Exception in ASGI application\n"
|
||||
self.logger.error(msg, exc_info=exc)
|
||||
if not self.response_started:
|
||||
await self.send_500_response()
|
||||
else:
|
||||
self.transport.close()
|
||||
else:
|
||||
if result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
self.transport.close()
|
||||
elif not self.response_started and not self.disconnected:
|
||||
msg = "ASGI callable returned without starting response."
|
||||
self.logger.error(msg)
|
||||
await self.send_500_response()
|
||||
elif not self.response_complete and not self.disconnected:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
self.transport.close()
|
||||
finally:
|
||||
self.on_response = lambda: None
|
||||
|
||||
async def send_500_response(self) -> None:
|
||||
await self.send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"content-length", b"21"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
)
|
||||
await self.send({"type": "http.response.body", "body": b"Internal Server Error", "more_body": False})
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
if self.disconnected:
|
||||
return # pragma: full coverage
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
if message_type != "http.response.start":
|
||||
msg = "Expected ASGI message 'http.response.start', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
message = cast("HTTPResponseStartEvent", message)
|
||||
|
||||
self.response_started = True
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
status_code = message["status"]
|
||||
headers = self.default_headers + list(message.get("headers", []))
|
||||
|
||||
if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
|
||||
headers = headers + [CLOSE_HEADER]
|
||||
|
||||
if self.access_log:
|
||||
self.access_logger.info(
|
||||
'%s - "%s %s HTTP/%s" %d',
|
||||
get_client_addr(self.scope),
|
||||
self.scope["method"],
|
||||
get_path_with_query_string(self.scope),
|
||||
self.scope["http_version"],
|
||||
status_code,
|
||||
)
|
||||
|
||||
# Write response status line and headers
|
||||
content = [STATUS_LINE[status_code]]
|
||||
|
||||
for name, value in headers:
|
||||
if HEADER_RE.search(name):
|
||||
raise RuntimeError("Invalid HTTP header name.") # pragma: full coverage
|
||||
if HEADER_VALUE_RE.search(value):
|
||||
raise RuntimeError("Invalid HTTP header value.")
|
||||
|
||||
name = name.lower()
|
||||
if name == b"content-length" and self.chunked_encoding is None:
|
||||
self.expected_content_length = int(value.decode())
|
||||
self.chunked_encoding = False
|
||||
elif name == b"transfer-encoding" and value.lower() == b"chunked":
|
||||
self.expected_content_length = 0
|
||||
self.chunked_encoding = True
|
||||
elif name == b"connection" and value.lower() == b"close":
|
||||
self.keep_alive = False
|
||||
content.extend([name, b": ", value, b"\r\n"])
|
||||
|
||||
if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304):
|
||||
# Neither content-length nor transfer-encoding specified
|
||||
self.chunked_encoding = True
|
||||
content.append(b"transfer-encoding: chunked\r\n")
|
||||
|
||||
content.append(b"\r\n")
|
||||
self.transport.write(b"".join(content))
|
||||
|
||||
elif not self.response_complete:
|
||||
# Sending response body
|
||||
if message_type != "http.response.body":
|
||||
msg = "Expected ASGI message 'http.response.body', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
body = cast(bytes, message.get("body", b""))
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# Write response body
|
||||
if self.scope["method"] == "HEAD":
|
||||
self.expected_content_length = 0
|
||||
elif self.chunked_encoding:
|
||||
if body:
|
||||
content = [b"%x\r\n" % len(body), body, b"\r\n"]
|
||||
else:
|
||||
content = []
|
||||
if not more_body:
|
||||
content.append(b"0\r\n\r\n")
|
||||
self.transport.write(b"".join(content))
|
||||
else:
|
||||
num_bytes = len(body)
|
||||
if num_bytes > self.expected_content_length:
|
||||
raise RuntimeError("Response content longer than Content-Length")
|
||||
else:
|
||||
self.expected_content_length -= num_bytes
|
||||
self.transport.write(body)
|
||||
|
||||
# Handle response completion
|
||||
if not more_body:
|
||||
if self.expected_content_length != 0:
|
||||
raise RuntimeError("Response content shorter than Content-Length")
|
||||
self.response_complete = True
|
||||
self.message_event.set()
|
||||
if not self.keep_alive:
|
||||
self.transport.close()
|
||||
self.on_response()
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
msg = "Unexpected ASGI message '%s' sent, after response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
if not self.disconnected and not self.response_complete:
|
||||
self.flow.resume_reading()
|
||||
await self.message_event.wait()
|
||||
self.message_event.clear()
|
||||
|
||||
if self.disconnected or self.response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
|
||||
self.body = b""
|
||||
return message
|
56
venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py
Normal file
56
venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import urllib.parse
|
||||
|
||||
from uvicorn._types import WWWScope
|
||||
|
||||
|
||||
class ClientDisconnected(OSError): ...
|
||||
|
||||
|
||||
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
socket_info = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
try:
|
||||
info = socket_info.getpeername()
|
||||
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
|
||||
except OSError: # pragma: no cover
|
||||
# This case appears to inconsistently occur with uvloop
|
||||
# bound to a unix domain socket.
|
||||
return None
|
||||
|
||||
info = transport.get_extra_info("peername")
|
||||
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
|
||||
return (str(info[0]), int(info[1]))
|
||||
return None
|
||||
|
||||
|
||||
def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
socket_info = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
info = socket_info.getsockname()
|
||||
|
||||
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
|
||||
info = transport.get_extra_info("sockname")
|
||||
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
|
||||
return (str(info[0]), int(info[1]))
|
||||
return None
|
||||
|
||||
|
||||
def is_ssl(transport: asyncio.Transport) -> bool:
|
||||
return bool(transport.get_extra_info("sslcontext"))
|
||||
|
||||
|
||||
def get_client_addr(scope: WWWScope) -> str:
|
||||
client = scope.get("client")
|
||||
if not client:
|
||||
return ""
|
||||
return "%s:%d" % client
|
||||
|
||||
|
||||
def get_path_with_query_string(scope: WWWScope) -> str:
|
||||
path_with_query_string = urllib.parse.quote(scope["path"])
|
||||
if scope["query_string"]:
|
||||
path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii"))
|
||||
return path_with_query_string
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
AutoWebSocketsProtocol: typing.Callable[..., asyncio.Protocol] | None
|
||||
try:
|
||||
import websockets # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
import wsproto # noqa
|
||||
except ImportError:
|
||||
AutoWebSocketsProtocol = None
|
||||
else:
|
||||
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
|
||||
|
||||
AutoWebSocketsProtocol = WSProtocol
|
||||
else:
|
||||
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
||||
|
||||
AutoWebSocketsProtocol = WebSocketProtocol
|
@ -0,0 +1,386 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import websockets
|
||||
import websockets.legacy.handshake
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.extensions.base import ServerExtensionFactory
|
||||
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
|
||||
from websockets.legacy.server import HTTPResponse
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketConnectEvent,
|
||||
WebSocketDisconnectEvent,
|
||||
WebSocketReceiveEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
class Server:
|
||||
closing = False
|
||||
|
||||
def register(self, ws: WebSocketServerProtocol) -> None:
|
||||
pass
|
||||
|
||||
def unregister(self, ws: WebSocketServerProtocol) -> None:
|
||||
pass
|
||||
|
||||
def is_serving(self) -> bool:
|
||||
return not self.closing
|
||||
|
||||
|
||||
class WebSocketProtocol(WebSocketServerProtocol):
|
||||
extra_headers: list[tuple[str, str]]
|
||||
logger: logging.Logger | logging.LoggerAdapter[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
):
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = cast(ASGI3Application, config.loaded_app)
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
|
||||
# Shared server state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# Connection events
|
||||
self.scope: WebSocketScope
|
||||
self.handshake_started_event = asyncio.Event()
|
||||
self.handshake_completed_event = asyncio.Event()
|
||||
self.closed_event = asyncio.Event()
|
||||
self.initial_response: HTTPResponse | None = None
|
||||
self.connect_sent = False
|
||||
self.lost_connection_before_handshake = False
|
||||
self.accepted_subprotocol: Subprotocol | None = None
|
||||
|
||||
self.ws_server: Server = Server() # type: ignore[assignment]
|
||||
|
||||
extensions: list[ServerExtensionFactory] = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions.append(ServerPerMessageDeflateFactory())
|
||||
|
||||
super().__init__(
|
||||
ws_handler=self.ws_handler,
|
||||
ws_server=self.ws_server, # type: ignore[arg-type]
|
||||
max_size=self.config.ws_max_size,
|
||||
max_queue=self.config.ws_max_queue,
|
||||
ping_interval=self.config.ws_ping_interval,
|
||||
ping_timeout=self.config.ws_ping_timeout,
|
||||
extensions=extensions,
|
||||
logger=logging.getLogger("uvicorn.error"),
|
||||
)
|
||||
self.server_header = None
|
||||
self.extra_headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
|
||||
]
|
||||
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "wss" if is_ssl(transport) else "ws"
|
||||
|
||||
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
super().connection_made(transport)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.lost_connection_before_handshake = not self.handshake_completed_event.is_set()
|
||||
self.handshake_completed_event.set()
|
||||
super().connection_lost(exc)
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.ws_server.closing = True
|
||||
if self.handshake_completed_event.is_set():
|
||||
self.fail_connection(1012)
|
||||
else:
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
|
||||
"""
|
||||
This hook is called to determine if the websocket should return
|
||||
an HTTP response and close.
|
||||
|
||||
Our behavior here is to start the ASGI application, and then wait
|
||||
for either `accept` or `close` in order to determine if we should
|
||||
close the connection.
|
||||
"""
|
||||
path_portion, _, query_string = path.partition("?")
|
||||
|
||||
websockets.legacy.handshake.check_request(request_headers)
|
||||
|
||||
subprotocols: list[str] = []
|
||||
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
|
||||
subprotocols.extend([token.strip() for token in header.split(",")])
|
||||
|
||||
asgi_headers = [
|
||||
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
|
||||
for name, value in request_headers.raw_items()
|
||||
]
|
||||
path = unquote(path_portion)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii")
|
||||
|
||||
self.scope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": asgi_headers,
|
||||
"subprotocols": subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
await self.handshake_started_event.wait()
|
||||
return self.initial_response
|
||||
|
||||
def process_subprotocol(
|
||||
self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
|
||||
) -> Subprotocol | None:
|
||||
"""
|
||||
We override the standard 'process_subprotocol' behavior here so that
|
||||
we return whatever subprotocol is sent in the 'accept' message.
|
||||
"""
|
||||
return self.accepted_subprotocol
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
msg = b"Internal Server Error"
|
||||
content = [
|
||||
b"HTTP/1.1 500 Internal Server Error\r\n" b"content-type: text/plain; charset=utf-8\r\n",
|
||||
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
|
||||
b"connection: close\r\n",
|
||||
b"\r\n",
|
||||
msg,
|
||||
]
|
||||
self.transport.write(b"".join(content))
|
||||
# Allow handler task to terminate cleanly, as websockets doesn't cancel it by
|
||||
# itself (see https://github.com/encode/uvicorn/issues/920)
|
||||
self.handshake_started_event.set()
|
||||
|
||||
async def ws_handler(self, protocol: WebSocketServerProtocol, path: str) -> Any: # type: ignore[override]
|
||||
"""
|
||||
This is the main handler function for the 'websockets' implementation
|
||||
to call into. We just wait for close then return, and instead allow
|
||||
'send' and 'receive' events to drive the flow.
|
||||
"""
|
||||
self.handshake_completed_event.set()
|
||||
await self.wait_closed()
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
"""
|
||||
Wrapper around the ASGI callable, handling exceptions and unexpected
|
||||
termination states.
|
||||
"""
|
||||
try:
|
||||
result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value]
|
||||
except ClientDisconnected: # pragma: full coverage
|
||||
self.closed_event.set()
|
||||
self.transport.close()
|
||||
except BaseException:
|
||||
self.closed_event.set()
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
if not self.handshake_started_event.is_set():
|
||||
self.send_500_response()
|
||||
else:
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
else:
|
||||
self.closed_event.set()
|
||||
if not self.handshake_started_event.is_set():
|
||||
self.logger.error("ASGI callable returned without sending handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
|
||||
async def asgi_send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_started_event.is_set():
|
||||
if message_type == "websocket.accept":
|
||||
message = cast("WebSocketAcceptEvent", message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.initial_response = None
|
||||
self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol"))
|
||||
if "headers" in message:
|
||||
self.extra_headers.extend(
|
||||
# ASGI spec requires bytes
|
||||
# But for compatibility we need to convert it to strings
|
||||
(name.decode("latin-1"), value.decode("latin-1"))
|
||||
for name, value in message["headers"]
|
||||
)
|
||||
self.handshake_started_event.set()
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast("WebSocketCloseEvent", message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"")
|
||||
self.handshake_started_event.set()
|
||||
self.closed_event.set()
|
||||
|
||||
elif message_type == "websocket.http.response.start":
|
||||
message = cast("WebSocketResponseStartEvent", message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
# websockets requires the status to be an enum. look it up.
|
||||
status = http.HTTPStatus(message["status"])
|
||||
headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", [])
|
||||
]
|
||||
self.initial_response = (status, headers, b"")
|
||||
self.handshake_started_event.set()
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close', "
|
||||
"or 'websocket.http.response.start' but got '%s'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.closed_event.is_set() and self.initial_response is None:
|
||||
await self.handshake_completed_event.wait()
|
||||
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = cast("WebSocketSendEvent", message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
await self.send(data) # type: ignore[arg-type]
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast("WebSocketCloseEvent", message)
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
await self.close(code, reason)
|
||||
self.closed_event.set()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except ConnectionClosed as exc:
|
||||
raise ClientDisconnected from exc
|
||||
|
||||
elif self.initial_response is not None:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = cast("WebSocketResponseBodyEvent", message)
|
||||
body = self.initial_response[2] + message["body"]
|
||||
self.initial_response = self.initial_response[:2] + (body,)
|
||||
if not message.get("more_body", False):
|
||||
self.closed_event.set()
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close' " "or response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
|
||||
if not self.connect_sent:
|
||||
self.connect_sent = True
|
||||
return {"type": "websocket.connect"}
|
||||
|
||||
await self.handshake_completed_event.wait()
|
||||
|
||||
if self.lost_connection_before_handshake:
|
||||
# If the handshake failed or the app closed before handshake completion,
|
||||
# use 1006 Abnormal Closure.
|
||||
return {"type": "websocket.disconnect", "code": 1006}
|
||||
|
||||
if self.closed_event.is_set():
|
||||
return {"type": "websocket.disconnect", "code": 1005}
|
||||
|
||||
try:
|
||||
data = await self.recv()
|
||||
except ConnectionClosed:
|
||||
self.closed_event.set()
|
||||
if self.ws_server.closing:
|
||||
return {"type": "websocket.disconnect", "code": 1012}
|
||||
return {"type": "websocket.disconnect", "code": self.close_code or 1005, "reason": self.close_reason}
|
||||
|
||||
if isinstance(data, str):
|
||||
return {"type": "websocket.receive", "text": data}
|
||||
return {"type": "websocket.receive", "bytes": data}
|
@ -0,0 +1,377 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from typing import Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import wsproto
|
||||
from wsproto import ConnectionType, events
|
||||
from wsproto.connection import ConnectionState
|
||||
from wsproto.extensions import Extension, PerMessageDeflate
|
||||
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
class WSProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, typing.Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load() # pragma: full coverage
|
||||
|
||||
self.config = config
|
||||
self.app = cast(ASGI3Application, config.loaded_app)
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
|
||||
# Shared server state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
self.default_headers = server_state.default_headers
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# WebSocket state
|
||||
self.queue: asyncio.Queue[WebSocketEvent] = asyncio.Queue()
|
||||
self.handshake_complete = False
|
||||
self.close_sent = False
|
||||
|
||||
# Rejection state
|
||||
self.response_started = False
|
||||
|
||||
self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER)
|
||||
|
||||
self.read_paused = False
|
||||
self.writable = asyncio.Event()
|
||||
self.writable.set()
|
||||
|
||||
# Buffers
|
||||
self.bytes = b""
|
||||
self.text = ""
|
||||
|
||||
# Protocol interface
|
||||
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "wss" if is_ssl(transport) else "ws"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
code = 1005 if self.handshake_complete else 1006
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.handshake_complete = True
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
try:
|
||||
self.conn.receive_data(data)
|
||||
except RemoteProtocolError as err:
|
||||
# TODO: Remove `type: ignore` when wsproto fixes the type annotation.
|
||||
self.transport.write(self.conn.send(err.event_hint)) # type: ignore[arg-type] # noqa: E501
|
||||
self.transport.close()
|
||||
else:
|
||||
self.handle_events()
|
||||
|
||||
def handle_events(self) -> None:
|
||||
for event in self.conn.events():
|
||||
if isinstance(event, events.Request):
|
||||
self.handle_connect(event)
|
||||
elif isinstance(event, events.TextMessage):
|
||||
self.handle_text(event)
|
||||
elif isinstance(event, events.BytesMessage):
|
||||
self.handle_bytes(event)
|
||||
elif isinstance(event, events.CloseConnection):
|
||||
self.handle_close(event)
|
||||
elif isinstance(event, events.Ping):
|
||||
self.handle_ping(event)
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.writable.clear() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.writable.set() # pragma: full coverage
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.handshake_complete:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
||||
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
|
||||
self.transport.write(output)
|
||||
else:
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
# Event handlers
|
||||
|
||||
def handle_connect(self, event: events.Request) -> None:
|
||||
headers = [(b"host", event.host.encode())]
|
||||
headers += [(key.lower(), value) for key, value in event.extra_headers]
|
||||
raw_path, _, query_string = event.target.partition("?")
|
||||
path = unquote(raw_path)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
|
||||
self.scope: WebSocketScope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": headers,
|
||||
"subprotocols": event.subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
self.queue.put_nowait({"type": "websocket.connect"})
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
|
||||
def handle_text(self, event: events.TextMessage) -> None:
|
||||
self.text += event.data
|
||||
if event.message_finished:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "text": self.text})
|
||||
self.text = ""
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
def handle_bytes(self, event: events.BytesMessage) -> None:
|
||||
self.bytes += event.data
|
||||
# todo: we may want to guard the size of self.bytes and self.text
|
||||
if event.message_finished:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
|
||||
self.bytes = b""
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
def handle_close(self, event: events.CloseConnection) -> None:
|
||||
if self.conn.state == ConnectionState.REMOTE_CLOSING:
|
||||
self.transport.write(self.conn.send(event.response()))
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code, "reason": event.reason})
|
||||
self.transport.close()
|
||||
|
||||
def handle_ping(self, event: events.Ping) -> None:
|
||||
self.transport.write(self.conn.send(event.response()))
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
if self.response_started or self.handshake_complete:
|
||||
return # we cannot send responses anymore
|
||||
headers: list[tuple[bytes, bytes]] = [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
(b"content-length", b"21"),
|
||||
]
|
||||
output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
|
||||
output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
|
||||
self.transport.write(output)
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
try:
|
||||
result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value]
|
||||
except ClientDisconnected:
|
||||
self.transport.close() # pragma: full coverage
|
||||
except BaseException:
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
else:
|
||||
if not self.handshake_complete:
|
||||
self.logger.error("ASGI callable returned without completing handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
self.transport.close()
|
||||
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_complete:
|
||||
if message_type == "websocket.accept":
|
||||
message = typing.cast(WebSocketAcceptEvent, message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
subprotocol = message.get("subprotocol")
|
||||
extra_headers = self.default_headers + list(message.get("headers", []))
|
||||
extensions: list[Extension] = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions.append(PerMessageDeflate())
|
||||
if not self.transport.is_closing():
|
||||
self.handshake_complete = True
|
||||
output = self.conn.send(
|
||||
wsproto.events.AcceptConnection(
|
||||
subprotocol=subprotocol,
|
||||
extensions=extensions,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
)
|
||||
self.transport.write(output)
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.handshake_complete = True
|
||||
self.close_sent = True
|
||||
event = events.RejectConnection(status_code=403, headers=[])
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
elif message_type == "websocket.http.response.start":
|
||||
message = typing.cast(WebSocketResponseStartEvent, message)
|
||||
# ensure status code is in the valid range
|
||||
if not (100 <= message["status"] < 600):
|
||||
msg = "Invalid HTTP status code '%d' in response."
|
||||
raise RuntimeError(msg % message["status"])
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
self.handshake_complete = True
|
||||
event = events.RejectConnection(
|
||||
status_code=message["status"],
|
||||
headers=list(message["headers"]),
|
||||
has_body=True,
|
||||
)
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
self.response_started = True
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close' "
|
||||
"or 'websocket.http.response.start' "
|
||||
"but got '%s'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.close_sent and not self.response_started:
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = typing.cast(WebSocketSendEvent, message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = typing.cast(WebSocketCloseEvent, message)
|
||||
self.close_sent = True
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except LocalProtocolError as exc:
|
||||
raise ClientDisconnected from exc
|
||||
elif self.response_started:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = typing.cast("WebSocketResponseBodyEvent", message)
|
||||
body_finished = not message.get("more_body", False)
|
||||
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
|
||||
output = self.conn.send(reject_data)
|
||||
self.transport.write(output)
|
||||
|
||||
if body_finished:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def receive(self) -> WebSocketEvent:
|
||||
message = await self.queue.get()
|
||||
if self.read_paused and self.queue.empty():
|
||||
self.read_paused = False
|
||||
self.transport.resume_reading()
|
||||
return message
|
Reference in New Issue
Block a user