diff --git a/python/lsst/daf/butler/_query_all_datasets.py b/python/lsst/daf/butler/_query_all_datasets.py index 1432798dda..e8bed6c9e1 100644 --- a/python/lsst/daf/butler/_query_all_datasets.py +++ b/python/lsst/daf/butler/_query_all_datasets.py @@ -113,30 +113,33 @@ def query_all_datasets( raise InvalidQueryError("Can not use wildcards in collections when find_first=True") dataset_type_query = list(ensure_iterable(args.name)) - dataset_type_collections = _filter_collections_and_dataset_types( - butler, args.collections, dataset_type_query - ) - limit = args.limit - for dt, filtered_collections in sorted(dataset_type_collections.items()): - _LOG.debug("Querying dataset type %s", dt) - results = ( - query.datasets(dt, filtered_collections, find_first=args.find_first) - .where(args.data_id, args.where, args.kwargs, bind=args.bind) - .limit(limit) + with butler.registry.caching_context(): + dataset_type_collections = _filter_collections_and_dataset_types( + butler, args.collections, dataset_type_query ) - if args.with_dimension_records: - results = results.with_dimension_records() - - for page in results._iter_pages(): - if limit is not None: - # Track how much of the limit has been used up by each query. - limit -= len(page) - - yield DatasetsPage(dataset_type=dt, data=page) - if limit is not None and limit <= 0: - break + limit = args.limit + for dt, filtered_collections in sorted(dataset_type_collections.items()): + _LOG.debug("Querying dataset type %s", dt) + results = ( + query.datasets(dt, filtered_collections, find_first=args.find_first) + .where(args.data_id, args.where, args.kwargs, bind=args.bind) + .limit(limit) + ) + if args.with_dimension_records: + results = results.with_dimension_records() + + for page in results._iter_pages(): + if limit is not None: + # Track how much of the limit has been used up by each + # query. + limit -= len(page) + + yield DatasetsPage(dataset_type=dt, data=page) + + if limit is not None and limit <= 0: + break def _filter_collections_and_dataset_types( diff --git a/python/lsst/daf/butler/direct_butler/_direct_butler.py b/python/lsst/daf/butler/direct_butler/_direct_butler.py index 69972d8856..4095fd2af5 100644 --- a/python/lsst/daf/butler/direct_butler/_direct_butler.py +++ b/python/lsst/daf/butler/direct_butler/_direct_butler.py @@ -1426,18 +1426,19 @@ def removeRuns(self, names: Iterable[str], unstore: bool = True) -> None: names = list(names) refs: list[DatasetRef] = [] all_dataset_types = [dt.name for dt in self._registry.queryDatasetTypes(...)] - for name in names: - collectionType = self._registry.getCollectionType(name) - if collectionType is not CollectionType.RUN: - raise TypeError(f"The collection type of '{name}' is {collectionType.name}, not RUN.") - with self.query() as query: - # Work out the dataset types that are relevant. - collections_info = self.collections.query_info(name, include_summary=True) - filtered_dataset_types = self.collections._filter_dataset_types( - all_dataset_types, collections_info - ) - for dt in filtered_dataset_types: - refs.extend(query.datasets(dt, collections=name)) + with self._caching_context(): + for name in names: + collectionType = self._registry.getCollectionType(name) + if collectionType is not CollectionType.RUN: + raise TypeError(f"The collection type of '{name}' is {collectionType.name}, not RUN.") + with self.query() as query: + # Work out the dataset types that are relevant. + collections_info = self.collections.query_info(name, include_summary=True) + filtered_dataset_types = self.collections._filter_dataset_types( + all_dataset_types, collections_info + ) + for dt in filtered_dataset_types: + refs.extend(query.datasets(dt, collections=name)) with self._datastore.transaction(), self._registry.transaction(): if unstore: self._datastore.trash(refs) diff --git a/python/lsst/daf/butler/registry/_caching_context.py b/python/lsst/daf/butler/registry/_caching_context.py index 7b021a8516..c5336eb92e 100644 --- a/python/lsst/daf/butler/registry/_caching_context.py +++ b/python/lsst/daf/butler/registry/_caching_context.py @@ -27,6 +27,10 @@ from __future__ import annotations +from collections.abc import Callable, Iterator +from contextlib import AbstractContextManager, contextmanager +from typing import Generic, TypeVar + __all__ = ["CachingContext"] from ._collection_record_cache import CollectionRecordCache @@ -48,46 +52,70 @@ class is passed to the relevant managers that can use it to query or """ def __init__(self) -> None: - self._collection_records: CollectionRecordCache | None = None - self._collection_summaries: CollectionSummaryCache | None = None - self._depth = 0 + self._collection_records = _CacheToggle(CollectionRecordCache) + self._collection_summaries = _CacheToggle(CollectionSummaryCache) - @property - def is_enabled(self) -> bool: - return self._collection_records is not None + def enable_collection_record_cache(self) -> AbstractContextManager[None]: + """Enable the collection record cache. - def _enable(self) -> None: - """Enable caches. - - For use only by RegistryManagerInstances, which is the single point - of entry for enabling and disabling the caches. + Notes + ----- + When this cache is enabled, any changes made by other processes to + collections in the database may not be visible. """ - if self._depth == 0: - self._collection_records = CollectionRecordCache() - self._collection_summaries = CollectionSummaryCache() - self._depth += 1 + return self._collection_records.enable() - def _disable(self) -> None: - """Disable caches. + def enable_collection_summary_cache(self) -> AbstractContextManager[None]: + """Enable the collection summary cache. - For use only by RegistryManagerInstances, which is the single point - of entry for enabling and disabling the caches. + Notes + ----- + When this cache is enabled, changes made by other processes to + collections in the database may not be visible. + + When the collection summary cache is enabled, the performance of + database lookups for summaries changes. Summaries will be aggressively + fetched for all dataset types in the collections, which can cause + significantly more rows to be returned than when the cache is disabled. + This should only be enabled when you know that you will be doing many + summary lookups for the same collections. """ - if self._depth == 1: - self._collection_records = None - self._collection_summaries = None - self._depth = 0 - elif self._depth > 1: - self._depth -= 1 - else: - raise AssertionError("Bad caching context management detected.") + return self._collection_summaries.enable() @property def collection_records(self) -> CollectionRecordCache | None: """Cache for collection records (`CollectionRecordCache`).""" - return self._collection_records + return self._collection_records.cache @property def collection_summaries(self) -> CollectionSummaryCache | None: """Cache for collection summary records (`CollectionSummaryCache`).""" - return self._collection_summaries + return self._collection_summaries.cache + + +_T = TypeVar("_T") + + +class _CacheToggle(Generic[_T]): + """Utility class to track nested enable/disable calls for a cache.""" + + def __init__(self, enable_function: Callable[[], _T]): + self.cache: _T | None = None + self._enable_function = enable_function + self._depth = 0 + + @contextmanager + def enable(self) -> Iterator[None]: + """Context manager to enable the cache. This context may be nested any + number of times, and the cache will only be disabled once all callers + have exited the context manager. + """ + self._depth += 1 + try: + if self._depth == 1: + self.cache = self._enable_function() + yield + finally: + self._depth -= 1 + if self._depth == 0: + self.cache = None diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index d222b03ec5..6580dc0dc8 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -581,7 +581,7 @@ def _modify_collection_chain( skip_caching_check: bool = False, skip_cycle_check: bool = False, ) -> Iterator[_CollectionChainModificationContext[K]]: - if (not skip_caching_check) and self._caching_context.is_enabled: + if (not skip_caching_check) and self._caching_context.collection_records is not None: # Avoid having cache-maintenance code around that is unlikely to # ever be used. raise RuntimeError("Chained collection modification not permitted with active caching context.") diff --git a/python/lsst/daf/butler/registry/managers.py b/python/lsst/daf/butler/registry/managers.py index 159330defc..bd7f35604a 100644 --- a/python/lsst/daf/butler/registry/managers.py +++ b/python/lsst/daf/butler/registry/managers.py @@ -362,11 +362,11 @@ def caching_context_manager(self) -> Iterator[None]: may even be closed out of order, with only the context manager entered and the last context manager exited having any effect. """ - self.caching_context._enable() - try: + with ( + self.caching_context.enable_collection_record_cache(), + self.caching_context.enable_collection_summary_cache(), + ): yield - finally: - self.caching_context._disable() @classmethod def initialize( diff --git a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py index dfc06f1712..7546db7075 100644 --- a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py +++ b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py @@ -86,7 +86,7 @@ def context(self) -> SqlQueryContext: def get_collection_name(self, key: Any) -> str: assert ( - self._managers.caching_context.is_enabled + self._managers.caching_context.collection_records is not None ), "Collection-record caching should already been enabled any time this is called." return self._managers.collections[key].name diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index edc35da3d2..8e90bbacd9 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -2332,7 +2332,9 @@ def _query_driver( default_data_id: DataCoordinate, ) -> Iterator[DirectQueryDriver]: """Set up a `QueryDriver` instance for query execution.""" - with self.caching_context(): + # Query internals do repeated lookups of the same collections, so it + # benefits from the collection record cache. + with self._managers.caching_context.enable_collection_record_cache(): driver = DirectQueryDriver( self._db, self.dimensions, diff --git a/python/lsst/daf/butler/script/removeRuns.py b/python/lsst/daf/butler/script/removeRuns.py index cf16d1812e..9c29172823 100644 --- a/python/lsst/daf/butler/script/removeRuns.py +++ b/python/lsst/daf/butler/script/removeRuns.py @@ -86,29 +86,34 @@ def _getCollectionInfo( The dataset types and and how many will be removed. """ butler = Butler.from_config(repo) - try: - collections = butler.collections.query_info( - collection, CollectionType.RUN, include_chains=False, include_parents=True, include_summary=True - ) - except MissingCollectionError: - # Act as if no collections matched. - collections = [] - dataset_types = [dt.name for dt in butler.registry.queryDatasetTypes(...)] - dataset_types = list(butler.collections._filter_dataset_types(dataset_types, collections)) - - runs = [] - datasets: dict[str, int] = defaultdict(int) - for collection_info in collections: - assert collection_info.type == CollectionType.RUN and collection_info.parents is not None - runs.append(RemoveRun(collection_info.name, list(collection_info.parents))) - with butler.query() as query: - for dt in dataset_types: - results = query.datasets(dt, collections=collection_info.name) - count = results.count(exact=False) - if count: - datasets[dt] += count - - return runs, {k: datasets[k] for k in sorted(datasets.keys())} + with butler.registry.caching_context(): + try: + collections = butler.collections.query_info( + collection, + CollectionType.RUN, + include_chains=False, + include_parents=True, + include_summary=True, + ) + except MissingCollectionError: + # Act as if no collections matched. + collections = [] + dataset_types = [dt.name for dt in butler.registry.queryDatasetTypes(...)] + dataset_types = list(butler.collections._filter_dataset_types(dataset_types, collections)) + + runs = [] + datasets: dict[str, int] = defaultdict(int) + for collection_info in collections: + assert collection_info.type == CollectionType.RUN and collection_info.parents is not None + runs.append(RemoveRun(collection_info.name, list(collection_info.parents))) + with butler.query() as query: + for dt in dataset_types: + results = query.datasets(dt, collections=collection_info.name) + count = results.count(exact=False) + if count: + datasets[dt] += count + + return runs, {k: datasets[k] for k in sorted(datasets.keys())} def removeRuns(