diff --git a/earthaccess/store.py b/earthaccess/store.py index f2f3618e..940b8aec 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -6,7 +6,7 @@ from itertools import chain from pathlib import Path from pickle import dumps, loads -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union from uuid import uuid4 import fsspec @@ -44,23 +44,15 @@ def __repr__(self) -> str: def _open_files( - data_links: List[str], - granules: Union[List[str], List[DataGranule]], + url_mapping: Mapping[str, Union[DataGranule, None]], fs: fsspec.AbstractFileSystem, threads: Optional[int] = 8, ) -> List[fsspec.AbstractFileSystem]: def multi_thread_open(data: tuple) -> EarthAccessFile: urls, granule = data - if not isinstance(granule, str): - if len(granule.data_links()) > 1: - print( - "Warning: This collection contains more than one file per granule. " - "earthaccess will only open the first data link, " - "try filtering the links before opening them." - ) return EarthAccessFile(fs.open(urls), granule) - fileset = pqdm(zip(data_links, granules), multi_thread_open, n_jobs=threads) + fileset = pqdm(url_mapping.items(), multi_thread_open, n_jobs=threads) return fileset @@ -84,6 +76,17 @@ def make_instance( return EarthAccessFile(loads(data), granule) +def _get_url_granule_mapping( + granules: List[DataGranule], access: str +) -> Mapping[str, DataGranule]: + """Construct a mapping between file urls and granules""" + url_mapping = {} + for granule in granules: + for url in granule.data_links(access=access): + url_mapping[url] = granule + return url_mapping + + class Store(object): """ Store class to access granules on-prem or in the cloud. @@ -320,7 +323,6 @@ def _open_granules( threads: Optional[int] = 8, ) -> List[Any]: fileset: List = [] - data_links: List = [] total_size = round(sum([granule.size() for granule in granules]) / 1024, 2) print(f"Opening {len(granules)} granules, approx size: {total_size} GB") @@ -331,7 +333,7 @@ def _open_granules( if self.running_in_aws: if granules[0].cloud_hosted: - access_method = "direct" + access = "direct" provider = granules[0]["meta"]["provider-id"] # if the data has its own S3 credentials endpoint we'll use it endpoint = self._own_s3_credentials(granules[0]["umm"]["RelatedUrls"]) @@ -342,20 +344,14 @@ def _open_granules( print(f"using provider: {provider}") s3_fs = self.get_s3fs_session(provider=provider) else: - access_method = "on_prem" + access = "on_prem" s3_fs = None - data_links = list( - chain.from_iterable( - granule.data_links(access=access_method) for granule in granules - ) - ) - + url_mapping = _get_url_granule_mapping(granules, access) if s3_fs is not None: try: fileset = _open_files( - data_links=data_links, - granules=granules, + url_mapping, fs=s3_fs, threads=threads, ) @@ -366,16 +362,11 @@ def _open_granules( f"Exception: {traceback.format_exc()}" ) from e else: - fileset = self._open_urls_https(data_links, granules, threads=threads) + fileset = self._open_urls_https(url_mapping, threads=threads) return fileset else: - access_method = "on_prem" - data_links = list( - chain.from_iterable( - granule.data_links(access=access_method) for granule in granules - ) - ) - fileset = self._open_urls_https(data_links, granules, threads=threads) + url_mapping = _get_url_granule_mapping(granules, access="on_prem") + fileset = self._open_urls_https(url_mapping, threads=threads) return fileset @_open.register @@ -386,14 +377,12 @@ def _open_urls( threads: Optional[int] = 8, ) -> List[Any]: fileset: List = [] - data_links: List = [] if isinstance(granules[0], str) and ( granules[0].startswith("s3") or granules[0].startswith("http") ): # TODO: method to derive the DAAC from url? provider = provider - data_links = granules else: raise ValueError( f"Schema for {granules[0]} is not recognized, must be an HTTP or S3 URL" @@ -403,14 +392,14 @@ def _open_urls( "A valid Earthdata login instance is required to retrieve S3 credentials" ) + url_mapping: Mapping[str, None] = {url: None for url in granules} if self.running_in_aws and granules[0].startswith("s3"): if provider is not None: s3_fs = self.get_s3fs_session(provider=provider) if s3_fs is not None: try: fileset = _open_files( - data_links=data_links, - granules=granules, + url_mapping, fs=s3_fs, threads=threads, ) @@ -432,7 +421,7 @@ def _open_urls( raise ValueError( "We cannot open S3 links when we are not in-region, try using HTTPS links" ) - fileset = self._open_urls_https(data_links, granules, threads) + fileset = self._open_urls_https(url_mapping, threads) return fileset def get( @@ -637,14 +626,13 @@ def _download_onprem_granules( def _open_urls_https( self, - urls: List[str], - granules: Union[List[str], List[DataGranule]], + url_mapping: Mapping[str, Union[DataGranule, None]], threads: Optional[int] = 8, ) -> List[fsspec.AbstractFileSystem]: https_fs = self.get_fsspec_session() if https_fs is not None: try: - fileset = _open_files(urls, granules, https_fs, threads) + fileset = _open_files(url_mapping, https_fs, threads) except Exception: print( "An exception occurred while trying to access remote files via HTTPS: " diff --git a/tests/integration/test_cloud_download.py b/tests/integration/test_cloud_download.py index 4ecc3137..63a05b93 100644 --- a/tests/integration/test_cloud_download.py +++ b/tests/integration/test_cloud_download.py @@ -157,3 +157,13 @@ def test_earthaccess_can_download_cloud_collection_granules(daac): f"Warning: {concept_id} downloaded size {total_mb_downloaded}MB is " f"different from the size reported by CMR: {total_size_cmr}MB" ) + + +def test_multi_file_granule(tmp_path): + # Ensure granules that contain multiple files are handled correctly + granules = earthaccess.search_data(short_name="HLSL30", count=1) + assert len(granules) == 1 + urls = granules[0].data_links() + assert len(urls) > 1 + files = earthaccess.download(granules, str(tmp_path)) + assert set(map(os.path.basename, urls)) == set(map(os.path.basename, files)) diff --git a/tests/integration/test_cloud_open.py b/tests/integration/test_cloud_open.py index d9d8fdea..78050f22 100644 --- a/tests/integration/test_cloud_open.py +++ b/tests/integration/test_cloud_open.py @@ -157,3 +157,13 @@ def test_earthaccess_can_open_onprem_collection_granules(daac): logger.info(f"File type: {magic.from_buffer(file.read(2048))}") else: logger.warning(f"File could not be open: {file}") + + +def test_multi_file_granule(): + # Ensure granules that contain multiple files are handled correctly + granules = earthaccess.search_data(short_name="HLSL30", count=1) + assert len(granules) == 1 + urls = granules[0].data_links() + assert len(urls) > 1 + files = earthaccess.open(granules) + assert set(urls) == set(f.path for f in files)