diff --git a/arpav_ppcv/database.py b/arpav_ppcv/database.py index dea9a3d6..7e92f7c4 100644 --- a/arpav_ppcv/database.py +++ b/arpav_ppcv/database.py @@ -143,7 +143,10 @@ def create_station( """Create a new station.""" geom = shapely.io.from_geojson(station_create.geom.model_dump_json()) wkbelement = from_shape(geom) - db_station = observations.Station(geom=wkbelement, code=station_create.code) + db_station = observations.Station( + **station_create.model_dump(exclude={"geom"}), + geom=wkbelement, + ) session.add(db_station) try: session.commit() @@ -205,8 +208,10 @@ def update_station( station_update: observations.StationUpdate, ) -> observations.Station: """Update a station.""" - data_ = station_update.model_dump(exclude_unset=True) - for key, value in data_.items(): + geom = from_shape(shapely.io.from_geojson(station_update.geom.model_dump_json())) + other_data = station_update.model_dump(exclude={"geom"}, exclude_unset=True) + data = {**other_data, "geom": geom} + for key, value in data.items(): setattr(db_station, key, value) session.add(db_station) session.commit() diff --git a/arpav_ppcv/webapp/admin/app.py b/arpav_ppcv/webapp/admin/app.py index b6e48532..f627a982 100644 --- a/arpav_ppcv/webapp/admin/app.py +++ b/arpav_ppcv/webapp/admin/app.py @@ -5,18 +5,25 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.exceptions import HTTPException from starlette_admin.contrib.sqlmodel import Admin -from starlette_admin.views import Link +from starlette_admin.views import ( + DropDown, + Link, +) from ... import ( config, database, ) -from ...schemas import coverages -from . import ( - auth, - views, +from ...schemas import ( + coverages, + observations, ) +from . import auth from .middlewares import SqlModelDbSessionMiddleware +from .views import ( + coverages as coverage_views, + observations as observations_views, +) logger = logging.getLogger(__name__) @@ -51,8 +58,31 @@ def create_admin(settings: config.ArpavPpcvSettings) -> ArpavPpcvAdmin: Middleware(SqlModelDbSessionMiddleware, engine=engine), ], ) - admin.add_view(views.ConfigurationParameterView(coverages.ConfigurationParameter)) - admin.add_view(views.CoverageConfigurationView(coverages.CoverageConfiguration)) + admin.add_view( + coverage_views.ConfigurationParameterView(coverages.ConfigurationParameter) + ) + admin.add_view( + coverage_views.CoverageConfigurationView(coverages.CoverageConfiguration) + ) + admin.add_view(observations_views.VariableView(observations.Variable)) + admin.add_view(observations_views.StationView(observations.Station)) + admin.add_view( + DropDown( + "Measurements", + icon="fa-solid fa-vials", + views=[ + observations_views.MonthlyMeasurementView( + observations.MonthlyMeasurement + ), + observations_views.SeasonalMeasurementView( + observations.SeasonalMeasurement + ), + observations_views.YearlyMeasurementView( + observations.YearlyMeasurement + ), + ], + ) + ) admin.add_view( Link( "V2 API docs", diff --git a/arpav_ppcv/webapp/admin/schemas.py b/arpav_ppcv/webapp/admin/schemas.py index 0c0bceb3..4b0be258 100644 --- a/arpav_ppcv/webapp/admin/schemas.py +++ b/arpav_ppcv/webapp/admin/schemas.py @@ -1,9 +1,13 @@ +import datetime as dt from typing import Optional import uuid import sqlmodel -from ...schemas.base import ObservationAggregationType +from ...schemas.base import ( + ObservationAggregationType, + Season, +) class ConfigurationParameterValueRead(sqlmodel.SQLModel): @@ -59,3 +63,44 @@ class ObservationVariableRead(sqlmodel.SQLModel): class CoverageConfigurationReadListItem(sqlmodel.SQLModel): id: uuid.UUID name: str + + +class VariableRead(sqlmodel.SQLModel): + id: uuid.UUID + name: str + description: str + unit: Optional[str] + + +class StationRead(sqlmodel.SQLModel): + id: uuid.UUID + name: str + code: str + type: str + longitude: float + latitude: float + active_since: Optional[dt.date] + active_until: Optional[dt.date] + altitude_m: Optional[float] + + +class MonthlyMeasurementRead(sqlmodel.SQLModel): + station: str + variable: str + date: dt.date + value: float + + +class SeasonalMeasurementRead(sqlmodel.SQLModel): + station: str + variable: str + year: int + season: Season + value: float + + +class YearlyMeasurementRead(sqlmodel.SQLModel): + station: str + variable: str + year: int + value: float diff --git a/arpav_ppcv/webapp/admin/views/__init__.py b/arpav_ppcv/webapp/admin/views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/arpav_ppcv/webapp/admin/views.py b/arpav_ppcv/webapp/admin/views/coverages.py similarity index 98% rename from arpav_ppcv/webapp/admin/views.py rename to arpav_ppcv/webapp/admin/views/coverages.py index 4770bbe6..53dc61c8 100644 --- a/arpav_ppcv/webapp/admin/views.py +++ b/arpav_ppcv/webapp/admin/views/coverages.py @@ -1,4 +1,4 @@ -"""Views for the admin app. +"""Views for the admin app's coverages. The classes contained in this module are derived from starlette_admin.contrib.sqlmodel.ModelView. This is done mostly for two reasons: @@ -23,12 +23,12 @@ from starlette.requests import Request from starlette_admin.contrib.sqlmodel import ModelView -from ... import database -from ...schemas import ( +from .... import database +from ....schemas import ( coverages, base, ) -from . import ( +from .. import ( fields, schemas as read_schemas, ) @@ -116,6 +116,10 @@ class ConfigurationParameterView(ModelView): ), ) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.icon = "fa-solid fa-quote-left" + async def get_pk_value(self, request: Request, obj: Any) -> Any: # note: we need to cast the value, which is a uuid.UUID, to a string # because starlette_admin just assumes that the value of a model's @@ -313,6 +317,10 @@ class CoverageConfigurationView(ModelView): ) exclude_fields_from_edit = ("coverage_id_pattern",) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.icon = "fa-solid fa-map" + async def get_pk_value(self, request: Request, obj: Any) -> Any: # note: we need to cast the value, which is a uuid.UUID, to a string # because starlette_admin just assumes that the value of a model's diff --git a/arpav_ppcv/webapp/admin/views/observations.py b/arpav_ppcv/webapp/admin/views/observations.py new file mode 100644 index 00000000..572765b7 --- /dev/null +++ b/arpav_ppcv/webapp/admin/views/observations.py @@ -0,0 +1,443 @@ +import functools +import logging +from typing import ( + Any, + Optional, + Sequence, + Union, +) + +import anyio +from geoalchemy2.shape import from_shape +import geojson_pydantic +import shapely.io +import starlette_admin +from starlette.requests import Request +from starlette_admin.contrib.sqlmodel import ModelView +from starlette_admin.exceptions import FormValidationError + +from .... import database as db +from ....schemas import ( + base, + observations, +) +from .. import fields +from .. import schemas as read_schemas + +logger = logging.getLogger(__name__) + + +class MonthlyMeasurementView(ModelView): + identity = "monthly measurements" + name = "Monthly Measurements" + label = "Monthly Measurements" + pk_attr = "id" + + fields = ( + starlette_admin.StringField("station", required=True), + starlette_admin.StringField("variable", required=True), + starlette_admin.DateField("date", required=True), + starlette_admin.FloatField("value", required=True), + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.icon = "fa-regular fa-calendar-days" + + def can_create(self, request: Request) -> bool: + return False + + def can_edit(self, request: Request) -> bool: + return False + + def can_view_details(self, request: Request) -> bool: + return False + + @staticmethod + def _serialize_instance( + instance: observations.MonthlyMeasurement, + ) -> read_schemas.MonthlyMeasurementRead: + return read_schemas.MonthlyMeasurementRead( + **instance.model_dump(), + station=instance.station.code, + variable=instance.variable.name, + ) + + async def find_all( + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, + ) -> Sequence[read_schemas.MonthlyMeasurementRead]: + list_measurements = functools.partial( + db.list_monthly_measurements, + limit=limit, + offset=skip, + include_total=False, + ) + db_measurements, _ = await anyio.to_thread.run_sync( + list_measurements, request.state.session + ) + return [self._serialize_instance(item) for item in db_measurements] + + +class SeasonalMeasurementView(ModelView): + identity = "seasonal measurements" + name = "Seasonal Measurements" + label = "Seasonal Measurements" + icon = "fa fa-blog" + pk_attr = "id" + + fields = ( + starlette_admin.StringField("station", required=True), + starlette_admin.StringField("variable", required=True), + starlette_admin.IntegerField("year", required=True), + starlette_admin.EnumField("season", enum=base.Season, required=True), + starlette_admin.FloatField("value", required=True), + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.icon = "fa-regular fa-calendar-days" + + def can_create(self, request: Request) -> bool: + return False + + def can_edit(self, request: Request) -> bool: + return False + + def can_view_details(self, request: Request) -> bool: + return False + + @staticmethod + def _serialize_instance( + instance: observations.SeasonalMeasurement, + ) -> read_schemas.SeasonalMeasurementRead: + return read_schemas.SeasonalMeasurementRead( + **instance.model_dump(), + station=instance.station.code, + variable=instance.variable.name, + ) + + async def find_all( + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, + ) -> Sequence[read_schemas.SeasonalMeasurementRead]: + list_measurements = functools.partial( + db.list_seasonal_measurements, + limit=limit, + offset=skip, + include_total=False, + ) + db_measurements, _ = await anyio.to_thread.run_sync( + list_measurements, request.state.session + ) + return [self._serialize_instance(item) for item in db_measurements] + + +class YearlyMeasurementView(ModelView): + identity = "yearly measurements" + name = "Yearly Measurements" + label = "Yearly Measurements" + icon = "fa fa-blog" + pk_attr = "id" + + fields = ( + starlette_admin.StringField("station", required=True), + starlette_admin.StringField("variable", required=True), + starlette_admin.IntegerField("year", required=True), + starlette_admin.FloatField("value", required=True), + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.icon = "fa-regular fa-calendar-days" + + def can_create(self, request: Request) -> bool: + return False + + def can_edit(self, request: Request) -> bool: + return False + + def can_view_details(self, request: Request) -> bool: + return False + + @staticmethod + def _serialize_instance( + instance: observations.YearlyMeasurement, + ) -> read_schemas.YearlyMeasurementRead: + return read_schemas.YearlyMeasurementRead( + **instance.model_dump(), + station=instance.station.code, + variable=instance.variable.name, + ) + + async def find_all( + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, + ) -> Sequence[read_schemas.YearlyMeasurementRead]: + list_measurements = functools.partial( + db.list_yearly_measurements, + limit=limit, + offset=skip, + include_total=False, + ) + db_measurements, _ = await anyio.to_thread.run_sync( + list_measurements, request.state.session + ) + return [self._serialize_instance(item) for item in db_measurements] + + +class VariableView(ModelView): + identity = "variables" + name = "Variable" + label = "Variables" + icon = "fa fa-blog" + pk_attr = "id" + + exclude_fields_from_list = ("id",) + exclude_fields_from_detail = ("id",) + + fields = ( + fields.UuidField("id"), + starlette_admin.StringField("name", required=True), + starlette_admin.StringField("description", required=True), + starlette_admin.StringField("unit"), + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.icon = "fa-solid fa-cloud-sun-rain" + + @staticmethod + def _serialize_instance( + instance: observations.Variable, + ) -> read_schemas.VariableRead: + return read_schemas.VariableRead(**instance.model_dump()) + + async def get_pk_value(self, request: Request, obj: Any) -> str: + # note: we need to cast the value, which is a uuid.UUID, to a string + # because starlette_admin just assumes that the value of a model's + # pk attribute is always JSON serializable so it doesn't bother with + # calling the respective field's `serialize_value()` method + result = await super().get_pk_value(request, obj) + return str(result) + + async def create( + self, request: Request, data: dict[str, Any] + ) -> Optional[read_schemas.VariableRead]: + try: + data = await self._arrange_data(request, data) + await self.validate(request, data) + var_create = observations.VariableCreate(**data) + db_variable = await anyio.to_thread.run_sync( + db.create_variable, + request.state.session, + var_create, + ) + return self._serialize_instance(db_variable) + except Exception as e: + return self.handle_exception(e) + + async def edit( + self, request: Request, pk: Any, data: dict[str, Any] + ) -> Optional[read_schemas.VariableRead]: + try: + data = await self._arrange_data(request, data, True) + await self.validate(request, data) + var_update = observations.VariableUpdate(**data) + db_var = await anyio.to_thread.run_sync( + db.get_variable, request.state.session, pk + ) + db_var = await anyio.to_thread.run_sync( + db.update_variable, request.state.session, db_var, var_update + ) + return self._serialize_instance(db_var) + except Exception as e: + logger.exception("something went wrong") + self.handle_exception(e) + + async def find_by_pk(self, request: Request, pk: Any) -> read_schemas.VariableRead: + db_var = await anyio.to_thread.run_sync( + db.get_variable, request.state.session, pk + ) + return self._serialize_instance(db_var) + + async def find_all( + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, + ) -> Sequence[read_schemas.VariableRead]: + list_variables = functools.partial( + db.list_variables, + limit=limit, + offset=skip, + include_total=False, + ) + db_vars, _ = await anyio.to_thread.run_sync( + list_variables, request.state.session + ) + return [self._serialize_instance(db_var) for db_var in db_vars] + + +class StationView(ModelView): + identity = "stations" + name = "Station" + label = "Stations" + icon = "fa fa-blog" + pk_attr = "id" + + exclude_fields_from_list = ("id",) + exclude_fields_from_detail = ("id",) + + fields = ( + fields.UuidField("id"), + starlette_admin.StringField("name", required=True), + starlette_admin.StringField("code", required=True), + starlette_admin.StringField("type", required=True), + starlette_admin.FloatField("longitude", required=True), + starlette_admin.FloatField("latitude", required=True), + starlette_admin.DateField("active_since"), + starlette_admin.DateField("active_until"), + starlette_admin.FloatField("altitude_m"), + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.icon = "fa-solid fa-tower-observation" + + @staticmethod + def _serialize_instance(instance: observations.Station) -> read_schemas.StationRead: + geom = shapely.io.from_wkb(bytes(instance.geom.data)) + return read_schemas.StationRead( + **instance.model_dump(exclude={"geom", "type_"}), + type=instance.type_, + longitude=geom.x, + latitude=geom.y, + ) + + async def get_pk_value(self, request: Request, obj: Any) -> str: + # note: we need to cast the value, which is a uuid.UUID, to a string + # because starlette_admin just assumes that the value of a model's + # pk attribute is always JSON serializable so it doesn't bother with + # calling the respective field's `serialize_value()` method + result = await super().get_pk_value(request, obj) + return str(result) + + async def validate(self, request: Request, data: dict[str, Any]) -> None: + """Validate data without file fields relation fields""" + errors: dict[str, str] = {} + if (lat := data["latitude"]) < -90 or lat > 90: + errors["latitude"] = "Invalid value" + if (lon := data["longitude"]) < -180 or lon > 180: + errors["longitude"] = "Invalid longitude" + if len(errors) > 0: + raise FormValidationError(errors) + else: + data_to_validate = data.copy() + data_to_validate["geom"] = from_shape(shapely.Point(lon, lat)) + del data_to_validate["longitude"] + del data_to_validate["latitude"] + fields_to_exclude = [ + f.name + for f in self.get_fields_list(request, request.state.action) + if isinstance( + f, (starlette_admin.FileField, starlette_admin.RelationField) + ) + ] + ["latitude", "longitude"] + self.model.validate( + { + k: v + for k, v in data_to_validate.items() + if k not in fields_to_exclude + } + ) + + async def create( + self, request: Request, data: dict[str, Any] + ) -> Optional[read_schemas.StationRead]: + try: + data = await self._arrange_data(request, data) + await self.validate(request, data) + geojson_geom = geojson_pydantic.Point( + type="Point", coordinates=(data.pop("longitude"), data.pop("latitude")) + ) + station_create = observations.StationCreate( + type_=data.pop("type"), + geom=geojson_geom, + **data, + ) + db_station = await anyio.to_thread.run_sync( + db.create_station, + request.state.session, + station_create, + ) + return self._serialize_instance(db_station) + except Exception as e: + logger.exception("could not create") + return self.handle_exception(e) + + async def edit( + self, request: Request, pk: Any, data: dict[str, Any] + ) -> Optional[read_schemas.StationRead]: + try: + data = await self._arrange_data(request, data, True) + await self.validate(request, data) + lon = data.pop("longitude", None) + lat = data.pop("latitude", None) + kwargs = {} + if all((lon, lat)): + kwargs["geom"] = geojson_pydantic.Point( + type="Point", coordinates=(lon, lat) + ) + if (type_ := data.pop("type", None)) is not None: + kwargs["type_"] = type_ + station_update = observations.StationUpdate(**data, **kwargs) + db_station = await anyio.to_thread.run_sync( + db.get_station, request.state.session, pk + ) + db_station = await anyio.to_thread.run_sync( + db.update_station, request.state.session, db_station, station_update + ) + return self._serialize_instance(db_station) + except Exception as e: + logger.exception("something went wrong") + self.handle_exception(e) + + async def find_by_pk(self, request: Request, pk: Any) -> read_schemas.StationRead: + db_station = await anyio.to_thread.run_sync( + db.get_station, request.state.session, pk + ) + return self._serialize_instance(db_station) + + async def find_all( + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, + ) -> Sequence[read_schemas.StationRead]: + list_stations = functools.partial( + db.list_stations, + limit=limit, + offset=skip, + include_total=False, + ) + db_stations, _ = await anyio.to_thread.run_sync( + list_stations, request.state.session + ) + return [self._serialize_instance(db_station) for db_station in db_stations] diff --git a/tests/notebooks/generic.ipynb b/tests/notebooks/generic.ipynb index bd99633b..5e5f4b4a 100644 --- a/tests/notebooks/generic.ipynb +++ b/tests/notebooks/generic.ipynb @@ -52,21 +52,21 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "b09d371f-2f1f-44b4-966a-e33808e0d9fb", + "execution_count": 8, + "id": "dc76a07a-af67-43bf-b66f-c6c7310416d9", "metadata": {}, "source": [ - "station.seasonal_variables" + "station.model_dump()" ], "outputs": [] }, { "cell_type": "code", - "execution_count": 7, - "id": "4eb9b279-9f7a-4031-a9c3-ccc15d32ec09", + "execution_count": 9, + "id": "2753c89b-2aa2-40ea-aa64-a30fc3b7f359", "metadata": {}, "source": [ - "db.collect_station_variables(session, station, ObservationAggregationType.YEARLY)" + "float(\"b\")" ], "outputs": [] }