diff --git a/pyproject.toml b/pyproject.toml index 99b2c38..be7865e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ requires-python = ">=3.7" test = [ "pytest", "coverage[toml]", + "fastapi>=0.111", + "httpx", ] mypy = [ "mypy>=1.4.1", diff --git a/src/extendable_pydantic/_patch.py b/src/extendable_pydantic/_patch.py index e32cc17..d274eb5 100644 --- a/src/extendable_pydantic/_patch.py +++ b/src/extendable_pydantic/_patch.py @@ -24,7 +24,7 @@ def _resolve_model_fields_annotation(model_fields): registry = context.extendable_registry.get() - if registry: + if registry and registry.ready: for field in model_fields: field_info = field.field_info new_type = resolve_annotation(field_info.annotation) @@ -77,3 +77,18 @@ def _create_response_field_wrapper(wrapped, instance, args, kwargs): wrapt.wrap_function_wrapper( utils, "create_model_field", _create_response_field_wrapper ) + + +@wrapt.when_imported("fastapi.dependencies.utils") +def hook_fastapi_dependencies_utils(utils): + def _analyze_param_wrapper(wrapped, instance, args, kwargs): + registry = context.extendable_registry.get() + if registry and registry.ready: + annotation = kwargs.get("annotation") + if annotation: + new_type = resolve_annotation(annotation) + if not all_identical(annotation, new_type): + kwargs["annotation"] = new_type + return wrapped(*args, **kwargs) + + wrapt.wrap_function_wrapper(utils, "analyze_param", _analyze_param_wrapper) diff --git a/tests/conftest.py b/tests/conftest.py index 7e34e0b..71551b2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,13 @@ +from extendable_pydantic import _patch # noqa: F401 import pytest import sys from extendable import context, main, registry +from fastapi import FastAPI, APIRouter +from fastapi.testclient import TestClient +from typing import Annotated + +from fastapi import Depends +from extendable_pydantic import ExtendableBaseModel skip_not_supported_version_for_generics = pytest.mark.skipif( @@ -21,3 +28,55 @@ def test_registry() -> registry.ExtendableClassesRegistry: finally: main._extendable_class_defs_by_module = initial_class_defs context.extendable_registry.reset(token) + + +@pytest.fixture +def test_fastapi(test_registry) -> TestClient: + app = FastAPI() + my_router = APIRouter() + + class TestRequest(ExtendableBaseModel): + name: str = "rqst" + + def get_type(self) -> str: + return "request" + + class TestResponse(ExtendableBaseModel): + name: str = "resp" + + def get_type(self) -> str: + return "response" + + @my_router.get("/") + def get() -> TestResponse: + """Get method.""" + resp = TestResponse(name="World") + assert hasattr(resp, "id") + return resp + + @my_router.post("/") + def post(rqst: TestRequest) -> TestResponse: + """Post method.""" + resp = TestResponse(**rqst.model_dump()) + assert hasattr(resp, "id") + return resp + + @my_router.get("/extended") + def get_with_params(rqst: Annotated[TestRequest, Depends()]) -> TestResponse: + """Get method with parameters.""" + resp = TestResponse(**rqst.model_dump()) + assert hasattr(resp, "id") + return resp + + class ExtendedTestRequest(TestRequest, extends=TestRequest): + id: int = 1 + + class ExtendedTestResponse(TestResponse, extends=TestResponse): + id: int = 2 + + test_registry.init_registry() + + app.include_router(my_router) + + with TestClient(app) as client: + yield client diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py new file mode 100644 index 0000000..0762c25 --- /dev/null +++ b/tests/test_fastapi.py @@ -0,0 +1,57 @@ +"""Test fastapi integration.""" + + +def test_open_api_schema(test_fastapi): + client = test_fastapi + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + schema = response.json() + rqst_schema = schema["components"]["schemas"]["TestRequest"] + assert rqst_schema["properties"] == { + "name": {"title": "Name", "type": "string", "default": "rqst"}, + "id": {"title": "Id", "type": "integer", "default": 1}, + } + resp_schema = schema["components"]["schemas"]["TestResponse"] + assert resp_schema["properties"] == { + "name": {"title": "Name", "type": "string", "default": "resp"}, + "id": {"title": "Id", "type": "integer", "default": 2}, + } + + extended_get_params = schema["paths"]["/extended"]["get"]["parameters"] + assert len(extended_get_params) == 2 + assert extended_get_params[0] == { + "in": "query", + "name": "name", + "required": False, + "schema": {"title": "Name", "type": "string", "default": "rqst"}, + } + assert extended_get_params[1] == { + "in": "query", + "name": "id", + "required": False, + "schema": {"title": "Id", "type": "integer", "default": 1}, + } + + +def test_extended_response(test_fastapi): + """Test extended pydantic model as response.""" + client = test_fastapi + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"name": "World", "id": 2} + + +def test_extended_request(test_fastapi): + """Test extended pydantic model as json request.""" + client = test_fastapi + response = client.post("/", json={"name": "Hello", "id": 3}) + assert response.status_code == 200 + assert response.json() == {"name": "Hello", "id": 3} + + +def test_extended_request_with_params(test_fastapi): + """Test extended pydantic model as request with parameters.""" + client = test_fastapi + response = client.get("/extended", params={"name": "echo", "id": 3}) + assert response.status_code == 200 + assert response.json() == {"name": "echo", "id": 3}