diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 8380134..e76d1b4 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,6 +1,9 @@ +from unittest.mock import MagicMock + import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session, sessionmaker from app.core.config import settings @@ -90,3 +93,26 @@ def create_user_model(post: dict): session.commit() post = session.query(Post).all() return post + + +@pytest.fixture +def mock_db_error(): + mock_session = MagicMock() + mock_session.execute.side_effect = SQLAlchemyError("Simulated database error") + return mock_session + + +@pytest.fixture +def client_with_db_error(mock_db_error): + def override_get_db(): + try: + yield mock_db_error + finally: + mock_db_error.close() + + app.dependency_overrides[get_db] = override_get_db + + try: + yield TestClient(app) + finally: + app.dependency_overrides.pop(get_db) diff --git a/app/tests/test_health.py b/app/tests/test_health.py index 54e7839..5be1604 100644 --- a/app/tests/test_health.py +++ b/app/tests/test_health.py @@ -14,3 +14,17 @@ def validate(data): print(data) assert res.status_code == 200 + + +def test_api_health_db_error(client_with_db_error: TestClient): + res = client_with_db_error.get("/health") + + def validate(data): + return APIStatus(**data) + + data = res.json() + validate(data) + print(data) + + assert res.status_code == 200 + assert data["db_status"] == "Unhealthy"