Skip to content

Commit

Permalink
Handle opening multi-file granules (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Dec 4, 2023
1 parent 1a45326 commit 68bea76
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 38 deletions.
64 changes: 26 additions & 38 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from itertools import chain
from pathlib import Path
from pickle import dumps, loads
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from uuid import uuid4

import fsspec
Expand Down Expand Up @@ -44,23 +44,15 @@ def __repr__(self) -> str:


def _open_files(
data_links: List[str],
granules: Union[List[str], List[DataGranule]],
url_mapping: Mapping[str, Union[DataGranule, None]],
fs: fsspec.AbstractFileSystem,
threads: Optional[int] = 8,
) -> List[fsspec.AbstractFileSystem]:
def multi_thread_open(data: tuple) -> EarthAccessFile:
urls, granule = data
if not isinstance(granule, str):
if len(granule.data_links()) > 1:
print(
"Warning: This collection contains more than one file per granule. "
"earthaccess will only open the first data link, "
"try filtering the links before opening them."
)
return EarthAccessFile(fs.open(urls), granule)

fileset = pqdm(zip(data_links, granules), multi_thread_open, n_jobs=threads)
fileset = pqdm(url_mapping.items(), multi_thread_open, n_jobs=threads)
return fileset


Expand All @@ -84,6 +76,17 @@ def make_instance(
return EarthAccessFile(loads(data), granule)


def _get_url_granule_mapping(
granules: List[DataGranule], access: str
) -> Mapping[str, DataGranule]:
"""Construct a mapping between file urls and granules"""
url_mapping = {}
for granule in granules:
for url in granule.data_links(access=access):
url_mapping[url] = granule
return url_mapping


class Store(object):
"""
Store class to access granules on-prem or in the cloud.
Expand Down Expand Up @@ -320,7 +323,6 @@ def _open_granules(
threads: Optional[int] = 8,
) -> List[Any]:
fileset: List = []
data_links: List = []
total_size = round(sum([granule.size() for granule in granules]) / 1024, 2)
print(f"Opening {len(granules)} granules, approx size: {total_size} GB")

Expand All @@ -331,7 +333,7 @@ def _open_granules(

if self.running_in_aws:
if granules[0].cloud_hosted:
access_method = "direct"
access = "direct"
provider = granules[0]["meta"]["provider-id"]
# if the data has its own S3 credentials endpoint we'll use it
endpoint = self._own_s3_credentials(granules[0]["umm"]["RelatedUrls"])
Expand All @@ -342,20 +344,14 @@ def _open_granules(
print(f"using provider: {provider}")
s3_fs = self.get_s3fs_session(provider=provider)
else:
access_method = "on_prem"
access = "on_prem"
s3_fs = None

data_links = list(
chain.from_iterable(
granule.data_links(access=access_method) for granule in granules
)
)

url_mapping = _get_url_granule_mapping(granules, access)
if s3_fs is not None:
try:
fileset = _open_files(
data_links=data_links,
granules=granules,
url_mapping,
fs=s3_fs,
threads=threads,
)
Expand All @@ -366,16 +362,11 @@ def _open_granules(
f"Exception: {traceback.format_exc()}"
) from e
else:
fileset = self._open_urls_https(data_links, granules, threads=threads)
fileset = self._open_urls_https(url_mapping, threads=threads)
return fileset
else:
access_method = "on_prem"
data_links = list(
chain.from_iterable(
granule.data_links(access=access_method) for granule in granules
)
)
fileset = self._open_urls_https(data_links, granules, threads=threads)
url_mapping = _get_url_granule_mapping(granules, access="on_prem")
fileset = self._open_urls_https(url_mapping, threads=threads)
return fileset

@_open.register
Expand All @@ -386,14 +377,12 @@ def _open_urls(
threads: Optional[int] = 8,
) -> List[Any]:
fileset: List = []
data_links: List = []

if isinstance(granules[0], str) and (
granules[0].startswith("s3") or granules[0].startswith("http")
):
# TODO: method to derive the DAAC from url?
provider = provider
data_links = granules
else:
raise ValueError(
f"Schema for {granules[0]} is not recognized, must be an HTTP or S3 URL"
Expand All @@ -403,14 +392,14 @@ def _open_urls(
"A valid Earthdata login instance is required to retrieve S3 credentials"
)

url_mapping: Mapping[str, None] = {url: None for url in granules}
if self.running_in_aws and granules[0].startswith("s3"):
if provider is not None:
s3_fs = self.get_s3fs_session(provider=provider)
if s3_fs is not None:
try:
fileset = _open_files(
data_links=data_links,
granules=granules,
url_mapping,
fs=s3_fs,
threads=threads,
)
Expand All @@ -432,7 +421,7 @@ def _open_urls(
raise ValueError(
"We cannot open S3 links when we are not in-region, try using HTTPS links"
)
fileset = self._open_urls_https(data_links, granules, threads)
fileset = self._open_urls_https(url_mapping, threads)
return fileset

def get(
Expand Down Expand Up @@ -637,14 +626,13 @@ def _download_onprem_granules(

def _open_urls_https(
self,
urls: List[str],
granules: Union[List[str], List[DataGranule]],
url_mapping: Mapping[str, Union[DataGranule, None]],
threads: Optional[int] = 8,
) -> List[fsspec.AbstractFileSystem]:
https_fs = self.get_fsspec_session()
if https_fs is not None:
try:
fileset = _open_files(urls, granules, https_fs, threads)
fileset = _open_files(url_mapping, https_fs, threads)
except Exception:
print(
"An exception occurred while trying to access remote files via HTTPS: "
Expand Down
10 changes: 10 additions & 0 deletions tests/integration/test_cloud_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,13 @@ def test_earthaccess_can_download_cloud_collection_granules(daac):
f"Warning: {concept_id} downloaded size {total_mb_downloaded}MB is "
f"different from the size reported by CMR: {total_size_cmr}MB"
)


def test_multi_file_granule(tmp_path):
# Ensure granules that contain multiple files are handled correctly
granules = earthaccess.search_data(short_name="HLSL30", count=1)
assert len(granules) == 1
urls = granules[0].data_links()
assert len(urls) > 1
files = earthaccess.download(granules, str(tmp_path))
assert set(map(os.path.basename, urls)) == set(map(os.path.basename, files))
10 changes: 10 additions & 0 deletions tests/integration/test_cloud_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,13 @@ def test_earthaccess_can_open_onprem_collection_granules(daac):
logger.info(f"File type: {magic.from_buffer(file.read(2048))}")
else:
logger.warning(f"File could not be open: {file}")


def test_multi_file_granule():
# Ensure granules that contain multiple files are handled correctly
granules = earthaccess.search_data(short_name="HLSL30", count=1)
assert len(granules) == 1
urls = granules[0].data_links()
assert len(urls) > 1
files = earthaccess.open(granules)
assert set(urls) == set(f.path for f in files)

0 comments on commit 68bea76

Please sign in to comment.