Update 2025-04-24_11:44:19
This commit is contained in:
42
venv/lib/python3.11/site-packages/urllib3/util/__init__.py
Normal file
42
venv/lib/python3.11/site-packages/urllib3/util/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
# For backwards compatibility, provide imports that used to be here.
|
||||
from __future__ import annotations
|
||||
|
||||
from .connection import is_connection_dropped
|
||||
from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers
|
||||
from .response import is_fp_closed
|
||||
from .retry import Retry
|
||||
from .ssl_ import (
|
||||
ALPN_PROTOCOLS,
|
||||
IS_PYOPENSSL,
|
||||
SSLContext,
|
||||
assert_fingerprint,
|
||||
create_urllib3_context,
|
||||
resolve_cert_reqs,
|
||||
resolve_ssl_version,
|
||||
ssl_wrap_socket,
|
||||
)
|
||||
from .timeout import Timeout
|
||||
from .url import Url, parse_url
|
||||
from .wait import wait_for_read, wait_for_write
|
||||
|
||||
__all__ = (
|
||||
"IS_PYOPENSSL",
|
||||
"SSLContext",
|
||||
"ALPN_PROTOCOLS",
|
||||
"Retry",
|
||||
"Timeout",
|
||||
"Url",
|
||||
"assert_fingerprint",
|
||||
"create_urllib3_context",
|
||||
"is_connection_dropped",
|
||||
"is_fp_closed",
|
||||
"parse_url",
|
||||
"make_headers",
|
||||
"resolve_cert_reqs",
|
||||
"resolve_ssl_version",
|
||||
"ssl_wrap_socket",
|
||||
"wait_for_read",
|
||||
"wait_for_write",
|
||||
"SKIP_HEADER",
|
||||
"SKIPPABLE_HEADERS",
|
||||
)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
137
venv/lib/python3.11/site-packages/urllib3/util/connection.py
Normal file
137
venv/lib/python3.11/site-packages/urllib3/util/connection.py
Normal file
@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ..exceptions import LocationParseError
|
||||
from .timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT
|
||||
|
||||
_TYPE_SOCKET_OPTIONS = list[tuple[int, int, typing.Union[int, bytes]]]
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .._base_connection import BaseHTTPConnection
|
||||
|
||||
|
||||
def is_connection_dropped(conn: BaseHTTPConnection) -> bool: # Platform-specific
|
||||
"""
|
||||
Returns True if the connection is dropped and should be closed.
|
||||
:param conn: :class:`urllib3.connection.HTTPConnection` object.
|
||||
"""
|
||||
return not conn.is_connected
|
||||
|
||||
|
||||
# This function is copied from socket.py in the Python 2.7 standard
|
||||
# library test suite. Added to its signature is only `socket_options`.
|
||||
# One additional modification is that we avoid binding to IPv6 servers
|
||||
# discovered in DNS if the system doesn't have IPv6 functionality.
|
||||
def create_connection(
|
||||
address: tuple[str, int],
|
||||
timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
|
||||
source_address: tuple[str, int] | None = None,
|
||||
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
|
||||
) -> socket.socket:
|
||||
"""Connect to *address* and return the socket object.
|
||||
|
||||
Convenience function. Connect to *address* (a 2-tuple ``(host,
|
||||
port)``) and return the socket object. Passing the optional
|
||||
*timeout* parameter will set the timeout on the socket instance
|
||||
before attempting to connect. If no *timeout* is supplied, the
|
||||
global default timeout setting returned by :func:`socket.getdefaulttimeout`
|
||||
is used. If *source_address* is set it must be a tuple of (host, port)
|
||||
for the socket to bind as a source address before making the connection.
|
||||
An host of '' or port 0 tells the OS to use the default.
|
||||
"""
|
||||
|
||||
host, port = address
|
||||
if host.startswith("["):
|
||||
host = host.strip("[]")
|
||||
err = None
|
||||
|
||||
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
|
||||
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
|
||||
# The original create_connection function always returns all records.
|
||||
family = allowed_gai_family()
|
||||
|
||||
try:
|
||||
host.encode("idna")
|
||||
except UnicodeError:
|
||||
raise LocationParseError(f"'{host}', label empty or too long") from None
|
||||
|
||||
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
|
||||
af, socktype, proto, canonname, sa = res
|
||||
sock = None
|
||||
try:
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
|
||||
# If provided, set socket level options before connecting.
|
||||
_set_socket_options(sock, socket_options)
|
||||
|
||||
if timeout is not _DEFAULT_TIMEOUT:
|
||||
sock.settimeout(timeout)
|
||||
if source_address:
|
||||
sock.bind(source_address)
|
||||
sock.connect(sa)
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
return sock
|
||||
|
||||
except OSError as _:
|
||||
err = _
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
|
||||
if err is not None:
|
||||
try:
|
||||
raise err
|
||||
finally:
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
else:
|
||||
raise OSError("getaddrinfo returns an empty list")
|
||||
|
||||
|
||||
def _set_socket_options(
|
||||
sock: socket.socket, options: _TYPE_SOCKET_OPTIONS | None
|
||||
) -> None:
|
||||
if options is None:
|
||||
return
|
||||
|
||||
for opt in options:
|
||||
sock.setsockopt(*opt)
|
||||
|
||||
|
||||
def allowed_gai_family() -> socket.AddressFamily:
|
||||
"""This function is designed to work in the context of
|
||||
getaddrinfo, where family=socket.AF_UNSPEC is the default and
|
||||
will perform a DNS search for both IPv6 and IPv4 records."""
|
||||
|
||||
family = socket.AF_INET
|
||||
if HAS_IPV6:
|
||||
family = socket.AF_UNSPEC
|
||||
return family
|
||||
|
||||
|
||||
def _has_ipv6(host: str) -> bool:
|
||||
"""Returns True if the system can bind an IPv6 address."""
|
||||
sock = None
|
||||
has_ipv6 = False
|
||||
|
||||
if socket.has_ipv6:
|
||||
# has_ipv6 returns true if cPython was compiled with IPv6 support.
|
||||
# It does not tell us if the system has IPv6 support enabled. To
|
||||
# determine that we must bind to an IPv6 address.
|
||||
# https://github.com/urllib3/urllib3/pull/611
|
||||
# https://bugs.python.org/issue658327
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET6)
|
||||
sock.bind((host, 0))
|
||||
has_ipv6 = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if sock:
|
||||
sock.close()
|
||||
return has_ipv6
|
||||
|
||||
|
||||
HAS_IPV6 = _has_ipv6("::1")
|
43
venv/lib/python3.11/site-packages/urllib3/util/proxy.py
Normal file
43
venv/lib/python3.11/site-packages/urllib3/util/proxy.py
Normal file
@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .url import Url
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..connection import ProxyConfig
|
||||
|
||||
|
||||
def connection_requires_http_tunnel(
|
||||
proxy_url: Url | None = None,
|
||||
proxy_config: ProxyConfig | None = None,
|
||||
destination_scheme: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the connection requires an HTTP CONNECT through the proxy.
|
||||
|
||||
:param URL proxy_url:
|
||||
URL of the proxy.
|
||||
:param ProxyConfig proxy_config:
|
||||
Proxy configuration from poolmanager.py
|
||||
:param str destination_scheme:
|
||||
The scheme of the destination. (i.e https, http, etc)
|
||||
"""
|
||||
# If we're not using a proxy, no way to use a tunnel.
|
||||
if proxy_url is None:
|
||||
return False
|
||||
|
||||
# HTTP destinations never require tunneling, we always forward.
|
||||
if destination_scheme == "http":
|
||||
return False
|
||||
|
||||
# Support for forwarding with HTTPS proxies and HTTPS destinations.
|
||||
if (
|
||||
proxy_url.scheme == "https"
|
||||
and proxy_config
|
||||
and proxy_config.use_forwarding_for_https
|
||||
):
|
||||
return False
|
||||
|
||||
# Otherwise always use a tunnel.
|
||||
return True
|
258
venv/lib/python3.11/site-packages/urllib3/util/request.py
Normal file
258
venv/lib/python3.11/site-packages/urllib3/util/request.py
Normal file
@ -0,0 +1,258 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
from enum import Enum
|
||||
|
||||
from ..exceptions import UnrewindableBodyError
|
||||
from .util import to_bytes
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing import Final
|
||||
|
||||
# Pass as a value within ``headers`` to skip
|
||||
# emitting some HTTP headers that are added automatically.
|
||||
# The only headers that are supported are ``Accept-Encoding``,
|
||||
# ``Host``, and ``User-Agent``.
|
||||
SKIP_HEADER = "@@@SKIP_HEADER@@@"
|
||||
SKIPPABLE_HEADERS = frozenset(["accept-encoding", "host", "user-agent"])
|
||||
|
||||
ACCEPT_ENCODING = "gzip,deflate"
|
||||
try:
|
||||
try:
|
||||
import brotlicffi as _unused_module_brotli # type: ignore[import-not-found] # noqa: F401
|
||||
except ImportError:
|
||||
import brotli as _unused_module_brotli # type: ignore[import-not-found] # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
ACCEPT_ENCODING += ",br"
|
||||
try:
|
||||
import zstandard as _unused_module_zstd # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
ACCEPT_ENCODING += ",zstd"
|
||||
|
||||
|
||||
class _TYPE_FAILEDTELL(Enum):
|
||||
token = 0
|
||||
|
||||
|
||||
_FAILEDTELL: Final[_TYPE_FAILEDTELL] = _TYPE_FAILEDTELL.token
|
||||
|
||||
_TYPE_BODY_POSITION = typing.Union[int, _TYPE_FAILEDTELL]
|
||||
|
||||
# When sending a request with these methods we aren't expecting
|
||||
# a body so don't need to set an explicit 'Content-Length: 0'
|
||||
# The reason we do this in the negative instead of tracking methods
|
||||
# which 'should' have a body is because unknown methods should be
|
||||
# treated as if they were 'POST' which *does* expect a body.
|
||||
_METHODS_NOT_EXPECTING_BODY = {"GET", "HEAD", "DELETE", "TRACE", "OPTIONS", "CONNECT"}
|
||||
|
||||
|
||||
def make_headers(
|
||||
keep_alive: bool | None = None,
|
||||
accept_encoding: bool | list[str] | str | None = None,
|
||||
user_agent: str | None = None,
|
||||
basic_auth: str | None = None,
|
||||
proxy_basic_auth: str | None = None,
|
||||
disable_cache: bool | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Shortcuts for generating request headers.
|
||||
|
||||
:param keep_alive:
|
||||
If ``True``, adds 'connection: keep-alive' header.
|
||||
|
||||
:param accept_encoding:
|
||||
Can be a boolean, list, or string.
|
||||
``True`` translates to 'gzip,deflate'. If the dependencies for
|
||||
Brotli (either the ``brotli`` or ``brotlicffi`` package) and/or Zstandard
|
||||
(the ``zstandard`` package) algorithms are installed, then their encodings are
|
||||
included in the string ('br' and 'zstd', respectively).
|
||||
List will get joined by comma.
|
||||
String will be used as provided.
|
||||
|
||||
:param user_agent:
|
||||
String representing the user-agent you want, such as
|
||||
"python-urllib3/0.6"
|
||||
|
||||
:param basic_auth:
|
||||
Colon-separated username:password string for 'authorization: basic ...'
|
||||
auth header.
|
||||
|
||||
:param proxy_basic_auth:
|
||||
Colon-separated username:password string for 'proxy-authorization: basic ...'
|
||||
auth header.
|
||||
|
||||
:param disable_cache:
|
||||
If ``True``, adds 'cache-control: no-cache' header.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import urllib3
|
||||
|
||||
print(urllib3.util.make_headers(keep_alive=True, user_agent="Batman/1.0"))
|
||||
# {'connection': 'keep-alive', 'user-agent': 'Batman/1.0'}
|
||||
print(urllib3.util.make_headers(accept_encoding=True))
|
||||
# {'accept-encoding': 'gzip,deflate'}
|
||||
"""
|
||||
headers: dict[str, str] = {}
|
||||
if accept_encoding:
|
||||
if isinstance(accept_encoding, str):
|
||||
pass
|
||||
elif isinstance(accept_encoding, list):
|
||||
accept_encoding = ",".join(accept_encoding)
|
||||
else:
|
||||
accept_encoding = ACCEPT_ENCODING
|
||||
headers["accept-encoding"] = accept_encoding
|
||||
|
||||
if user_agent:
|
||||
headers["user-agent"] = user_agent
|
||||
|
||||
if keep_alive:
|
||||
headers["connection"] = "keep-alive"
|
||||
|
||||
if basic_auth:
|
||||
headers["authorization"] = (
|
||||
f"Basic {b64encode(basic_auth.encode('latin-1')).decode()}"
|
||||
)
|
||||
|
||||
if proxy_basic_auth:
|
||||
headers["proxy-authorization"] = (
|
||||
f"Basic {b64encode(proxy_basic_auth.encode('latin-1')).decode()}"
|
||||
)
|
||||
|
||||
if disable_cache:
|
||||
headers["cache-control"] = "no-cache"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def set_file_position(
|
||||
body: typing.Any, pos: _TYPE_BODY_POSITION | None
|
||||
) -> _TYPE_BODY_POSITION | None:
|
||||
"""
|
||||
If a position is provided, move file to that point.
|
||||
Otherwise, we'll attempt to record a position for future use.
|
||||
"""
|
||||
if pos is not None:
|
||||
rewind_body(body, pos)
|
||||
elif getattr(body, "tell", None) is not None:
|
||||
try:
|
||||
pos = body.tell()
|
||||
except OSError:
|
||||
# This differentiates from None, allowing us to catch
|
||||
# a failed `tell()` later when trying to rewind the body.
|
||||
pos = _FAILEDTELL
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def rewind_body(body: typing.IO[typing.AnyStr], body_pos: _TYPE_BODY_POSITION) -> None:
|
||||
"""
|
||||
Attempt to rewind body to a certain position.
|
||||
Primarily used for request redirects and retries.
|
||||
|
||||
:param body:
|
||||
File-like object that supports seek.
|
||||
|
||||
:param int pos:
|
||||
Position to seek to in file.
|
||||
"""
|
||||
body_seek = getattr(body, "seek", None)
|
||||
if body_seek is not None and isinstance(body_pos, int):
|
||||
try:
|
||||
body_seek(body_pos)
|
||||
except OSError as e:
|
||||
raise UnrewindableBodyError(
|
||||
"An error occurred when rewinding request body for redirect/retry."
|
||||
) from e
|
||||
elif body_pos is _FAILEDTELL:
|
||||
raise UnrewindableBodyError(
|
||||
"Unable to record file position for rewinding "
|
||||
"request body during a redirect/retry."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"body_pos must be of type integer, instead it was {type(body_pos)}."
|
||||
)
|
||||
|
||||
|
||||
class ChunksAndContentLength(typing.NamedTuple):
|
||||
chunks: typing.Iterable[bytes] | None
|
||||
content_length: int | None
|
||||
|
||||
|
||||
def body_to_chunks(
|
||||
body: typing.Any | None, method: str, blocksize: int
|
||||
) -> ChunksAndContentLength:
|
||||
"""Takes the HTTP request method, body, and blocksize and
|
||||
transforms them into an iterable of chunks to pass to
|
||||
socket.sendall() and an optional 'Content-Length' header.
|
||||
|
||||
A 'Content-Length' of 'None' indicates the length of the body
|
||||
can't be determined so should use 'Transfer-Encoding: chunked'
|
||||
for framing instead.
|
||||
"""
|
||||
|
||||
chunks: typing.Iterable[bytes] | None
|
||||
content_length: int | None
|
||||
|
||||
# No body, we need to make a recommendation on 'Content-Length'
|
||||
# based on whether that request method is expected to have
|
||||
# a body or not.
|
||||
if body is None:
|
||||
chunks = None
|
||||
if method.upper() not in _METHODS_NOT_EXPECTING_BODY:
|
||||
content_length = 0
|
||||
else:
|
||||
content_length = None
|
||||
|
||||
# Bytes or strings become bytes
|
||||
elif isinstance(body, (str, bytes)):
|
||||
chunks = (to_bytes(body),)
|
||||
content_length = len(chunks[0])
|
||||
|
||||
# File-like object, TODO: use seek() and tell() for length?
|
||||
elif hasattr(body, "read"):
|
||||
|
||||
def chunk_readable() -> typing.Iterable[bytes]:
|
||||
nonlocal body, blocksize
|
||||
encode = isinstance(body, io.TextIOBase)
|
||||
while True:
|
||||
datablock = body.read(blocksize)
|
||||
if not datablock:
|
||||
break
|
||||
if encode:
|
||||
datablock = datablock.encode("utf-8")
|
||||
yield datablock
|
||||
|
||||
chunks = chunk_readable()
|
||||
content_length = None
|
||||
|
||||
# Otherwise we need to start checking via duck-typing.
|
||||
else:
|
||||
try:
|
||||
# Check if the body implements the buffer API.
|
||||
mv = memoryview(body)
|
||||
except TypeError:
|
||||
try:
|
||||
# Check if the body is an iterable
|
||||
chunks = iter(body)
|
||||
content_length = None
|
||||
except TypeError:
|
||||
raise TypeError(
|
||||
f"'body' must be a bytes-like object, file-like "
|
||||
f"object, or iterable. Instead was {body!r}"
|
||||
) from None
|
||||
else:
|
||||
# Since it implements the buffer API can be passed directly to socket.sendall()
|
||||
chunks = (body,)
|
||||
content_length = mv.nbytes
|
||||
|
||||
return ChunksAndContentLength(chunks=chunks, content_length=content_length)
|
101
venv/lib/python3.11/site-packages/urllib3/util/response.py
Normal file
101
venv/lib/python3.11/site-packages/urllib3/util/response.py
Normal file
@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import http.client as httplib
|
||||
from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect
|
||||
|
||||
from ..exceptions import HeaderParsingError
|
||||
|
||||
|
||||
def is_fp_closed(obj: object) -> bool:
|
||||
"""
|
||||
Checks whether a given file-like object is closed.
|
||||
|
||||
:param obj:
|
||||
The file-like object to check.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Check `isclosed()` first, in case Python3 doesn't set `closed`.
|
||||
# GH Issue #928
|
||||
return obj.isclosed() # type: ignore[no-any-return, attr-defined]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Check via the official file-like-object way.
|
||||
return obj.closed # type: ignore[no-any-return, attr-defined]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Check if the object is a container for another file-like object that
|
||||
# gets released on exhaustion (e.g. HTTPResponse).
|
||||
return obj.fp is None # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
raise ValueError("Unable to determine whether fp is closed.")
|
||||
|
||||
|
||||
def assert_header_parsing(headers: httplib.HTTPMessage) -> None:
|
||||
"""
|
||||
Asserts whether all headers have been successfully parsed.
|
||||
Extracts encountered errors from the result of parsing headers.
|
||||
|
||||
Only works on Python 3.
|
||||
|
||||
:param http.client.HTTPMessage headers: Headers to verify.
|
||||
|
||||
:raises urllib3.exceptions.HeaderParsingError:
|
||||
If parsing errors are found.
|
||||
"""
|
||||
|
||||
# This will fail silently if we pass in the wrong kind of parameter.
|
||||
# To make debugging easier add an explicit check.
|
||||
if not isinstance(headers, httplib.HTTPMessage):
|
||||
raise TypeError(f"expected httplib.Message, got {type(headers)}.")
|
||||
|
||||
unparsed_data = None
|
||||
|
||||
# get_payload is actually email.message.Message.get_payload;
|
||||
# we're only interested in the result if it's not a multipart message
|
||||
if not headers.is_multipart():
|
||||
payload = headers.get_payload()
|
||||
|
||||
if isinstance(payload, (bytes, str)):
|
||||
unparsed_data = payload
|
||||
|
||||
# httplib is assuming a response body is available
|
||||
# when parsing headers even when httplib only sends
|
||||
# header data to parse_headers() This results in
|
||||
# defects on multipart responses in particular.
|
||||
# See: https://github.com/urllib3/urllib3/issues/800
|
||||
|
||||
# So we ignore the following defects:
|
||||
# - StartBoundaryNotFoundDefect:
|
||||
# The claimed start boundary was never found.
|
||||
# - MultipartInvariantViolationDefect:
|
||||
# A message claimed to be a multipart but no subparts were found.
|
||||
defects = [
|
||||
defect
|
||||
for defect in headers.defects
|
||||
if not isinstance(
|
||||
defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect)
|
||||
)
|
||||
]
|
||||
|
||||
if defects or unparsed_data:
|
||||
raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data)
|
||||
|
||||
|
||||
def is_response_to_head(response: httplib.HTTPResponse) -> bool:
|
||||
"""
|
||||
Checks whether the request of a response has been a HEAD-request.
|
||||
|
||||
:param http.client.HTTPResponse response:
|
||||
Response to check if the originating request
|
||||
used 'HEAD' as a method.
|
||||
"""
|
||||
# FIXME: Can we do this somehow without accessing private httplib _method?
|
||||
method_str = response._method # type: str # type: ignore[attr-defined]
|
||||
return method_str.upper() == "HEAD"
|
533
venv/lib/python3.11/site-packages/urllib3/util/retry.py
Normal file
533
venv/lib/python3.11/site-packages/urllib3/util/retry.py
Normal file
@ -0,0 +1,533 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import email
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import typing
|
||||
from itertools import takewhile
|
||||
from types import TracebackType
|
||||
|
||||
from ..exceptions import (
|
||||
ConnectTimeoutError,
|
||||
InvalidHeader,
|
||||
MaxRetryError,
|
||||
ProtocolError,
|
||||
ProxyError,
|
||||
ReadTimeoutError,
|
||||
ResponseError,
|
||||
)
|
||||
from .util import reraise
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..connectionpool import ConnectionPool
|
||||
from ..response import BaseHTTPResponse
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Data structure for representing the metadata of requests that result in a retry.
|
||||
class RequestHistory(typing.NamedTuple):
|
||||
method: str | None
|
||||
url: str | None
|
||||
error: Exception | None
|
||||
status: int | None
|
||||
redirect_location: str | None
|
||||
|
||||
|
||||
class Retry:
|
||||
"""Retry configuration.
|
||||
|
||||
Each retry attempt will create a new Retry object with updated values, so
|
||||
they can be safely reused.
|
||||
|
||||
Retries can be defined as a default for a pool:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
retries = Retry(connect=5, read=2, redirect=5)
|
||||
http = PoolManager(retries=retries)
|
||||
response = http.request("GET", "https://example.com/")
|
||||
|
||||
Or per-request (which overrides the default for the pool):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
response = http.request("GET", "https://example.com/", retries=Retry(10))
|
||||
|
||||
Retries can be disabled by passing ``False``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
response = http.request("GET", "https://example.com/", retries=False)
|
||||
|
||||
Errors will be wrapped in :class:`~urllib3.exceptions.MaxRetryError` unless
|
||||
retries are disabled, in which case the causing exception will be raised.
|
||||
|
||||
:param int total:
|
||||
Total number of retries to allow. Takes precedence over other counts.
|
||||
|
||||
Set to ``None`` to remove this constraint and fall back on other
|
||||
counts.
|
||||
|
||||
Set to ``0`` to fail on the first retry.
|
||||
|
||||
Set to ``False`` to disable and imply ``raise_on_redirect=False``.
|
||||
|
||||
:param int connect:
|
||||
How many connection-related errors to retry on.
|
||||
|
||||
These are errors raised before the request is sent to the remote server,
|
||||
which we assume has not triggered the server to process the request.
|
||||
|
||||
Set to ``0`` to fail on the first retry of this type.
|
||||
|
||||
:param int read:
|
||||
How many times to retry on read errors.
|
||||
|
||||
These errors are raised after the request was sent to the server, so the
|
||||
request may have side-effects.
|
||||
|
||||
Set to ``0`` to fail on the first retry of this type.
|
||||
|
||||
:param int redirect:
|
||||
How many redirects to perform. Limit this to avoid infinite redirect
|
||||
loops.
|
||||
|
||||
A redirect is a HTTP response with a status code 301, 302, 303, 307 or
|
||||
308.
|
||||
|
||||
Set to ``0`` to fail on the first retry of this type.
|
||||
|
||||
Set to ``False`` to disable and imply ``raise_on_redirect=False``.
|
||||
|
||||
:param int status:
|
||||
How many times to retry on bad status codes.
|
||||
|
||||
These are retries made on responses, where status code matches
|
||||
``status_forcelist``.
|
||||
|
||||
Set to ``0`` to fail on the first retry of this type.
|
||||
|
||||
:param int other:
|
||||
How many times to retry on other errors.
|
||||
|
||||
Other errors are errors that are not connect, read, redirect or status errors.
|
||||
These errors might be raised after the request was sent to the server, so the
|
||||
request might have side-effects.
|
||||
|
||||
Set to ``0`` to fail on the first retry of this type.
|
||||
|
||||
If ``total`` is not set, it's a good idea to set this to 0 to account
|
||||
for unexpected edge cases and avoid infinite retry loops.
|
||||
|
||||
:param Collection allowed_methods:
|
||||
Set of uppercased HTTP method verbs that we should retry on.
|
||||
|
||||
By default, we only retry on methods which are considered to be
|
||||
idempotent (multiple requests with the same parameters end with the
|
||||
same state). See :attr:`Retry.DEFAULT_ALLOWED_METHODS`.
|
||||
|
||||
Set to a ``None`` value to retry on any verb.
|
||||
|
||||
:param Collection status_forcelist:
|
||||
A set of integer HTTP status codes that we should force a retry on.
|
||||
A retry is initiated if the request method is in ``allowed_methods``
|
||||
and the response status code is in ``status_forcelist``.
|
||||
|
||||
By default, this is disabled with ``None``.
|
||||
|
||||
:param float backoff_factor:
|
||||
A backoff factor to apply between attempts after the second try
|
||||
(most errors are resolved immediately by a second try without a
|
||||
delay). urllib3 will sleep for::
|
||||
|
||||
{backoff factor} * (2 ** ({number of previous retries}))
|
||||
|
||||
seconds. If `backoff_jitter` is non-zero, this sleep is extended by::
|
||||
|
||||
random.uniform(0, {backoff jitter})
|
||||
|
||||
seconds. For example, if the backoff_factor is 0.1, then :func:`Retry.sleep` will
|
||||
sleep for [0.0s, 0.2s, 0.4s, 0.8s, ...] between retries. No backoff will ever
|
||||
be longer than `backoff_max`.
|
||||
|
||||
By default, backoff is disabled (factor set to 0).
|
||||
|
||||
:param bool raise_on_redirect: Whether, if the number of redirects is
|
||||
exhausted, to raise a MaxRetryError, or to return a response with a
|
||||
response code in the 3xx range.
|
||||
|
||||
:param bool raise_on_status: Similar meaning to ``raise_on_redirect``:
|
||||
whether we should raise an exception, or return a response,
|
||||
if status falls in ``status_forcelist`` range and retries have
|
||||
been exhausted.
|
||||
|
||||
:param tuple history: The history of the request encountered during
|
||||
each call to :meth:`~Retry.increment`. The list is in the order
|
||||
the requests occurred. Each list item is of class :class:`RequestHistory`.
|
||||
|
||||
:param bool respect_retry_after_header:
|
||||
Whether to respect Retry-After header on status codes defined as
|
||||
:attr:`Retry.RETRY_AFTER_STATUS_CODES` or not.
|
||||
|
||||
:param Collection remove_headers_on_redirect:
|
||||
Sequence of headers to remove from the request when a response
|
||||
indicating a redirect is returned before firing off the redirected
|
||||
request.
|
||||
"""
|
||||
|
||||
#: Default methods to be used for ``allowed_methods``
|
||||
DEFAULT_ALLOWED_METHODS = frozenset(
|
||||
["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]
|
||||
)
|
||||
|
||||
#: Default status codes to be used for ``status_forcelist``
|
||||
RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503])
|
||||
|
||||
#: Default headers to be used for ``remove_headers_on_redirect``
|
||||
DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset(
|
||||
["Cookie", "Authorization", "Proxy-Authorization"]
|
||||
)
|
||||
|
||||
#: Default maximum backoff time.
|
||||
DEFAULT_BACKOFF_MAX = 120
|
||||
|
||||
# Backward compatibility; assigned outside of the class.
|
||||
DEFAULT: typing.ClassVar[Retry]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
total: bool | int | None = 10,
|
||||
connect: int | None = None,
|
||||
read: int | None = None,
|
||||
redirect: bool | int | None = None,
|
||||
status: int | None = None,
|
||||
other: int | None = None,
|
||||
allowed_methods: typing.Collection[str] | None = DEFAULT_ALLOWED_METHODS,
|
||||
status_forcelist: typing.Collection[int] | None = None,
|
||||
backoff_factor: float = 0,
|
||||
backoff_max: float = DEFAULT_BACKOFF_MAX,
|
||||
raise_on_redirect: bool = True,
|
||||
raise_on_status: bool = True,
|
||||
history: tuple[RequestHistory, ...] | None = None,
|
||||
respect_retry_after_header: bool = True,
|
||||
remove_headers_on_redirect: typing.Collection[
|
||||
str
|
||||
] = DEFAULT_REMOVE_HEADERS_ON_REDIRECT,
|
||||
backoff_jitter: float = 0.0,
|
||||
) -> None:
|
||||
self.total = total
|
||||
self.connect = connect
|
||||
self.read = read
|
||||
self.status = status
|
||||
self.other = other
|
||||
|
||||
if redirect is False or total is False:
|
||||
redirect = 0
|
||||
raise_on_redirect = False
|
||||
|
||||
self.redirect = redirect
|
||||
self.status_forcelist = status_forcelist or set()
|
||||
self.allowed_methods = allowed_methods
|
||||
self.backoff_factor = backoff_factor
|
||||
self.backoff_max = backoff_max
|
||||
self.raise_on_redirect = raise_on_redirect
|
||||
self.raise_on_status = raise_on_status
|
||||
self.history = history or ()
|
||||
self.respect_retry_after_header = respect_retry_after_header
|
||||
self.remove_headers_on_redirect = frozenset(
|
||||
h.lower() for h in remove_headers_on_redirect
|
||||
)
|
||||
self.backoff_jitter = backoff_jitter
|
||||
|
||||
def new(self, **kw: typing.Any) -> Self:
|
||||
params = dict(
|
||||
total=self.total,
|
||||
connect=self.connect,
|
||||
read=self.read,
|
||||
redirect=self.redirect,
|
||||
status=self.status,
|
||||
other=self.other,
|
||||
allowed_methods=self.allowed_methods,
|
||||
status_forcelist=self.status_forcelist,
|
||||
backoff_factor=self.backoff_factor,
|
||||
backoff_max=self.backoff_max,
|
||||
raise_on_redirect=self.raise_on_redirect,
|
||||
raise_on_status=self.raise_on_status,
|
||||
history=self.history,
|
||||
remove_headers_on_redirect=self.remove_headers_on_redirect,
|
||||
respect_retry_after_header=self.respect_retry_after_header,
|
||||
backoff_jitter=self.backoff_jitter,
|
||||
)
|
||||
|
||||
params.update(kw)
|
||||
return type(self)(**params) # type: ignore[arg-type]
|
||||
|
||||
@classmethod
|
||||
def from_int(
|
||||
cls,
|
||||
retries: Retry | bool | int | None,
|
||||
redirect: bool | int | None = True,
|
||||
default: Retry | bool | int | None = None,
|
||||
) -> Retry:
|
||||
"""Backwards-compatibility for the old retries format."""
|
||||
if retries is None:
|
||||
retries = default if default is not None else cls.DEFAULT
|
||||
|
||||
if isinstance(retries, Retry):
|
||||
return retries
|
||||
|
||||
redirect = bool(redirect) and None
|
||||
new_retries = cls(retries, redirect=redirect)
|
||||
log.debug("Converted retries value: %r -> %r", retries, new_retries)
|
||||
return new_retries
|
||||
|
||||
def get_backoff_time(self) -> float:
|
||||
"""Formula for computing the current backoff
|
||||
|
||||
:rtype: float
|
||||
"""
|
||||
# We want to consider only the last consecutive errors sequence (Ignore redirects).
|
||||
consecutive_errors_len = len(
|
||||
list(
|
||||
takewhile(lambda x: x.redirect_location is None, reversed(self.history))
|
||||
)
|
||||
)
|
||||
if consecutive_errors_len <= 1:
|
||||
return 0
|
||||
|
||||
backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1))
|
||||
if self.backoff_jitter != 0.0:
|
||||
backoff_value += random.random() * self.backoff_jitter
|
||||
return float(max(0, min(self.backoff_max, backoff_value)))
|
||||
|
||||
def parse_retry_after(self, retry_after: str) -> float:
|
||||
seconds: float
|
||||
# Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4
|
||||
if re.match(r"^\s*[0-9]+\s*$", retry_after):
|
||||
seconds = int(retry_after)
|
||||
else:
|
||||
retry_date_tuple = email.utils.parsedate_tz(retry_after)
|
||||
if retry_date_tuple is None:
|
||||
raise InvalidHeader(f"Invalid Retry-After header: {retry_after}")
|
||||
|
||||
retry_date = email.utils.mktime_tz(retry_date_tuple)
|
||||
seconds = retry_date - time.time()
|
||||
|
||||
seconds = max(seconds, 0)
|
||||
|
||||
return seconds
|
||||
|
||||
def get_retry_after(self, response: BaseHTTPResponse) -> float | None:
|
||||
"""Get the value of Retry-After in seconds."""
|
||||
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
|
||||
if retry_after is None:
|
||||
return None
|
||||
|
||||
return self.parse_retry_after(retry_after)
|
||||
|
||||
def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
|
||||
retry_after = self.get_retry_after(response)
|
||||
if retry_after:
|
||||
time.sleep(retry_after)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _sleep_backoff(self) -> None:
|
||||
backoff = self.get_backoff_time()
|
||||
if backoff <= 0:
|
||||
return
|
||||
time.sleep(backoff)
|
||||
|
||||
def sleep(self, response: BaseHTTPResponse | None = None) -> None:
|
||||
"""Sleep between retry attempts.
|
||||
|
||||
This method will respect a server's ``Retry-After`` response header
|
||||
and sleep the duration of the time requested. If that is not present, it
|
||||
will use an exponential backoff. By default, the backoff factor is 0 and
|
||||
this method will return immediately.
|
||||
"""
|
||||
|
||||
if self.respect_retry_after_header and response:
|
||||
slept = self.sleep_for_retry(response)
|
||||
if slept:
|
||||
return
|
||||
|
||||
self._sleep_backoff()
|
||||
|
||||
def _is_connection_error(self, err: Exception) -> bool:
|
||||
"""Errors when we're fairly sure that the server did not receive the
|
||||
request, so it should be safe to retry.
|
||||
"""
|
||||
if isinstance(err, ProxyError):
|
||||
err = err.original_error
|
||||
return isinstance(err, ConnectTimeoutError)
|
||||
|
||||
def _is_read_error(self, err: Exception) -> bool:
|
||||
"""Errors that occur after the request has been started, so we should
|
||||
assume that the server began processing it.
|
||||
"""
|
||||
return isinstance(err, (ReadTimeoutError, ProtocolError))
|
||||
|
||||
def _is_method_retryable(self, method: str) -> bool:
|
||||
"""Checks if a given HTTP method should be retried upon, depending if
|
||||
it is included in the allowed_methods
|
||||
"""
|
||||
if self.allowed_methods and method.upper() not in self.allowed_methods:
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_retry(
|
||||
self, method: str, status_code: int, has_retry_after: bool = False
|
||||
) -> bool:
|
||||
"""Is this method/status code retryable? (Based on allowlists and control
|
||||
variables such as the number of total retries to allow, whether to
|
||||
respect the Retry-After header, whether this header is present, and
|
||||
whether the returned status code is on the list of status codes to
|
||||
be retried upon on the presence of the aforementioned header)
|
||||
"""
|
||||
if not self._is_method_retryable(method):
|
||||
return False
|
||||
|
||||
if self.status_forcelist and status_code in self.status_forcelist:
|
||||
return True
|
||||
|
||||
return bool(
|
||||
self.total
|
||||
and self.respect_retry_after_header
|
||||
and has_retry_after
|
||||
and (status_code in self.RETRY_AFTER_STATUS_CODES)
|
||||
)
|
||||
|
||||
def is_exhausted(self) -> bool:
|
||||
"""Are we out of retries?"""
|
||||
retry_counts = [
|
||||
x
|
||||
for x in (
|
||||
self.total,
|
||||
self.connect,
|
||||
self.read,
|
||||
self.redirect,
|
||||
self.status,
|
||||
self.other,
|
||||
)
|
||||
if x
|
||||
]
|
||||
if not retry_counts:
|
||||
return False
|
||||
|
||||
return min(retry_counts) < 0
|
||||
|
||||
def increment(
|
||||
self,
|
||||
method: str | None = None,
|
||||
url: str | None = None,
|
||||
response: BaseHTTPResponse | None = None,
|
||||
error: Exception | None = None,
|
||||
_pool: ConnectionPool | None = None,
|
||||
_stacktrace: TracebackType | None = None,
|
||||
) -> Self:
|
||||
"""Return a new Retry object with incremented retry counters.
|
||||
|
||||
:param response: A response object, or None, if the server did not
|
||||
return a response.
|
||||
:type response: :class:`~urllib3.response.BaseHTTPResponse`
|
||||
:param Exception error: An error encountered during the request, or
|
||||
None if the response was received successfully.
|
||||
|
||||
:return: A new ``Retry`` object.
|
||||
"""
|
||||
if self.total is False and error:
|
||||
# Disabled, indicate to re-raise the error.
|
||||
raise reraise(type(error), error, _stacktrace)
|
||||
|
||||
total = self.total
|
||||
if total is not None:
|
||||
total -= 1
|
||||
|
||||
connect = self.connect
|
||||
read = self.read
|
||||
redirect = self.redirect
|
||||
status_count = self.status
|
||||
other = self.other
|
||||
cause = "unknown"
|
||||
status = None
|
||||
redirect_location = None
|
||||
|
||||
if error and self._is_connection_error(error):
|
||||
# Connect retry?
|
||||
if connect is False:
|
||||
raise reraise(type(error), error, _stacktrace)
|
||||
elif connect is not None:
|
||||
connect -= 1
|
||||
|
||||
elif error and self._is_read_error(error):
|
||||
# Read retry?
|
||||
if read is False or method is None or not self._is_method_retryable(method):
|
||||
raise reraise(type(error), error, _stacktrace)
|
||||
elif read is not None:
|
||||
read -= 1
|
||||
|
||||
elif error:
|
||||
# Other retry?
|
||||
if other is not None:
|
||||
other -= 1
|
||||
|
||||
elif response and response.get_redirect_location():
|
||||
# Redirect retry?
|
||||
if redirect is not None:
|
||||
redirect -= 1
|
||||
cause = "too many redirects"
|
||||
response_redirect_location = response.get_redirect_location()
|
||||
if response_redirect_location:
|
||||
redirect_location = response_redirect_location
|
||||
status = response.status
|
||||
|
||||
else:
|
||||
# Incrementing because of a server error like a 500 in
|
||||
# status_forcelist and the given method is in the allowed_methods
|
||||
cause = ResponseError.GENERIC_ERROR
|
||||
if response and response.status:
|
||||
if status_count is not None:
|
||||
status_count -= 1
|
||||
cause = ResponseError.SPECIFIC_ERROR.format(status_code=response.status)
|
||||
status = response.status
|
||||
|
||||
history = self.history + (
|
||||
RequestHistory(method, url, error, status, redirect_location),
|
||||
)
|
||||
|
||||
new_retry = self.new(
|
||||
total=total,
|
||||
connect=connect,
|
||||
read=read,
|
||||
redirect=redirect,
|
||||
status=status_count,
|
||||
other=other,
|
||||
history=history,
|
||||
)
|
||||
|
||||
if new_retry.is_exhausted():
|
||||
reason = error or ResponseError(cause)
|
||||
raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type]
|
||||
|
||||
log.debug("Incremented Retry for (url='%s'): %r", url, new_retry)
|
||||
|
||||
return new_retry
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{type(self).__name__}(total={self.total}, connect={self.connect}, "
|
||||
f"read={self.read}, redirect={self.redirect}, status={self.status})"
|
||||
)
|
||||
|
||||
|
||||
# For backwards compatibility (equivalent to pre-v1.9):
|
||||
Retry.DEFAULT = Retry(3)
|
524
venv/lib/python3.11/site-packages/urllib3/util/ssl_.py
Normal file
524
venv/lib/python3.11/site-packages/urllib3/util/ssl_.py
Normal file
@ -0,0 +1,524 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from binascii import unhexlify
|
||||
|
||||
from ..exceptions import ProxySchemeUnsupported, SSLError
|
||||
from .url import _BRACELESS_IPV6_ADDRZ_RE, _IPV4_RE
|
||||
|
||||
SSLContext = None
|
||||
SSLTransport = None
|
||||
HAS_NEVER_CHECK_COMMON_NAME = False
|
||||
IS_PYOPENSSL = False
|
||||
ALPN_PROTOCOLS = ["http/1.1"]
|
||||
|
||||
_TYPE_VERSION_INFO = tuple[int, int, int, str, int]
|
||||
|
||||
# Maps the length of a digest to a possible hash function producing this digest
|
||||
HASHFUNC_MAP = {
|
||||
length: getattr(hashlib, algorithm, None)
|
||||
for length, algorithm in ((32, "md5"), (40, "sha1"), (64, "sha256"))
|
||||
}
|
||||
|
||||
|
||||
def _is_bpo_43522_fixed(
|
||||
implementation_name: str,
|
||||
version_info: _TYPE_VERSION_INFO,
|
||||
pypy_version_info: _TYPE_VERSION_INFO | None,
|
||||
) -> bool:
|
||||
"""Return True for CPython 3.9.3+ or 3.10+ and PyPy 7.3.8+ where
|
||||
setting SSLContext.hostname_checks_common_name to False works.
|
||||
|
||||
Outside of CPython and PyPy we don't know which implementations work
|
||||
or not so we conservatively use our hostname matching as we know that works
|
||||
on all implementations.
|
||||
|
||||
https://github.com/urllib3/urllib3/issues/2192#issuecomment-821832963
|
||||
https://foss.heptapod.net/pypy/pypy/-/issues/3539
|
||||
"""
|
||||
if implementation_name == "pypy":
|
||||
# https://foss.heptapod.net/pypy/pypy/-/issues/3129
|
||||
return pypy_version_info >= (7, 3, 8) # type: ignore[operator]
|
||||
elif implementation_name == "cpython":
|
||||
major_minor = version_info[:2]
|
||||
micro = version_info[2]
|
||||
return (major_minor == (3, 9) and micro >= 3) or major_minor >= (3, 10)
|
||||
else: # Defensive:
|
||||
return False
|
||||
|
||||
|
||||
def _is_has_never_check_common_name_reliable(
|
||||
openssl_version: str,
|
||||
openssl_version_number: int,
|
||||
implementation_name: str,
|
||||
version_info: _TYPE_VERSION_INFO,
|
||||
pypy_version_info: _TYPE_VERSION_INFO | None,
|
||||
) -> bool:
|
||||
# As of May 2023, all released versions of LibreSSL fail to reject certificates with
|
||||
# only common names, see https://github.com/urllib3/urllib3/pull/3024
|
||||
is_openssl = openssl_version.startswith("OpenSSL ")
|
||||
# Before fixing OpenSSL issue #14579, the SSL_new() API was not copying hostflags
|
||||
# like X509_CHECK_FLAG_NEVER_CHECK_SUBJECT, which tripped up CPython.
|
||||
# https://github.com/openssl/openssl/issues/14579
|
||||
# This was released in OpenSSL 1.1.1l+ (>=0x101010cf)
|
||||
is_openssl_issue_14579_fixed = openssl_version_number >= 0x101010CF
|
||||
|
||||
return is_openssl and (
|
||||
is_openssl_issue_14579_fixed
|
||||
or _is_bpo_43522_fixed(implementation_name, version_info, pypy_version_info)
|
||||
)
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ssl import VerifyMode
|
||||
from typing import TypedDict
|
||||
|
||||
from .ssltransport import SSLTransport as SSLTransportType
|
||||
|
||||
class _TYPE_PEER_CERT_RET_DICT(TypedDict, total=False):
|
||||
subjectAltName: tuple[tuple[str, str], ...]
|
||||
subject: tuple[tuple[tuple[str, str], ...], ...]
|
||||
serialNumber: str
|
||||
|
||||
|
||||
# Mapping from 'ssl.PROTOCOL_TLSX' to 'TLSVersion.X'
|
||||
_SSL_VERSION_TO_TLS_VERSION: dict[int, int] = {}
|
||||
|
||||
try: # Do we have ssl at all?
|
||||
import ssl
|
||||
from ssl import ( # type: ignore[assignment]
|
||||
CERT_REQUIRED,
|
||||
HAS_NEVER_CHECK_COMMON_NAME,
|
||||
OP_NO_COMPRESSION,
|
||||
OP_NO_TICKET,
|
||||
OPENSSL_VERSION,
|
||||
OPENSSL_VERSION_NUMBER,
|
||||
PROTOCOL_TLS,
|
||||
PROTOCOL_TLS_CLIENT,
|
||||
VERIFY_X509_STRICT,
|
||||
OP_NO_SSLv2,
|
||||
OP_NO_SSLv3,
|
||||
SSLContext,
|
||||
TLSVersion,
|
||||
)
|
||||
|
||||
PROTOCOL_SSLv23 = PROTOCOL_TLS
|
||||
|
||||
# Needed for Python 3.9 which does not define this
|
||||
VERIFY_X509_PARTIAL_CHAIN = getattr(ssl, "VERIFY_X509_PARTIAL_CHAIN", 0x80000)
|
||||
|
||||
# Setting SSLContext.hostname_checks_common_name = False didn't work before CPython
|
||||
# 3.9.3, and 3.10 (but OK on PyPy) or OpenSSL 1.1.1l+
|
||||
if HAS_NEVER_CHECK_COMMON_NAME and not _is_has_never_check_common_name_reliable(
|
||||
OPENSSL_VERSION,
|
||||
OPENSSL_VERSION_NUMBER,
|
||||
sys.implementation.name,
|
||||
sys.version_info,
|
||||
sys.pypy_version_info if sys.implementation.name == "pypy" else None, # type: ignore[attr-defined]
|
||||
): # Defensive: for Python < 3.9.3
|
||||
HAS_NEVER_CHECK_COMMON_NAME = False
|
||||
|
||||
# Need to be careful here in case old TLS versions get
|
||||
# removed in future 'ssl' module implementations.
|
||||
for attr in ("TLSv1", "TLSv1_1", "TLSv1_2"):
|
||||
try:
|
||||
_SSL_VERSION_TO_TLS_VERSION[getattr(ssl, f"PROTOCOL_{attr}")] = getattr(
|
||||
TLSVersion, attr
|
||||
)
|
||||
except AttributeError: # Defensive:
|
||||
continue
|
||||
|
||||
from .ssltransport import SSLTransport # type: ignore[assignment]
|
||||
except ImportError:
|
||||
OP_NO_COMPRESSION = 0x20000 # type: ignore[assignment]
|
||||
OP_NO_TICKET = 0x4000 # type: ignore[assignment]
|
||||
OP_NO_SSLv2 = 0x1000000 # type: ignore[assignment]
|
||||
OP_NO_SSLv3 = 0x2000000 # type: ignore[assignment]
|
||||
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 # type: ignore[assignment]
|
||||
PROTOCOL_TLS_CLIENT = 16 # type: ignore[assignment]
|
||||
VERIFY_X509_PARTIAL_CHAIN = 0x80000
|
||||
VERIFY_X509_STRICT = 0x20 # type: ignore[assignment]
|
||||
|
||||
|
||||
_TYPE_PEER_CERT_RET = typing.Union["_TYPE_PEER_CERT_RET_DICT", bytes, None]
|
||||
|
||||
|
||||
def assert_fingerprint(cert: bytes | None, fingerprint: str) -> None:
|
||||
"""
|
||||
Checks if given fingerprint matches the supplied certificate.
|
||||
|
||||
:param cert:
|
||||
Certificate as bytes object.
|
||||
:param fingerprint:
|
||||
Fingerprint as string of hexdigits, can be interspersed by colons.
|
||||
"""
|
||||
|
||||
if cert is None:
|
||||
raise SSLError("No certificate for the peer.")
|
||||
|
||||
fingerprint = fingerprint.replace(":", "").lower()
|
||||
digest_length = len(fingerprint)
|
||||
if digest_length not in HASHFUNC_MAP:
|
||||
raise SSLError(f"Fingerprint of invalid length: {fingerprint}")
|
||||
hashfunc = HASHFUNC_MAP.get(digest_length)
|
||||
if hashfunc is None:
|
||||
raise SSLError(
|
||||
f"Hash function implementation unavailable for fingerprint length: {digest_length}"
|
||||
)
|
||||
|
||||
# We need encode() here for py32; works on py2 and p33.
|
||||
fingerprint_bytes = unhexlify(fingerprint.encode())
|
||||
|
||||
cert_digest = hashfunc(cert).digest()
|
||||
|
||||
if not hmac.compare_digest(cert_digest, fingerprint_bytes):
|
||||
raise SSLError(
|
||||
f'Fingerprints did not match. Expected "{fingerprint}", got "{cert_digest.hex()}"'
|
||||
)
|
||||
|
||||
|
||||
def resolve_cert_reqs(candidate: None | int | str) -> VerifyMode:
|
||||
"""
|
||||
Resolves the argument to a numeric constant, which can be passed to
|
||||
the wrap_socket function/method from the ssl module.
|
||||
Defaults to :data:`ssl.CERT_REQUIRED`.
|
||||
If given a string it is assumed to be the name of the constant in the
|
||||
:mod:`ssl` module or its abbreviation.
|
||||
(So you can specify `REQUIRED` instead of `CERT_REQUIRED`.
|
||||
If it's neither `None` nor a string we assume it is already the numeric
|
||||
constant which can directly be passed to wrap_socket.
|
||||
"""
|
||||
if candidate is None:
|
||||
return CERT_REQUIRED
|
||||
|
||||
if isinstance(candidate, str):
|
||||
res = getattr(ssl, candidate, None)
|
||||
if res is None:
|
||||
res = getattr(ssl, "CERT_" + candidate)
|
||||
return res # type: ignore[no-any-return]
|
||||
|
||||
return candidate # type: ignore[return-value]
|
||||
|
||||
|
||||
def resolve_ssl_version(candidate: None | int | str) -> int:
|
||||
"""
|
||||
like resolve_cert_reqs
|
||||
"""
|
||||
if candidate is None:
|
||||
return PROTOCOL_TLS
|
||||
|
||||
if isinstance(candidate, str):
|
||||
res = getattr(ssl, candidate, None)
|
||||
if res is None:
|
||||
res = getattr(ssl, "PROTOCOL_" + candidate)
|
||||
return typing.cast(int, res)
|
||||
|
||||
return candidate
|
||||
|
||||
|
||||
def create_urllib3_context(
|
||||
ssl_version: int | None = None,
|
||||
cert_reqs: int | None = None,
|
||||
options: int | None = None,
|
||||
ciphers: str | None = None,
|
||||
ssl_minimum_version: int | None = None,
|
||||
ssl_maximum_version: int | None = None,
|
||||
verify_flags: int | None = None,
|
||||
) -> ssl.SSLContext:
|
||||
"""Creates and configures an :class:`ssl.SSLContext` instance for use with urllib3.
|
||||
|
||||
:param ssl_version:
|
||||
The desired protocol version to use. This will default to
|
||||
PROTOCOL_SSLv23 which will negotiate the highest protocol that both
|
||||
the server and your installation of OpenSSL support.
|
||||
|
||||
This parameter is deprecated instead use 'ssl_minimum_version'.
|
||||
:param ssl_minimum_version:
|
||||
The minimum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value.
|
||||
:param ssl_maximum_version:
|
||||
The maximum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value.
|
||||
Not recommended to set to anything other than 'ssl.TLSVersion.MAXIMUM_SUPPORTED' which is the
|
||||
default value.
|
||||
:param cert_reqs:
|
||||
Whether to require the certificate verification. This defaults to
|
||||
``ssl.CERT_REQUIRED``.
|
||||
:param options:
|
||||
Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``,
|
||||
``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``, and ``ssl.OP_NO_TICKET``.
|
||||
:param ciphers:
|
||||
Which cipher suites to allow the server to select. Defaults to either system configured
|
||||
ciphers if OpenSSL 1.1.1+, otherwise uses a secure default set of ciphers.
|
||||
:param verify_flags:
|
||||
The flags for certificate verification operations. These default to
|
||||
``ssl.VERIFY_X509_PARTIAL_CHAIN`` and ``ssl.VERIFY_X509_STRICT`` for Python 3.13+.
|
||||
:returns:
|
||||
Constructed SSLContext object with specified options
|
||||
:rtype: SSLContext
|
||||
"""
|
||||
if SSLContext is None:
|
||||
raise TypeError("Can't create an SSLContext object without an ssl module")
|
||||
|
||||
# This means 'ssl_version' was specified as an exact value.
|
||||
if ssl_version not in (None, PROTOCOL_TLS, PROTOCOL_TLS_CLIENT):
|
||||
# Disallow setting 'ssl_version' and 'ssl_minimum|maximum_version'
|
||||
# to avoid conflicts.
|
||||
if ssl_minimum_version is not None or ssl_maximum_version is not None:
|
||||
raise ValueError(
|
||||
"Can't specify both 'ssl_version' and either "
|
||||
"'ssl_minimum_version' or 'ssl_maximum_version'"
|
||||
)
|
||||
|
||||
# 'ssl_version' is deprecated and will be removed in the future.
|
||||
else:
|
||||
# Use 'ssl_minimum_version' and 'ssl_maximum_version' instead.
|
||||
ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get(
|
||||
ssl_version, TLSVersion.MINIMUM_SUPPORTED
|
||||
)
|
||||
ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get(
|
||||
ssl_version, TLSVersion.MAXIMUM_SUPPORTED
|
||||
)
|
||||
|
||||
# This warning message is pushing users to use 'ssl_minimum_version'
|
||||
# instead of both min/max. Best practice is to only set the minimum version and
|
||||
# keep the maximum version to be it's default value: 'TLSVersion.MAXIMUM_SUPPORTED'
|
||||
warnings.warn(
|
||||
"'ssl_version' option is deprecated and will be "
|
||||
"removed in urllib3 v2.1.0. Instead use 'ssl_minimum_version'",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# PROTOCOL_TLS is deprecated in Python 3.10 so we always use PROTOCOL_TLS_CLIENT
|
||||
context = SSLContext(PROTOCOL_TLS_CLIENT)
|
||||
|
||||
if ssl_minimum_version is not None:
|
||||
context.minimum_version = ssl_minimum_version
|
||||
else: # Python <3.10 defaults to 'MINIMUM_SUPPORTED' so explicitly set TLSv1.2 here
|
||||
context.minimum_version = TLSVersion.TLSv1_2
|
||||
|
||||
if ssl_maximum_version is not None:
|
||||
context.maximum_version = ssl_maximum_version
|
||||
|
||||
# Unless we're given ciphers defer to either system ciphers in
|
||||
# the case of OpenSSL 1.1.1+ or use our own secure default ciphers.
|
||||
if ciphers:
|
||||
context.set_ciphers(ciphers)
|
||||
|
||||
# Setting the default here, as we may have no ssl module on import
|
||||
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
|
||||
|
||||
if options is None:
|
||||
options = 0
|
||||
# SSLv2 is easily broken and is considered harmful and dangerous
|
||||
options |= OP_NO_SSLv2
|
||||
# SSLv3 has several problems and is now dangerous
|
||||
options |= OP_NO_SSLv3
|
||||
# Disable compression to prevent CRIME attacks for OpenSSL 1.0+
|
||||
# (issue #309)
|
||||
options |= OP_NO_COMPRESSION
|
||||
# TLSv1.2 only. Unless set explicitly, do not request tickets.
|
||||
# This may save some bandwidth on wire, and although the ticket is encrypted,
|
||||
# there is a risk associated with it being on wire,
|
||||
# if the server is not rotating its ticketing keys properly.
|
||||
options |= OP_NO_TICKET
|
||||
|
||||
context.options |= options
|
||||
|
||||
if verify_flags is None:
|
||||
verify_flags = 0
|
||||
# In Python 3.13+ ssl.create_default_context() sets VERIFY_X509_PARTIAL_CHAIN
|
||||
# and VERIFY_X509_STRICT so we do the same
|
||||
if sys.version_info >= (3, 13):
|
||||
verify_flags |= VERIFY_X509_PARTIAL_CHAIN
|
||||
verify_flags |= VERIFY_X509_STRICT
|
||||
|
||||
context.verify_flags |= verify_flags
|
||||
|
||||
# Enable post-handshake authentication for TLS 1.3, see GH #1634. PHA is
|
||||
# necessary for conditional client cert authentication with TLS 1.3.
|
||||
# The attribute is None for OpenSSL <= 1.1.0 or does not exist when using
|
||||
# an SSLContext created by pyOpenSSL.
|
||||
if getattr(context, "post_handshake_auth", None) is not None:
|
||||
context.post_handshake_auth = True
|
||||
|
||||
# The order of the below lines setting verify_mode and check_hostname
|
||||
# matter due to safe-guards SSLContext has to prevent an SSLContext with
|
||||
# check_hostname=True, verify_mode=NONE/OPTIONAL.
|
||||
# We always set 'check_hostname=False' for pyOpenSSL so we rely on our own
|
||||
# 'ssl.match_hostname()' implementation.
|
||||
if cert_reqs == ssl.CERT_REQUIRED and not IS_PYOPENSSL:
|
||||
context.verify_mode = cert_reqs
|
||||
context.check_hostname = True
|
||||
else:
|
||||
context.check_hostname = False
|
||||
context.verify_mode = cert_reqs
|
||||
|
||||
try:
|
||||
context.hostname_checks_common_name = False
|
||||
except AttributeError: # Defensive: for CPython < 3.9.3; for PyPy < 7.3.8
|
||||
pass
|
||||
|
||||
sslkeylogfile = os.environ.get("SSLKEYLOGFILE")
|
||||
if sslkeylogfile:
|
||||
context.keylog_filename = sslkeylogfile
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@typing.overload
|
||||
def ssl_wrap_socket(
|
||||
sock: socket.socket,
|
||||
keyfile: str | None = ...,
|
||||
certfile: str | None = ...,
|
||||
cert_reqs: int | None = ...,
|
||||
ca_certs: str | None = ...,
|
||||
server_hostname: str | None = ...,
|
||||
ssl_version: int | None = ...,
|
||||
ciphers: str | None = ...,
|
||||
ssl_context: ssl.SSLContext | None = ...,
|
||||
ca_cert_dir: str | None = ...,
|
||||
key_password: str | None = ...,
|
||||
ca_cert_data: None | str | bytes = ...,
|
||||
tls_in_tls: typing.Literal[False] = ...,
|
||||
) -> ssl.SSLSocket: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
def ssl_wrap_socket(
|
||||
sock: socket.socket,
|
||||
keyfile: str | None = ...,
|
||||
certfile: str | None = ...,
|
||||
cert_reqs: int | None = ...,
|
||||
ca_certs: str | None = ...,
|
||||
server_hostname: str | None = ...,
|
||||
ssl_version: int | None = ...,
|
||||
ciphers: str | None = ...,
|
||||
ssl_context: ssl.SSLContext | None = ...,
|
||||
ca_cert_dir: str | None = ...,
|
||||
key_password: str | None = ...,
|
||||
ca_cert_data: None | str | bytes = ...,
|
||||
tls_in_tls: bool = ...,
|
||||
) -> ssl.SSLSocket | SSLTransportType: ...
|
||||
|
||||
|
||||
def ssl_wrap_socket(
|
||||
sock: socket.socket,
|
||||
keyfile: str | None = None,
|
||||
certfile: str | None = None,
|
||||
cert_reqs: int | None = None,
|
||||
ca_certs: str | None = None,
|
||||
server_hostname: str | None = None,
|
||||
ssl_version: int | None = None,
|
||||
ciphers: str | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
ca_cert_dir: str | None = None,
|
||||
key_password: str | None = None,
|
||||
ca_cert_data: None | str | bytes = None,
|
||||
tls_in_tls: bool = False,
|
||||
) -> ssl.SSLSocket | SSLTransportType:
|
||||
"""
|
||||
All arguments except for server_hostname, ssl_context, tls_in_tls, ca_cert_data and
|
||||
ca_cert_dir have the same meaning as they do when using
|
||||
:func:`ssl.create_default_context`, :meth:`ssl.SSLContext.load_cert_chain`,
|
||||
:meth:`ssl.SSLContext.set_ciphers` and :meth:`ssl.SSLContext.wrap_socket`.
|
||||
|
||||
:param server_hostname:
|
||||
When SNI is supported, the expected hostname of the certificate
|
||||
:param ssl_context:
|
||||
A pre-made :class:`SSLContext` object. If none is provided, one will
|
||||
be created using :func:`create_urllib3_context`.
|
||||
:param ciphers:
|
||||
A string of ciphers we wish the client to support.
|
||||
:param ca_cert_dir:
|
||||
A directory containing CA certificates in multiple separate files, as
|
||||
supported by OpenSSL's -CApath flag or the capath argument to
|
||||
SSLContext.load_verify_locations().
|
||||
:param key_password:
|
||||
Optional password if the keyfile is encrypted.
|
||||
:param ca_cert_data:
|
||||
Optional string containing CA certificates in PEM format suitable for
|
||||
passing as the cadata parameter to SSLContext.load_verify_locations()
|
||||
:param tls_in_tls:
|
||||
Use SSLTransport to wrap the existing socket.
|
||||
"""
|
||||
context = ssl_context
|
||||
if context is None:
|
||||
# Note: This branch of code and all the variables in it are only used in tests.
|
||||
# We should consider deprecating and removing this code.
|
||||
context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers)
|
||||
|
||||
if ca_certs or ca_cert_dir or ca_cert_data:
|
||||
try:
|
||||
context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data)
|
||||
except OSError as e:
|
||||
raise SSLError(e) from e
|
||||
|
||||
elif ssl_context is None and hasattr(context, "load_default_certs"):
|
||||
# try to load OS default certs; works well on Windows.
|
||||
context.load_default_certs()
|
||||
|
||||
# Attempt to detect if we get the goofy behavior of the
|
||||
# keyfile being encrypted and OpenSSL asking for the
|
||||
# passphrase via the terminal and instead error out.
|
||||
if keyfile and key_password is None and _is_key_file_encrypted(keyfile):
|
||||
raise SSLError("Client private key is encrypted, password is required")
|
||||
|
||||
if certfile:
|
||||
if key_password is None:
|
||||
context.load_cert_chain(certfile, keyfile)
|
||||
else:
|
||||
context.load_cert_chain(certfile, keyfile, key_password)
|
||||
|
||||
context.set_alpn_protocols(ALPN_PROTOCOLS)
|
||||
|
||||
ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls, server_hostname)
|
||||
return ssl_sock
|
||||
|
||||
|
||||
def is_ipaddress(hostname: str | bytes) -> bool:
|
||||
"""Detects whether the hostname given is an IPv4 or IPv6 address.
|
||||
Also detects IPv6 addresses with Zone IDs.
|
||||
|
||||
:param str hostname: Hostname to examine.
|
||||
:return: True if the hostname is an IP address, False otherwise.
|
||||
"""
|
||||
if isinstance(hostname, bytes):
|
||||
# IDN A-label bytes are ASCII compatible.
|
||||
hostname = hostname.decode("ascii")
|
||||
return bool(_IPV4_RE.match(hostname) or _BRACELESS_IPV6_ADDRZ_RE.match(hostname))
|
||||
|
||||
|
||||
def _is_key_file_encrypted(key_file: str) -> bool:
|
||||
"""Detects if a key file is encrypted or not."""
|
||||
with open(key_file) as f:
|
||||
for line in f:
|
||||
# Look for Proc-Type: 4,ENCRYPTED
|
||||
if "ENCRYPTED" in line:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _ssl_wrap_socket_impl(
|
||||
sock: socket.socket,
|
||||
ssl_context: ssl.SSLContext,
|
||||
tls_in_tls: bool,
|
||||
server_hostname: str | None = None,
|
||||
) -> ssl.SSLSocket | SSLTransportType:
|
||||
if tls_in_tls:
|
||||
if not SSLTransport:
|
||||
# Import error, ssl is not available.
|
||||
raise ProxySchemeUnsupported(
|
||||
"TLS in TLS requires support for the 'ssl' module"
|
||||
)
|
||||
|
||||
SSLTransport._validate_ssl_context_for_tls_in_tls(ssl_context)
|
||||
return SSLTransport(sock, ssl_context, server_hostname)
|
||||
|
||||
return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
|
@ -0,0 +1,159 @@
|
||||
"""The match_hostname() function from Python 3.5, essential when using SSL."""
|
||||
|
||||
# Note: This file is under the PSF license as the code comes from the python
|
||||
# stdlib. http://docs.python.org/3/license.html
|
||||
# It is modified to remove commonName support.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import re
|
||||
import typing
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .ssl_ import _TYPE_PEER_CERT_RET_DICT
|
||||
|
||||
__version__ = "3.5.0.1"
|
||||
|
||||
|
||||
class CertificateError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _dnsname_match(
|
||||
dn: typing.Any, hostname: str, max_wildcards: int = 1
|
||||
) -> typing.Match[str] | None | bool:
|
||||
"""Matching according to RFC 6125, section 6.4.3
|
||||
|
||||
http://tools.ietf.org/html/rfc6125#section-6.4.3
|
||||
"""
|
||||
pats = []
|
||||
if not dn:
|
||||
return False
|
||||
|
||||
# Ported from python3-syntax:
|
||||
# leftmost, *remainder = dn.split(r'.')
|
||||
parts = dn.split(r".")
|
||||
leftmost = parts[0]
|
||||
remainder = parts[1:]
|
||||
|
||||
wildcards = leftmost.count("*")
|
||||
if wildcards > max_wildcards:
|
||||
# Issue #17980: avoid denials of service by refusing more
|
||||
# than one wildcard per fragment. A survey of established
|
||||
# policy among SSL implementations showed it to be a
|
||||
# reasonable choice.
|
||||
raise CertificateError(
|
||||
"too many wildcards in certificate DNS name: " + repr(dn)
|
||||
)
|
||||
|
||||
# speed up common case w/o wildcards
|
||||
if not wildcards:
|
||||
return bool(dn.lower() == hostname.lower())
|
||||
|
||||
# RFC 6125, section 6.4.3, subitem 1.
|
||||
# The client SHOULD NOT attempt to match a presented identifier in which
|
||||
# the wildcard character comprises a label other than the left-most label.
|
||||
if leftmost == "*":
|
||||
# When '*' is a fragment by itself, it matches a non-empty dotless
|
||||
# fragment.
|
||||
pats.append("[^.]+")
|
||||
elif leftmost.startswith("xn--") or hostname.startswith("xn--"):
|
||||
# RFC 6125, section 6.4.3, subitem 3.
|
||||
# The client SHOULD NOT attempt to match a presented identifier
|
||||
# where the wildcard character is embedded within an A-label or
|
||||
# U-label of an internationalized domain name.
|
||||
pats.append(re.escape(leftmost))
|
||||
else:
|
||||
# Otherwise, '*' matches any dotless string, e.g. www*
|
||||
pats.append(re.escape(leftmost).replace(r"\*", "[^.]*"))
|
||||
|
||||
# add the remaining fragments, ignore any wildcards
|
||||
for frag in remainder:
|
||||
pats.append(re.escape(frag))
|
||||
|
||||
pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
|
||||
return pat.match(hostname)
|
||||
|
||||
|
||||
def _ipaddress_match(ipname: str, host_ip: IPv4Address | IPv6Address) -> bool:
|
||||
"""Exact matching of IP addresses.
|
||||
|
||||
RFC 9110 section 4.3.5: "A reference identity of IP-ID contains the decoded
|
||||
bytes of the IP address. An IP version 4 address is 4 octets, and an IP
|
||||
version 6 address is 16 octets. [...] A reference identity of type IP-ID
|
||||
matches if the address is identical to an iPAddress value of the
|
||||
subjectAltName extension of the certificate."
|
||||
"""
|
||||
# OpenSSL may add a trailing newline to a subjectAltName's IP address
|
||||
# Divergence from upstream: ipaddress can't handle byte str
|
||||
ip = ipaddress.ip_address(ipname.rstrip())
|
||||
return bool(ip.packed == host_ip.packed)
|
||||
|
||||
|
||||
def match_hostname(
|
||||
cert: _TYPE_PEER_CERT_RET_DICT | None,
|
||||
hostname: str,
|
||||
hostname_checks_common_name: bool = False,
|
||||
) -> None:
|
||||
"""Verify that *cert* (in decoded format as returned by
|
||||
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
|
||||
rules are followed, but IP addresses are not accepted for *hostname*.
|
||||
|
||||
CertificateError is raised on failure. On success, the function
|
||||
returns nothing.
|
||||
"""
|
||||
if not cert:
|
||||
raise ValueError(
|
||||
"empty or no certificate, match_hostname needs a "
|
||||
"SSL socket or SSL context with either "
|
||||
"CERT_OPTIONAL or CERT_REQUIRED"
|
||||
)
|
||||
try:
|
||||
# Divergence from upstream: ipaddress can't handle byte str
|
||||
#
|
||||
# The ipaddress module shipped with Python < 3.9 does not support
|
||||
# scoped IPv6 addresses so we unconditionally strip the Zone IDs for
|
||||
# now. Once we drop support for Python 3.9 we can remove this branch.
|
||||
if "%" in hostname:
|
||||
host_ip = ipaddress.ip_address(hostname[: hostname.rfind("%")])
|
||||
else:
|
||||
host_ip = ipaddress.ip_address(hostname)
|
||||
|
||||
except ValueError:
|
||||
# Not an IP address (common case)
|
||||
host_ip = None
|
||||
dnsnames = []
|
||||
san: tuple[tuple[str, str], ...] = cert.get("subjectAltName", ())
|
||||
key: str
|
||||
value: str
|
||||
for key, value in san:
|
||||
if key == "DNS":
|
||||
if host_ip is None and _dnsname_match(value, hostname):
|
||||
return
|
||||
dnsnames.append(value)
|
||||
elif key == "IP Address":
|
||||
if host_ip is not None and _ipaddress_match(value, host_ip):
|
||||
return
|
||||
dnsnames.append(value)
|
||||
|
||||
# We only check 'commonName' if it's enabled and we're not verifying
|
||||
# an IP address. IP addresses aren't valid within 'commonName'.
|
||||
if hostname_checks_common_name and host_ip is None and not dnsnames:
|
||||
for sub in cert.get("subject", ()):
|
||||
for key, value in sub:
|
||||
if key == "commonName":
|
||||
if _dnsname_match(value, hostname):
|
||||
return
|
||||
dnsnames.append(value) # Defensive: for Python < 3.9.3
|
||||
|
||||
if len(dnsnames) > 1:
|
||||
raise CertificateError(
|
||||
"hostname %r "
|
||||
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
|
||||
)
|
||||
elif len(dnsnames) == 1:
|
||||
raise CertificateError(f"hostname {hostname!r} doesn't match {dnsnames[0]!r}")
|
||||
else:
|
||||
raise CertificateError("no appropriate subjectAltName fields were found")
|
271
venv/lib/python3.11/site-packages/urllib3/util/ssltransport.py
Normal file
271
venv/lib/python3.11/site-packages/urllib3/util/ssltransport.py
Normal file
@ -0,0 +1,271 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import socket
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from ..exceptions import ProxySchemeUnsupported
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
|
||||
|
||||
|
||||
_WriteBuffer = typing.Union[bytearray, memoryview]
|
||||
_ReturnValue = typing.TypeVar("_ReturnValue")
|
||||
|
||||
SSL_BLOCKSIZE = 16384
|
||||
|
||||
|
||||
class SSLTransport:
|
||||
"""
|
||||
The SSLTransport wraps an existing socket and establishes an SSL connection.
|
||||
|
||||
Contrary to Python's implementation of SSLSocket, it allows you to chain
|
||||
multiple TLS connections together. It's particularly useful if you need to
|
||||
implement TLS within TLS.
|
||||
|
||||
The class supports most of the socket API operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
|
||||
"""
|
||||
Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
|
||||
for TLS in TLS.
|
||||
|
||||
The only requirement is that the ssl_context provides the 'wrap_bio'
|
||||
methods.
|
||||
"""
|
||||
|
||||
if not hasattr(ssl_context, "wrap_bio"):
|
||||
raise ProxySchemeUnsupported(
|
||||
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
|
||||
"available on non-native SSLContext"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket: socket.socket,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: str | None = None,
|
||||
suppress_ragged_eofs: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Create an SSLTransport around socket using the provided ssl_context.
|
||||
"""
|
||||
self.incoming = ssl.MemoryBIO()
|
||||
self.outgoing = ssl.MemoryBIO()
|
||||
|
||||
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||||
self.socket = socket
|
||||
|
||||
self.sslobj = ssl_context.wrap_bio(
|
||||
self.incoming, self.outgoing, server_hostname=server_hostname
|
||||
)
|
||||
|
||||
# Perform initial handshake.
|
||||
self._ssl_io_loop(self.sslobj.do_handshake)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(self, *_: typing.Any) -> None:
|
||||
self.close()
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self.socket.fileno()
|
||||
|
||||
def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
|
||||
return self._wrap_ssl_read(len, buffer)
|
||||
|
||||
def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to recv")
|
||||
return self._wrap_ssl_read(buflen)
|
||||
|
||||
def recv_into(
|
||||
self,
|
||||
buffer: _WriteBuffer,
|
||||
nbytes: int | None = None,
|
||||
flags: int = 0,
|
||||
) -> None | int | bytes:
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to recv_into")
|
||||
if nbytes is None:
|
||||
nbytes = len(buffer)
|
||||
return self.read(nbytes, buffer)
|
||||
|
||||
def sendall(self, data: bytes, flags: int = 0) -> None:
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to sendall")
|
||||
count = 0
|
||||
with memoryview(data) as view, view.cast("B") as byte_view:
|
||||
amount = len(byte_view)
|
||||
while count < amount:
|
||||
v = self.send(byte_view[count:])
|
||||
count += v
|
||||
|
||||
def send(self, data: bytes, flags: int = 0) -> int:
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to send")
|
||||
return self._ssl_io_loop(self.sslobj.write, data)
|
||||
|
||||
def makefile(
|
||||
self,
|
||||
mode: str,
|
||||
buffering: int | None = None,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
errors: str | None = None,
|
||||
newline: str | None = None,
|
||||
) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
|
||||
"""
|
||||
Python's httpclient uses makefile and buffered io when reading HTTP
|
||||
messages and we need to support it.
|
||||
|
||||
This is unfortunately a copy and paste of socket.py makefile with small
|
||||
changes to point to the socket directly.
|
||||
"""
|
||||
if not set(mode) <= {"r", "w", "b"}:
|
||||
raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
|
||||
|
||||
writing = "w" in mode
|
||||
reading = "r" in mode or not writing
|
||||
assert reading or writing
|
||||
binary = "b" in mode
|
||||
rawmode = ""
|
||||
if reading:
|
||||
rawmode += "r"
|
||||
if writing:
|
||||
rawmode += "w"
|
||||
raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
|
||||
self.socket._io_refs += 1 # type: ignore[attr-defined]
|
||||
if buffering is None:
|
||||
buffering = -1
|
||||
if buffering < 0:
|
||||
buffering = io.DEFAULT_BUFFER_SIZE
|
||||
if buffering == 0:
|
||||
if not binary:
|
||||
raise ValueError("unbuffered streams must be binary")
|
||||
return raw
|
||||
buffer: typing.BinaryIO
|
||||
if reading and writing:
|
||||
buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
|
||||
elif reading:
|
||||
buffer = io.BufferedReader(raw, buffering)
|
||||
else:
|
||||
assert writing
|
||||
buffer = io.BufferedWriter(raw, buffering)
|
||||
if binary:
|
||||
return buffer
|
||||
text = io.TextIOWrapper(buffer, encoding, errors, newline)
|
||||
text.mode = mode # type: ignore[misc]
|
||||
return text
|
||||
|
||||
def unwrap(self) -> None:
|
||||
self._ssl_io_loop(self.sslobj.unwrap)
|
||||
|
||||
def close(self) -> None:
|
||||
self.socket.close()
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(
|
||||
self, binary_form: typing.Literal[False] = ...
|
||||
) -> _TYPE_PEER_CERT_RET_DICT | None: ...
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ...
|
||||
|
||||
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
|
||||
return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
|
||||
|
||||
def version(self) -> str | None:
|
||||
return self.sslobj.version()
|
||||
|
||||
def cipher(self) -> tuple[str, str, int] | None:
|
||||
return self.sslobj.cipher()
|
||||
|
||||
def selected_alpn_protocol(self) -> str | None:
|
||||
return self.sslobj.selected_alpn_protocol()
|
||||
|
||||
def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
|
||||
return self.sslobj.shared_ciphers()
|
||||
|
||||
def compression(self) -> str | None:
|
||||
return self.sslobj.compression()
|
||||
|
||||
def settimeout(self, value: float | None) -> None:
|
||||
self.socket.settimeout(value)
|
||||
|
||||
def gettimeout(self) -> float | None:
|
||||
return self.socket.gettimeout()
|
||||
|
||||
def _decref_socketios(self) -> None:
|
||||
self.socket._decref_socketios() # type: ignore[attr-defined]
|
||||
|
||||
def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
|
||||
try:
|
||||
return self._ssl_io_loop(self.sslobj.read, len, buffer)
|
||||
except ssl.SSLError as e:
|
||||
if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
||||
return 0 # eof, return 0.
|
||||
else:
|
||||
raise
|
||||
|
||||
# func is sslobj.do_handshake or sslobj.unwrap
|
||||
@typing.overload
|
||||
def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: ...
|
||||
|
||||
# func is sslobj.write, arg1 is data
|
||||
@typing.overload
|
||||
def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: ...
|
||||
|
||||
# func is sslobj.read, arg1 is len, arg2 is buffer
|
||||
@typing.overload
|
||||
def _ssl_io_loop(
|
||||
self,
|
||||
func: typing.Callable[[int, bytearray | None], bytes],
|
||||
arg1: int,
|
||||
arg2: bytearray | None,
|
||||
) -> bytes: ...
|
||||
|
||||
def _ssl_io_loop(
|
||||
self,
|
||||
func: typing.Callable[..., _ReturnValue],
|
||||
arg1: None | bytes | int = None,
|
||||
arg2: bytearray | None = None,
|
||||
) -> _ReturnValue:
|
||||
"""Performs an I/O loop between incoming/outgoing and the socket."""
|
||||
should_loop = True
|
||||
ret = None
|
||||
|
||||
while should_loop:
|
||||
errno = None
|
||||
try:
|
||||
if arg1 is None and arg2 is None:
|
||||
ret = func()
|
||||
elif arg2 is None:
|
||||
ret = func(arg1)
|
||||
else:
|
||||
ret = func(arg1, arg2)
|
||||
except ssl.SSLError as e:
|
||||
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
|
||||
# WANT_READ, and WANT_WRITE are expected, others are not.
|
||||
raise e
|
||||
errno = e.errno
|
||||
|
||||
buf = self.outgoing.read()
|
||||
self.socket.sendall(buf)
|
||||
|
||||
if errno is None:
|
||||
should_loop = False
|
||||
elif errno == ssl.SSL_ERROR_WANT_READ:
|
||||
buf = self.socket.recv(SSL_BLOCKSIZE)
|
||||
if buf:
|
||||
self.incoming.write(buf)
|
||||
else:
|
||||
self.incoming.write_eof()
|
||||
return typing.cast(_ReturnValue, ret)
|
275
venv/lib/python3.11/site-packages/urllib3/util/timeout.py
Normal file
275
venv/lib/python3.11/site-packages/urllib3/util/timeout.py
Normal file
@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import typing
|
||||
from enum import Enum
|
||||
from socket import getdefaulttimeout
|
||||
|
||||
from ..exceptions import TimeoutStateError
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing import Final
|
||||
|
||||
|
||||
class _TYPE_DEFAULT(Enum):
|
||||
# This value should never be passed to socket.settimeout() so for safety we use a -1.
|
||||
# socket.settimout() raises a ValueError for negative values.
|
||||
token = -1
|
||||
|
||||
|
||||
_DEFAULT_TIMEOUT: Final[_TYPE_DEFAULT] = _TYPE_DEFAULT.token
|
||||
|
||||
_TYPE_TIMEOUT = typing.Optional[typing.Union[float, _TYPE_DEFAULT]]
|
||||
|
||||
|
||||
class Timeout:
|
||||
"""Timeout configuration.
|
||||
|
||||
Timeouts can be defined as a default for a pool:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import urllib3
|
||||
|
||||
timeout = urllib3.util.Timeout(connect=2.0, read=7.0)
|
||||
|
||||
http = urllib3.PoolManager(timeout=timeout)
|
||||
|
||||
resp = http.request("GET", "https://example.com/")
|
||||
|
||||
print(resp.status)
|
||||
|
||||
Or per-request (which overrides the default for the pool):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
response = http.request("GET", "https://example.com/", timeout=Timeout(10))
|
||||
|
||||
Timeouts can be disabled by setting all the parameters to ``None``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
no_timeout = Timeout(connect=None, read=None)
|
||||
response = http.request("GET", "https://example.com/", timeout=no_timeout)
|
||||
|
||||
|
||||
:param total:
|
||||
This combines the connect and read timeouts into one; the read timeout
|
||||
will be set to the time leftover from the connect attempt. In the
|
||||
event that both a connect timeout and a total are specified, or a read
|
||||
timeout and a total are specified, the shorter timeout will be applied.
|
||||
|
||||
Defaults to None.
|
||||
|
||||
:type total: int, float, or None
|
||||
|
||||
:param connect:
|
||||
The maximum amount of time (in seconds) to wait for a connection
|
||||
attempt to a server to succeed. Omitting the parameter will default the
|
||||
connect timeout to the system default, probably `the global default
|
||||
timeout in socket.py
|
||||
<http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
|
||||
None will set an infinite timeout for connection attempts.
|
||||
|
||||
:type connect: int, float, or None
|
||||
|
||||
:param read:
|
||||
The maximum amount of time (in seconds) to wait between consecutive
|
||||
read operations for a response from the server. Omitting the parameter
|
||||
will default the read timeout to the system default, probably `the
|
||||
global default timeout in socket.py
|
||||
<http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
|
||||
None will set an infinite timeout.
|
||||
|
||||
:type read: int, float, or None
|
||||
|
||||
.. note::
|
||||
|
||||
Many factors can affect the total amount of time for urllib3 to return
|
||||
an HTTP response.
|
||||
|
||||
For example, Python's DNS resolver does not obey the timeout specified
|
||||
on the socket. Other factors that can affect total request time include
|
||||
high CPU load, high swap, the program running at a low priority level,
|
||||
or other behaviors.
|
||||
|
||||
In addition, the read and total timeouts only measure the time between
|
||||
read operations on the socket connecting the client and the server,
|
||||
not the total amount of time for the request to return a complete
|
||||
response. For most requests, the timeout is raised because the server
|
||||
has not sent the first byte in the specified time. This is not always
|
||||
the case; if a server streams one byte every fifteen seconds, a timeout
|
||||
of 20 seconds will not trigger, even though the request will take
|
||||
several minutes to complete.
|
||||
"""
|
||||
|
||||
#: A sentinel object representing the default timeout value
|
||||
DEFAULT_TIMEOUT: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
total: _TYPE_TIMEOUT = None,
|
||||
connect: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
|
||||
read: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
|
||||
) -> None:
|
||||
self._connect = self._validate_timeout(connect, "connect")
|
||||
self._read = self._validate_timeout(read, "read")
|
||||
self.total = self._validate_timeout(total, "total")
|
||||
self._start_connect: float | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}(connect={self._connect!r}, read={self._read!r}, total={self.total!r})"
|
||||
|
||||
# __str__ provided for backwards compatibility
|
||||
__str__ = __repr__
|
||||
|
||||
@staticmethod
|
||||
def resolve_default_timeout(timeout: _TYPE_TIMEOUT) -> float | None:
|
||||
return getdefaulttimeout() if timeout is _DEFAULT_TIMEOUT else timeout
|
||||
|
||||
@classmethod
|
||||
def _validate_timeout(cls, value: _TYPE_TIMEOUT, name: str) -> _TYPE_TIMEOUT:
|
||||
"""Check that a timeout attribute is valid.
|
||||
|
||||
:param value: The timeout value to validate
|
||||
:param name: The name of the timeout attribute to validate. This is
|
||||
used to specify in error messages.
|
||||
:return: The validated and casted version of the given value.
|
||||
:raises ValueError: If it is a numeric value less than or equal to
|
||||
zero, or the type is not an integer, float, or None.
|
||||
"""
|
||||
if value is None or value is _DEFAULT_TIMEOUT:
|
||||
return value
|
||||
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(
|
||||
"Timeout cannot be a boolean value. It must "
|
||||
"be an int, float or None."
|
||||
)
|
||||
try:
|
||||
float(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"Timeout value %s was %s, but it must be an "
|
||||
"int, float or None." % (name, value)
|
||||
) from None
|
||||
|
||||
try:
|
||||
if value <= 0:
|
||||
raise ValueError(
|
||||
"Attempted to set %s timeout to %s, but the "
|
||||
"timeout cannot be set to a value less "
|
||||
"than or equal to 0." % (name, value)
|
||||
)
|
||||
except TypeError:
|
||||
raise ValueError(
|
||||
"Timeout value %s was %s, but it must be an "
|
||||
"int, float or None." % (name, value)
|
||||
) from None
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, timeout: _TYPE_TIMEOUT) -> Timeout:
|
||||
"""Create a new Timeout from a legacy timeout value.
|
||||
|
||||
The timeout value used by httplib.py sets the same timeout on the
|
||||
connect(), and recv() socket requests. This creates a :class:`Timeout`
|
||||
object that sets the individual timeouts to the ``timeout`` value
|
||||
passed to this function.
|
||||
|
||||
:param timeout: The legacy timeout value.
|
||||
:type timeout: integer, float, :attr:`urllib3.util.Timeout.DEFAULT_TIMEOUT`, or None
|
||||
:return: Timeout object
|
||||
:rtype: :class:`Timeout`
|
||||
"""
|
||||
return Timeout(read=timeout, connect=timeout)
|
||||
|
||||
def clone(self) -> Timeout:
|
||||
"""Create a copy of the timeout object
|
||||
|
||||
Timeout properties are stored per-pool but each request needs a fresh
|
||||
Timeout object to ensure each one has its own start/stop configured.
|
||||
|
||||
:return: a copy of the timeout object
|
||||
:rtype: :class:`Timeout`
|
||||
"""
|
||||
# We can't use copy.deepcopy because that will also create a new object
|
||||
# for _GLOBAL_DEFAULT_TIMEOUT, which socket.py uses as a sentinel to
|
||||
# detect the user default.
|
||||
return Timeout(connect=self._connect, read=self._read, total=self.total)
|
||||
|
||||
def start_connect(self) -> float:
|
||||
"""Start the timeout clock, used during a connect() attempt
|
||||
|
||||
:raises urllib3.exceptions.TimeoutStateError: if you attempt
|
||||
to start a timer that has been started already.
|
||||
"""
|
||||
if self._start_connect is not None:
|
||||
raise TimeoutStateError("Timeout timer has already been started.")
|
||||
self._start_connect = time.monotonic()
|
||||
return self._start_connect
|
||||
|
||||
def get_connect_duration(self) -> float:
|
||||
"""Gets the time elapsed since the call to :meth:`start_connect`.
|
||||
|
||||
:return: Elapsed time in seconds.
|
||||
:rtype: float
|
||||
:raises urllib3.exceptions.TimeoutStateError: if you attempt
|
||||
to get duration for a timer that hasn't been started.
|
||||
"""
|
||||
if self._start_connect is None:
|
||||
raise TimeoutStateError(
|
||||
"Can't get connect duration for timer that has not started."
|
||||
)
|
||||
return time.monotonic() - self._start_connect
|
||||
|
||||
@property
|
||||
def connect_timeout(self) -> _TYPE_TIMEOUT:
|
||||
"""Get the value to use when setting a connection timeout.
|
||||
|
||||
This will be a positive float or integer, the value None
|
||||
(never timeout), or the default system timeout.
|
||||
|
||||
:return: Connect timeout.
|
||||
:rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None
|
||||
"""
|
||||
if self.total is None:
|
||||
return self._connect
|
||||
|
||||
if self._connect is None or self._connect is _DEFAULT_TIMEOUT:
|
||||
return self.total
|
||||
|
||||
return min(self._connect, self.total) # type: ignore[type-var]
|
||||
|
||||
@property
|
||||
def read_timeout(self) -> float | None:
|
||||
"""Get the value for the read timeout.
|
||||
|
||||
This assumes some time has elapsed in the connection timeout and
|
||||
computes the read timeout appropriately.
|
||||
|
||||
If self.total is set, the read timeout is dependent on the amount of
|
||||
time taken by the connect timeout. If the connection time has not been
|
||||
established, a :exc:`~urllib3.exceptions.TimeoutStateError` will be
|
||||
raised.
|
||||
|
||||
:return: Value to use for the read timeout.
|
||||
:rtype: int, float or None
|
||||
:raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect`
|
||||
has not yet been called on this object.
|
||||
"""
|
||||
if (
|
||||
self.total is not None
|
||||
and self.total is not _DEFAULT_TIMEOUT
|
||||
and self._read is not None
|
||||
and self._read is not _DEFAULT_TIMEOUT
|
||||
):
|
||||
# In case the connect timeout has not yet been established.
|
||||
if self._start_connect is None:
|
||||
return self._read
|
||||
return max(0, min(self.total - self.get_connect_duration(), self._read))
|
||||
elif self.total is not None and self.total is not _DEFAULT_TIMEOUT:
|
||||
return max(0, self.total - self.get_connect_duration())
|
||||
else:
|
||||
return self.resolve_default_timeout(self._read)
|
469
venv/lib/python3.11/site-packages/urllib3/util/url.py
Normal file
469
venv/lib/python3.11/site-packages/urllib3/util/url.py
Normal file
@ -0,0 +1,469 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing
|
||||
|
||||
from ..exceptions import LocationParseError
|
||||
from .util import to_str
|
||||
|
||||
# We only want to normalize urls with an HTTP(S) scheme.
|
||||
# urllib3 infers URLs without a scheme (None) to be http.
|
||||
_NORMALIZABLE_SCHEMES = ("http", "https", None)
|
||||
|
||||
# Almost all of these patterns were derived from the
|
||||
# 'rfc3986' module: https://github.com/python-hyper/rfc3986
|
||||
_PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}")
|
||||
_SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)")
|
||||
_URI_RE = re.compile(
|
||||
r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?"
|
||||
r"(?://([^\\/?#]*))?"
|
||||
r"([^?#]*)"
|
||||
r"(?:\?([^#]*))?"
|
||||
r"(?:#(.*))?$",
|
||||
re.UNICODE | re.DOTALL,
|
||||
)
|
||||
|
||||
_IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
|
||||
_HEX_PAT = "[0-9A-Fa-f]{1,4}"
|
||||
_LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=_HEX_PAT, ipv4=_IPV4_PAT)
|
||||
_subs = {"hex": _HEX_PAT, "ls32": _LS32_PAT}
|
||||
_variations = [
|
||||
# 6( h16 ":" ) ls32
|
||||
"(?:%(hex)s:){6}%(ls32)s",
|
||||
# "::" 5( h16 ":" ) ls32
|
||||
"::(?:%(hex)s:){5}%(ls32)s",
|
||||
# [ h16 ] "::" 4( h16 ":" ) ls32
|
||||
"(?:%(hex)s)?::(?:%(hex)s:){4}%(ls32)s",
|
||||
# [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32
|
||||
"(?:(?:%(hex)s:)?%(hex)s)?::(?:%(hex)s:){3}%(ls32)s",
|
||||
# [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32
|
||||
"(?:(?:%(hex)s:){0,2}%(hex)s)?::(?:%(hex)s:){2}%(ls32)s",
|
||||
# [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32
|
||||
"(?:(?:%(hex)s:){0,3}%(hex)s)?::%(hex)s:%(ls32)s",
|
||||
# [ *4( h16 ":" ) h16 ] "::" ls32
|
||||
"(?:(?:%(hex)s:){0,4}%(hex)s)?::%(ls32)s",
|
||||
# [ *5( h16 ":" ) h16 ] "::" h16
|
||||
"(?:(?:%(hex)s:){0,5}%(hex)s)?::%(hex)s",
|
||||
# [ *6( h16 ":" ) h16 ] "::"
|
||||
"(?:(?:%(hex)s:){0,6}%(hex)s)?::",
|
||||
]
|
||||
|
||||
_UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._\-~"
|
||||
_IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
|
||||
_ZONE_ID_PAT = "(?:%25|%)(?:[" + _UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
|
||||
_IPV6_ADDRZ_PAT = r"\[" + _IPV6_PAT + r"(?:" + _ZONE_ID_PAT + r")?\]"
|
||||
_REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*"
|
||||
_TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$")
|
||||
|
||||
_IPV4_RE = re.compile("^" + _IPV4_PAT + "$")
|
||||
_IPV6_RE = re.compile("^" + _IPV6_PAT + "$")
|
||||
_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT + "$")
|
||||
_BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT[2:-2] + "$")
|
||||
_ZONE_ID_RE = re.compile("(" + _ZONE_ID_PAT + r")\]$")
|
||||
|
||||
_HOST_PORT_PAT = ("^(%s|%s|%s)(?::0*?(|0|[1-9][0-9]{0,4}))?$") % (
|
||||
_REG_NAME_PAT,
|
||||
_IPV4_PAT,
|
||||
_IPV6_ADDRZ_PAT,
|
||||
)
|
||||
_HOST_PORT_RE = re.compile(_HOST_PORT_PAT, re.UNICODE | re.DOTALL)
|
||||
|
||||
_UNRESERVED_CHARS = set(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~"
|
||||
)
|
||||
_SUB_DELIM_CHARS = set("!$&'()*+,;=")
|
||||
_USERINFO_CHARS = _UNRESERVED_CHARS | _SUB_DELIM_CHARS | {":"}
|
||||
_PATH_CHARS = _USERINFO_CHARS | {"@", "/"}
|
||||
_QUERY_CHARS = _FRAGMENT_CHARS = _PATH_CHARS | {"?"}
|
||||
|
||||
|
||||
class Url(
|
||||
typing.NamedTuple(
|
||||
"Url",
|
||||
[
|
||||
("scheme", typing.Optional[str]),
|
||||
("auth", typing.Optional[str]),
|
||||
("host", typing.Optional[str]),
|
||||
("port", typing.Optional[int]),
|
||||
("path", typing.Optional[str]),
|
||||
("query", typing.Optional[str]),
|
||||
("fragment", typing.Optional[str]),
|
||||
],
|
||||
)
|
||||
):
|
||||
"""
|
||||
Data structure for representing an HTTP URL. Used as a return value for
|
||||
:func:`parse_url`. Both the scheme and host are normalized as they are
|
||||
both case-insensitive according to RFC 3986.
|
||||
"""
|
||||
|
||||
def __new__( # type: ignore[no-untyped-def]
|
||||
cls,
|
||||
scheme: str | None = None,
|
||||
auth: str | None = None,
|
||||
host: str | None = None,
|
||||
port: int | None = None,
|
||||
path: str | None = None,
|
||||
query: str | None = None,
|
||||
fragment: str | None = None,
|
||||
):
|
||||
if path and not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if scheme is not None:
|
||||
scheme = scheme.lower()
|
||||
return super().__new__(cls, scheme, auth, host, port, path, query, fragment)
|
||||
|
||||
@property
|
||||
def hostname(self) -> str | None:
|
||||
"""For backwards-compatibility with urlparse. We're nice like that."""
|
||||
return self.host
|
||||
|
||||
@property
|
||||
def request_uri(self) -> str:
|
||||
"""Absolute path including the query string."""
|
||||
uri = self.path or "/"
|
||||
|
||||
if self.query is not None:
|
||||
uri += "?" + self.query
|
||||
|
||||
return uri
|
||||
|
||||
@property
|
||||
def authority(self) -> str | None:
|
||||
"""
|
||||
Authority component as defined in RFC 3986 3.2.
|
||||
This includes userinfo (auth), host and port.
|
||||
|
||||
i.e.
|
||||
userinfo@host:port
|
||||
"""
|
||||
userinfo = self.auth
|
||||
netloc = self.netloc
|
||||
if netloc is None or userinfo is None:
|
||||
return netloc
|
||||
else:
|
||||
return f"{userinfo}@{netloc}"
|
||||
|
||||
@property
|
||||
def netloc(self) -> str | None:
|
||||
"""
|
||||
Network location including host and port.
|
||||
|
||||
If you need the equivalent of urllib.parse's ``netloc``,
|
||||
use the ``authority`` property instead.
|
||||
"""
|
||||
if self.host is None:
|
||||
return None
|
||||
if self.port:
|
||||
return f"{self.host}:{self.port}"
|
||||
return self.host
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
"""
|
||||
Convert self into a url
|
||||
|
||||
This function should more or less round-trip with :func:`.parse_url`. The
|
||||
returned url may not be exactly the same as the url inputted to
|
||||
:func:`.parse_url`, but it should be equivalent by the RFC (e.g., urls
|
||||
with a blank port will have : removed).
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import urllib3
|
||||
|
||||
U = urllib3.util.parse_url("https://google.com/mail/")
|
||||
|
||||
print(U.url)
|
||||
# "https://google.com/mail/"
|
||||
|
||||
print( urllib3.util.Url("https", "username:password",
|
||||
"host.com", 80, "/path", "query", "fragment"
|
||||
).url
|
||||
)
|
||||
# "https://username:password@host.com:80/path?query#fragment"
|
||||
"""
|
||||
scheme, auth, host, port, path, query, fragment = self
|
||||
url = ""
|
||||
|
||||
# We use "is not None" we want things to happen with empty strings (or 0 port)
|
||||
if scheme is not None:
|
||||
url += scheme + "://"
|
||||
if auth is not None:
|
||||
url += auth + "@"
|
||||
if host is not None:
|
||||
url += host
|
||||
if port is not None:
|
||||
url += ":" + str(port)
|
||||
if path is not None:
|
||||
url += path
|
||||
if query is not None:
|
||||
url += "?" + query
|
||||
if fragment is not None:
|
||||
url += "#" + fragment
|
||||
|
||||
return url
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.url
|
||||
|
||||
|
||||
@typing.overload
|
||||
def _encode_invalid_chars(
|
||||
component: str, allowed_chars: typing.Container[str]
|
||||
) -> str: # Abstract
|
||||
...
|
||||
|
||||
|
||||
@typing.overload
|
||||
def _encode_invalid_chars(
|
||||
component: None, allowed_chars: typing.Container[str]
|
||||
) -> None: # Abstract
|
||||
...
|
||||
|
||||
|
||||
def _encode_invalid_chars(
|
||||
component: str | None, allowed_chars: typing.Container[str]
|
||||
) -> str | None:
|
||||
"""Percent-encodes a URI component without reapplying
|
||||
onto an already percent-encoded component.
|
||||
"""
|
||||
if component is None:
|
||||
return component
|
||||
|
||||
component = to_str(component)
|
||||
|
||||
# Normalize existing percent-encoded bytes.
|
||||
# Try to see if the component we're encoding is already percent-encoded
|
||||
# so we can skip all '%' characters but still encode all others.
|
||||
component, percent_encodings = _PERCENT_RE.subn(
|
||||
lambda match: match.group(0).upper(), component
|
||||
)
|
||||
|
||||
uri_bytes = component.encode("utf-8", "surrogatepass")
|
||||
is_percent_encoded = percent_encodings == uri_bytes.count(b"%")
|
||||
encoded_component = bytearray()
|
||||
|
||||
for i in range(0, len(uri_bytes)):
|
||||
# Will return a single character bytestring
|
||||
byte = uri_bytes[i : i + 1]
|
||||
byte_ord = ord(byte)
|
||||
if (is_percent_encoded and byte == b"%") or (
|
||||
byte_ord < 128 and byte.decode() in allowed_chars
|
||||
):
|
||||
encoded_component += byte
|
||||
continue
|
||||
encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper()))
|
||||
|
||||
return encoded_component.decode()
|
||||
|
||||
|
||||
def _remove_path_dot_segments(path: str) -> str:
|
||||
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
|
||||
segments = path.split("/") # Turn the path into a list of segments
|
||||
output = [] # Initialize the variable to use to store output
|
||||
|
||||
for segment in segments:
|
||||
# '.' is the current directory, so ignore it, it is superfluous
|
||||
if segment == ".":
|
||||
continue
|
||||
# Anything other than '..', should be appended to the output
|
||||
if segment != "..":
|
||||
output.append(segment)
|
||||
# In this case segment == '..', if we can, we should pop the last
|
||||
# element
|
||||
elif output:
|
||||
output.pop()
|
||||
|
||||
# If the path starts with '/' and the output is empty or the first string
|
||||
# is non-empty
|
||||
if path.startswith("/") and (not output or output[0]):
|
||||
output.insert(0, "")
|
||||
|
||||
# If the path starts with '/.' or '/..' ensure we add one more empty
|
||||
# string to add a trailing '/'
|
||||
if path.endswith(("/.", "/..")):
|
||||
output.append("")
|
||||
|
||||
return "/".join(output)
|
||||
|
||||
|
||||
@typing.overload
|
||||
def _normalize_host(host: None, scheme: str | None) -> None: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
def _normalize_host(host: str, scheme: str | None) -> str: ...
|
||||
|
||||
|
||||
def _normalize_host(host: str | None, scheme: str | None) -> str | None:
|
||||
if host:
|
||||
if scheme in _NORMALIZABLE_SCHEMES:
|
||||
is_ipv6 = _IPV6_ADDRZ_RE.match(host)
|
||||
if is_ipv6:
|
||||
# IPv6 hosts of the form 'a::b%zone' are encoded in a URL as
|
||||
# such per RFC 6874: 'a::b%25zone'. Unquote the ZoneID
|
||||
# separator as necessary to return a valid RFC 4007 scoped IP.
|
||||
match = _ZONE_ID_RE.search(host)
|
||||
if match:
|
||||
start, end = match.span(1)
|
||||
zone_id = host[start:end]
|
||||
|
||||
if zone_id.startswith("%25") and zone_id != "%25":
|
||||
zone_id = zone_id[3:]
|
||||
else:
|
||||
zone_id = zone_id[1:]
|
||||
zone_id = _encode_invalid_chars(zone_id, _UNRESERVED_CHARS)
|
||||
return f"{host[:start].lower()}%{zone_id}{host[end:]}"
|
||||
else:
|
||||
return host.lower()
|
||||
elif not _IPV4_RE.match(host):
|
||||
return to_str(
|
||||
b".".join([_idna_encode(label) for label in host.split(".")]),
|
||||
"ascii",
|
||||
)
|
||||
return host
|
||||
|
||||
|
||||
def _idna_encode(name: str) -> bytes:
|
||||
if not name.isascii():
|
||||
try:
|
||||
import idna
|
||||
except ImportError:
|
||||
raise LocationParseError(
|
||||
"Unable to parse URL without the 'idna' module"
|
||||
) from None
|
||||
|
||||
try:
|
||||
return idna.encode(name.lower(), strict=True, std3_rules=True)
|
||||
except idna.IDNAError:
|
||||
raise LocationParseError(
|
||||
f"Name '{name}' is not a valid IDNA label"
|
||||
) from None
|
||||
|
||||
return name.lower().encode("ascii")
|
||||
|
||||
|
||||
def _encode_target(target: str) -> str:
|
||||
"""Percent-encodes a request target so that there are no invalid characters
|
||||
|
||||
Pre-condition for this function is that 'target' must start with '/'.
|
||||
If that is the case then _TARGET_RE will always produce a match.
|
||||
"""
|
||||
match = _TARGET_RE.match(target)
|
||||
if not match: # Defensive:
|
||||
raise LocationParseError(f"{target!r} is not a valid request URI")
|
||||
|
||||
path, query = match.groups()
|
||||
encoded_target = _encode_invalid_chars(path, _PATH_CHARS)
|
||||
if query is not None:
|
||||
query = _encode_invalid_chars(query, _QUERY_CHARS)
|
||||
encoded_target += "?" + query
|
||||
return encoded_target
|
||||
|
||||
|
||||
def parse_url(url: str) -> Url:
|
||||
"""
|
||||
Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is
|
||||
performed to parse incomplete urls. Fields not provided will be None.
|
||||
This parser is RFC 3986 and RFC 6874 compliant.
|
||||
|
||||
The parser logic and helper functions are based heavily on
|
||||
work done in the ``rfc3986`` module.
|
||||
|
||||
:param str url: URL to parse into a :class:`.Url` namedtuple.
|
||||
|
||||
Partly backwards-compatible with :mod:`urllib.parse`.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import urllib3
|
||||
|
||||
print( urllib3.util.parse_url('http://google.com/mail/'))
|
||||
# Url(scheme='http', host='google.com', port=None, path='/mail/', ...)
|
||||
|
||||
print( urllib3.util.parse_url('google.com:80'))
|
||||
# Url(scheme=None, host='google.com', port=80, path=None, ...)
|
||||
|
||||
print( urllib3.util.parse_url('/foo?bar'))
|
||||
# Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
|
||||
"""
|
||||
if not url:
|
||||
# Empty
|
||||
return Url()
|
||||
|
||||
source_url = url
|
||||
if not _SCHEME_RE.search(url):
|
||||
url = "//" + url
|
||||
|
||||
scheme: str | None
|
||||
authority: str | None
|
||||
auth: str | None
|
||||
host: str | None
|
||||
port: str | None
|
||||
port_int: int | None
|
||||
path: str | None
|
||||
query: str | None
|
||||
fragment: str | None
|
||||
|
||||
try:
|
||||
scheme, authority, path, query, fragment = _URI_RE.match(url).groups() # type: ignore[union-attr]
|
||||
normalize_uri = scheme is None or scheme.lower() in _NORMALIZABLE_SCHEMES
|
||||
|
||||
if scheme:
|
||||
scheme = scheme.lower()
|
||||
|
||||
if authority:
|
||||
auth, _, host_port = authority.rpartition("@")
|
||||
auth = auth or None
|
||||
host, port = _HOST_PORT_RE.match(host_port).groups() # type: ignore[union-attr]
|
||||
if auth and normalize_uri:
|
||||
auth = _encode_invalid_chars(auth, _USERINFO_CHARS)
|
||||
if port == "":
|
||||
port = None
|
||||
else:
|
||||
auth, host, port = None, None, None
|
||||
|
||||
if port is not None:
|
||||
port_int = int(port)
|
||||
if not (0 <= port_int <= 65535):
|
||||
raise LocationParseError(url)
|
||||
else:
|
||||
port_int = None
|
||||
|
||||
host = _normalize_host(host, scheme)
|
||||
|
||||
if normalize_uri and path:
|
||||
path = _remove_path_dot_segments(path)
|
||||
path = _encode_invalid_chars(path, _PATH_CHARS)
|
||||
if normalize_uri and query:
|
||||
query = _encode_invalid_chars(query, _QUERY_CHARS)
|
||||
if normalize_uri and fragment:
|
||||
fragment = _encode_invalid_chars(fragment, _FRAGMENT_CHARS)
|
||||
|
||||
except (ValueError, AttributeError) as e:
|
||||
raise LocationParseError(source_url) from e
|
||||
|
||||
# For the sake of backwards compatibility we put empty
|
||||
# string values for path if there are any defined values
|
||||
# beyond the path in the URL.
|
||||
# TODO: Remove this when we break backwards compatibility.
|
||||
if not path:
|
||||
if query is not None or fragment is not None:
|
||||
path = ""
|
||||
else:
|
||||
path = None
|
||||
|
||||
return Url(
|
||||
scheme=scheme,
|
||||
auth=auth,
|
||||
host=host,
|
||||
port=port_int,
|
||||
path=path,
|
||||
query=query,
|
||||
fragment=fragment,
|
||||
)
|
42
venv/lib/python3.11/site-packages/urllib3/util/util.py
Normal file
42
venv/lib/python3.11/site-packages/urllib3/util/util.py
Normal file
@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
|
||||
def to_bytes(
|
||||
x: str | bytes, encoding: str | None = None, errors: str | None = None
|
||||
) -> bytes:
|
||||
if isinstance(x, bytes):
|
||||
return x
|
||||
elif not isinstance(x, str):
|
||||
raise TypeError(f"not expecting type {type(x).__name__}")
|
||||
if encoding or errors:
|
||||
return x.encode(encoding or "utf-8", errors=errors or "strict")
|
||||
return x.encode()
|
||||
|
||||
|
||||
def to_str(
|
||||
x: str | bytes, encoding: str | None = None, errors: str | None = None
|
||||
) -> str:
|
||||
if isinstance(x, str):
|
||||
return x
|
||||
elif not isinstance(x, bytes):
|
||||
raise TypeError(f"not expecting type {type(x).__name__}")
|
||||
if encoding or errors:
|
||||
return x.decode(encoding or "utf-8", errors=errors or "strict")
|
||||
return x.decode()
|
||||
|
||||
|
||||
def reraise(
|
||||
tp: type[BaseException] | None,
|
||||
value: BaseException,
|
||||
tb: TracebackType | None = None,
|
||||
) -> typing.NoReturn:
|
||||
try:
|
||||
if value.__traceback__ is not tb:
|
||||
raise value.with_traceback(tb)
|
||||
raise value
|
||||
finally:
|
||||
value = None # type: ignore[assignment]
|
||||
tb = None
|
124
venv/lib/python3.11/site-packages/urllib3/util/wait.py
Normal file
124
venv/lib/python3.11/site-packages/urllib3/util/wait.py
Normal file
@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import socket
|
||||
from functools import partial
|
||||
|
||||
__all__ = ["wait_for_read", "wait_for_write"]
|
||||
|
||||
|
||||
# How should we wait on sockets?
|
||||
#
|
||||
# There are two types of APIs you can use for waiting on sockets: the fancy
|
||||
# modern stateful APIs like epoll/kqueue, and the older stateless APIs like
|
||||
# select/poll. The stateful APIs are more efficient when you have a lots of
|
||||
# sockets to keep track of, because you can set them up once and then use them
|
||||
# lots of times. But we only ever want to wait on a single socket at a time
|
||||
# and don't want to keep track of state, so the stateless APIs are actually
|
||||
# more efficient. So we want to use select() or poll().
|
||||
#
|
||||
# Now, how do we choose between select() and poll()? On traditional Unixes,
|
||||
# select() has a strange calling convention that makes it slow, or fail
|
||||
# altogether, for high-numbered file descriptors. The point of poll() is to fix
|
||||
# that, so on Unixes, we prefer poll().
|
||||
#
|
||||
# On Windows, there is no poll() (or at least Python doesn't provide a wrapper
|
||||
# for it), but that's OK, because on Windows, select() doesn't have this
|
||||
# strange calling convention; plain select() works fine.
|
||||
#
|
||||
# So: on Windows we use select(), and everywhere else we use poll(). We also
|
||||
# fall back to select() in case poll() is somehow broken or missing.
|
||||
|
||||
|
||||
def select_wait_for_socket(
|
||||
sock: socket.socket,
|
||||
read: bool = False,
|
||||
write: bool = False,
|
||||
timeout: float | None = None,
|
||||
) -> bool:
|
||||
if not read and not write:
|
||||
raise RuntimeError("must specify at least one of read=True, write=True")
|
||||
rcheck = []
|
||||
wcheck = []
|
||||
if read:
|
||||
rcheck.append(sock)
|
||||
if write:
|
||||
wcheck.append(sock)
|
||||
# When doing a non-blocking connect, most systems signal success by
|
||||
# marking the socket writable. Windows, though, signals success by marked
|
||||
# it as "exceptional". We paper over the difference by checking the write
|
||||
# sockets for both conditions. (The stdlib selectors module does the same
|
||||
# thing.)
|
||||
fn = partial(select.select, rcheck, wcheck, wcheck)
|
||||
rready, wready, xready = fn(timeout)
|
||||
return bool(rready or wready or xready)
|
||||
|
||||
|
||||
def poll_wait_for_socket(
|
||||
sock: socket.socket,
|
||||
read: bool = False,
|
||||
write: bool = False,
|
||||
timeout: float | None = None,
|
||||
) -> bool:
|
||||
if not read and not write:
|
||||
raise RuntimeError("must specify at least one of read=True, write=True")
|
||||
mask = 0
|
||||
if read:
|
||||
mask |= select.POLLIN
|
||||
if write:
|
||||
mask |= select.POLLOUT
|
||||
poll_obj = select.poll()
|
||||
poll_obj.register(sock, mask)
|
||||
|
||||
# For some reason, poll() takes timeout in milliseconds
|
||||
def do_poll(t: float | None) -> list[tuple[int, int]]:
|
||||
if t is not None:
|
||||
t *= 1000
|
||||
return poll_obj.poll(t)
|
||||
|
||||
return bool(do_poll(timeout))
|
||||
|
||||
|
||||
def _have_working_poll() -> bool:
|
||||
# Apparently some systems have a select.poll that fails as soon as you try
|
||||
# to use it, either due to strange configuration or broken monkeypatching
|
||||
# from libraries like eventlet/greenlet.
|
||||
try:
|
||||
poll_obj = select.poll()
|
||||
poll_obj.poll(0)
|
||||
except (AttributeError, OSError):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def wait_for_socket(
|
||||
sock: socket.socket,
|
||||
read: bool = False,
|
||||
write: bool = False,
|
||||
timeout: float | None = None,
|
||||
) -> bool:
|
||||
# We delay choosing which implementation to use until the first time we're
|
||||
# called. We could do it at import time, but then we might make the wrong
|
||||
# decision if someone goes wild with monkeypatching select.poll after
|
||||
# we're imported.
|
||||
global wait_for_socket
|
||||
if _have_working_poll():
|
||||
wait_for_socket = poll_wait_for_socket
|
||||
elif hasattr(select, "select"):
|
||||
wait_for_socket = select_wait_for_socket
|
||||
return wait_for_socket(sock, read, write, timeout)
|
||||
|
||||
|
||||
def wait_for_read(sock: socket.socket, timeout: float | None = None) -> bool:
|
||||
"""Waits for reading to be available on a given socket.
|
||||
Returns True if the socket is readable, or False if the timeout expired.
|
||||
"""
|
||||
return wait_for_socket(sock, read=True, timeout=timeout)
|
||||
|
||||
|
||||
def wait_for_write(sock: socket.socket, timeout: float | None = None) -> bool:
|
||||
"""Waits for writing to be available on a given socket.
|
||||
Returns True if the socket is readable, or False if the timeout expired.
|
||||
"""
|
||||
return wait_for_socket(sock, write=True, timeout=timeout)
|
Reference in New Issue
Block a user