diff --git a/application/data_access/entity_queries.py b/application/data_access/entity_queries.py index 9c150565..bad2e9ab 100644 --- a/application/data_access/entity_queries.py +++ b/application/data_access/entity_queries.py @@ -63,25 +63,27 @@ def get_entities(session, dataset: str, limit: int) -> List[EntityModel]: def get_entity_search(session: Session, parameters: dict): params = normalised_params(parameters) + count: int entities: list[EntityModel] - # Create a single base query - base_query = session.query(EntityOrm, func.count().over().label("total_count")) - - base_query = _apply_base_filters(base_query, params) - base_query = _apply_date_filters(base_query, params) - base_query = _apply_location_filters(session, base_query, params) - base_query = _apply_period_option_filter(base_query, params) - base_query = _apply_limit_and_pagination_filters(base_query, params) - - result = base_query.all() - if result: - count = result[0].total_count - else: - count = 0 - - entities = [entity_factory(row.EntityOrm) for row in result] - + # get count + subquery = session.query(EntityOrm.entity) + subquery = _apply_base_filters(subquery, params) + subquery = _apply_date_filters(subquery, params) + subquery = _apply_location_filters(session, subquery, params) + subquery = _apply_period_option_filter(subquery, params).subquery() + count_query = session.query(func.count()).select_from(subquery) + count = count_query.scalar() + # get entities + query_args = [EntityOrm] + query = session.query(*query_args) + query = _apply_base_filters(query, params) + query = _apply_date_filters(query, params) + query = _apply_location_filters(session, query, params) + query = _apply_period_option_filter(query, params) + query = _apply_limit_and_pagination_filters(query, params) + entities = query.all() + entities = [entity_factory(entity_orm) for entity_orm in entities] return {"params": params, "count": count, "entities": entities} @@ -156,11 +158,15 @@ def _apply_date_filters(query, params): def _apply_location_filters(session, query, params): point = get_point(params) + + geometry_is_valid = func.ST_IsValid(EntityOrm.geometry) + point_is_valid = func.ST_IsValid(EntityOrm.point) + if point is not None: query = query.filter( and_( EntityOrm.geometry.is_not(None), - func.ST_IsValid(EntityOrm.geometry), + geometry_is_valid, func.ST_Contains(EntityOrm.geometry, func.ST_GeomFromText(point, 4326)), ) ) @@ -171,24 +177,28 @@ def _apply_location_filters(session, query, params): clauses = [] for geometry in params.get("geometry", []): + simplified_geom = func.ST_Envelope(func.ST_GeomFromText(geometry, 4326)) clauses.append( or_( and_( EntityOrm.geometry.is_not(None), - func.ST_IsValid(EntityOrm.geometry), - spatial_function( - EntityOrm.geometry, func.ST_GeomFromText(geometry, 4326) - ), + geometry_is_valid, + EntityOrm.geometry.op("&&")( + simplified_geom + ), # Using && operator to trigger index + spatial_function(EntityOrm.geometry, simplified_geom), ), and_( EntityOrm.point.is_not(None), - func.ST_IsValid(EntityOrm.point), - spatial_function( - EntityOrm.point, func.ST_GeomFromText(geometry, 4326) - ), + point_is_valid, + EntityOrm.point.op("&&")( + simplified_geom + ), # Using && operator to trigger index + spatial_function(EntityOrm.point, simplified_geom), ), ) ) + if clauses: query = query.filter(or_(*clauses)) @@ -206,7 +216,7 @@ def _apply_location_filters(session, query, params): or_( and_( EntityOrm.geometry.is_not(None), - func.ST_IsValid(EntityOrm.geometry), + geometry_is_valid, func.ST_IsValid(intersecting_entities_query.c.geometry), spatial_function( EntityOrm.geometry, @@ -262,7 +272,7 @@ def _apply_location_filters(session, query, params): or_( and_( EntityOrm.geometry.is_not(None), - func.ST_IsValid(EntityOrm.geometry), + geometry_is_valid, func.ST_IsValid(curie_query.c.geometry), spatial_function(EntityOrm.geometry, curie_query.c.geometry), ),