Skip to content

Commit

Permalink
latitude & longitue isadded in aggregate step
Browse files Browse the repository at this point in the history
Signed-off-by: Davis Broda <[email protected]>
  • Loading branch information
DavisBroda committed Aug 29, 2024
1 parent 2736d6a commit b2b0f69
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 17 deletions.
21 changes: 19 additions & 2 deletions src/loader/aggregation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pandas
from pandas import DataFrame

from common.const import LOGGING_FORMAT, CELL_COL
from common.const import LOGGING_FORMAT, CELL_COL, LONGITUDE_COL, LATITUDE_COL

# Set up logging

Expand Down Expand Up @@ -85,7 +85,24 @@ def run(self, in_df: DataFrame) -> DataFrame:

with_agg = groups.agg(**self.agg_map).reset_index()

return with_agg
with_cell_ll = self._add_cell_centroid_lat_long(with_agg)

return with_cell_ll

def _add_cell_centroid_lat_long(self, in_df: DataFrame) -> DataFrame:
def cell_to_lat(row):
cell = row['h3_cell']
lat, long = h3.h3_to_geo(cell)
return lat

def cell_to_long(row):
cell = row['h3_cell']
lat, long = h3.h3_to_geo(cell)

return long
in_df[LATITUDE_COL] = in_df.apply(cell_to_lat, axis='columns')
in_df[LONGITUDE_COL] = in_df.apply(cell_to_long, axis='columns')
return in_df

def _add_cell_column(self, in_df: DataFrame) -> DataFrame:
def to_cell(row):
Expand Down
9 changes: 5 additions & 4 deletions test/aggregationstep/test_cell_aggreation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,14 @@ def test_cell_col_in_output(self, agg_df):

assert "h3_cell" in out.columns

def test_lat_long_not_in_output(self, agg_df):
def test_lat_long_in_output(self, agg_df):
agg_step = MinAggregation({})
all_agg = CellAggregationStep([agg_step], 1, ['value1', 'value2'], [])

out = all_agg.run(agg_df)

assert len(out.columns) == 3
assert "latitude" not in out.columns
assert "longitude" not in out.columns
assert "latitude" in out.columns
assert "longitude" in out.columns

def test_output_name_format_multiple_built_in_agg(self, agg_df):
agg_min = MinAggregation({})
Expand All @@ -79,6 +78,8 @@ def test_output_name_format_multiple_built_in_agg(self, agg_df):
out = all_agg.run(agg_df)
all_expected_cols = {
"h3_cell",
"latitude",
"longitude",
"value1_min",
"value2_min",
"value1_max",
Expand Down
60 changes: 49 additions & 11 deletions test/load_pipeline/test_loading_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List, Tuple, Dict, Set

import duckdb
import h3
import pytest
from pandas import DataFrame

Expand Down Expand Up @@ -190,14 +191,28 @@ def test_read_out_aggregate(self, database_dir):
pipeline.run()

out = read_temp_db(dataset)
def round_latlong(t:Tuple) -> Tuple:
as_l = list(t)
as_l[5] = round(as_l[5], 4)
as_l[6] = round(as_l[6], 4)
return tuple(as_l)

out = list(map(
round_latlong,
out
))

# the same as the raw data in the initial file
cell1_lat, cell1_long = h3.h3_to_geo('8110bffffffffff')
cell2_lat, cell2_long = h3.h3_to_geo('81defffffffffff')
expected = {
('8110bffffffffff', 0, 10, 0, 100),
('81defffffffffff', 0, 10, 0, 100),
('8110bffffffffff', 0, 10, 0, 100,
cell1_lat, cell1_long),
('81defffffffffff', 0, 10, 0, 100,
cell2_lat, cell2_long),
}

assert set(out) == expected
assert round_floats(set(out)) == round_floats(expected)

def test_fail_if_agg_but_no_res(self, database_dir):
parquet_file = data_dir + "/2_cell_agg.parquet"
Expand Down Expand Up @@ -292,16 +307,21 @@ def test_full_pipeline(self, database_dir):

out = read_temp_db(dataset)

cell1_lat, cell1_long = h3.h3_to_geo('8110bffffffffff')
cell2_lat, cell2_long = h3.h3_to_geo('81defffffffffff')

def f(i: int):
return (i + 1) * 2

# the same as the raw data in the initial file
expected = {
('8110bffffffffff', f(0), f(10), f(0), f(100)),
('81defffffffffff', f(0), f(10), f(0), f(100)),
('8110bffffffffff', f(0), f(10), f(0), f(100),
cell1_lat, cell1_long),
('81defffffffffff', f(0), f(10), f(0), f(100),
cell2_lat, cell2_long),
}

assert round_floats(set(out)) == expected
assert round_floats(set(out)) == round_floats(expected)

def test_additional_key_cols(self, database_dir):
parquet_file = data_dir + "with_company.parquet"
Expand Down Expand Up @@ -332,16 +352,34 @@ def test_additional_key_cols(self, database_dir):
pipeline.run()

out = read_temp_db(dataset)
def round_latlong(t:Tuple) -> Tuple:
as_l = list(t)
as_l[5] = round(as_l[5], 4)
as_l[6] = round(as_l[6], 4)
return tuple(as_l)

out = list(map(
round_latlong,
out
))

# the same as the raw data in the initial file
cell1_lat, cell1_long = h3.h3_to_geo('8110bffffffffff')
cell2_lat, cell2_long = h3.h3_to_geo('81defffffffffff')

# the same as the raw data in the initial file
expected = {
('company1', '8110bffffffffff', 0, 10, 0, 100),
('company2', '8110bffffffffff', 2, 2, 20, 20),
('company1', '81defffffffffff', 0, 10, 0, 100),
('company2', '81defffffffffff', 2, 2, 20, 20)
('company1', '8110bffffffffff', 0, 10, 0, 100,
cell1_lat, cell1_long),
('company2', '8110bffffffffff', 2, 2, 20, 20,
cell1_lat, cell1_long),
('company1', '81defffffffffff', 0, 10, 0, 100,
cell2_lat, cell2_long),
('company2', '81defffffffffff', 2, 2, 20, 20,
cell2_lat, cell2_long)
}

assert set(out) == expected
assert round_floats(set(out)) == round_floats(expected)

def test_metadata_creation(self, database_dir):
parquet_file = data_dir + "/2_cell_agg.parquet"
Expand Down

0 comments on commit b2b0f69

Please sign in to comment.