Skip to content

Commit

Permalink
Merge pull request #1135 from lsst/tickets/DM-47980
Browse files Browse the repository at this point in the history
DM-47980: Add an option to include dimension records into general query result
  • Loading branch information
andy-slac authored Dec 19, 2024
2 parents d1c5477 + df7a350 commit 797dd73
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,21 +326,48 @@ class GeneralResultPageConverter(ResultPageConverter): # numpydoc ignore=PR01

def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> None:
self.spec = spec

result_columns = spec.get_result_columns()
# In case `spec.include_dimension_records` is True then in addition to
# columns returned by the query we have to add columns from dimension
# records that are not returned by the query. These columns belong to
# either cached or skypix dimensions.
columns = spec.get_result_columns()
universe = spec.dimensions.universe
self.converters: list[_GeneralColumnConverter] = []
for column in result_columns:
self.record_converters: dict[DimensionElement, _DimensionRecordRowConverter] = {}
for column in columns:
column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field)
converter: _GeneralColumnConverter
if column.field == TimespanDatabaseRepresentation.NAME:
self.converters.append(_TimespanGeneralColumnConverter(column_name, ctx.db))
converter = _TimespanGeneralColumnConverter(column_name, ctx.db)
elif column.field == "ingest_date":
self.converters.append(_TimestampGeneralColumnConverter(column_name))
converter = _TimestampGeneralColumnConverter(column_name)
else:
self.converters.append(_DefaultGeneralColumnConverter(column_name))
converter = _DefaultGeneralColumnConverter(column_name)
self.converters.append(converter)

if spec.include_dimension_records:
universe = self.spec.dimensions.universe
for element_name in self.spec.dimensions.elements:
element = universe[element_name]
if isinstance(element, SkyPixDimension):
self.record_converters[element] = _SkypixDimensionRecordRowConverter(element)
elif element.is_cached:
self.record_converters[element] = _CachedDimensionRecordRowConverter(
element, ctx.dimension_record_cache
)

def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage:
rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows]
return GeneralResultPage(spec=self.spec, rows=rows)
rows = []
dimension_records = None
if self.spec.include_dimension_records:
dimension_records = {element: DimensionRecordSet(element) for element in self.record_converters}
for row in raw_rows:
rows.append(tuple(cvt.convert(row) for cvt in self.converters))
if dimension_records:
for element, converter in self.record_converters.items():
dimension_records[element].add(converter.convert(row))

return GeneralResultPage(spec=self.spec, rows=rows, dimension_records=dimension_records)


class _GeneralColumnConverter:
Expand Down
100 changes: 86 additions & 14 deletions python/lsst/daf/butler/queries/_general_query_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@

from .._dataset_ref import DatasetRef
from .._dataset_type import DatasetType
from ..dimensions import DataCoordinate, DimensionGroup
from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord, DimensionRecordSet
from ._base import QueryResultsBase
from .driver import QueryDriver
from .result_specs import GeneralResultSpec
from .tree import QueryTree
from .tree import QueryTree, ResultColumn


class GeneralResultTuple(NamedTuple):
Expand Down Expand Up @@ -101,7 +101,11 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
for page in self._driver.execute(self._spec, self._tree):
columns = tuple(str(column) for column in page.spec.get_result_columns())
for row in page.rows:
yield dict(zip(columns, row))
result = dict(zip(columns, row, strict=True))
if page.dimension_records:
records = self._get_cached_dimension_records(result, page.dimension_records)
self._add_dimension_records(result, records)
yield result

def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]:
"""Iterate over result rows and return data coordinate, and dataset
Expand All @@ -124,23 +128,40 @@ def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTupl
id_key = f"{dataset_type.name}.dataset_id"
run_key = f"{dataset_type.name}.run"
dataset_keys.append((dataset_type, dimensions, id_key, run_key))
for row in self:
values = tuple(
row[key] for key in itertools.chain(all_dimensions.required, all_dimensions.implied)
)
data_coordinate = DataCoordinate.from_full_values(all_dimensions, values)
refs = []
for dataset_type, dimensions, id_key, run_key in dataset_keys:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_id = DataCoordinate.from_full_values(dimensions, values)
refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)
for page in self._driver.execute(self._spec, self._tree):
columns = tuple(str(column) for column in page.spec.get_result_columns())
for page_row in page.rows:
row = dict(zip(columns, page_row, strict=True))
if page.dimension_records:
cached_records = self._get_cached_dimension_records(row, page.dimension_records)
self._add_dimension_records(row, cached_records)
else:
cached_records = {}
data_coordinate = self._make_data_id(row, all_dimensions, cached_records)
refs = []
for dataset_type, dimensions, id_key, run_key in dataset_keys:
data_id = data_coordinate.subset(dimensions)
refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)

@property
def dimensions(self) -> DimensionGroup:
# Docstring inherited
return self._spec.dimensions

@property
def has_dimension_records(self) -> bool:
"""Whether all data IDs in this iterable contain dimension records."""
return self._spec.include_dimension_records

def with_dimension_records(self) -> GeneralQueryResults:
"""Return a results object for which `has_dimension_records` is
`True`.
"""
if self.has_dimension_records:
return self
return self._copy(tree=self._tree, include_dimension_records=True)

def count(self, *, exact: bool = True, discard: bool = False) -> int:
# Docstring inherited.
return self._driver.count(self._tree, self._spec, exact=exact, discard=discard)
Expand All @@ -152,3 +173,54 @@ def _copy(self, tree: QueryTree, **kwargs: Any) -> GeneralQueryResults:
def _get_datasets(self) -> frozenset[str]:
# Docstring inherited.
return frozenset(self._spec.dataset_fields)

def _make_data_id(
self,
row: dict[str, Any],
dimensions: DimensionGroup,
cached_row_records: dict[DimensionElement, DimensionRecord],
) -> DataCoordinate:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_coordinate = DataCoordinate.from_full_values(dimensions, values)
if self.has_dimension_records:
records = {}
for name in dimensions.elements:
element = dimensions.universe[name]
record = cached_row_records.get(element)
if record is None:
record = self._make_dimension_record(row, dimensions.universe[name])
records[name] = record
data_coordinate = data_coordinate.expanded(records)
return data_coordinate

def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement) -> DimensionRecord:
column_map = list(
zip(
element.schema.dimensions.names,
element.dimensions.names,
)
)
for field in element.schema.remainder.names:
column_map.append((field, str(ResultColumn(element.name, field))))
d = {k: row[v] for k, v in column_map}
record_cls = element.RecordClass
return record_cls(**d)

def _get_cached_dimension_records(
self, row: dict[str, Any], dimension_records: dict[DimensionElement, DimensionRecordSet]
) -> dict[DimensionElement, DimensionRecord]:
"""Find cached dimension records matching this row."""
records = {}
for element, element_records in dimension_records.items():
required_values = tuple(row[key] for key in element.required.names)
records[element] = element_records.find_with_required_values(required_values)
return records

def _add_dimension_records(
self, row: dict[str, Any], records: dict[DimensionElement, DimensionRecord]
) -> None:
"""Extend row with the fields from cached dimension records."""
for element, record in records.items():
for name, value in record.toDict().items():
if name not in element.schema.required.names:
row[f"{element.name}.{name}"] = value
5 changes: 5 additions & 0 deletions python/lsst/daf/butler/queries/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ..dimensions import (
DataCoordinate,
DataIdValue,
DimensionElement,
DimensionGroup,
DimensionRecord,
DimensionRecordSet,
Expand Down Expand Up @@ -120,6 +121,10 @@ class GeneralResultPage:
# spec.get_result_columns().
rows: list[tuple[Any, ...]]

# This map contains dimension records for cached and skypix elements,
# and only when spec.include_dimension_records is True.
dimension_records: dict[DimensionElement, DimensionRecordSet] | None


ResultPage: TypeAlias = Union[
DataCoordinateResultPage, DimensionRecordResultPage, DatasetRefResultPage, GeneralResultPage
Expand Down
11 changes: 11 additions & 0 deletions python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ class GeneralResultSpec(ResultSpecBase):
dataset_fields: Mapping[str, set[DatasetFieldName]]
"""Dataset fields included in this query."""

include_dimension_records: bool = False
"""Whether to include fields for all dimension records, in addition to
explicitly specified in `dimension_fields`.
"""

find_first: bool
"""Whether this query requires find-first resolution for a dataset.
Expand Down Expand Up @@ -241,6 +246,12 @@ def get_result_columns(self) -> ColumnSet:
result.dimension_fields[element_name].update(fields_for_element)
for dataset_type, fields_for_dataset in self.dataset_fields.items():
result.dataset_fields[dataset_type].update(fields_for_dataset)
if self.include_dimension_records:
# This only adds record fields for non-cached and non-skypix
# elements, this is what we want when generating query. When
# `include_dimension_records` is True, dimension records for cached
# and skypix elements are added to result pages by page converter.
_add_dimension_records_to_column_set(self.dimensions, result)
return result

@pydantic.model_validator(mode="after")
Expand Down
32 changes: 29 additions & 3 deletions python/lsst/daf/butler/remote_butler/_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@

from ...butler import Butler
from .._dataset_type import DatasetType
from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionRecord, DimensionUniverse
from ..dimensions import (
DataCoordinate,
DataIdValue,
DimensionGroup,
DimensionRecord,
DimensionRecordSet,
DimensionUniverse,
)
from ..queries.driver import (
DataCoordinateResultPage,
DatasetRefResultPage,
Expand Down Expand Up @@ -257,12 +264,31 @@ def _convert_query_result_page(

def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage:
"""Convert GeneralResultModel to a general result page."""
if spec.include_dimension_records:
# dimension_records must not be None when `include_dimension_records`
# is True, but it will be None if remote server was not upgraded.
if model.dimension_records is None:
raise ValueError(
"Missing dimension records in general result -- " "it is likely that server needs an upgrade."
)

columns = spec.get_result_columns()
serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers))
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in model.rows
]
return GeneralResultPage(spec=spec, rows=rows)

universe = spec.dimensions.universe
dimension_records = None
if model.dimension_records is not None:
dimension_records = {}
for name, records in model.dimension_records.items():
element = universe[name]
dimension_records[element] = DimensionRecordSet(
element, (DimensionRecord.from_simple(r, universe) for r in records)
)

return GeneralResultPage(spec=spec, rows=rows, dimension_records=dimension_records)
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel:
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers)) for row in page.rows
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in page.rows
]
return GeneralResultModel(rows=rows)
dimension_records = None
if page.dimension_records is not None:
dimension_records = {
element.name: [record.to_simple() for record in records]
for element, records in page.dimension_records.items()
}
return GeneralResultModel(rows=rows, dimension_records=dimension_records)
4 changes: 4 additions & 0 deletions python/lsst/daf/butler/remote_butler/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ class GeneralResultModel(pydantic.BaseModel):

type: Literal["general"] = "general"
rows: list[tuple[Any, ...]]
# Dimension records indexed by element name, only cached and skypix
# elements are included. Default is used for compatibility with older
# servers that do not set this field.
dimension_records: dict[str, list[SerializedDimensionRecord]] | None = None


class QueryErrorResultModel(pydantic.BaseModel):
Expand Down
42 changes: 42 additions & 0 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ def test_general_query(self) -> None:
self.assertEqual(len(row_tuple.refs), 1)
self.assertEqual(row_tuple.refs[0].datasetType, flat)
self.assertTrue(row_tuple.refs[0].dataId.hasFull())
self.assertFalse(row_tuple.refs[0].dataId.hasRecords())
self.assertTrue(row_tuple.data_id.hasFull())
self.assertFalse(row_tuple.data_id.hasRecords())
self.assertEqual(row_tuple.data_id.dimensions, dimensions)
self.assertEqual(row_tuple.raw_row["flat.run"], "imported_g")

Expand Down Expand Up @@ -511,6 +513,46 @@ def test_general_query(self) -> None:
{Timespan(t1, t2), Timespan(t2, t3), Timespan(t3, None), Timespan.makeEmpty(), None},
)

dimensions = butler.dimensions["detector"].minimal_group

# Include dimension records into query.
with butler.query() as query:
query = query.join_dimensions(dimensions)
result = query.general(dimensions).order_by("detector")
rows = list(result.with_dimension_records())
self.assertEqual(
rows[0],
{
"instrument": "Cam1",
"detector": 1,
"instrument.visit_max": 1024,
"instrument.visit_system": 1,
"instrument.exposure_max": 512,
"instrument.detector_max": 4,
"instrument.class_name": "lsst.pipe.base.Instrument",
"detector.full_name": "Aa",
"detector.name_in_raft": "a",
"detector.raft": "A",
"detector.purpose": "SCIENCE",
},
)

dimensions = butler.dimensions.conform(["detector", "physical_filter"])

# DataIds should come with records.
with butler.query() as query:
query = query.join_dataset_search("flat", "imported_g")
result = query.general(dimensions, dataset_fields={"flat": ...}, find_first=True).order_by(
"detector"
)
result = result.with_dimension_records()
row_tuples = list(result.iter_tuples(flat))
self.assertEqual(len(row_tuples), 3)
for row_tuple in row_tuples:
self.assertTrue(row_tuple.data_id.hasRecords())
self.assertEqual(len(row_tuple.refs), 1)
self.assertTrue(row_tuple.refs[0].dataId.hasRecords())

def test_query_ingest_date(self) -> None:
"""Test general query returning ingest_date field."""
before_ingest = astropy.time.Time.now()
Expand Down

0 comments on commit 797dd73

Please sign in to comment.