Skip to content

Commit

Permalink
[REFACTOR] argilla-server: review server startup (#5263)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR reviews the server startup in two different ways:

1. Run the simpler setup functions out of the startup event listeners.
2. Use the
[lifespan](https://fastapi.tiangolo.com/advanced/events/#lifespan)
function instead of deprecated startup/shutdown listeners

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- Refactor (change restructuring the codebase without changing
functionality)
- Improvement (change adding some improvement to an existing
functionality)
- Documentation update

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Francisco Aranda <[email protected]>
  • Loading branch information
frascuchon and frascuchon authored Jul 18, 2024
1 parent d857834 commit 934c08b
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 118 deletions.
163 changes: 69 additions & 94 deletions argilla-server/src/argilla_server/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import backoff
from brotli_asgi import BrotliMiddleware
from fastapi import FastAPI, Request
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse

Expand All @@ -37,13 +38,20 @@
from argilla_server.logging import configure_logging
from argilla_server.models import User
from argilla_server.search_engine import get_search_engine
from argilla_server.security import auth
from argilla_server.settings import settings
from argilla_server.static_rewrite import RewriteStaticFiles

_LOGGER = logging.getLogger("argilla")


@contextlib.asynccontextmanager
async def app_lifespan(app: FastAPI):
# See https://fastapi.tiangolo.com/advanced/events/#lifespan
await configure_database()
await configure_search_engine()
yield


def create_server_app() -> FastAPI:
"""Configure the argilla server"""

Expand All @@ -54,29 +62,17 @@ def create_server_app() -> FastAPI:
redoc_url=None,
redirect_slashes=False,
version=str(argilla_version),
lifespan=app_lifespan,
)

@app.get("/docs", include_in_schema=False)
async def redirect_docs():
return RedirectResponse(url=f"{settings.base_url}api/v1/docs")
configure_logging()
configure_telemetry()
configure_middleware(app)
configure_api_router(app)
configure_app_statics(app)
configure_api_docs(app)

@app.get("/api", include_in_schema=False)
async def redirect_api():
return RedirectResponse(url=f"{settings.base_url}api/v1/docs")

for app_configure in [
configure_app_logging,
configure_database,
configure_search_engine,
configure_telemetry,
configure_middleware,
configure_app_security,
configure_api_router,
configure_app_statics,
]:
app_configure(app)

# This if-else clause is needed to simplify the test dependencies setup. Otherwise we cannot override dependencies
# This if-else clause is needed to simplify the test dependency setup. Otherwise, we cannot override dependencies
# easily. We can review this once we have separate fastapi application for the api and the webapp.
if settings.base_url and settings.base_url != "/":
_app = FastAPI(
Expand All @@ -88,6 +84,16 @@ async def redirect_api():
return app


def configure_api_docs(app: FastAPI):
@app.get("/docs", include_in_schema=False)
async def redirect_docs():
return RedirectResponse(url=f"{settings.base_url}api/v1/docs")

@app.get("/api", include_in_schema=False)
async def redirect_api():
return RedirectResponse(url=f"{settings.base_url}api/v1/docs")


def configure_middleware(app: FastAPI):
"""Configures fastapi middleware"""

Expand Down Expand Up @@ -161,94 +167,63 @@ def _create_statics_folder(path_from):
)


def configure_search_engine(app: FastAPI):
@app.on_event("startup")
async def configure_elasticsearch():
if not settings.search_engine_is_elasticsearch:
return
def configure_telemetry():
message = "\n"
message += inspect.cleandoc(
"Argilla uses telemetry to report anonymous usage and error information. You\n"
"can know more about what information is reported at:\n\n"
" https://docs.argilla.io/en/latest/reference/telemetry.html\n\n"
"Telemetry is currently enabled. If you want to disable it, you can configure\n"
"the environment variable before relaunching the server:\n\n"
f'{"#set ARGILLA_ENABLE_TELEMETRY=0" if os.name == "nt" else "$>export ARGILLA_ENABLE_TELEMETRY=0"}'
)

if settings.enable_telemetry:
_LOGGER.warning(message)


async def configure_database():
async def check_default_user(db: AsyncSession):
def _user_has_default_credentials(user: User):
return user.api_key == DEFAULT_API_KEY or accounts.verify_password(
DEFAULT_PASSWORD, user.password_hash
)

default_user = await accounts.get_user_by_username(db, DEFAULT_USERNAME)
if default_user and _user_has_default_credentials(default_user):
_LOGGER.warning(
f"User {DEFAULT_USERNAME!r} with default credentials has been found in the database. "
"If you are using argilla in a production environment this can be a serious security problem. "
f"We recommend that you create a new admin user and then delete the default {DEFAULT_USERNAME!r} one."
)

async with contextlib.asynccontextmanager(get_async_db)() as db:
await check_default_user(db)


async def configure_search_engine():
if settings.search_engine_is_elasticsearch:
# TODO: Move this to the search engine implementation module
logging.getLogger("elasticsearch").setLevel(logging.ERROR)
logging.getLogger("elastic_transport").setLevel(logging.ERROR)

@app.on_event("startup")
async def configure_opensearch():
if not settings.search_engine_is_opensearch:
return

elif settings.search_engine_is_opensearch:
# TODO: Move this to the search engine implementation module
logging.getLogger("opensearch").setLevel(logging.ERROR)
logging.getLogger("opensearch_transport").setLevel(logging.ERROR)

@app.on_event("startup")
@backoff.on_exception(backoff.expo, ConnectionError, max_time=60)
async def ping_search_engine():
async for search_engine in get_search_engine():
if not await search_engine.ping():
raise ConnectionError(
f"Your {settings.search_engine} endpoint at {settings.obfuscated_elasticsearch()} is not available or not responding.\n"
f"Your {settings.search_engine} is not available or not responding.\n"
f"Please make sure your {settings.search_engine} instance is launched and correctly running and\n"
"you have the necessary access permissions. Once you have verified this, restart the argilla server.\n"
"you have the necessary access permissions. Once you have verified this, restart "
"the argilla server.\n"
)


def configure_app_security(app: FastAPI):
auth.configure_app(app)


def configure_app_logging(app: FastAPI):
"""Configure app logging using"""
app.on_event("startup")(configure_logging)


def configure_telemetry(app: FastAPI):
message = "\n"
message += inspect.cleandoc(
"""
Argilla uses telemetry to report anonymous usage and error information.
You can know more about what information is reported at:
https://docs.argilla.io/en/latest/reference/telemetry.html
Telemetry is currently enabled. If you want to disable it, you can configure
the environment variable before relaunching the server:
"""
)
message += "\n\n "
message += (
"#set ARGILLA_ENABLE_TELEMETRY=0"
if os.name == "nt"
else "$>export ARGILLA_ENABLE_TELEMETRY=0"
)
message += "\n"

@app.on_event("startup")
async def check_telemetry():
if settings.enable_telemetry:
print(message, flush=True)


_get_db_wrapper = contextlib.asynccontextmanager(get_async_db)


def configure_database(app: FastAPI):
def _user_has_default_credentials(user: User):
return user.api_key == DEFAULT_API_KEY or accounts.verify_password(
DEFAULT_PASSWORD, user.password_hash
)

def _log_default_user_warning():
_LOGGER.warning(
f"User {DEFAULT_USERNAME!r} with default credentials has been found in the database. "
"If you are using argilla in a production environment this can be a serious security problem. "
f"We recommend that you create a new admin user and then delete the default {DEFAULT_USERNAME!r} one."
)

@app.on_event("startup")
async def log_default_user_warning_if_present():
async with _get_db_wrapper() as db:
default_user = await accounts.get_user_by_username(db, DEFAULT_USERNAME)
if default_user and _user_has_default_credentials(default_user):
_log_default_user_warning()
await ping_search_engine()


app = create_server_app()
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import ClassVar, List, Optional

from fastapi import Depends, FastAPI
from fastapi import Depends
from fastapi.security import SecurityScopes
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.authentication import AuthenticationBackend
Expand Down Expand Up @@ -43,9 +43,6 @@ class AuthenticationProvider:
def new_instance(cls):
return AuthenticationProvider()

def configure_app(self, app: FastAPI) -> None:
pass

async def get_current_user(
self,
security_scopes: SecurityScopes, # noqa
Expand Down
9 changes: 0 additions & 9 deletions argilla-server/src/argilla_server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import warnings
from pathlib import Path
from typing import Dict, List, Optional
from urllib.parse import urlparse

from argilla_server.constants import (
DATABASE_SQLITE,
Expand Down Expand Up @@ -286,14 +285,6 @@ def search_engine_is_elasticsearch(self) -> bool:
def search_engine_is_opensearch(self) -> bool:
return self.search_engine == SEARCH_ENGINE_OPENSEARCH

def obfuscated_elasticsearch(self) -> str:
"""Returns configured elasticsearch url obfuscating the provided password, if any"""
parsed = urlparse(self.elasticsearch)
if parsed.password:
return self.elasticsearch.replace(parsed.password, "XXXX")

return self.elasticsearch

class Config:
env_prefix = "ARGILLA_"

Expand Down
24 changes: 13 additions & 11 deletions argilla-server/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import uuid
from typing import TYPE_CHECKING, Dict, Generator

import pytest
import pytest_asyncio
from httpx import AsyncClient
from opensearchpy import OpenSearch

from argilla_server import telemetry
from argilla_server.api.routes import api_v1
from argilla_server.constants import API_KEY_HEADER_NAME, DEFAULT_API_KEY
Expand All @@ -26,9 +28,6 @@
from argilla_server.search_engine import SearchEngine, get_search_engine
from argilla_server.settings import settings
from argilla_server.telemetry import TelemetryClient
from httpx import AsyncClient
from opensearchpy import OpenSearch

from tests.database import TestSession
from tests.factories import AnnotatorFactory, OwnerFactory, UserFactory

Expand Down Expand Up @@ -89,14 +88,17 @@ async def override_get_async_db():
async def override_get_search_engine():
yield mock_search_engine

mocker.patch(
"argilla_server._app._get_db_wrapper",
wraps=contextlib.asynccontextmanager(override_get_async_db),
)
# TODO: Once the db and search engine are wrapped in high-level dependencies, this code should works.
# Commented for now.
# mocker.patch("argilla_server.database.get_async_db", wraps=override_get_async_db)
# mocker.patch("argilla_server.search_engine.get_search_engine", wraps=override_get_search_engine)

for api in [api_v1]:
api.dependency_overrides[get_async_db] = override_get_async_db
api.dependency_overrides[get_search_engine] = override_get_search_engine
api_v1.dependency_overrides.update(
{
get_async_db: override_get_async_db,
get_search_engine: override_get_search_engine,
}
)

async with AsyncClient(app=app, base_url="http://testserver") as async_client:
yield async_client
Expand Down

0 comments on commit 934c08b

Please sign in to comment.