Skip to content

Commit

Permalink
Adding caching
Browse files Browse the repository at this point in the history
  • Loading branch information
ssadhu-sl committed Sep 24, 2024
1 parent 7d0decc commit ed84ef8
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 76 deletions.
113 changes: 94 additions & 19 deletions application/data_access/entity_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,31 @@
)
from application.db.models import EntityOrm, OldEntityOrm
from application.search.enum import GeometryRelation, PeriodOption
import redis
import json

redis_client = redis.Redis(host="127.0.0.1", port=6379, db=0)
logger = logging.getLogger(__name__)


def get_cached_query_result(key):
result = redis_client.get(key)

if result:
result_data = json.loads(result)
return result_data

return None


def cache_query_result(key, result):
# Serialize the result as JSON
result_json = json.dumps(
result, default=lambda o: o.dict() if hasattr(o, "dict") else str(o)
)
redis_client.set(key, result_json, ex=3600) # Cache with 1-hour expiration for now


def get_entity_query(
session: Session,
id: int,
Expand Down Expand Up @@ -61,18 +82,61 @@ def get_entities(session, dataset: str, limit: int) -> List[EntityModel]:


def get_entity_search(session: Session, parameters: dict):
params = normalised_params(parameters)
count: int
entities: list[EntityModel]
# get count
subquery = session.query(EntityOrm.entity)
subquery = _apply_base_filters(subquery, params)
subquery = _apply_date_filters(subquery, params)
subquery = _apply_location_filters(session, subquery, params)
subquery = _apply_period_option_filter(subquery, params).subquery()
count_query = session.query(func.count()).select_from(subquery)
count = count_query.scalar()
# get entities
orignal_params = normalised_params(parameters)
query_key_orignal = str(orignal_params)
cached_result = get_cached_query_result(query_key_orignal)
if cached_result:
return cached_result
datasets = orignal_params.get("dataset", []) # Assume this is a list of datasets
params = orignal_params
combined_count = 0
combined_entities = []
if not datasets:
# get count
combined_count = count_entity_method(session, params)

# get entities
combined_entities = fetch_all_entities(session, params)

for dataset in datasets:
# Update params to work with the current dataset
params["dataset"] = dataset

query_key = str(params)
cached_result = get_cached_query_result(query_key)
if cached_result:
combined_count += cached_result["count"]
combined_entities.extend(cached_result["entities"])
continue # Skip query if the result is cached

count: int
entities: list[EntityModel]

# get count
count = count_entity_method(session, params)

# get entities
entities = fetch_all_entities(session, params)

# Combine results
combined_count += count
combined_entities.extend(entities)

# Cache individual dataset result
result = {"params": params, "count": count, "entities": entities}
cache_query_result(query_key, result)

# Return the combined result
final_result = {
"params": orignal_params,
"count": combined_count,
"entities": combined_entities,
}
cache_query_result(query_key_orignal, final_result)
return final_result


def fetch_all_entities(session, params):
query_args = [EntityOrm]
query = session.query(*query_args)
query = _apply_base_filters(query, params)
Expand All @@ -82,7 +146,18 @@ def get_entity_search(session: Session, parameters: dict):
query = _apply_limit_and_pagination_filters(query, params)
entities = query.all()
entities = [entity_factory(entity_orm) for entity_orm in entities]
return {"params": params, "count": count, "entities": entities}
return entities


def count_entity_method(session, params):
subquery = session.query(EntityOrm.entity)
subquery = _apply_base_filters(subquery, params)
subquery = _apply_date_filters(subquery, params)
subquery = _apply_location_filters(session, subquery, params)
subquery = _apply_period_option_filter(subquery, params).subquery()
count_query = session.query(func.count()).select_from(subquery)
count = count_query.scalar()
return count


def lookup_entity_link(
Expand Down Expand Up @@ -176,14 +251,14 @@ def _apply_location_filters(session, query, params):

for geometry in params.get("geometry", []):
simplified_geom = func.ST_GeomFromText(geometry, 4326)
bbox_filter_geometry = func.ST_Envelope(simplified_geom).op("&&")(
EntityOrm.geometry
)
bbox_filter_point = func.ST_Envelope(simplified_geom).op("&&")(
EntityOrm.geometry
)

if params.get("geometry_relation") == GeometryRelation.intersects.name:
bbox_filter_geometry = func.ST_Envelope(simplified_geom).op("&&")(
EntityOrm.geometry
)
bbox_filter_point = func.ST_Envelope(simplified_geom).op("&&")(
EntityOrm.point
)
clauses.append(
or_(
and_(
Expand Down
31 changes: 22 additions & 9 deletions application/routers/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,34 @@ def _get_geojson(


def _get_entity_json(
data: List[EntityModel],
include: Optional[Set] = None,
data: List[Union[EntityModel, dict]], # Allow both EntityModel and dict
include: Optional[Set[str]] = None,
exclude: Optional[List[str]] = None,
):
entities = []
for entity in data:
if include is not None:
# always return at least the entity (id)
include.add("entity")
e = entity.dict(include=include, by_alias=True)

if isinstance(entity, dict):
# Handle the entity as a dict
if include is not None:
include.add("entity")
e = {key: value for key, value in entity.items() if key in include}
else:
exclude = set(exclude) if exclude else set()
exclude.add("geojson") # Always exclude 'geojson'
e = {key: value for key, value in entity.items() if key not in exclude}
else:
exclude = set(exclude) if exclude else set()
exclude.add("geojson") # Always exclude 'geojson'
e = entity.dict(exclude=exclude, by_alias=True)
if include is not None:
# Always return at least the entity (id)
include.add("entity")
e = entity.dict(include=include, by_alias=True)
else:
exclude = set(exclude) if exclude else set()
exclude.add("geojson") # Always exclude 'geojson'
e = entity.dict(exclude=exclude, by_alias=True)

entities.append(e)

return entities


Expand Down
33 changes: 20 additions & 13 deletions tests/acceptance/parameters/test_exclude_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
[["organisation-entity"], ["entry-date", "prefix"]],
)
def test_entity_search_exclude_field_for_json_response(
client,
db_session,
test_data: dict,
exclude: list,
client, db_session, test_data: dict, exclude: list, mocker
):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
# Clear the table to avoid duplicates
db_session.query(EntityOrm).delete()
db_session.commit()
Expand Down Expand Up @@ -58,11 +59,12 @@ def test_entity_search_exclude_field_for_json_response(
],
)
def test_entity_search_no_exclude_field_for_json_response(
client,
db_session,
test_data: dict,
expected_fields: set,
client, db_session, test_data: dict, expected_fields: set, mocker
):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
db_session.query(EntityOrm).delete()
db_session.commit()
# Load entities into the database
Expand Down Expand Up @@ -90,8 +92,12 @@ def test_entity_search_no_exclude_field_for_json_response(

@pytest.mark.parametrize("exclude", [["notes"], ["geometry", "notes"]])
def test_entity_search_exclude_field_for_geojson_response(
client, db_session, test_data: dict, exclude: list
client, db_session, test_data: dict, exclude: list, mocker
):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
db_session.query(EntityOrm).delete()
db_session.commit()
# Load entities into the database
Expand Down Expand Up @@ -138,11 +144,12 @@ def test_entity_search_exclude_field_for_geojson_response(
],
)
def test_entity_search_no_exclude_field_for_geojson_response(
client,
db_session,
test_data: dict,
expected_properties: set,
client, db_session, test_data: dict, expected_properties: set, mocker
):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
db_session.query(EntityOrm).delete()
db_session.commit()
# Load entities into the database
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/data_access/test_entity_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from application.db.models import EntityOrm
from application.data_access.entity_queries import get_entity_search


# set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,19 +74,22 @@
],
)
def test_get_entity_search_geometry_reference_queries_returns_correct_results(
entities, parameters, expected_count, expected_entities, db_session
entities, parameters, expected_count, expected_entities, db_session, mocker
):
"""
A test to check if the correct results are returned when using the geometry_reference parameter
"""
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
# load data points into entity table
# add datasets
for entity in entities:
db_session.add(EntityOrm(**entity))

# run query and get results
results = get_entity_search(db_session, parameters)

# assert count
assert results["count"] == expected_count, results

Expand Down
26 changes: 22 additions & 4 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def _transform_dataset_to_response(dataset, is_geojson=False):
return dataset


def test_app_returns_valid_geojson_list(client):
def test_app_returns_valid_geojson_list(client, mocker):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
response = client.get("/entity.geojson", headers={"Origin": "localhost"})
data = response.json()
assert "type" in data
Expand All @@ -51,7 +55,11 @@ def test_app_returns_valid_geojson_list(client):
assert [] == data["features"]


def test_app_returns_valid_populated_geojson_list(client, test_data):
def test_app_returns_valid_populated_geojson_list(client, test_data, mocker):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
response = client.get("/entity.geojson", headers={"Origin": "localhost"})
data = response.json()
assert "type" in data
Expand All @@ -68,7 +76,11 @@ def test_app_returns_valid_populated_geojson_list(client, test_data):
) == len(data["features"])


def test_lasso_geo_search_finds_results(client, test_data):
def test_lasso_geo_search_finds_results(client, test_data, mocker):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
params = {
"geometry_relation": "intersects",
"geometry": intersects_with_greenspace_entity,
Expand Down Expand Up @@ -188,7 +200,13 @@ def test_dataset_json_endpoint_returns_as_expected(test_data, client):


@pytest.mark.parametrize("point, expected_status_code", wkt_params)
def test_api_handles_invalid_wkt(point, expected_status_code, client, test_data):
def test_api_handles_invalid_wkt(
point, expected_status_code, client, test_data, mocker
):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
params = {"geometry_relation": "intersects", "geometry": point}
response = client.get("/entity.geojson", params=params)
assert response.status_code == expected_status_code
Expand Down
14 changes: 12 additions & 2 deletions tests/integration/test_entity_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,26 @@
from application.db.models import EntityOrm


def test__lookup_entity_link_returns_nothing_when_the_entity_isnt_found(db_session):
def test__lookup_entity_link_returns_nothing_when_the_entity_isnt_found(
db_session, mocker
):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
linked_entity = lookup_entity_link(
db_session, "a-reference", "article-4-direction", 123
)
assert linked_entity is None


def test__lookup_entity_link_returns_the_looked_up_entity_when_the_link_exists(
db_session,
db_session, mocker
):
mocker.patch(
"application.data_access.entity_queries.get_cached_query_result",
return_value=None,
)
lookup_entity = {
"entity": 106,
"name": "A space",
Expand Down
Loading

0 comments on commit ed84ef8

Please sign in to comment.