Update 2025-04-24_11:44:19
This commit is contained in:
271
venv/lib/python3.11/site-packages/urllib3/util/ssltransport.py
Normal file
271
venv/lib/python3.11/site-packages/urllib3/util/ssltransport.py
Normal file
@ -0,0 +1,271 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import socket
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from ..exceptions import ProxySchemeUnsupported
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
|
||||
|
||||
|
||||
_WriteBuffer = typing.Union[bytearray, memoryview]
|
||||
_ReturnValue = typing.TypeVar("_ReturnValue")
|
||||
|
||||
SSL_BLOCKSIZE = 16384
|
||||
|
||||
|
||||
class SSLTransport:
|
||||
"""
|
||||
The SSLTransport wraps an existing socket and establishes an SSL connection.
|
||||
|
||||
Contrary to Python's implementation of SSLSocket, it allows you to chain
|
||||
multiple TLS connections together. It's particularly useful if you need to
|
||||
implement TLS within TLS.
|
||||
|
||||
The class supports most of the socket API operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
|
||||
"""
|
||||
Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
|
||||
for TLS in TLS.
|
||||
|
||||
The only requirement is that the ssl_context provides the 'wrap_bio'
|
||||
methods.
|
||||
"""
|
||||
|
||||
if not hasattr(ssl_context, "wrap_bio"):
|
||||
raise ProxySchemeUnsupported(
|
||||
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
|
||||
"available on non-native SSLContext"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket: socket.socket,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: str | None = None,
|
||||
suppress_ragged_eofs: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Create an SSLTransport around socket using the provided ssl_context.
|
||||
"""
|
||||
self.incoming = ssl.MemoryBIO()
|
||||
self.outgoing = ssl.MemoryBIO()
|
||||
|
||||
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||||
self.socket = socket
|
||||
|
||||
self.sslobj = ssl_context.wrap_bio(
|
||||
self.incoming, self.outgoing, server_hostname=server_hostname
|
||||
)
|
||||
|
||||
# Perform initial handshake.
|
||||
self._ssl_io_loop(self.sslobj.do_handshake)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(self, *_: typing.Any) -> None:
|
||||
self.close()
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self.socket.fileno()
|
||||
|
||||
def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
|
||||
return self._wrap_ssl_read(len, buffer)
|
||||
|
||||
def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to recv")
|
||||
return self._wrap_ssl_read(buflen)
|
||||
|
||||
def recv_into(
|
||||
self,
|
||||
buffer: _WriteBuffer,
|
||||
nbytes: int | None = None,
|
||||
flags: int = 0,
|
||||
) -> None | int | bytes:
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to recv_into")
|
||||
if nbytes is None:
|
||||
nbytes = len(buffer)
|
||||
return self.read(nbytes, buffer)
|
||||
|
||||
def sendall(self, data: bytes, flags: int = 0) -> None:
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to sendall")
|
||||
count = 0
|
||||
with memoryview(data) as view, view.cast("B") as byte_view:
|
||||
amount = len(byte_view)
|
||||
while count < amount:
|
||||
v = self.send(byte_view[count:])
|
||||
count += v
|
||||
|
||||
def send(self, data: bytes, flags: int = 0) -> int:
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to send")
|
||||
return self._ssl_io_loop(self.sslobj.write, data)
|
||||
|
||||
def makefile(
|
||||
self,
|
||||
mode: str,
|
||||
buffering: int | None = None,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
errors: str | None = None,
|
||||
newline: str | None = None,
|
||||
) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
|
||||
"""
|
||||
Python's httpclient uses makefile and buffered io when reading HTTP
|
||||
messages and we need to support it.
|
||||
|
||||
This is unfortunately a copy and paste of socket.py makefile with small
|
||||
changes to point to the socket directly.
|
||||
"""
|
||||
if not set(mode) <= {"r", "w", "b"}:
|
||||
raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
|
||||
|
||||
writing = "w" in mode
|
||||
reading = "r" in mode or not writing
|
||||
assert reading or writing
|
||||
binary = "b" in mode
|
||||
rawmode = ""
|
||||
if reading:
|
||||
rawmode += "r"
|
||||
if writing:
|
||||
rawmode += "w"
|
||||
raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
|
||||
self.socket._io_refs += 1 # type: ignore[attr-defined]
|
||||
if buffering is None:
|
||||
buffering = -1
|
||||
if buffering < 0:
|
||||
buffering = io.DEFAULT_BUFFER_SIZE
|
||||
if buffering == 0:
|
||||
if not binary:
|
||||
raise ValueError("unbuffered streams must be binary")
|
||||
return raw
|
||||
buffer: typing.BinaryIO
|
||||
if reading and writing:
|
||||
buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
|
||||
elif reading:
|
||||
buffer = io.BufferedReader(raw, buffering)
|
||||
else:
|
||||
assert writing
|
||||
buffer = io.BufferedWriter(raw, buffering)
|
||||
if binary:
|
||||
return buffer
|
||||
text = io.TextIOWrapper(buffer, encoding, errors, newline)
|
||||
text.mode = mode # type: ignore[misc]
|
||||
return text
|
||||
|
||||
def unwrap(self) -> None:
|
||||
self._ssl_io_loop(self.sslobj.unwrap)
|
||||
|
||||
def close(self) -> None:
|
||||
self.socket.close()
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(
|
||||
self, binary_form: typing.Literal[False] = ...
|
||||
) -> _TYPE_PEER_CERT_RET_DICT | None: ...
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ...
|
||||
|
||||
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
|
||||
return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
|
||||
|
||||
def version(self) -> str | None:
|
||||
return self.sslobj.version()
|
||||
|
||||
def cipher(self) -> tuple[str, str, int] | None:
|
||||
return self.sslobj.cipher()
|
||||
|
||||
def selected_alpn_protocol(self) -> str | None:
|
||||
return self.sslobj.selected_alpn_protocol()
|
||||
|
||||
def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
|
||||
return self.sslobj.shared_ciphers()
|
||||
|
||||
def compression(self) -> str | None:
|
||||
return self.sslobj.compression()
|
||||
|
||||
def settimeout(self, value: float | None) -> None:
|
||||
self.socket.settimeout(value)
|
||||
|
||||
def gettimeout(self) -> float | None:
|
||||
return self.socket.gettimeout()
|
||||
|
||||
def _decref_socketios(self) -> None:
|
||||
self.socket._decref_socketios() # type: ignore[attr-defined]
|
||||
|
||||
def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
|
||||
try:
|
||||
return self._ssl_io_loop(self.sslobj.read, len, buffer)
|
||||
except ssl.SSLError as e:
|
||||
if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
||||
return 0 # eof, return 0.
|
||||
else:
|
||||
raise
|
||||
|
||||
# func is sslobj.do_handshake or sslobj.unwrap
|
||||
@typing.overload
|
||||
def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: ...
|
||||
|
||||
# func is sslobj.write, arg1 is data
|
||||
@typing.overload
|
||||
def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: ...
|
||||
|
||||
# func is sslobj.read, arg1 is len, arg2 is buffer
|
||||
@typing.overload
|
||||
def _ssl_io_loop(
|
||||
self,
|
||||
func: typing.Callable[[int, bytearray | None], bytes],
|
||||
arg1: int,
|
||||
arg2: bytearray | None,
|
||||
) -> bytes: ...
|
||||
|
||||
def _ssl_io_loop(
|
||||
self,
|
||||
func: typing.Callable[..., _ReturnValue],
|
||||
arg1: None | bytes | int = None,
|
||||
arg2: bytearray | None = None,
|
||||
) -> _ReturnValue:
|
||||
"""Performs an I/O loop between incoming/outgoing and the socket."""
|
||||
should_loop = True
|
||||
ret = None
|
||||
|
||||
while should_loop:
|
||||
errno = None
|
||||
try:
|
||||
if arg1 is None and arg2 is None:
|
||||
ret = func()
|
||||
elif arg2 is None:
|
||||
ret = func(arg1)
|
||||
else:
|
||||
ret = func(arg1, arg2)
|
||||
except ssl.SSLError as e:
|
||||
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
|
||||
# WANT_READ, and WANT_WRITE are expected, others are not.
|
||||
raise e
|
||||
errno = e.errno
|
||||
|
||||
buf = self.outgoing.read()
|
||||
self.socket.sendall(buf)
|
||||
|
||||
if errno is None:
|
||||
should_loop = False
|
||||
elif errno == ssl.SSL_ERROR_WANT_READ:
|
||||
buf = self.socket.recv(SSL_BLOCKSIZE)
|
||||
if buf:
|
||||
self.incoming.write(buf)
|
||||
else:
|
||||
self.incoming.write_eof()
|
||||
return typing.cast(_ReturnValue, ret)
|
Reference in New Issue
Block a user