Skip to content

Commit

Permalink
Add storage class conversion to get_dataset and find_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
timj committed Nov 1, 2023
1 parent faa1b00 commit eadf4a5
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 12 deletions.
9 changes: 8 additions & 1 deletion python/lsst/daf/butler/_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,13 +800,16 @@ def get_dataset_type(self, name: str) -> DatasetType:
raise NotImplementedError()

@abstractmethod
def get_dataset(self, id: DatasetId) -> DatasetRef | None:
def get_dataset(self, id: DatasetId, storage_class: str | StorageClass | None) -> DatasetRef | None:
"""Retrieve a Dataset entry.
Parameters
----------
id : `DatasetId`
The unique identifier for the dataset.
storage_class : `str` or `StorageClass` or `None`
A storage class to use when creating the returned entry. If given
it must be compatible with the default storage class.
Returns
-------
Expand All @@ -824,6 +827,7 @@ def find_dataset(
*,
collections: str | Sequence[str] | None = None,
timespan: Timespan | None = None,
storage_class: str | StorageClass | None = None,
datastore_records: bool = False,
**kwargs: Any,
) -> DatasetRef | None:
Expand Down Expand Up @@ -851,6 +855,9 @@ def find_dataset(
A timespan that the validity range of the dataset must overlap.
If not provided, any `~CollectionType.CALIBRATION` collections
matched by the ``collections`` argument will not be searched.
storage_class : `str` or `StorageClass` or `None`
A storage class to use when creating the returned entry. If given
it must be compatible with the default storage class.
**kwargs
Additional keyword arguments passed to
`DataCoordinate.standardize` to convert ``dataId`` to a true
Expand Down
15 changes: 12 additions & 3 deletions python/lsst/daf/butler/direct_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,8 +1321,13 @@ def getURI(
def get_dataset_type(self, name: str) -> DatasetType:
return self._registry.getDatasetType(name)

def get_dataset(self, id: DatasetId) -> DatasetRef | None:
return self._registry.getDataset(id)
def get_dataset(
self, id: DatasetId, storage_class: str | StorageClass | None = None
) -> DatasetRef | None:
ref = self._registry.getDataset(id)
if ref is not None and storage_class:
ref = ref.overrideStorageClass(storage_class)
return ref

def find_dataset(
self,
Expand All @@ -1331,6 +1336,7 @@ def find_dataset(
*,
collections: str | Sequence[str] | None = None,
timespan: Timespan | None = None,
storage_class: str | StorageClass | None = None,
datastore_records: bool = False,
**kwargs: Any,
) -> DatasetRef | None:
Expand All @@ -1342,14 +1348,17 @@ def find_dataset(
actual_type = dataset_type
data_id, kwargs = self._rewrite_data_id(data_id, actual_type, **kwargs)

return self._registry.findDataset(
ref = self._registry.findDataset(
dataset_type,
data_id,
collections=collections,
timespan=timespan,
dataset_records=datastore_records,
**kwargs,
)
if ref is not None and storage_class is not None:
ref = ref.overrideStorageClass(storage_class)
return ref

def retrieveArtifacts(
self,
Expand Down
23 changes: 19 additions & 4 deletions python/lsst/daf/butler/remote_butler/_remote_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,18 @@ def get_dataset_type(self, name: str) -> DatasetType:
response.raise_for_status()
return DatasetType.from_simple(SerializedDatasetType(**response.json()), universe=self.dimensions)

def get_dataset(self, id: DatasetId) -> DatasetRef | None:
def get_dataset(
self, id: DatasetId, storage_class: str | StorageClass | None = None
) -> DatasetRef | None:
path = f"dataset/{id}"
response = self._client.get(self._get_url(path))
if isinstance(storage_class, StorageClass):
storage_class_name = storage_class.name
elif storage_class:
storage_class_name = storage_class
params: dict[str, str] = {}
if storage_class:
params["storage_class"] = storage_class_name
response = self._client.get(self._get_url(path), params=params)
response.raise_for_status()
if response.json() is None:
return None
Expand All @@ -234,6 +243,7 @@ def find_dataset(
*,
collections: str | Sequence[str] | None = None,
timespan: Timespan | None = None,
storage_class: str | StorageClass | None = None,
datastore_records: bool = False,
**kwargs: Any,
) -> DatasetRef | None:
Expand All @@ -251,13 +261,18 @@ def find_dataset(
if isinstance(dataset_type, DatasetType):
dataset_type = dataset_type.name

if isinstance(storage_class, StorageClass):
storage_class = storage_class.name

query = FindDatasetModel(
data_id=self._simplify_dataId(data_id, **kwargs), collections=wildcards.strings
data_id=self._simplify_dataId(data_id, **kwargs),
collections=wildcards.strings,
storage_class=storage_class,
)

path = f"find_dataset/{dataset_type}"
response = self._client.post(
self._get_url(path), json=query.model_dump(mode="json", exclude_unset=True)
self._get_url(path), json=query.model_dump(mode="json", exclude_unset=True, exclude_defaults=True)
)
response.raise_for_status()

Expand Down
10 changes: 7 additions & 3 deletions python/lsst/daf/butler/remote_butler/server/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ def get_dataset_type(
response_model_exclude_defaults=True,
response_model_exclude_none=True,
)
def get_dataset(id: uuid.UUID, factory: Factory = Depends(factory_dependency)) -> SerializedDatasetRef | None:
def get_dataset(
id: uuid.UUID, storage_class: str | None = None, factory: Factory = Depends(factory_dependency)
) -> SerializedDatasetRef | None:
"""Return a single dataset reference."""
butler = factory.create_butler()
ref = butler.get_dataset(id)
ref = butler.get_dataset(id, storage_class=storage_class)
if ref is not None:
return ref.to_simple()
# This could raise a 404 since id is not found. The standard implementation
Expand Down Expand Up @@ -150,5 +152,7 @@ def find_dataset(
data_id = query.data_id.dataId

butler = factory.create_butler()
ref = butler.find_dataset(dataset_type, None, collections=collection_query, **data_id)
ref = butler.find_dataset(
dataset_type, None, collections=collection_query, storage_class=query.storage_class, **data_id
)
return ref.to_simple() if ref else None
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
class FindDatasetModel(_BaseModelCompat):
data_id: SerializedDataCoordinate
collections: list[str]
storage_class: str | None
14 changes: 13 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
TestClient = None
app = None

from lsst.daf.butler import Butler, DataCoordinate, DatasetRef
from lsst.daf.butler import Butler, DataCoordinate, DatasetRef, StorageClassFactory
from lsst.daf.butler.tests import DatastoreMock
from lsst.daf.butler.tests.utils import MetricTestRepo, makeTestTempDir, removeTestTempDir

Expand All @@ -64,6 +64,8 @@ class ButlerClientServerTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.storageClassFactory = StorageClassFactory()

# First create a butler and populate it.
cls.root = makeTestTempDir(TESTDIR)
cls.repo = MetricTestRepo(root=cls.root, configFile=os.path.join(TESTDIR, "config/basic/butler.yaml"))
Expand Down Expand Up @@ -106,6 +108,8 @@ def test_get_dataset_type(self):
self.assertEqual(bias_type.name, "bias")

def test_find_dataset(self):
storage_class = self.storageClassFactory.getStorageClass("Exposure")

ref = self.butler.find_dataset("bias", collections="imported_g", detector=1, instrument="Cam1")
self.assertIsInstance(ref, DatasetRef)
self.assertEqual(ref.id, uuid.UUID("e15ab039-bc8b-4135-87c5-90902a7c0b22"))
Expand All @@ -123,6 +127,7 @@ def test_find_dataset(self):
ref.datasetType,
DataCoordinate.standardize(detector=1, instrument="Cam1", universe=self.butler.dimensions),
collections="imported_g",
storage_class=storage_class,
)
self.assertEqual(ref_new, ref)

Expand All @@ -138,8 +143,15 @@ def test_find_dataset(self):
)
self.assertEqual(ref2, ref3)

# The test datasets are all Exposure so storage class conversion
# can not be tested until we fix that. For now at least test the
# code paths.
bias = self.butler.get_dataset(ref.id, storage_class=storage_class)
self.assertEqual(bias.datasetType.storageClass, storage_class)

# Unknown dataset should not fail.
self.assertIsNone(self.butler.get_dataset(uuid.uuid4()))
self.assertIsNone(self.butler.get_dataset(uuid.uuid4(), storage_class="NumpyArray"))


if __name__ == "__main__":
Expand Down

0 comments on commit eadf4a5

Please sign in to comment.