From 934c08b7e364355745c9a720a96dbe4f7af5bd3b Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Thu, 18 Jul 2024 16:44:53 +0200 Subject: [PATCH] [REFACTOR] `argilla-server`: review server startup (#5263) # Description This PR reviews the server startup in two different ways: 1. Run the simpler setup functions out of the startup event listeners. 2. Use the [lifespan](https://fastapi.tiangolo.com/advanced/events/#lifespan) function instead of deprecated startup/shutdown listeners **Type of change** - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Francisco Aranda --- argilla-server/src/argilla_server/_app.py | 163 ++++++++---------- .../security/authentication/provider.py | 5 +- argilla-server/src/argilla_server/settings.py | 9 - argilla-server/tests/unit/conftest.py | 24 +-- 4 files changed, 83 insertions(+), 118 deletions(-) diff --git a/argilla-server/src/argilla_server/_app.py b/argilla-server/src/argilla_server/_app.py index c230de2122..14ce6ead5e 100644 --- a/argilla-server/src/argilla_server/_app.py +++ b/argilla-server/src/argilla_server/_app.py @@ -25,6 +25,7 @@ import backoff from brotli_asgi import BrotliMiddleware from fastapi import FastAPI, Request +from sqlalchemy.ext.asyncio import AsyncSession from starlette.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse @@ -37,13 +38,20 @@ from argilla_server.logging import configure_logging from argilla_server.models import User from argilla_server.search_engine import get_search_engine -from argilla_server.security import auth from argilla_server.settings import settings from argilla_server.static_rewrite import RewriteStaticFiles _LOGGER = logging.getLogger("argilla") +@contextlib.asynccontextmanager +async def app_lifespan(app: FastAPI): + # See https://fastapi.tiangolo.com/advanced/events/#lifespan + await configure_database() + await configure_search_engine() + yield + + def create_server_app() -> FastAPI: """Configure the argilla server""" @@ -54,29 +62,17 @@ def create_server_app() -> FastAPI: redoc_url=None, redirect_slashes=False, version=str(argilla_version), + lifespan=app_lifespan, ) - @app.get("/docs", include_in_schema=False) - async def redirect_docs(): - return RedirectResponse(url=f"{settings.base_url}api/v1/docs") + configure_logging() + configure_telemetry() + configure_middleware(app) + configure_api_router(app) + configure_app_statics(app) + configure_api_docs(app) - @app.get("/api", include_in_schema=False) - async def redirect_api(): - return RedirectResponse(url=f"{settings.base_url}api/v1/docs") - - for app_configure in [ - configure_app_logging, - configure_database, - configure_search_engine, - configure_telemetry, - configure_middleware, - configure_app_security, - configure_api_router, - configure_app_statics, - ]: - app_configure(app) - - # This if-else clause is needed to simplify the test dependencies setup. Otherwise we cannot override dependencies + # This if-else clause is needed to simplify the test dependency setup. Otherwise, we cannot override dependencies # easily. We can review this once we have separate fastapi application for the api and the webapp. if settings.base_url and settings.base_url != "/": _app = FastAPI( @@ -88,6 +84,16 @@ async def redirect_api(): return app +def configure_api_docs(app: FastAPI): + @app.get("/docs", include_in_schema=False) + async def redirect_docs(): + return RedirectResponse(url=f"{settings.base_url}api/v1/docs") + + @app.get("/api", include_in_schema=False) + async def redirect_api(): + return RedirectResponse(url=f"{settings.base_url}api/v1/docs") + + def configure_middleware(app: FastAPI): """Configures fastapi middleware""" @@ -161,94 +167,63 @@ def _create_statics_folder(path_from): ) -def configure_search_engine(app: FastAPI): - @app.on_event("startup") - async def configure_elasticsearch(): - if not settings.search_engine_is_elasticsearch: - return +def configure_telemetry(): + message = "\n" + message += inspect.cleandoc( + "Argilla uses telemetry to report anonymous usage and error information. You\n" + "can know more about what information is reported at:\n\n" + " https://docs.argilla.io/en/latest/reference/telemetry.html\n\n" + "Telemetry is currently enabled. If you want to disable it, you can configure\n" + "the environment variable before relaunching the server:\n\n" + f'{"#set ARGILLA_ENABLE_TELEMETRY=0" if os.name == "nt" else "$>export ARGILLA_ENABLE_TELEMETRY=0"}' + ) + + if settings.enable_telemetry: + _LOGGER.warning(message) + + +async def configure_database(): + async def check_default_user(db: AsyncSession): + def _user_has_default_credentials(user: User): + return user.api_key == DEFAULT_API_KEY or accounts.verify_password( + DEFAULT_PASSWORD, user.password_hash + ) + + default_user = await accounts.get_user_by_username(db, DEFAULT_USERNAME) + if default_user and _user_has_default_credentials(default_user): + _LOGGER.warning( + f"User {DEFAULT_USERNAME!r} with default credentials has been found in the database. " + "If you are using argilla in a production environment this can be a serious security problem. " + f"We recommend that you create a new admin user and then delete the default {DEFAULT_USERNAME!r} one." + ) + async with contextlib.asynccontextmanager(get_async_db)() as db: + await check_default_user(db) + + +async def configure_search_engine(): + if settings.search_engine_is_elasticsearch: + # TODO: Move this to the search engine implementation module logging.getLogger("elasticsearch").setLevel(logging.ERROR) logging.getLogger("elastic_transport").setLevel(logging.ERROR) - @app.on_event("startup") - async def configure_opensearch(): - if not settings.search_engine_is_opensearch: - return - + elif settings.search_engine_is_opensearch: + # TODO: Move this to the search engine implementation module logging.getLogger("opensearch").setLevel(logging.ERROR) logging.getLogger("opensearch_transport").setLevel(logging.ERROR) - @app.on_event("startup") @backoff.on_exception(backoff.expo, ConnectionError, max_time=60) async def ping_search_engine(): async for search_engine in get_search_engine(): if not await search_engine.ping(): raise ConnectionError( - f"Your {settings.search_engine} endpoint at {settings.obfuscated_elasticsearch()} is not available or not responding.\n" + f"Your {settings.search_engine} is not available or not responding.\n" f"Please make sure your {settings.search_engine} instance is launched and correctly running and\n" - "you have the necessary access permissions. Once you have verified this, restart the argilla server.\n" + "you have the necessary access permissions. Once you have verified this, restart " + "the argilla server.\n" ) - -def configure_app_security(app: FastAPI): - auth.configure_app(app) - - -def configure_app_logging(app: FastAPI): - """Configure app logging using""" - app.on_event("startup")(configure_logging) - - -def configure_telemetry(app: FastAPI): - message = "\n" - message += inspect.cleandoc( - """ - Argilla uses telemetry to report anonymous usage and error information. - - You can know more about what information is reported at: - - https://docs.argilla.io/en/latest/reference/telemetry.html - - Telemetry is currently enabled. If you want to disable it, you can configure - the environment variable before relaunching the server: - """ - ) - message += "\n\n " - message += ( - "#set ARGILLA_ENABLE_TELEMETRY=0" - if os.name == "nt" - else "$>export ARGILLA_ENABLE_TELEMETRY=0" - ) - message += "\n" - - @app.on_event("startup") - async def check_telemetry(): - if settings.enable_telemetry: - print(message, flush=True) - - -_get_db_wrapper = contextlib.asynccontextmanager(get_async_db) - - -def configure_database(app: FastAPI): - def _user_has_default_credentials(user: User): - return user.api_key == DEFAULT_API_KEY or accounts.verify_password( - DEFAULT_PASSWORD, user.password_hash - ) - - def _log_default_user_warning(): - _LOGGER.warning( - f"User {DEFAULT_USERNAME!r} with default credentials has been found in the database. " - "If you are using argilla in a production environment this can be a serious security problem. " - f"We recommend that you create a new admin user and then delete the default {DEFAULT_USERNAME!r} one." - ) - - @app.on_event("startup") - async def log_default_user_warning_if_present(): - async with _get_db_wrapper() as db: - default_user = await accounts.get_user_by_username(db, DEFAULT_USERNAME) - if default_user and _user_has_default_credentials(default_user): - _log_default_user_warning() + await ping_search_engine() app = create_server_app() diff --git a/argilla-server/src/argilla_server/security/authentication/provider.py b/argilla-server/src/argilla_server/security/authentication/provider.py index 5fb1dea257..f6e80cc875 100644 --- a/argilla-server/src/argilla_server/security/authentication/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/provider.py @@ -14,7 +14,7 @@ from typing import ClassVar, List, Optional -from fastapi import Depends, FastAPI +from fastapi import Depends from fastapi.security import SecurityScopes from sqlalchemy.ext.asyncio import AsyncSession from starlette.authentication import AuthenticationBackend @@ -43,9 +43,6 @@ class AuthenticationProvider: def new_instance(cls): return AuthenticationProvider() - def configure_app(self, app: FastAPI) -> None: - pass - async def get_current_user( self, security_scopes: SecurityScopes, # noqa diff --git a/argilla-server/src/argilla_server/settings.py b/argilla-server/src/argilla_server/settings.py index 3a392b86e1..16d7456313 100644 --- a/argilla-server/src/argilla_server/settings.py +++ b/argilla-server/src/argilla_server/settings.py @@ -23,7 +23,6 @@ import warnings from pathlib import Path from typing import Dict, List, Optional -from urllib.parse import urlparse from argilla_server.constants import ( DATABASE_SQLITE, @@ -286,14 +285,6 @@ def search_engine_is_elasticsearch(self) -> bool: def search_engine_is_opensearch(self) -> bool: return self.search_engine == SEARCH_ENGINE_OPENSEARCH - def obfuscated_elasticsearch(self) -> str: - """Returns configured elasticsearch url obfuscating the provided password, if any""" - parsed = urlparse(self.elasticsearch) - if parsed.password: - return self.elasticsearch.replace(parsed.password, "XXXX") - - return self.elasticsearch - class Config: env_prefix = "ARGILLA_" diff --git a/argilla-server/tests/unit/conftest.py b/argilla-server/tests/unit/conftest.py index 08ec7ad34d..8b3cbd7926 100644 --- a/argilla-server/tests/unit/conftest.py +++ b/argilla-server/tests/unit/conftest.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import uuid from typing import TYPE_CHECKING, Dict, Generator import pytest import pytest_asyncio +from httpx import AsyncClient +from opensearchpy import OpenSearch + from argilla_server import telemetry from argilla_server.api.routes import api_v1 from argilla_server.constants import API_KEY_HEADER_NAME, DEFAULT_API_KEY @@ -26,9 +28,6 @@ from argilla_server.search_engine import SearchEngine, get_search_engine from argilla_server.settings import settings from argilla_server.telemetry import TelemetryClient -from httpx import AsyncClient -from opensearchpy import OpenSearch - from tests.database import TestSession from tests.factories import AnnotatorFactory, OwnerFactory, UserFactory @@ -89,14 +88,17 @@ async def override_get_async_db(): async def override_get_search_engine(): yield mock_search_engine - mocker.patch( - "argilla_server._app._get_db_wrapper", - wraps=contextlib.asynccontextmanager(override_get_async_db), - ) + # TODO: Once the db and search engine are wrapped in high-level dependencies, this code should works. + # Commented for now. + # mocker.patch("argilla_server.database.get_async_db", wraps=override_get_async_db) + # mocker.patch("argilla_server.search_engine.get_search_engine", wraps=override_get_search_engine) - for api in [api_v1]: - api.dependency_overrides[get_async_db] = override_get_async_db - api.dependency_overrides[get_search_engine] = override_get_search_engine + api_v1.dependency_overrides.update( + { + get_async_db: override_get_async_db, + get_search_engine: override_get_search_engine, + } + ) async with AsyncClient(app=app, base_url="http://testserver") as async_client: yield async_client