Update 2025-04-13_16:26:04

This commit is contained in:
root
2025-04-13 16:26:06 +02:00
commit f5d5898dc4
2312 changed files with 422700 additions and 0 deletions

View File

@ -0,0 +1 @@
92b12bc045050b55b848d37167a1a63947c364579889ce1d39788e45e9fac9e5

View File

@ -0,0 +1,101 @@
from typing import cast, List, Type, Union, ValuesView
from .._connection import Connection, NEED_DATA, PAUSED
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._state import CLIENT, CLOSED, DONE, MUST_CLOSE, SERVER
from .._util import Sentinel
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore
def get_all_events(conn: Connection) -> List[Event]:
got_events = []
while True:
event = conn.next_event()
if event in (NEED_DATA, PAUSED):
break
event = cast(Event, event)
got_events.append(event)
if type(event) is ConnectionClosed:
break
return got_events
def receive_and_get(conn: Connection, data: bytes) -> List[Event]:
conn.receive_data(data)
return get_all_events(conn)
# Merges adjacent Data events, converts payloads to bytestrings, and removes
# chunk boundaries.
def normalize_data_events(in_events: List[Event]) -> List[Event]:
out_events: List[Event] = []
for event in in_events:
if type(event) is Data:
event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False)
if out_events and type(out_events[-1]) is type(event) is Data:
out_events[-1] = Data(
data=out_events[-1].data + event.data,
chunk_start=out_events[-1].chunk_start,
chunk_end=out_events[-1].chunk_end,
)
else:
out_events.append(event)
return out_events
# Given that we want to write tests that push some events through a Connection
# and check that its state updates appropriately... we might as make a habit
# of pushing them through two Connections with a fake network link in
# between.
class ConnectionPair:
def __init__(self) -> None:
self.conn = {CLIENT: Connection(CLIENT), SERVER: Connection(SERVER)}
self.other = {CLIENT: SERVER, SERVER: CLIENT}
@property
def conns(self) -> ValuesView[Connection]:
return self.conn.values()
# expect="match" if expect=send_events; expect=[...] to say what expected
def send(
self,
role: Type[Sentinel],
send_events: Union[List[Event], Event],
expect: Union[List[Event], Event, Literal["match"]] = "match",
) -> bytes:
if not isinstance(send_events, list):
send_events = [send_events]
data = b""
closed = False
for send_event in send_events:
new_data = self.conn[role].send(send_event)
if new_data is None:
closed = True
else:
data += new_data
# send uses b"" to mean b"", and None to mean closed
# receive uses b"" to mean closed, and None to mean "try again"
# so we have to translate between the two conventions
if data:
self.conn[self.other[role]].receive_data(data)
if closed:
self.conn[self.other[role]].receive_data(b"")
got_events = get_all_events(self.conn[self.other[role]])
if expect == "match":
expect = send_events
if not isinstance(expect, list):
expect = [expect]
assert got_events == expect
return data

View File

@ -0,0 +1,115 @@
import json
import os.path
import socket
import socketserver
import threading
from contextlib import closing, contextmanager
from http.server import SimpleHTTPRequestHandler
from typing import Callable, Generator
from urllib.request import urlopen
import h11
@contextmanager
def socket_server(
handler: Callable[..., socketserver.BaseRequestHandler]
) -> Generator[socketserver.TCPServer, None, None]:
httpd = socketserver.TCPServer(("127.0.0.1", 0), handler)
thread = threading.Thread(
target=httpd.serve_forever, kwargs={"poll_interval": 0.01}
)
thread.daemon = True
try:
thread.start()
yield httpd
finally:
httpd.shutdown()
test_file_path = os.path.join(os.path.dirname(__file__), "data/test-file")
with open(test_file_path, "rb") as f:
test_file_data = f.read()
class SingleMindedRequestHandler(SimpleHTTPRequestHandler):
def translate_path(self, path: str) -> str:
return test_file_path
def test_h11_as_client() -> None:
with socket_server(SingleMindedRequestHandler) as httpd:
with closing(socket.create_connection(httpd.server_address)) as s:
c = h11.Connection(h11.CLIENT)
s.sendall(
c.send( # type: ignore[arg-type]
h11.Request(
method="GET", target="/foo", headers=[("Host", "localhost")]
)
)
)
s.sendall(c.send(h11.EndOfMessage())) # type: ignore[arg-type]
data = bytearray()
while True:
event = c.next_event()
print(event)
if event is h11.NEED_DATA:
# Use a small read buffer to make things more challenging
# and exercise more paths :-)
c.receive_data(s.recv(10))
continue
if type(event) is h11.Response:
assert event.status_code == 200
if type(event) is h11.Data:
data += event.data
if type(event) is h11.EndOfMessage:
break
assert bytes(data) == test_file_data
class H11RequestHandler(socketserver.BaseRequestHandler):
def handle(self) -> None:
with closing(self.request) as s:
c = h11.Connection(h11.SERVER)
request = None
while True:
event = c.next_event()
if event is h11.NEED_DATA:
# Use a small read buffer to make things more challenging
# and exercise more paths :-)
c.receive_data(s.recv(10))
continue
if type(event) is h11.Request:
request = event
if type(event) is h11.EndOfMessage:
break
assert request is not None
info = json.dumps(
{
"method": request.method.decode("ascii"),
"target": request.target.decode("ascii"),
"headers": {
name.decode("ascii"): value.decode("ascii")
for (name, value) in request.headers
},
}
)
s.sendall(c.send(h11.Response(status_code=200, headers=[]))) # type: ignore[arg-type]
s.sendall(c.send(h11.Data(data=info.encode("ascii"))))
s.sendall(c.send(h11.EndOfMessage()))
def test_h11_as_server() -> None:
with socket_server(H11RequestHandler) as httpd:
host, port = httpd.server_address
url = "http://{}:{}/some-path".format(host, port)
with closing(urlopen(url)) as f:
assert f.getcode() == 200
data = f.read()
info = json.loads(data.decode("ascii"))
print(info)
assert info["method"] == "GET"
assert info["target"] == "/some-path"
assert "urllib" in info["headers"]["user-agent"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,150 @@
from http import HTTPStatus
import pytest
from .. import _events
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._util import LocalProtocolError
def test_events() -> None:
with pytest.raises(LocalProtocolError):
# Missing Host:
req = Request(
method="GET", target="/", headers=[("a", "b")], http_version="1.1"
)
# But this is okay (HTTP/1.0)
req = Request(method="GET", target="/", headers=[("a", "b")], http_version="1.0")
# fields are normalized
assert req.method == b"GET"
assert req.target == b"/"
assert req.headers == [(b"a", b"b")]
assert req.http_version == b"1.0"
# This is also okay -- has a Host (with weird capitalization, which is ok)
req = Request(
method="GET",
target="/",
headers=[("a", "b"), ("hOSt", "example.com")],
http_version="1.1",
)
# we normalize header capitalization
assert req.headers == [(b"a", b"b"), (b"host", b"example.com")]
# Multiple host is bad too
with pytest.raises(LocalProtocolError):
req = Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Host", "a")],
http_version="1.1",
)
# Even for HTTP/1.0
with pytest.raises(LocalProtocolError):
req = Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Host", "a")],
http_version="1.0",
)
# Header values are validated
for bad_char in "\x00\r\n\f\v":
with pytest.raises(LocalProtocolError):
req = Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Foo", "asd" + bad_char)],
http_version="1.0",
)
# But for compatibility we allow non-whitespace control characters, even
# though they're forbidden by the spec.
Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Foo", "asd\x01\x02\x7f")],
http_version="1.0",
)
# Request target is validated
for bad_byte in b"\x00\x20\x7f\xee":
target = bytearray(b"/")
target.append(bad_byte)
with pytest.raises(LocalProtocolError):
Request(
method="GET", target=target, headers=[("Host", "a")], http_version="1.1"
)
# Request method is validated
with pytest.raises(LocalProtocolError):
Request(
method="GET / HTTP/1.1",
target=target,
headers=[("Host", "a")],
http_version="1.1",
)
ir = InformationalResponse(status_code=100, headers=[("Host", "a")])
assert ir.status_code == 100
assert ir.headers == [(b"host", b"a")]
assert ir.http_version == b"1.1"
with pytest.raises(LocalProtocolError):
InformationalResponse(status_code=200, headers=[("Host", "a")])
resp = Response(status_code=204, headers=[], http_version="1.0") # type: ignore[arg-type]
assert resp.status_code == 204
assert resp.headers == []
assert resp.http_version == b"1.0"
with pytest.raises(LocalProtocolError):
resp = Response(status_code=100, headers=[], http_version="1.0") # type: ignore[arg-type]
with pytest.raises(LocalProtocolError):
Response(status_code="100", headers=[], http_version="1.0") # type: ignore[arg-type]
with pytest.raises(LocalProtocolError):
InformationalResponse(status_code=b"100", headers=[], http_version="1.0") # type: ignore[arg-type]
d = Data(data=b"asdf")
assert d.data == b"asdf"
eom = EndOfMessage()
assert eom.headers == []
cc = ConnectionClosed()
assert repr(cc) == "ConnectionClosed()"
def test_intenum_status_code() -> None:
# https://github.com/python-hyper/h11/issues/72
r = Response(status_code=HTTPStatus.OK, headers=[], http_version="1.0") # type: ignore[arg-type]
assert r.status_code == HTTPStatus.OK
assert type(r.status_code) is not type(HTTPStatus.OK)
assert type(r.status_code) is int
def test_header_casing() -> None:
r = Request(
method="GET",
target="/",
headers=[("Host", "example.org"), ("Connection", "keep-alive")],
http_version="1.1",
)
assert len(r.headers) == 2
assert r.headers[0] == (b"host", b"example.org")
assert r.headers == [(b"host", b"example.org"), (b"connection", b"keep-alive")]
assert r.headers.raw_items() == [
(b"Host", b"example.org"),
(b"Connection", b"keep-alive"),
]

View File

@ -0,0 +1,157 @@
import pytest
from .._events import Request
from .._headers import (
get_comma_header,
has_expect_100_continue,
Headers,
normalize_and_validate,
set_comma_header,
)
from .._util import LocalProtocolError
def test_normalize_and_validate() -> None:
assert normalize_and_validate([("foo", "bar")]) == [(b"foo", b"bar")]
assert normalize_and_validate([(b"foo", b"bar")]) == [(b"foo", b"bar")]
# no leading/trailing whitespace in names
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo ", "bar")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b" foo", "bar")])
# no weird characters in names
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate([(b"foo bar", b"baz")])
assert "foo bar" in str(excinfo.value)
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo\x00bar", b"baz")])
# Not even 8-bit characters:
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo\xffbar", b"baz")])
# And not even the control characters we allow in values:
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo\x01bar", b"baz")])
# no return or NUL characters in values
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate([("foo", "bar\rbaz")])
assert "bar\\rbaz" in str(excinfo.value)
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "bar\nbaz")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "bar\x00baz")])
# no leading/trailing whitespace
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "barbaz ")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", " barbaz")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "barbaz\t")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "\tbarbaz")])
# content-length
assert normalize_and_validate([("Content-Length", "1")]) == [
(b"content-length", b"1")
]
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "asdf")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "1x")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "1"), ("Content-Length", "2")])
assert normalize_and_validate(
[("Content-Length", "0"), ("Content-Length", "0")]
) == [(b"content-length", b"0")]
assert normalize_and_validate([("Content-Length", "0 , 0")]) == [
(b"content-length", b"0")
]
with pytest.raises(LocalProtocolError):
normalize_and_validate(
[("Content-Length", "1"), ("Content-Length", "1"), ("Content-Length", "2")]
)
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "1 , 1,2")])
# transfer-encoding
assert normalize_and_validate([("Transfer-Encoding", "chunked")]) == [
(b"transfer-encoding", b"chunked")
]
assert normalize_and_validate([("Transfer-Encoding", "cHuNkEd")]) == [
(b"transfer-encoding", b"chunked")
]
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate([("Transfer-Encoding", "gzip")])
assert excinfo.value.error_status_hint == 501 # Not Implemented
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate(
[("Transfer-Encoding", "chunked"), ("Transfer-Encoding", "gzip")]
)
assert excinfo.value.error_status_hint == 501 # Not Implemented
def test_get_set_comma_header() -> None:
headers = normalize_and_validate(
[
("Connection", "close"),
("whatever", "something"),
("connectiON", "fOo,, , BAR"),
]
)
assert get_comma_header(headers, b"connection") == [b"close", b"foo", b"bar"]
headers = set_comma_header(headers, b"newthing", ["a", "b"]) # type: ignore
with pytest.raises(LocalProtocolError):
set_comma_header(headers, b"newthing", [" a", "b"]) # type: ignore
assert headers == [
(b"connection", b"close"),
(b"whatever", b"something"),
(b"connection", b"fOo,, , BAR"),
(b"newthing", b"a"),
(b"newthing", b"b"),
]
headers = set_comma_header(headers, b"whatever", ["different thing"]) # type: ignore
assert headers == [
(b"connection", b"close"),
(b"connection", b"fOo,, , BAR"),
(b"newthing", b"a"),
(b"newthing", b"b"),
(b"whatever", b"different thing"),
]
def test_has_100_continue() -> None:
assert has_expect_100_continue(
Request(
method="GET",
target="/",
headers=[("Host", "example.com"), ("Expect", "100-continue")],
)
)
assert not has_expect_100_continue(
Request(method="GET", target="/", headers=[("Host", "example.com")])
)
# Case insensitive
assert has_expect_100_continue(
Request(
method="GET",
target="/",
headers=[("Host", "example.com"), ("Expect", "100-Continue")],
)
)
# Doesn't work in HTTP/1.0
assert not has_expect_100_continue(
Request(
method="GET",
target="/",
headers=[("Host", "example.com"), ("Expect", "100-continue")],
http_version="1.0",
)
)

View File

@ -0,0 +1,32 @@
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .helpers import normalize_data_events
def test_normalize_data_events() -> None:
assert normalize_data_events(
[
Data(data=bytearray(b"1")),
Data(data=b"2"),
Response(status_code=200, headers=[]), # type: ignore[arg-type]
Data(data=b"3"),
Data(data=b"4"),
EndOfMessage(),
Data(data=b"5"),
Data(data=b"6"),
Data(data=b"7"),
]
) == [
Data(data=b"12"),
Response(status_code=200, headers=[]), # type: ignore[arg-type]
Data(data=b"34"),
EndOfMessage(),
Data(data=b"567"),
]

View File

@ -0,0 +1,572 @@
from typing import Any, Callable, Generator, List
import pytest
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._headers import Headers, normalize_and_validate
from .._readers import (
_obsolete_line_fold,
ChunkedReader,
ContentLengthReader,
Http10Reader,
READERS,
)
from .._receivebuffer import ReceiveBuffer
from .._state import (
CLIENT,
CLOSED,
DONE,
IDLE,
MIGHT_SWITCH_PROTOCOL,
MUST_CLOSE,
SEND_BODY,
SEND_RESPONSE,
SERVER,
SWITCHED_PROTOCOL,
)
from .._util import LocalProtocolError
from .._writers import (
ChunkedWriter,
ContentLengthWriter,
Http10Writer,
write_any_response,
write_headers,
write_request,
WRITERS,
)
from .helpers import normalize_data_events
SIMPLE_CASES = [
(
(CLIENT, IDLE),
Request(
method="GET",
target="/a",
headers=[("Host", "foo"), ("Connection", "close")],
),
b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"),
b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Response(status_code=200, headers=[], reason=b"OK"), # type: ignore[arg-type]
b"HTTP/1.1 200 OK\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
InformationalResponse(
status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade"
),
b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), # type: ignore[arg-type]
b"HTTP/1.1 101 Upgrade\r\n\r\n",
),
]
def dowrite(writer: Callable[..., None], obj: Any) -> bytes:
got_list: List[bytes] = []
writer(obj, got_list.append)
return b"".join(got_list)
def tw(writer: Any, obj: Any, expected: Any) -> None:
got = dowrite(writer, obj)
assert got == expected
def makebuf(data: bytes) -> ReceiveBuffer:
buf = ReceiveBuffer()
buf += data
return buf
def tr(reader: Any, data: bytes, expected: Any) -> None:
def check(got: Any) -> None:
assert got == expected
# Headers should always be returned as bytes, not e.g. bytearray
# https://github.com/python-hyper/wsproto/pull/54#issuecomment-377709478
for name, value in getattr(got, "headers", []):
assert type(name) is bytes
assert type(value) is bytes
# Simple: consume whole thing
buf = makebuf(data)
check(reader(buf))
assert not buf
# Incrementally growing buffer
buf = ReceiveBuffer()
for i in range(len(data)):
assert reader(buf) is None
buf += data[i : i + 1]
check(reader(buf))
# Trailing data
buf = makebuf(data)
buf += b"trailing"
check(reader(buf))
assert bytes(buf) == b"trailing"
def test_writers_simple() -> None:
for ((role, state), event, binary) in SIMPLE_CASES:
tw(WRITERS[role, state], event, binary)
def test_readers_simple() -> None:
for ((role, state), event, binary) in SIMPLE_CASES:
tr(READERS[role, state], binary, event)
def test_writers_unusual() -> None:
# Simple test of the write_headers utility routine
tw(
write_headers,
normalize_and_validate([("foo", "bar"), ("baz", "quux")]),
b"foo: bar\r\nbaz: quux\r\n\r\n",
)
tw(write_headers, Headers([]), b"\r\n")
# We understand HTTP/1.0, but we don't speak it
with pytest.raises(LocalProtocolError):
tw(
write_request,
Request(
method="GET",
target="/",
headers=[("Host", "foo"), ("Connection", "close")],
http_version="1.0",
),
None,
)
with pytest.raises(LocalProtocolError):
tw(
write_any_response,
Response(
status_code=200, headers=[("Connection", "close")], http_version="1.0"
),
None,
)
def test_readers_unusual() -> None:
# Reading HTTP/1.0
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.0\r\nSome: header\r\n\r\n",
Request(
method="HEAD",
target="/foo",
headers=[("Some", "header")],
http_version="1.0",
),
)
# check no-headers, since it's only legal with HTTP/1.0
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.0\r\n\r\n",
Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), # type: ignore[arg-type]
)
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\nSome: header\r\n\r\n",
Response(
status_code=200,
headers=[("Some", "header")],
http_version="1.0",
reason=b"OK",
),
)
# single-character header values (actually disallowed by the ABNF in RFC
# 7230 -- this is a bug in the standard that we originally copied...)
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\n" b"Foo: a a a a a \r\n\r\n",
Response(
status_code=200,
headers=[("Foo", "a a a a a")],
http_version="1.0",
reason=b"OK",
),
)
# Empty headers -- also legal
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\n" b"Foo:\r\n\r\n",
Response(
status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
),
)
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\n" b"Foo: \t \t \r\n\r\n",
Response(
status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
),
)
# Tolerate broken servers that leave off the response code
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200\r\n" b"Foo: bar\r\n\r\n",
Response(
status_code=200, headers=[("Foo", "bar")], http_version="1.0", reason=b""
),
)
# Tolerate headers line endings (\r\n and \n)
# \n\r\b between headers and body
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.1 200 OK\r\nSomeHeader: val\n\r\n",
Response(
status_code=200,
headers=[("SomeHeader", "val")],
http_version="1.1",
reason="OK",
),
)
# delimited only with \n
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.1 200 OK\nSomeHeader1: val1\nSomeHeader2: val2\n\n",
Response(
status_code=200,
headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
http_version="1.1",
reason="OK",
),
)
# mixed \r\n and \n
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.1 200 OK\r\nSomeHeader1: val1\nSomeHeader2: val2\n\r\n",
Response(
status_code=200,
headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
http_version="1.1",
reason="OK",
),
)
# obsolete line folding
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Some: multi-line\r\n"
b" header\r\n"
b"\tnonsense\r\n"
b" \t \t\tI guess\r\n"
b"Connection: close\r\n"
b"More-nonsense: in the\r\n"
b" last header \r\n\r\n",
Request(
method="HEAD",
target="/foo",
headers=[
("Host", "example.com"),
("Some", "multi-line header nonsense I guess"),
("Connection", "close"),
("More-nonsense", "in the last header"),
],
),
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b" folded: line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"foo : line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" b": line\r\n\r\n", None)
def test__obsolete_line_fold_bytes() -> None:
# _obsolete_line_fold has a defensive cast to bytearray, which is
# necessary to protect against O(n^2) behavior in case anyone ever passes
# in regular bytestrings... but right now we never pass in regular
# bytestrings. so this test just exists to get some coverage on that
# defensive cast.
assert list(_obsolete_line_fold([b"aaa", b"bbb", b" ccc", b"ddd"])) == [
b"aaa",
bytearray(b"bbb ccc"),
b"ddd",
]
def _run_reader_iter(
reader: Any, buf: bytes, do_eof: bool
) -> Generator[Any, None, None]:
while True:
event = reader(buf)
if event is None:
break
yield event
# body readers have undefined behavior after returning EndOfMessage,
# because this changes the state so they don't get called again
if type(event) is EndOfMessage:
break
if do_eof:
assert not buf
yield reader.read_eof()
def _run_reader(*args: Any) -> List[Event]:
events = list(_run_reader_iter(*args))
return normalize_data_events(events)
def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None:
# Simple: consume whole thing
print("Test 1")
buf = makebuf(data)
assert _run_reader(thunk(), buf, do_eof) == expected
# Incrementally growing buffer
print("Test 2")
reader = thunk()
buf = ReceiveBuffer()
events = []
for i in range(len(data)):
events += _run_reader(reader, buf, False)
buf += data[i : i + 1]
events += _run_reader(reader, buf, do_eof)
assert normalize_data_events(events) == expected
is_complete = any(type(event) is EndOfMessage for event in expected)
if is_complete and not do_eof:
buf = makebuf(data + b"trailing")
assert _run_reader(thunk(), buf, False) == expected
def test_ContentLengthReader() -> None:
t_body_reader(lambda: ContentLengthReader(0), b"", [EndOfMessage()])
t_body_reader(
lambda: ContentLengthReader(10),
b"0123456789",
[Data(data=b"0123456789"), EndOfMessage()],
)
def test_Http10Reader() -> None:
t_body_reader(Http10Reader, b"", [EndOfMessage()], do_eof=True)
t_body_reader(Http10Reader, b"asdf", [Data(data=b"asdf")], do_eof=False)
t_body_reader(
Http10Reader, b"asdf", [Data(data=b"asdf"), EndOfMessage()], do_eof=True
)
def test_ChunkedReader() -> None:
t_body_reader(ChunkedReader, b"0\r\n\r\n", [EndOfMessage()])
t_body_reader(
ChunkedReader,
b"0\r\nSome: header\r\n\r\n",
[EndOfMessage(headers=[("Some", "header")])],
)
t_body_reader(
ChunkedReader,
b"5\r\n01234\r\n"
+ b"10\r\n0123456789abcdef\r\n"
+ b"0\r\n"
+ b"Some: header\r\n\r\n",
[
Data(data=b"012340123456789abcdef"),
EndOfMessage(headers=[("Some", "header")]),
],
)
t_body_reader(
ChunkedReader,
b"5\r\n01234\r\n" + b"10\r\n0123456789abcdef\r\n" + b"0\r\n\r\n",
[Data(data=b"012340123456789abcdef"), EndOfMessage()],
)
# handles upper and lowercase hex
t_body_reader(
ChunkedReader,
b"aA\r\n" + b"x" * 0xAA + b"\r\n" + b"0\r\n\r\n",
[Data(data=b"x" * 0xAA), EndOfMessage()],
)
# refuses arbitrarily long chunk integers
with pytest.raises(LocalProtocolError):
# Technically this is legal HTTP/1.1, but we refuse to process chunk
# sizes that don't fit into 20 characters of hex
t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [Data(data=b"xxx")])
# refuses garbage in the chunk count
with pytest.raises(LocalProtocolError):
t_body_reader(ChunkedReader, b"10\x00\r\nxxx", None)
# handles (and discards) "chunk extensions" omg wtf
t_body_reader(
ChunkedReader,
b"5; hello=there\r\n"
+ b"xxxxx"
+ b"\r\n"
+ b'0; random="junk"; some=more; canbe=lonnnnngg\r\n\r\n',
[Data(data=b"xxxxx"), EndOfMessage()],
)
t_body_reader(
ChunkedReader,
b"5 \r\n01234\r\n" + b"0\r\n\r\n",
[Data(data=b"01234"), EndOfMessage()],
)
def test_ContentLengthWriter() -> None:
w = ContentLengthWriter(5)
assert dowrite(w, Data(data=b"123")) == b"123"
assert dowrite(w, Data(data=b"45")) == b"45"
assert dowrite(w, EndOfMessage()) == b""
w = ContentLengthWriter(5)
with pytest.raises(LocalProtocolError):
dowrite(w, Data(data=b"123456"))
w = ContentLengthWriter(5)
dowrite(w, Data(data=b"123"))
with pytest.raises(LocalProtocolError):
dowrite(w, Data(data=b"456"))
w = ContentLengthWriter(5)
dowrite(w, Data(data=b"123"))
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage())
w = ContentLengthWriter(5)
dowrite(w, Data(data=b"123")) == b"123"
dowrite(w, Data(data=b"45")) == b"45"
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
def test_ChunkedWriter() -> None:
w = ChunkedWriter()
assert dowrite(w, Data(data=b"aaa")) == b"3\r\naaa\r\n"
assert dowrite(w, Data(data=b"a" * 20)) == b"14\r\n" + b"a" * 20 + b"\r\n"
assert dowrite(w, Data(data=b"")) == b""
assert dowrite(w, EndOfMessage()) == b"0\r\n\r\n"
assert (
dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")]))
== b"0\r\nEtag: asdf\r\na: b\r\n\r\n"
)
def test_Http10Writer() -> None:
w = Http10Writer()
assert dowrite(w, Data(data=b"1234")) == b"1234"
assert dowrite(w, EndOfMessage()) == b""
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
def test_reject_garbage_after_request_line() -> None:
with pytest.raises(LocalProtocolError):
tr(READERS[SERVER, SEND_RESPONSE], b"HTTP/1.0 200 OK\x00xxxx\r\n\r\n", None)
def test_reject_garbage_after_response_line() -> None:
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1 xxxxxx\r\n" b"Host: a\r\n\r\n",
None,
)
def test_reject_garbage_in_header_line() -> None:
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"Host: foo\x00bar\r\n\r\n",
None,
)
def test_reject_non_vchar_in_path() -> None:
for bad_char in b"\x00\x20\x7f\xee":
message = bytearray(b"HEAD /")
message.append(bad_char)
message.extend(b" HTTP/1.1\r\nHost: foobar\r\n\r\n")
with pytest.raises(LocalProtocolError):
tr(READERS[CLIENT, IDLE], message, None)
# https://github.com/python-hyper/h11/issues/57
def test_allow_some_garbage_in_cookies() -> None:
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n"
b"Host: foo\r\n"
b"Set-Cookie: ___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900\r\n"
b"\r\n",
Request(
method="HEAD",
target="/foo",
headers=[
("Host", "foo"),
("Set-Cookie", "___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900"),
],
),
)
def test_host_comes_first() -> None:
tw(
write_headers,
normalize_and_validate([("foo", "bar"), ("Host", "example.com")]),
b"Host: example.com\r\nfoo: bar\r\n\r\n",
)

View File

@ -0,0 +1,135 @@
import re
from typing import Tuple
import pytest
from .._receivebuffer import ReceiveBuffer
def test_receivebuffer() -> None:
b = ReceiveBuffer()
assert not b
assert len(b) == 0
assert bytes(b) == b""
b += b"123"
assert b
assert len(b) == 3
assert bytes(b) == b"123"
assert bytes(b) == b"123"
assert b.maybe_extract_at_most(2) == b"12"
assert b
assert len(b) == 1
assert bytes(b) == b"3"
assert bytes(b) == b"3"
assert b.maybe_extract_at_most(10) == b"3"
assert bytes(b) == b""
assert b.maybe_extract_at_most(10) is None
assert not b
################################################################
# maybe_extract_until_next
################################################################
b += b"123\n456\r\n789\r\n"
assert b.maybe_extract_next_line() == b"123\n456\r\n"
assert bytes(b) == b"789\r\n"
assert b.maybe_extract_next_line() == b"789\r\n"
assert bytes(b) == b""
b += b"12\r"
assert b.maybe_extract_next_line() is None
assert bytes(b) == b"12\r"
b += b"345\n\r"
assert b.maybe_extract_next_line() is None
assert bytes(b) == b"12\r345\n\r"
# here we stopped at the middle of b"\r\n" delimiter
b += b"\n6789aaa123\r\n"
assert b.maybe_extract_next_line() == b"12\r345\n\r\n"
assert b.maybe_extract_next_line() == b"6789aaa123\r\n"
assert b.maybe_extract_next_line() is None
assert bytes(b) == b""
################################################################
# maybe_extract_lines
################################################################
b += b"123\r\na: b\r\nfoo:bar\r\n\r\ntrailing"
lines = b.maybe_extract_lines()
assert lines == [b"123", b"a: b", b"foo:bar"]
assert bytes(b) == b"trailing"
assert b.maybe_extract_lines() is None
b += b"\r\n\r"
assert b.maybe_extract_lines() is None
assert b.maybe_extract_at_most(100) == b"trailing\r\n\r"
assert not b
# Empty body case (as happens at the end of chunked encoding if there are
# no trailing headers, e.g.)
b += b"\r\ntrailing"
assert b.maybe_extract_lines() == []
assert bytes(b) == b"trailing"
@pytest.mark.parametrize(
"data",
[
pytest.param(
(
b"HTTP/1.1 200 OK\r\n",
b"Content-type: text/plain\r\n",
b"Connection: close\r\n",
b"\r\n",
b"Some body",
),
id="with_crlf_delimiter",
),
pytest.param(
(
b"HTTP/1.1 200 OK\n",
b"Content-type: text/plain\n",
b"Connection: close\n",
b"\n",
b"Some body",
),
id="with_lf_only_delimiter",
),
pytest.param(
(
b"HTTP/1.1 200 OK\n",
b"Content-type: text/plain\r\n",
b"Connection: close\n",
b"\n",
b"Some body",
),
id="with_mixed_crlf_and_lf",
),
],
)
def test_receivebuffer_for_invalid_delimiter(data: Tuple[bytes]) -> None:
b = ReceiveBuffer()
for line in data:
b += line
lines = b.maybe_extract_lines()
assert lines == [
b"HTTP/1.1 200 OK",
b"Content-type: text/plain",
b"Connection: close",
]
assert bytes(b) == b"Some body"

View File

@ -0,0 +1,271 @@
import pytest
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._state import (
_SWITCH_CONNECT,
_SWITCH_UPGRADE,
CLIENT,
CLOSED,
ConnectionState,
DONE,
IDLE,
MIGHT_SWITCH_PROTOCOL,
MUST_CLOSE,
SEND_BODY,
SEND_RESPONSE,
SERVER,
SWITCHED_PROTOCOL,
)
from .._util import LocalProtocolError
def test_ConnectionState() -> None:
cs = ConnectionState()
# Basic event-triggered transitions
assert cs.states == {CLIENT: IDLE, SERVER: IDLE}
cs.process_event(CLIENT, Request)
# The SERVER-Request special case:
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
# Illegal transitions raise an error and nothing happens
with pytest.raises(LocalProtocolError):
cs.process_event(CLIENT, Request)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, InformationalResponse)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, Response)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_BODY}
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, EndOfMessage)
assert cs.states == {CLIENT: DONE, SERVER: DONE}
# State-triggered transition
cs.process_event(SERVER, ConnectionClosed)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: CLOSED}
def test_ConnectionState_keep_alive() -> None:
# keep_alive = False
cs = ConnectionState()
cs.process_event(CLIENT, Request)
cs.process_keep_alive_disabled()
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE}
def test_ConnectionState_keep_alive_in_DONE() -> None:
# Check that if keep_alive is disabled when the CLIENT is already in DONE,
# then this is sufficient to immediately trigger the DONE -> MUST_CLOSE
# transition
cs = ConnectionState()
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
assert cs.states[CLIENT] is DONE
cs.process_keep_alive_disabled()
assert cs.states[CLIENT] is MUST_CLOSE
def test_ConnectionState_switch_denied() -> None:
for switch_type in (_SWITCH_CONNECT, _SWITCH_UPGRADE):
for deny_early in (True, False):
cs = ConnectionState()
cs.process_client_switch_proposal(switch_type)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, Data)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
assert switch_type in cs.pending_switch_proposals
if deny_early:
# before client reaches DONE
cs.process_event(SERVER, Response)
assert not cs.pending_switch_proposals
cs.process_event(CLIENT, EndOfMessage)
if deny_early:
assert cs.states == {CLIENT: DONE, SERVER: SEND_BODY}
else:
assert cs.states == {
CLIENT: MIGHT_SWITCH_PROTOCOL,
SERVER: SEND_RESPONSE,
}
cs.process_event(SERVER, InformationalResponse)
assert cs.states == {
CLIENT: MIGHT_SWITCH_PROTOCOL,
SERVER: SEND_RESPONSE,
}
cs.process_event(SERVER, Response)
assert cs.states == {CLIENT: DONE, SERVER: SEND_BODY}
assert not cs.pending_switch_proposals
_response_type_for_switch = {
_SWITCH_UPGRADE: InformationalResponse,
_SWITCH_CONNECT: Response,
None: Response,
}
def test_ConnectionState_protocol_switch_accepted() -> None:
for switch_event in [_SWITCH_UPGRADE, _SWITCH_CONNECT]:
cs = ConnectionState()
cs.process_client_switch_proposal(switch_event)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, Data)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, InformationalResponse)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, _response_type_for_switch[switch_event], switch_event)
assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL}
def test_ConnectionState_double_protocol_switch() -> None:
# CONNECT + Upgrade is legal! Very silly, but legal. So we support
# it. Because sometimes doing the silly thing is easier than not.
for server_switch in [None, _SWITCH_UPGRADE, _SWITCH_CONNECT]:
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_client_switch_proposal(_SWITCH_CONNECT)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
cs.process_event(
SERVER, _response_type_for_switch[server_switch], server_switch
)
if server_switch is None:
assert cs.states == {CLIENT: DONE, SERVER: SEND_BODY}
else:
assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL}
def test_ConnectionState_inconsistent_protocol_switch() -> None:
for client_switches, server_switch in [
([], _SWITCH_CONNECT),
([], _SWITCH_UPGRADE),
([_SWITCH_UPGRADE], _SWITCH_CONNECT),
([_SWITCH_CONNECT], _SWITCH_UPGRADE),
]:
cs = ConnectionState()
for client_switch in client_switches: # type: ignore[attr-defined]
cs.process_client_switch_proposal(client_switch)
cs.process_event(CLIENT, Request)
with pytest.raises(LocalProtocolError):
cs.process_event(SERVER, Response, server_switch)
def test_ConnectionState_keepalive_protocol_switch_interaction() -> None:
# keep_alive=False + pending_switch_proposals
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_event(CLIENT, Request)
cs.process_keep_alive_disabled()
cs.process_event(CLIENT, Data)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
# the protocol switch "wins"
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
# but when the server denies the request, keep_alive comes back into play
cs.process_event(SERVER, Response)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: SEND_BODY}
def test_ConnectionState_reuse() -> None:
cs = ConnectionState()
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
cs.start_next_cycle()
assert cs.states == {CLIENT: IDLE, SERVER: IDLE}
# No keepalive
cs.process_event(CLIENT, Request)
cs.process_keep_alive_disabled()
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
# One side closed
cs = ConnectionState()
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(CLIENT, ConnectionClosed)
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
# Succesful protocol switch
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, InformationalResponse, _SWITCH_UPGRADE)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
# Failed protocol switch
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
cs.start_next_cycle()
assert cs.states == {CLIENT: IDLE, SERVER: IDLE}
def test_server_request_is_illegal() -> None:
# There used to be a bug in how we handled the Request special case that
# made this allowed...
cs = ConnectionState()
with pytest.raises(LocalProtocolError):
cs.process_event(SERVER, Request)

View File

@ -0,0 +1,112 @@
import re
import sys
import traceback
from typing import NoReturn
import pytest
from .._util import (
bytesify,
LocalProtocolError,
ProtocolError,
RemoteProtocolError,
Sentinel,
validate,
)
def test_ProtocolError() -> None:
with pytest.raises(TypeError):
ProtocolError("abstract base class")
def test_LocalProtocolError() -> None:
try:
raise LocalProtocolError("foo")
except LocalProtocolError as e:
assert str(e) == "foo"
assert e.error_status_hint == 400
try:
raise LocalProtocolError("foo", error_status_hint=418)
except LocalProtocolError as e:
assert str(e) == "foo"
assert e.error_status_hint == 418
def thunk() -> NoReturn:
raise LocalProtocolError("a", error_status_hint=420)
try:
try:
thunk()
except LocalProtocolError as exc1:
orig_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
exc1._reraise_as_remote_protocol_error()
except RemoteProtocolError as exc2:
assert type(exc2) is RemoteProtocolError
assert exc2.args == ("a",)
assert exc2.error_status_hint == 420
new_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
assert new_traceback.endswith(orig_traceback)
def test_validate() -> None:
my_re = re.compile(rb"(?P<group1>[0-9]+)\.(?P<group2>[0-9]+)")
with pytest.raises(LocalProtocolError):
validate(my_re, b"0.")
groups = validate(my_re, b"0.1")
assert groups == {"group1": b"0", "group2": b"1"}
# successful partial matches are an error - must match whole string
with pytest.raises(LocalProtocolError):
validate(my_re, b"0.1xx")
with pytest.raises(LocalProtocolError):
validate(my_re, b"0.1\n")
def test_validate_formatting() -> None:
my_re = re.compile(rb"foo")
with pytest.raises(LocalProtocolError) as excinfo:
validate(my_re, b"", "oops")
assert "oops" in str(excinfo.value)
with pytest.raises(LocalProtocolError) as excinfo:
validate(my_re, b"", "oops {}")
assert "oops {}" in str(excinfo.value)
with pytest.raises(LocalProtocolError) as excinfo:
validate(my_re, b"", "oops {} xx", 10)
assert "oops 10 xx" in str(excinfo.value)
def test_make_sentinel() -> None:
class S(Sentinel, metaclass=Sentinel):
pass
assert repr(S) == "S"
assert S == S
assert type(S).__name__ == "S"
assert S in {S}
assert type(S) is S
class S2(Sentinel, metaclass=Sentinel):
pass
assert repr(S2) == "S2"
assert S != S2
assert S not in {S2}
assert type(S) is not type(S2)
def test_bytesify() -> None:
assert bytesify(b"123") == b"123"
assert bytesify(bytearray(b"123")) == b"123"
assert bytesify("123") == b"123"
with pytest.raises(UnicodeEncodeError):
bytesify("\u1234")
with pytest.raises(TypeError):
bytesify(10)