-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathexample_patch_loading.py
55 lines (45 loc) · 2.67 KB
/
example_patch_loading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""GeoLifeCLEF23 patch example module.
This module provides an example of how to load and use GeoLifeCLEF2023 patch
datasets using the framework developped for the challenge and incorporated
in malpolon at `malpolon.data.datasets.geolifeclef2023`.
"""
import argparse
import random
from malpolon.data.datasets.geolifeclef2023 import (
JpegPatchProvider, MultipleRasterPatchProvider, PatchesDataset,
PatchesDatasetMultiLabel, RasterPatchProvider)
def main(display: bool = True):
"""Run GLC23 patch example script."""
data_path = 'dataset/sample_data/' # root path of the data
# configure providers
p_rgb = JpegPatchProvider(data_path + 'SatelliteImages/',
dataset_stats='jpeg_patches_sample_stats.csv',
id_getitem='patchID') # take all sentinel imagery layer (r,g,b,nir)
p_hfp_d = MultipleRasterPatchProvider(data_path + 'EnvironmentalRasters/HumanFootprint/detailed/') # take all rasters from human footprint detailed
p_bioclim = MultipleRasterPatchProvider(data_path + 'EnvironmentalRasters/Climate/BioClimatic_Average_1981-2010/',
select=['bio1', 'bio2']) # take only bio1 and bio2 from bioclimatic rasters
p_hfp_s = RasterPatchProvider(data_path + 'EnvironmentalRasters/HumanFootprint/summarized/HFP2009_WGS84.tif') # take the human footprint 2009 summurized raster
# create dataset
dataset = PatchesDataset(occurrences=data_path + 'Presence_only_occurrences/Presences_only_train_sample.csv',
providers=[p_hfp_d, p_bioclim, p_hfp_s, p_rgb],
item_columns=['lat', 'lon', 'patchID'])
dataset_multi = PatchesDatasetMultiLabel(occurrences=data_path + 'Presence_only_occurrences/Presences_only_train_sample.csv',
providers=[p_hfp_d, p_bioclim, p_hfp_s, p_rgb],
item_columns=['lat', 'lon', 'patchID'],
id_getitem='patchID')
# print random tensors from dataset
ids = [random.randint(0, len(dataset) - 1) for i in range(5)]
for i in ids:
tensor, label = dataset[i]
label_multi = dataset_multi[i][1]
print(f'Tensor type: {type(tensor)}, tensor shape: {tensor.shape}, '
f'label: {label}, \nlabel_multi: {label_multi}')
if display:
dataset.plot_patch(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--plot", help="Plot patches.",
nargs='*', action='store')
args = parser.parse_args()
display = args.plot is not None
main(display)