Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
ritviksahajpal committed Nov 20, 2024
1 parent dc0a85a commit 6ddf840
Showing 1 changed file with 192 additions and 58 deletions.
250 changes: 192 additions & 58 deletions geoprepare/eoaccess/eoaccess.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
import logging
import rasterio as rio
import xarray as xr
import numpy as np
import rioxarray as rxr
from pathlib import Path
import earthaccess as ea
@@ -46,10 +47,8 @@ def __post_init__(self):
if isinstance(self.output_dir, str):
self.output_dir = Path(self.output_dir)

# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)

# Create logging directory if it does not exist
self.logger = log.Logger(
dir_log=self.output_dir / "logs",
project=self.logging_project,
@@ -63,7 +62,6 @@ def get_bbox_from_shapefile(self):
except:
raise RuntimeError(f"{self.shapefile} could not be read")

# Convert shapefile to EPSG:4326 if not already
if gdf.crs != "EPSG:4326":
gdf = gdf.to_crs("EPSG:4326")

@@ -83,41 +81,36 @@ def stream(self):
fileset = ea.open(self.results)

print(f" Using {type(fileset[0])} filesystem")
# Open each file with rioxarray and store in a list
datasets = [rxr.open_rasterio(file) for file in fileset]

# Concatenate datasets with a progress bar
with tqdm(total=len(datasets), desc="Concatenating datasets") as pbar:
def concat_with_progress(ds_list):
for ds in ds_list:
yield ds
pbar.update(1)

merged_dataset = xr.concat(concat_with_progress(datasets), dim='new_dimension')
breakpoint()

# Merge the datasets
# This example uses concat along a new dimension; adjust as needed
merged_dataset = xr.concat(datasets)

breakpoint()
ds = xr.open_mfdataset(fileset)
ds
breakpoint()

@staticmethod
def download(item):
result, output_dir = item
ea.download([result], str(output_dir))

def download_parallel(self):
# Create a list of tuples containing result and output_dir
combinations = [(result, self.output_dir) for result in self.results]

num_cpu = int(cpu_count() * 0.6)
with Pool(num_cpu) as p:
for i, _ in enumerate(p.imap_unordered(self.download, combinations)):
pass

try:
with Pool(num_cpu) as p:
for i, _ in enumerate(p.imap_unordered(self.download, combinations)):
pass
finally:
p.close()
p.join()

@dataclass
class EarthAccessProcessor:
@@ -130,58 +123,199 @@ class EarthAccessProcessor:
input_dir: str = None

def __post_init__(self):
# Both bbox and shapefile cannot be not None at the same time
if self.bbox and self.shapefile:
raise ValueError("Both bbox and shapefile cannot be specified")

# if shapefile exists then it should have columns specifying start and end dates
# if self.shapefile:
# assert self.start_date_col and self.end_date_col, "Start and end date columns must be specified"
# assert self.start_date_col in self.shapefile.columns, f"{self.start_date_col} not found in shapefile"
# assert self.end_date_col in self.shapefile.columns, f"{self.end_date_col} not found in shapefile"

# Create mosaic directory within output directory
self.mosaic_dir = Path(os.path.join(self.input_dir, "mosaic"))
os.makedirs(self.mosaic_dir, exist_ok=True)

def get_ts(self):
# Loop over shapefile
dg = gpd.read_file(self.shapefile)
for index, row in tqdm(dg.iterrows(), desc="Getting time-series"):
breakpoint()
pass

def mosaic(self):
# Check if self.dataset contains either HLSS30 or HLSL30
if "HLSS30" in self.dataset or "HLSL30" in self.dataset:
def group_files_by_band_and_date():
grouped_files = defaultdict(list)

# Iterate through all files in the directory
for filename in os.listdir(self.input_dir):
if filename.endswith(".tif"):
# Parse filename to extract band and date
parts = filename.split(".")
band = parts[-2] # Spectral band is the second last part
date = parts[3][:-7] # Julian Date of Acquisition

# Group the files
grouped_files[(band, date)].append(filename)
try:
dg = gpd.read_file(self.shapefile)
for index, row in tqdm(dg.iterrows(), desc="Getting time-series"):
pass
except Exception as e:
logging.error(f"Error while processing time-series: {e}")
raise

return grouped_files
def group_files_by_band_and_date(self):
grouped_files = defaultdict(list)

grouped_files = group_files_by_band_and_date()
for filename in os.listdir(self.input_dir):
if filename.endswith(".tif"):
parts = filename.split(".")
band = parts[-2]
date = parts[3][:8]

pbar = tqdm(grouped_files.items())
for key, files in pbar:
band, date = key
grouped_files[(band, date)].append(filename)

pbar.set_description(f"Mosaicing: {band} {date}")
pbar.update()
return grouped_files

mosaic_file = self.mosaic_dir / f"mosaic_{band}_{date}.tif"
if os.path.exists(mosaic_file):
continue
def mosaic(self):
grouped_files = self.group_files_by_band_and_date()

for key, files in tqdm(grouped_files.items(), desc="Mosaicing files"):
band, date = key
mosaic_file = self.mosaic_dir / f"mosaic_{band}_{date}.tif"

if os.path.exists(mosaic_file):
continue

tif_files = [Path(self.input_dir) / file for file in files]

# Use rioxarray to mosaic the files
first_file = rxr.open_rasterio(tif_files[0])
crs = first_file.rio.crs
res = first_file.rio.resolution()

for file in tif_files[1:]:
ds = rxr.open_rasterio(file)
if ds.rio.crs != crs or ds.rio.resolution() != res:
raise ValueError(f"File {file} has different CRS or resolution")

# Call mosaic utility function
utils.mosaic(tif_files, mosaic_file)

import numpy as np
import rioxarray as rxr
from pathlib import Path
from tqdm import tqdm

def create_quality_mask(quality_data, bit_nums: list = [1, 2, 3, 4, 5]):
"""
Uses the Fmask layer and bit numbers to create a binary mask of good pixels.
By default, bits 1-5 are used to remove bad pixels like cloud, shadow, snow.
Parameters:
- quality_data: The Fmask layer data (2D array).
- bit_nums: List of bit numbers to use for masking (default: bits 1-5).
Returns:
- mask_array: A binary mask where 1 indicates bad pixels, 0 indicates good pixels.
"""
mask_array = np.zeros((quality_data.shape[0], quality_data.shape[1]))

# Replace NaNs with 0 and convert to integer
quality_data = np.nan_to_num(quality_data, 0).astype(np.int8)

# Iterate through the bits to generate the mask
for bit in bit_nums:
mask_temp = np.array(quality_data) & 1 << bit > 0
mask_array = np.logical_or(mask_array, mask_temp)

return mask_array

def compute_selected_indices(self, red_band_file, nir_band_file, green_band_file, blue_band_file, fmask_file,
output_dir, selected_indices, swir_band_file=None, red_edge_band_file=None):
"""
Compute selected vegetation indices like NDVI, GCVI, EVI, SAVI, etc., apply scaling, and QA masking.
Avoid recomputation if the file already exists.
Parameters:
- red_band_file, nir_band_file, green_band_file, blue_band_file, fmask_file: Paths to the necessary bands
- output_dir: Directory to save the output indices
- selected_indices: A list of selected indices to compute (e.g., ['NDVI', 'EVI', 'SAVI'])
- swir_band_file, red_edge_band_file: Optional paths to the SWIR and Red Edge bands for specific indices
"""

# Ensure the output directory exists
output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True)

# Open the red, NIR, green, and blue bands with scaling
scale_factor = 0.0001
red_band = rxr.open_rasterio(red_band_file).squeeze() * scale_factor
nir_band = rxr.open_rasterio(nir_band_file).squeeze() * scale_factor
green_band = rxr.open_rasterio(green_band_file).squeeze() * scale_factor
blue_band = rxr.open_rasterio(blue_band_file).squeeze() * scale_factor

# Open the QA mask (Fmask)
fmask = rxr.open_rasterio(fmask_file).squeeze()

# Create the quality mask from Fmask using the create_quality_mask function
bit_nums = [1, 2, 3, 4, 5] # Define which bits to use for masking (can be adjusted)
mask_layer = self.create_quality_mask(fmask.data, bit_nums)

# Apply the QA mask to each band (good pixels are where mask_layer == 0)
red_band = red_band.where(~mask_layer)
nir_band = nir_band.where(~mask_layer)
green_band = green_band.where(~mask_layer)
blue_band = blue_band.where(~mask_layer)

# Available index computation functions
index_functions = {
'NDVI': self.compute_ndvi,
'GCVI': self.compute_gcvi,
'EVI': self.compute_evi,
'SAVI': self.compute_savi,
'MSAVI': self.compute_msavi,
'NDWI': self.compute_ndwi,
'GNDVI': self.compute_gndvi,
'ARVI': self.compute_arvi,
'NDMI': self.compute_ndmi if swir_band_file else None,
'RENDVI': self.compute_rendvi if red_edge_band_file else None,
'VARI': self.compute_vari,
}

# Loop over the selected indices and compute them if not already saved
for index in tqdm(selected_indices, desc="Computing indices"):
output_file = output_dir / f"{index}.tif"
if output_file.exists():
print(f"{index} already exists, skipping computation.")
continue

if index_functions.get(index) is not None:
# Call the appropriate index function and save the result
index_functions[index](red_band, nir_band, green_band, blue_band, output_file, swir_band_file,
red_edge_band_file)
else:
print(f"{index} is not available or missing necessary bands (e.g., SWIR or Red Edge).")

# Functions to compute individual indices
def compute_ndvi(self, red_band, nir_band, *_):
ndvi = (nir_band - red_band) / (nir_band + red_band)
ndvi.rio.to_raster(_[0])

def compute_gcvi(self, red_band, nir_band, green_band, *_):
gcvi = (nir_band / green_band) - 1
gcvi.rio.to_raster(_[0])

def compute_evi(self, red_band, nir_band, green_band, blue_band, output_file, *_):
evi = 2.5 * (nir_band - red_band) / (nir_band + 6 * red_band - 7.5 * blue_band + 1)
evi.rio.to_raster(output_file)

def compute_savi(self, red_band, nir_band, *_):
L = 0.5
savi = ((nir_band - red_band) / (nir_band + red_band + L)) * (1 + L)
savi.rio.to_raster(_[0])

def compute_msavi(self, red_band, nir_band, *_):
msavi = (2 * nir_band + 1 - np.sqrt((2 * nir_band + 1) ** 2 - 8 * (nir_band - red_band))) / 2
msavi.rio.to_raster(_[0])

def compute_ndwi(self, red_band, nir_band, green_band, *_):
ndwi = (green_band - nir_band) / (green_band + nir_band)
ndwi.rio.to_raster(_[0])

def compute_gndvi(self, red_band, nir_band, green_band, *_):
gndvi = (nir_band - green_band) / (nir_band + green_band)
gndvi.rio.to_raster(_[0])

def compute_arvi(self, red_band, nir_band, green_band, blue_band, output_file, *_):
arvi = (nir_band - (2 * red_band - blue_band)) / (nir_band + (2 * red_band - blue_band))
arvi.rio.to_raster(output_file)

def compute_ndmi(self, red_band, nir_band, green_band, blue_band, output_file, swir_band_file, *_):
swir_band = rxr.open_rasterio(swir_band_file)
ndmi = (nir_band - swir_band) / (nir_band + swir_band)
ndmi.rio.to_raster(output_file)

def compute_rendvi(self, red_band, nir_band, green_band, blue_band, output_file, _, red_edge_band_file):
red_edge_band = rxr.open_rasterio(red_edge_band_file)
rendvi = (nir_band - red_edge_band) / (nir_band + red_edge_band)
rendvi.rio.to_raster(output_file)

def compute_vari(self, red_band, nir_band, green_band, blue_band, output_file, *_):
vari = (green_band - red_band) / (green_band + red_band - blue_band)
vari.rio.to_raster(output_file)

tif_files = [Path(self.input_dir) / filename for filename in files]
utils.mosaic(tif_files, mosaic_file)

0 comments on commit 6ddf840

Please sign in to comment.