diff --git a/src/loader/aggregation_step.py b/src/loader/aggregation_step.py index 714fcca..e55c97f 100644 --- a/src/loader/aggregation_step.py +++ b/src/loader/aggregation_step.py @@ -15,7 +15,7 @@ import pandas from pandas import DataFrame -from common.const import LOGGING_FORMAT, CELL_COL +from common.const import LOGGING_FORMAT, CELL_COL, LONGITUDE_COL, LATITUDE_COL # Set up logging @@ -85,7 +85,24 @@ def run(self, in_df: DataFrame) -> DataFrame: with_agg = groups.agg(**self.agg_map).reset_index() - return with_agg + with_cell_ll = self._add_cell_centroid_lat_long(with_agg) + + return with_cell_ll + + def _add_cell_centroid_lat_long(self, in_df: DataFrame) -> DataFrame: + def cell_to_lat(row): + cell = row['h3_cell'] + lat, long = h3.h3_to_geo(cell) + return lat + + def cell_to_long(row): + cell = row['h3_cell'] + lat, long = h3.h3_to_geo(cell) + + return long + in_df[LATITUDE_COL] = in_df.apply(cell_to_lat, axis='columns') + in_df[LONGITUDE_COL] = in_df.apply(cell_to_long, axis='columns') + return in_df def _add_cell_column(self, in_df: DataFrame) -> DataFrame: def to_cell(row): diff --git a/test/aggregationstep/test_cell_aggreation_step.py b/test/aggregationstep/test_cell_aggreation_step.py index be3ce4c..5821b51 100644 --- a/test/aggregationstep/test_cell_aggreation_step.py +++ b/test/aggregationstep/test_cell_aggreation_step.py @@ -58,15 +58,14 @@ def test_cell_col_in_output(self, agg_df): assert "h3_cell" in out.columns - def test_lat_long_not_in_output(self, agg_df): + def test_lat_long_in_output(self, agg_df): agg_step = MinAggregation({}) all_agg = CellAggregationStep([agg_step], 1, ['value1', 'value2'], []) out = all_agg.run(agg_df) - assert len(out.columns) == 3 - assert "latitude" not in out.columns - assert "longitude" not in out.columns + assert "latitude" in out.columns + assert "longitude" in out.columns def test_output_name_format_multiple_built_in_agg(self, agg_df): agg_min = MinAggregation({}) @@ -79,6 +78,8 @@ def test_output_name_format_multiple_built_in_agg(self, agg_df): out = all_agg.run(agg_df) all_expected_cols = { "h3_cell", + "latitude", + "longitude", "value1_min", "value2_min", "value1_max", diff --git a/test/load_pipeline/test_loading_pipeline.py b/test/load_pipeline/test_loading_pipeline.py index af513b6..c08bcd5 100644 --- a/test/load_pipeline/test_loading_pipeline.py +++ b/test/load_pipeline/test_loading_pipeline.py @@ -5,6 +5,7 @@ from typing import List, Tuple, Dict, Set import duckdb +import h3 import pytest from pandas import DataFrame @@ -190,14 +191,28 @@ def test_read_out_aggregate(self, database_dir): pipeline.run() out = read_temp_db(dataset) + def round_latlong(t:Tuple) -> Tuple: + as_l = list(t) + as_l[5] = round(as_l[5], 4) + as_l[6] = round(as_l[6], 4) + return tuple(as_l) + + out = list(map( + round_latlong, + out + )) # the same as the raw data in the initial file + cell1_lat, cell1_long = h3.h3_to_geo('8110bffffffffff') + cell2_lat, cell2_long = h3.h3_to_geo('81defffffffffff') expected = { - ('8110bffffffffff', 0, 10, 0, 100), - ('81defffffffffff', 0, 10, 0, 100), + ('8110bffffffffff', 0, 10, 0, 100, + cell1_lat, cell1_long), + ('81defffffffffff', 0, 10, 0, 100, + cell2_lat, cell2_long), } - assert set(out) == expected + assert round_floats(set(out)) == round_floats(expected) def test_fail_if_agg_but_no_res(self, database_dir): parquet_file = data_dir + "/2_cell_agg.parquet" @@ -292,16 +307,21 @@ def test_full_pipeline(self, database_dir): out = read_temp_db(dataset) + cell1_lat, cell1_long = h3.h3_to_geo('8110bffffffffff') + cell2_lat, cell2_long = h3.h3_to_geo('81defffffffffff') + def f(i: int): return (i + 1) * 2 # the same as the raw data in the initial file expected = { - ('8110bffffffffff', f(0), f(10), f(0), f(100)), - ('81defffffffffff', f(0), f(10), f(0), f(100)), + ('8110bffffffffff', f(0), f(10), f(0), f(100), + cell1_lat, cell1_long), + ('81defffffffffff', f(0), f(10), f(0), f(100), + cell2_lat, cell2_long), } - assert round_floats(set(out)) == expected + assert round_floats(set(out)) == round_floats(expected) def test_additional_key_cols(self, database_dir): parquet_file = data_dir + "with_company.parquet" @@ -332,16 +352,34 @@ def test_additional_key_cols(self, database_dir): pipeline.run() out = read_temp_db(dataset) + def round_latlong(t:Tuple) -> Tuple: + as_l = list(t) + as_l[5] = round(as_l[5], 4) + as_l[6] = round(as_l[6], 4) + return tuple(as_l) + + out = list(map( + round_latlong, + out + )) + + # the same as the raw data in the initial file + cell1_lat, cell1_long = h3.h3_to_geo('8110bffffffffff') + cell2_lat, cell2_long = h3.h3_to_geo('81defffffffffff') # the same as the raw data in the initial file expected = { - ('company1', '8110bffffffffff', 0, 10, 0, 100), - ('company2', '8110bffffffffff', 2, 2, 20, 20), - ('company1', '81defffffffffff', 0, 10, 0, 100), - ('company2', '81defffffffffff', 2, 2, 20, 20) + ('company1', '8110bffffffffff', 0, 10, 0, 100, + cell1_lat, cell1_long), + ('company2', '8110bffffffffff', 2, 2, 20, 20, + cell1_lat, cell1_long), + ('company1', '81defffffffffff', 0, 10, 0, 100, + cell2_lat, cell2_long), + ('company2', '81defffffffffff', 2, 2, 20, 20, + cell2_lat, cell2_long) } - assert set(out) == expected + assert round_floats(set(out)) == round_floats(expected) def test_metadata_creation(self, database_dir): parquet_file = data_dir + "/2_cell_agg.parquet"