diff --git a/CHANGELOG.md b/CHANGELOG.md index cf57b10..b2088ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Fixed +- Custom HTTP transport are now handled (parent call to `handle_async_request` or `handle_request`). + +### Changed +- Only HTTP transport are now mocked, this should not have any impact, however if it does, please feel free to open an issue describing your use case. ## [0.26.0] - 2023-09-18 ### Added diff --git a/pytest_httpx/__init__.py b/pytest_httpx/__init__.py index b579fa9..3fe8071 100644 --- a/pytest_httpx/__init__.py +++ b/pytest_httpx/__init__.py @@ -5,11 +5,7 @@ import pytest from pytest import MonkeyPatch -from pytest_httpx._httpx_mock import ( - HTTPXMock, - _PytestSyncTransport, - _PytestAsyncTransport, -) +from pytest_httpx._httpx_mock import HTTPXMock from pytest_httpx._httpx_internals import IteratorStream from pytest_httpx.version import __version__ @@ -45,22 +41,36 @@ def httpx_mock( mock = HTTPXMock() # Mock synchronous requests - real_sync_transport = httpx.Client._transport_for_url + real_handle_request = httpx.HTTPTransport.handle_request + + def mocked_handle_request( + transport: httpx.HTTPTransport, request: httpx.Request + ) -> httpx.Response: + if request.url.host in non_mocked_hosts: + return real_handle_request(transport, request) + return mock._handle_request(transport, request) + monkeypatch.setattr( - httpx.Client, - "_transport_for_url", - lambda self, url: real_sync_transport(self, url) - if url.host in non_mocked_hosts - else _PytestSyncTransport(real_sync_transport(self, url), mock), + httpx.HTTPTransport, + "handle_request", + mocked_handle_request, ) + # Mock asynchronous requests - real_async_transport = httpx.AsyncClient._transport_for_url + real_handle_async_request = httpx.AsyncHTTPTransport.handle_async_request + + async def mocked_handle_async_request( + transport: httpx.AsyncHTTPTransport, request: httpx.Request + ) -> httpx.Response: + if request.url.host in non_mocked_hosts: + return await real_handle_async_request(transport, request) + return await mock._handle_async_request(transport, request) + monkeypatch.setattr( - httpx.AsyncClient, - "_transport_for_url", - lambda self, url: real_async_transport(self, url) - if url.host in non_mocked_hosts - else _PytestAsyncTransport(real_async_transport(self, url), mock), + httpx.AsyncHTTPTransport, + "handle_async_request", + mocked_handle_async_request, ) + yield mock mock.reset(assert_all_responses_were_requested) diff --git a/pytest_httpx/_httpx_internals.py b/pytest_httpx/_httpx_internals.py index e12b933..da21797 100644 --- a/pytest_httpx/_httpx_internals.py +++ b/pytest_httpx/_httpx_internals.py @@ -61,12 +61,9 @@ def _to_httpx_url(url: httpcore.URL, headers: list[tuple[bytes, bytes]]) -> http def _proxy_url( - real_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport] + real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport] ) -> Optional[httpx.URL]: - if isinstance(real_transport, httpx.HTTPTransport): - if isinstance(real_pool := real_transport._pool, httpcore.HTTPProxy): - return _to_httpx_url(real_pool._proxy_url, real_pool._proxy_headers) - - if isinstance(real_transport, httpx.AsyncHTTPTransport): - if isinstance(real_pool := real_transport._pool, httpcore.AsyncHTTPProxy): - return _to_httpx_url(real_pool._proxy_url, real_pool._proxy_headers) + if isinstance( + real_pool := real_transport._pool, (httpcore.HTTPProxy, httpcore.AsyncHTTPProxy) + ): + return _to_httpx_url(real_pool._proxy_url, real_pool._proxy_headers) diff --git a/pytest_httpx/_httpx_mock.py b/pytest_httpx/_httpx_mock.py index 425f420..941b050 100644 --- a/pytest_httpx/_httpx_mock.py +++ b/pytest_httpx/_httpx_mock.py @@ -12,7 +12,7 @@ class HTTPXMock: def __init__(self) -> None: self._requests: list[ - tuple[Union[httpx.BaseTransport, httpx.AsyncBaseTransport], httpx.Request] + tuple[Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport], httpx.Request] ] = [] self._callbacks: list[ tuple[ @@ -123,7 +123,7 @@ def exception_callback(request: httpx.Request) -> None: def _handle_request( self, - real_transport: httpx.BaseTransport, + real_transport: httpx.HTTPTransport, request: httpx.Request, ) -> httpx.Response: self._requests.append((real_transport, request)) @@ -142,7 +142,7 @@ def _handle_request( async def _handle_async_request( self, - real_transport: httpx.AsyncBaseTransport, + real_transport: httpx.AsyncHTTPTransport, request: httpx.Request, ) -> httpx.Response: self._requests.append((real_transport, request)) @@ -178,7 +178,7 @@ def _explain_that_no_response_was_found( def _get_callback( self, - real_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport], + real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport], request: httpx.Request, ) -> Optional[ Callable[ @@ -266,24 +266,6 @@ def _reset_callbacks(self) -> list[_RequestMatcher]: return callbacks_not_executed -class _PytestSyncTransport(httpx.BaseTransport): - def __init__(self, real_transport: httpx.BaseTransport, mock: HTTPXMock): - self._real_transport = real_transport - self._mock = mock - - def handle_request(self, request: httpx.Request) -> httpx.Response: - return self._mock._handle_request(self._real_transport, request) - - -class _PytestAsyncTransport(httpx.AsyncBaseTransport): - def __init__(self, real_transport: httpx.AsyncBaseTransport, mock: HTTPXMock): - self._real_transport = real_transport - self._mock = mock - - async def handle_async_request(self, request: httpx.Request) -> httpx.Response: - return await self._mock._handle_async_request(self._real_transport, request) - - def _unread(response: httpx.Response) -> httpx.Response: # Allow to read the response on client side response.is_stream_consumed = False diff --git a/pytest_httpx/_request_matcher.py b/pytest_httpx/_request_matcher.py index e80df4b..1bb590c 100644 --- a/pytest_httpx/_request_matcher.py +++ b/pytest_httpx/_request_matcher.py @@ -52,7 +52,7 @@ def __init__( def match( self, - real_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport], + real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport], request: httpx.Request, ) -> bool: return ( @@ -106,7 +106,7 @@ def _content_match(self, request: httpx.Request) -> bool: return False def _proxy_match( - self, real_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport] + self, real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport] ) -> bool: if not self.proxy_url: return True diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index e9ad034..54eb82a 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -2015,3 +2015,26 @@ async def test_streams_are_not_cascading_resulting_in_maximum_recursion( tasks = [client.get("https://example.com/") for _ in range(950)] await asyncio.gather(*tasks) # No need to assert anything, this test case ensure that no error was raised by the gather + + +@pytest.mark.asyncio +async def test_custom_transport(httpx_mock: HTTPXMock) -> None: + class CustomTransport(httpx.AsyncHTTPTransport): + def __init__(self, prefix: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prefix = prefix + + async def handle_async_request( + self, + request: httpx.Request, + ) -> httpx.Response: + httpx_response = await super().handle_async_request(request) + httpx_response.headers["x-prefix"] = self.prefix + return httpx_response + + httpx_mock.add_response() + + async with httpx.AsyncClient(transport=CustomTransport(prefix="test")) as client: + response = await client.post("https://test_url", content=b"This is the body") + assert response.read() == b"" + assert response.headers["x-prefix"] == "test" diff --git a/tests/test_httpx_sync.py b/tests/test_httpx_sync.py index 00431c5..e1e4938 100644 --- a/tests/test_httpx_sync.py +++ b/tests/test_httpx_sync.py @@ -1,5 +1,4 @@ import re -from typing import Any from unittest.mock import ANY import httpx @@ -1706,3 +1705,25 @@ def test_mutating_json(httpx_mock: HTTPXMock) -> None: response = client.get("https://test_url") assert response.json() == {"content": "request 2"} + + +def test_custom_transport(httpx_mock: HTTPXMock) -> None: + class CustomTransport(httpx.HTTPTransport): + def __init__(self, prefix: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prefix = prefix + + def handle_request( + self, + request: httpx.Request, + ) -> httpx.Response: + httpx_response = super().handle_request(request) + httpx_response.headers["x-prefix"] = self.prefix + return httpx_response + + httpx_mock.add_response() + + with httpx.Client(transport=CustomTransport(prefix="test")) as client: + response = client.post("https://test_url", content=b"This is the body") + assert response.read() == b"" + assert response.headers["x-prefix"] == "test"