Skip to content

Commit

Permalink
Merge pull request #124 from Colin-b/handle_custom_transport
Browse files Browse the repository at this point in the history
Handle custom transport
  • Loading branch information
Colin-b authored Nov 13, 2023
2 parents c686de6 + 7ac04bc commit cce8116
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 50 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Fixed
- Custom HTTP transport are now handled (parent call to `handle_async_request` or `handle_request`).

### Changed
- Only HTTP transport are now mocked, this should not have any impact, however if it does, please feel free to open an issue describing your use case.

## [0.26.0] - 2023-09-18
### Added
Expand Down
44 changes: 27 additions & 17 deletions pytest_httpx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
import pytest
from pytest import MonkeyPatch

from pytest_httpx._httpx_mock import (
HTTPXMock,
_PytestSyncTransport,
_PytestAsyncTransport,
)
from pytest_httpx._httpx_mock import HTTPXMock
from pytest_httpx._httpx_internals import IteratorStream
from pytest_httpx.version import __version__

Expand Down Expand Up @@ -45,22 +41,36 @@ def httpx_mock(
mock = HTTPXMock()

# Mock synchronous requests
real_sync_transport = httpx.Client._transport_for_url
real_handle_request = httpx.HTTPTransport.handle_request

def mocked_handle_request(
transport: httpx.HTTPTransport, request: httpx.Request
) -> httpx.Response:
if request.url.host in non_mocked_hosts:
return real_handle_request(transport, request)
return mock._handle_request(transport, request)

monkeypatch.setattr(
httpx.Client,
"_transport_for_url",
lambda self, url: real_sync_transport(self, url)
if url.host in non_mocked_hosts
else _PytestSyncTransport(real_sync_transport(self, url), mock),
httpx.HTTPTransport,
"handle_request",
mocked_handle_request,
)

# Mock asynchronous requests
real_async_transport = httpx.AsyncClient._transport_for_url
real_handle_async_request = httpx.AsyncHTTPTransport.handle_async_request

async def mocked_handle_async_request(
transport: httpx.AsyncHTTPTransport, request: httpx.Request
) -> httpx.Response:
if request.url.host in non_mocked_hosts:
return await real_handle_async_request(transport, request)
return await mock._handle_async_request(transport, request)

monkeypatch.setattr(
httpx.AsyncClient,
"_transport_for_url",
lambda self, url: real_async_transport(self, url)
if url.host in non_mocked_hosts
else _PytestAsyncTransport(real_async_transport(self, url), mock),
httpx.AsyncHTTPTransport,
"handle_async_request",
mocked_handle_async_request,
)

yield mock
mock.reset(assert_all_responses_were_requested)
13 changes: 5 additions & 8 deletions pytest_httpx/_httpx_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,9 @@ def _to_httpx_url(url: httpcore.URL, headers: list[tuple[bytes, bytes]]) -> http


def _proxy_url(
real_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport]
real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport]
) -> Optional[httpx.URL]:
if isinstance(real_transport, httpx.HTTPTransport):
if isinstance(real_pool := real_transport._pool, httpcore.HTTPProxy):
return _to_httpx_url(real_pool._proxy_url, real_pool._proxy_headers)

if isinstance(real_transport, httpx.AsyncHTTPTransport):
if isinstance(real_pool := real_transport._pool, httpcore.AsyncHTTPProxy):
return _to_httpx_url(real_pool._proxy_url, real_pool._proxy_headers)
if isinstance(
real_pool := real_transport._pool, (httpcore.HTTPProxy, httpcore.AsyncHTTPProxy)
):
return _to_httpx_url(real_pool._proxy_url, real_pool._proxy_headers)
26 changes: 4 additions & 22 deletions pytest_httpx/_httpx_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class HTTPXMock:
def __init__(self) -> None:
self._requests: list[
tuple[Union[httpx.BaseTransport, httpx.AsyncBaseTransport], httpx.Request]
tuple[Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport], httpx.Request]
] = []
self._callbacks: list[
tuple[
Expand Down Expand Up @@ -123,7 +123,7 @@ def exception_callback(request: httpx.Request) -> None:

def _handle_request(
self,
real_transport: httpx.BaseTransport,
real_transport: httpx.HTTPTransport,
request: httpx.Request,
) -> httpx.Response:
self._requests.append((real_transport, request))
Expand All @@ -142,7 +142,7 @@ def _handle_request(

async def _handle_async_request(
self,
real_transport: httpx.AsyncBaseTransport,
real_transport: httpx.AsyncHTTPTransport,
request: httpx.Request,
) -> httpx.Response:
self._requests.append((real_transport, request))
Expand Down Expand Up @@ -178,7 +178,7 @@ def _explain_that_no_response_was_found(

def _get_callback(
self,
real_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport],
real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport],
request: httpx.Request,
) -> Optional[
Callable[
Expand Down Expand Up @@ -266,24 +266,6 @@ def _reset_callbacks(self) -> list[_RequestMatcher]:
return callbacks_not_executed


class _PytestSyncTransport(httpx.BaseTransport):
def __init__(self, real_transport: httpx.BaseTransport, mock: HTTPXMock):
self._real_transport = real_transport
self._mock = mock

def handle_request(self, request: httpx.Request) -> httpx.Response:
return self._mock._handle_request(self._real_transport, request)


class _PytestAsyncTransport(httpx.AsyncBaseTransport):
def __init__(self, real_transport: httpx.AsyncBaseTransport, mock: HTTPXMock):
self._real_transport = real_transport
self._mock = mock

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
return await self._mock._handle_async_request(self._real_transport, request)


def _unread(response: httpx.Response) -> httpx.Response:
# Allow to read the response on client side
response.is_stream_consumed = False
Expand Down
4 changes: 2 additions & 2 deletions pytest_httpx/_request_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(

def match(
self,
real_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport],
real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport],
request: httpx.Request,
) -> bool:
return (
Expand Down Expand Up @@ -106,7 +106,7 @@ def _content_match(self, request: httpx.Request) -> bool:
return False

def _proxy_match(
self, real_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport]
self, real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport]
) -> bool:
if not self.proxy_url:
return True
Expand Down
23 changes: 23 additions & 0 deletions tests/test_httpx_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,3 +2015,26 @@ async def test_streams_are_not_cascading_resulting_in_maximum_recursion(
tasks = [client.get("https://example.com/") for _ in range(950)]
await asyncio.gather(*tasks)
# No need to assert anything, this test case ensure that no error was raised by the gather


@pytest.mark.asyncio
async def test_custom_transport(httpx_mock: HTTPXMock) -> None:
class CustomTransport(httpx.AsyncHTTPTransport):
def __init__(self, prefix: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prefix = prefix

async def handle_async_request(
self,
request: httpx.Request,
) -> httpx.Response:
httpx_response = await super().handle_async_request(request)
httpx_response.headers["x-prefix"] = self.prefix
return httpx_response

httpx_mock.add_response()

async with httpx.AsyncClient(transport=CustomTransport(prefix="test")) as client:
response = await client.post("https://test_url", content=b"This is the body")
assert response.read() == b""
assert response.headers["x-prefix"] == "test"
23 changes: 22 additions & 1 deletion tests/test_httpx_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any
from unittest.mock import ANY

import httpx
Expand Down Expand Up @@ -1706,3 +1705,25 @@ def test_mutating_json(httpx_mock: HTTPXMock) -> None:

response = client.get("https://test_url")
assert response.json() == {"content": "request 2"}


def test_custom_transport(httpx_mock: HTTPXMock) -> None:
class CustomTransport(httpx.HTTPTransport):
def __init__(self, prefix: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prefix = prefix

def handle_request(
self,
request: httpx.Request,
) -> httpx.Response:
httpx_response = super().handle_request(request)
httpx_response.headers["x-prefix"] = self.prefix
return httpx_response

httpx_mock.add_response()

with httpx.Client(transport=CustomTransport(prefix="test")) as client:
response = client.post("https://test_url", content=b"This is the body")
assert response.read() == b""
assert response.headers["x-prefix"] == "test"

0 comments on commit cce8116

Please sign in to comment.