Update 2025-04-24_11:44:19
This commit is contained in:
337
venv/lib/python3.11/site-packages/uvicorn/server.py
Normal file
337
venv/lib/python3.11/site-packages/uvicorn/server.py
Normal file
@ -0,0 +1,337 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator, Sequence
|
||||
from email.utils import formatdate
|
||||
from types import FrameType
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import click
|
||||
|
||||
from uvicorn.config import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
|
||||
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
||||
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
|
||||
|
||||
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol]
|
||||
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||
)
|
||||
if sys.platform == "win32": # pragma: py-not-win32
|
||||
HANDLED_SIGNALS += (signal.SIGBREAK,) # Windows signal 21. Sent by Ctrl+Break.
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
class ServerState:
|
||||
"""
|
||||
Shared servers state that is available between all protocol instances.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.total_requests = 0
|
||||
self.connections: set[Protocols] = set()
|
||||
self.tasks: set[asyncio.Task[None]] = set()
|
||||
self.default_headers: list[tuple[bytes, bytes]] = []
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
self.server_state = ServerState()
|
||||
|
||||
self.started = False
|
||||
self.should_exit = False
|
||||
self.force_exit = False
|
||||
self.last_notified = 0.0
|
||||
|
||||
self._captured_signals: list[int] = []
|
||||
|
||||
def run(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
self.config.setup_event_loop()
|
||||
return asyncio.run(self.serve(sockets=sockets))
|
||||
|
||||
async def serve(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
with self.capture_signals():
|
||||
await self._serve(sockets)
|
||||
|
||||
async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
process_id = os.getpid()
|
||||
|
||||
config = self.config
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.lifespan = config.lifespan_class(config)
|
||||
|
||||
message = "Started server process [%d]"
|
||||
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
|
||||
logger.info(message, process_id, extra={"color_message": color_message})
|
||||
|
||||
await self.startup(sockets=sockets)
|
||||
if self.should_exit:
|
||||
return
|
||||
await self.main_loop()
|
||||
await self.shutdown(sockets=sockets)
|
||||
|
||||
message = "Finished server process [%d]"
|
||||
color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]"
|
||||
logger.info(message, process_id, extra={"color_message": color_message})
|
||||
|
||||
async def startup(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
await self.lifespan.startup()
|
||||
if self.lifespan.should_exit:
|
||||
self.should_exit = True
|
||||
return
|
||||
|
||||
config = self.config
|
||||
|
||||
def create_protocol(
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> asyncio.Protocol:
|
||||
return config.http_protocol_class( # type: ignore[call-arg]
|
||||
config=config,
|
||||
server_state=self.server_state,
|
||||
app_state=self.lifespan.state,
|
||||
_loop=_loop,
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
listeners: Sequence[socket.SocketType]
|
||||
if sockets is not None: # pragma: full coverage
|
||||
# Explicitly passed a list of open sockets.
|
||||
# We use this when the server is run from a Gunicorn worker.
|
||||
|
||||
def _share_socket(
|
||||
sock: socket.SocketType,
|
||||
) -> socket.SocketType: # pragma py-not-win32
|
||||
# Windows requires the socket be explicitly shared across
|
||||
# multiple workers (processes).
|
||||
from socket import fromshare # type: ignore[attr-defined]
|
||||
|
||||
sock_data = sock.share(os.getpid()) # type: ignore[attr-defined]
|
||||
return fromshare(sock_data)
|
||||
|
||||
self.servers: list[asyncio.base_events.Server] = []
|
||||
for sock in sockets:
|
||||
is_windows = platform.system() == "Windows"
|
||||
if config.workers > 1 and is_windows: # pragma: py-not-win32
|
||||
sock = _share_socket(sock) # type: ignore[assignment]
|
||||
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
|
||||
self.servers.append(server)
|
||||
listeners = sockets
|
||||
|
||||
elif config.fd is not None: # pragma: py-win32
|
||||
# Use an existing socket, from a file descriptor.
|
||||
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
|
||||
assert server.sockets is not None # mypy
|
||||
listeners = server.sockets
|
||||
self.servers = [server]
|
||||
|
||||
elif config.uds is not None: # pragma: py-win32
|
||||
# Create a socket using UNIX domain socket.
|
||||
uds_perms = 0o666
|
||||
if os.path.exists(config.uds):
|
||||
uds_perms = os.stat(config.uds).st_mode # pragma: full coverage
|
||||
server = await loop.create_unix_server(
|
||||
create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
|
||||
)
|
||||
os.chmod(config.uds, uds_perms)
|
||||
assert server.sockets is not None # mypy
|
||||
listeners = server.sockets
|
||||
self.servers = [server]
|
||||
|
||||
else:
|
||||
# Standard case. Create a socket from a host/port pair.
|
||||
try:
|
||||
server = await loop.create_server(
|
||||
create_protocol,
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
ssl=config.ssl,
|
||||
backlog=config.backlog,
|
||||
)
|
||||
except OSError as exc:
|
||||
logger.error(exc)
|
||||
await self.lifespan.shutdown()
|
||||
sys.exit(1)
|
||||
|
||||
assert server.sockets is not None
|
||||
listeners = server.sockets
|
||||
self.servers = [server]
|
||||
|
||||
if sockets is None:
|
||||
self._log_started_message(listeners)
|
||||
else:
|
||||
# We're most likely running multiple workers, so a message has already been
|
||||
# logged by `config.bind_socket()`.
|
||||
pass # pragma: full coverage
|
||||
|
||||
self.started = True
|
||||
|
||||
def _log_started_message(self, listeners: Sequence[socket.SocketType]) -> None:
|
||||
config = self.config
|
||||
|
||||
if config.fd is not None: # pragma: py-win32
|
||||
sock = listeners[0]
|
||||
logger.info(
|
||||
"Uvicorn running on socket %s (Press CTRL+C to quit)",
|
||||
sock.getsockname(),
|
||||
)
|
||||
|
||||
elif config.uds is not None: # pragma: py-win32
|
||||
logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds)
|
||||
|
||||
else:
|
||||
addr_format = "%s://%s:%d"
|
||||
host = "0.0.0.0" if config.host is None else config.host
|
||||
if ":" in host:
|
||||
# It's an IPv6 address.
|
||||
addr_format = "%s://[%s]:%d"
|
||||
|
||||
port = config.port
|
||||
if port == 0:
|
||||
port = listeners[0].getsockname()[1]
|
||||
|
||||
protocol_name = "https" if config.ssl else "http"
|
||||
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
|
||||
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger.info(
|
||||
message,
|
||||
protocol_name,
|
||||
host,
|
||||
port,
|
||||
extra={"color_message": color_message},
|
||||
)
|
||||
|
||||
async def main_loop(self) -> None:
|
||||
counter = 0
|
||||
should_exit = await self.on_tick(counter)
|
||||
while not should_exit:
|
||||
counter += 1
|
||||
counter = counter % 864000
|
||||
await asyncio.sleep(0.1)
|
||||
should_exit = await self.on_tick(counter)
|
||||
|
||||
async def on_tick(self, counter: int) -> bool:
|
||||
# Update the default headers, once per second.
|
||||
if counter % 10 == 0:
|
||||
current_time = time.time()
|
||||
current_date = formatdate(current_time, usegmt=True).encode()
|
||||
|
||||
if self.config.date_header:
|
||||
date_header = [(b"date", current_date)]
|
||||
else:
|
||||
date_header = []
|
||||
|
||||
self.server_state.default_headers = date_header + self.config.encoded_headers
|
||||
|
||||
# Callback to `callback_notify` once every `timeout_notify` seconds.
|
||||
if self.config.callback_notify is not None:
|
||||
if current_time - self.last_notified > self.config.timeout_notify: # pragma: full coverage
|
||||
self.last_notified = current_time
|
||||
await self.config.callback_notify()
|
||||
|
||||
# Determine if we should exit.
|
||||
if self.should_exit:
|
||||
return True
|
||||
|
||||
max_requests = self.config.limit_max_requests
|
||||
if max_requests is not None and self.server_state.total_requests >= max_requests:
|
||||
logger.warning(f"Maximum request limit of {max_requests} exceeded. Terminating process.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def shutdown(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
logger.info("Shutting down")
|
||||
|
||||
# Stop accepting new connections.
|
||||
for server in self.servers:
|
||||
server.close()
|
||||
for sock in sockets or []:
|
||||
sock.close() # pragma: full coverage
|
||||
|
||||
# Request shutdown on all existing connections.
|
||||
for connection in list(self.server_state.connections):
|
||||
connection.shutdown()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# When 3.10 is not supported anymore, use `async with asyncio.timeout(...):`.
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._wait_tasks_to_complete(),
|
||||
timeout=self.config.timeout_graceful_shutdown,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"Cancel %s running task(s), timeout graceful shutdown exceeded",
|
||||
len(self.server_state.tasks),
|
||||
)
|
||||
for t in self.server_state.tasks:
|
||||
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
|
||||
|
||||
# Send the lifespan shutdown event, and wait for application shutdown.
|
||||
if not self.force_exit:
|
||||
await self.lifespan.shutdown()
|
||||
|
||||
async def _wait_tasks_to_complete(self) -> None:
|
||||
# Wait for existing connections to finish sending responses.
|
||||
if self.server_state.connections and not self.force_exit:
|
||||
msg = "Waiting for connections to close. (CTRL+C to force quit)"
|
||||
logger.info(msg)
|
||||
while self.server_state.connections and not self.force_exit:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Wait for existing tasks to complete.
|
||||
if self.server_state.tasks and not self.force_exit:
|
||||
msg = "Waiting for background tasks to complete. (CTRL+C to force quit)"
|
||||
logger.info(msg)
|
||||
while self.server_state.tasks and not self.force_exit:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
for server in self.servers:
|
||||
await server.wait_closed()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_signals(self) -> Generator[None, None, None]:
|
||||
# Signals can only be listened to from the main thread.
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
yield
|
||||
return
|
||||
# always use signal.signal, even if loop.add_signal_handler is available
|
||||
# this allows to restore previous signal handlers later on
|
||||
original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for sig, handler in original_handlers.items():
|
||||
signal.signal(sig, handler)
|
||||
# If we did gracefully shut down due to a signal, try to
|
||||
# trigger the expected behaviour now; multiple signals would be
|
||||
# done LIFO, see https://stackoverflow.com/questions/48434964
|
||||
for captured_signal in reversed(self._captured_signals):
|
||||
signal.raise_signal(captured_signal)
|
||||
|
||||
def handle_exit(self, sig: int, frame: FrameType | None) -> None:
|
||||
self._captured_signals.append(sig)
|
||||
if self.should_exit and sig == signal.SIGINT:
|
||||
self.force_exit = True # pragma: full coverage
|
||||
else:
|
||||
self.should_exit = True
|
Reference in New Issue
Block a user