from __future__ import annotations import contextlib import inspect import io import json import math import sys import typing import warnings from concurrent.futures import Future from types import GeneratorType from urllib.parse import unquote, urljoin import anyio import anyio.abc import anyio.from_thread from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect if sys.version_info >= (3, 10): # pragma: no cover from typing import TypeGuard else: # pragma: no cover from typing_extensions import TypeGuard try: import httpx except ModuleNotFoundError: # pragma: no cover raise RuntimeError( "The starlette.testclient module requires the httpx package to be installed.\n" "You can install this with:\n" " $ pip install httpx\n" ) _PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]] ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] ASGI2App = typing.Callable[[Scope], ASGIInstance] ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] _RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]] def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: if inspect.isclass(app): return hasattr(app, "__await__") return is_async_callable(app) class _WrapASGI2: """ Provide an ASGI3 interface onto an ASGI2 app. """ def __init__(self, app: ASGI2App) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: instance = self.app(scope) await instance(receive, send) class _AsyncBackend(typing.TypedDict): backend: str backend_options: dict[str, typing.Any] class _Upgrade(Exception): def __init__(self, session: WebSocketTestSession) -> None: self.session = session class WebSocketDenialResponse( # type: ignore[misc] httpx.Response, WebSocketDisconnect, ): """ A special case of `WebSocketDisconnect`, raised in the `TestClient` if the `WebSocket` is closed before being accepted with a `send_denial_response()`. """ class WebSocketTestSession: def __init__( self, app: ASGI3App, scope: Scope, portal_factory: _PortalFactoryType, ) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None self.portal_factory = portal_factory self.extra_headers = None def __enter__(self) -> WebSocketTestSession: with contextlib.ExitStack() as stack: self.portal = portal = stack.enter_context(self.portal_factory()) fut, cs = portal.start_task(self._run) stack.callback(fut.result) stack.callback(portal.call, cs.cancel) self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) self.accepted_subprotocol = message.get("subprotocol", None) self.extra_headers = message.get("headers", None) stack.callback(self.close, 1000) self.exit_stack = stack.pop_all() return self def __exit__(self, *args: typing.Any) -> bool | None: return self.exit_stack.__exit__(*args) async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: """ The sub-thread in which the websocket session runs. """ send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) send_tx, send_rx = send receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) receive_tx, receive_rx = receive with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: self._receive_tx = receive_tx self._send_rx = send_rx task_status.started(cs) await self.app(self.scope, receive_rx.receive, send_tx.send) # wait for cs.cancel to be called before closing streams await anyio.sleep_forever() def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) elif message["type"] == "websocket.http.response.start": status_code: int = message["status"] headers: list[tuple[bytes, bytes]] = message["headers"] body: list[bytes] = [] while True: message = self.receive() assert message["type"] == "websocket.http.response.body" body.append(message["body"]) if not message.get("more_body", False): break raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) def send(self, message: Message) -> None: self.portal.call(self._receive_tx.send, message) def send_text(self, data: str) -> None: self.send({"type": "websocket.receive", "text": data}) def send_bytes(self, data: bytes) -> None: self.send({"type": "websocket.receive", "bytes": data}) def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None: text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) if mode == "text": self.send({"type": "websocket.receive", "text": text}) else: self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) def close(self, code: int = 1000, reason: str | None = None) -> None: self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) def receive(self) -> Message: return self.portal.call(self._send_rx.receive) def receive_text(self) -> str: message = self.receive() self._raise_on_close(message) return typing.cast(str, message["text"]) def receive_bytes(self) -> bytes: message = self.receive() self._raise_on_close(message) return typing.cast(bytes, message["bytes"]) def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any: message = self.receive() self._raise_on_close(message) if mode == "text": text = message["text"] else: text = message["bytes"].decode("utf-8") return json.loads(text) class _TestClientTransport(httpx.BaseTransport): def __init__( self, app: ASGI3App, portal_factory: _PortalFactoryType, raise_server_exceptions: bool = True, root_path: str = "", *, client: tuple[str, int], app_state: dict[str, typing.Any], ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path self.portal_factory = portal_factory self.app_state = app_state self.client = client def handle_request(self, request: httpx.Request) -> httpx.Response: scheme = request.url.scheme netloc = request.url.netloc.decode(encoding="ascii") path = request.url.path raw_path = request.url.raw_path query = request.url.query.decode(encoding="ascii") default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] if ":" in netloc: host, port_string = netloc.split(":", 1) port = int(port_string) else: host = netloc port = default_port # Include the 'host' header. if "host" in request.headers: headers: list[tuple[bytes, bytes]] = [] elif port == default_port: # pragma: no cover headers = [(b"host", host.encode())] else: # pragma: no cover headers = [(b"host", (f"{host}:{port}").encode())] # Include other request headers. headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] scope: dict[str, typing.Any] if scheme in {"ws", "wss"}: subprotocol = request.headers.get("sec-websocket-protocol", None) if subprotocol is None: subprotocols: typing.Sequence[str] = [] else: subprotocols = [value.strip() for value in subprotocol.split(",")] scope = { "type": "websocket", "path": unquote(path), "raw_path": raw_path.split(b"?", 1)[0], "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), "headers": headers, "client": self.client, "server": [host, port], "subprotocols": subprotocols, "state": self.app_state.copy(), "extensions": {"websocket.http.response": {}}, } session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) scope = { "type": "http", "http_version": "1.1", "method": request.method, "path": unquote(path), "raw_path": raw_path.split(b"?", 1)[0], "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), "headers": headers, "client": self.client, "server": [host, port], "extensions": {"http.response.debug": {}}, "state": self.app_state.copy(), } request_complete = False response_started = False response_complete: anyio.Event raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()} template = None context = None async def receive() -> Message: nonlocal request_complete if request_complete: if not response_complete.is_set(): await response_complete.wait() return {"type": "http.disconnect"} body = request.read() if isinstance(body, str): body_bytes: bytes = body.encode("utf-8") # pragma: no cover elif body is None: body_bytes = b"" # pragma: no cover elif isinstance(body, GeneratorType): try: # pragma: no cover chunk = body.send(None) if isinstance(chunk, str): chunk = chunk.encode("utf-8") return {"type": "http.request", "body": chunk, "more_body": True} except StopIteration: # pragma: no cover request_complete = True return {"type": "http.request", "body": b""} else: body_bytes = body request_complete = True return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": assert not response_started, 'Received multiple "http.response.start" messages.' raw_kwargs["status_code"] = message["status"] raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] response_started = True elif message["type"] == "http.response.body": assert response_started, 'Received "http.response.body" without "http.response.start".' assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) if request.method != "HEAD": raw_kwargs["stream"].write(body) if not more_body: raw_kwargs["stream"].seek(0) response_complete.set() elif message["type"] == "http.response.debug": template = message["info"]["template"] context = message["info"]["context"] try: with self.portal_factory() as portal: response_complete = portal.call(anyio.Event) portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: raise exc if self.raise_server_exceptions: assert response_started, "TestClient did not receive any response." elif not response_started: raw_kwargs = { "status_code": 500, "headers": [], "stream": io.BytesIO(), } raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) response = httpx.Response(**raw_kwargs, request=request) if template is not None: response.template = template # type: ignore[attr-defined] response.context = context # type: ignore[attr-defined] return response class TestClient(httpx.Client): __test__ = False task: Future[None] portal: anyio.abc.BlockingPortal | None = None def __init__( self, app: ASGIApp, base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", backend: typing.Literal["asyncio", "trio"] = "asyncio", backend_options: dict[str, typing.Any] | None = None, cookies: httpx._types.CookieTypes | None = None, headers: dict[str, str] | None = None, follow_redirects: bool = True, client: tuple[str, int] = ("testclient", 50000), ) -> None: self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) if _is_asgi3(app): asgi_app = app else: app = typing.cast(ASGI2App, app) # type: ignore[assignment] asgi_app = _WrapASGI2(app) # type: ignore[arg-type] self.app = asgi_app self.app_state: dict[str, typing.Any] = {} transport = _TestClientTransport( self.app, portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, app_state=self.app_state, client=client, ) if headers is None: headers = {} headers.setdefault("user-agent", "testclient") super().__init__( base_url=base_url, headers=headers, transport=transport, follow_redirects=follow_redirects, cookies=cookies, ) @contextlib.contextmanager def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: if self.portal is not None: yield self.portal else: with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: yield portal def request( # type: ignore[override] self, method: str, url: httpx._types.URLTypes, *, content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, json: typing.Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: if timeout is not httpx.USE_CLIENT_DEFAULT: warnings.warn( "You should not use the 'timeout' argument with the TestClient. " "See https://github.com/encode/starlette/issues/1108 for more information.", DeprecationWarning, ) url = self._merge_url(url) return super().request( method, url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def get( # type: ignore[override] self, url: httpx._types.URLTypes, *, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: return super().get( url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def options( # type: ignore[override] self, url: httpx._types.URLTypes, *, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: return super().options( url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def head( # type: ignore[override] self, url: httpx._types.URLTypes, *, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: return super().head( url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def post( # type: ignore[override] self, url: httpx._types.URLTypes, *, content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, json: typing.Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: return super().post( url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def put( # type: ignore[override] self, url: httpx._types.URLTypes, *, content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, json: typing.Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: return super().put( url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def patch( # type: ignore[override] self, url: httpx._types.URLTypes, *, content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, json: typing.Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: return super().patch( url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def delete( # type: ignore[override] self, url: httpx._types.URLTypes, *, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: return super().delete( url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def websocket_connect( self, url: str, subprotocols: typing.Sequence[str] | None = None, **kwargs: typing.Any, ) -> WebSocketTestSession: url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) headers.setdefault("connection", "upgrade") headers.setdefault("sec-websocket-key", "testserver==") headers.setdefault("sec-websocket-version", "13") if subprotocols is not None: headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) kwargs["headers"] = headers try: super().request("GET", url, **kwargs) except _Upgrade as exc: session = exc.session else: raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover return session def __enter__(self) -> TestClient: with contextlib.ExitStack() as stack: self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) @stack.callback def reset_portal() -> None: self.portal = None send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = ( anyio.create_memory_object_stream(math.inf) ) receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = ( anyio.create_memory_object_stream(math.inf) ) for channel in (*send, *receive): stack.callback(channel.close) self.stream_send = StapledObjectStream(*send) self.stream_receive = StapledObjectStream(*receive) self.task = portal.start_task_soon(self.lifespan) portal.call(self.wait_startup) @stack.callback def wait_shutdown() -> None: portal.call(self.wait_shutdown) self.exit_stack = stack.pop_all() return self def __exit__(self, *args: typing.Any) -> None: self.exit_stack.close() async def lifespan(self) -> None: scope = {"type": "lifespan", "state": self.app_state} try: await self.app(scope, self.stream_receive.receive, self.stream_send.send) finally: await self.stream_send.send(None) async def wait_startup(self) -> None: await self.stream_receive.send({"type": "lifespan.startup"}) async def receive() -> typing.Any: message = await self.stream_send.receive() if message is None: self.task.result() return message message = await receive() assert message["type"] in ( "lifespan.startup.complete", "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": await receive() async def wait_shutdown(self) -> None: async def receive() -> typing.Any: message = await self.stream_send.receive() if message is None: self.task.result() return message await self.stream_receive.send({"type": "lifespan.shutdown"}) message = await receive() assert message["type"] in ( "lifespan.shutdown.complete", "lifespan.shutdown.failed", ) if message["type"] == "lifespan.shutdown.failed": await receive()