diff --git a/doc/changes/DM-53233.bugfix.md b/doc/changes/DM-53233.bugfix.md new file mode 100644 index 0000000000..9c97ffb45a --- /dev/null +++ b/doc/changes/DM-53233.bugfix.md @@ -0,0 +1,3 @@ +The unit tests will no longer fail due to file descriptor exhaustion. +All `ResourceWarning` messages in the unit test suite have been fixed. +Butler instances are now garbage-collected immediately when they are not longer referenced, although their database connections still might not be released until mark-and-sweep garbage collection runs. diff --git a/doc/changes/DM-53233.feature.md b/doc/changes/DM-53233.feature.md new file mode 100644 index 0000000000..7f4ec02d0f --- /dev/null +++ b/doc/changes/DM-53233.feature.md @@ -0,0 +1 @@ +Added `Butler.close()`, a context manager implementation for `Butler`, and `QuantumBackedButler.close()`. These can be used to ensure that database connections are closed deterministically, rather than waiting for mark-and-sweep garbage collection. diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py index be880678b4..a5d3daef9c 100644 --- a/python/lsst/daf/butler/_butler.py +++ b/python/lsst/daf/butler/_butler.py @@ -36,7 +36,7 @@ from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence from contextlib import AbstractContextManager from types import EllipsisType -from typing import TYPE_CHECKING, Any, TextIO +from typing import TYPE_CHECKING, Any, Literal, Self, TextIO from lsst.resources import ResourcePath, ResourcePathExpression from lsst.utils import doImportType @@ -94,7 +94,7 @@ class _DeprecatedDefault: """Default value for a deprecated parameter.""" -class Butler(LimitedButler): # numpydoc ignore=PR02 +class Butler(LimitedButler, AbstractContextManager): # numpydoc ignore=PR02 """Interface for data butler and factory for Butler instances. Parameters @@ -358,6 +358,16 @@ def from_config( case _: raise TypeError(f"Unknown Butler type '{butler_type}'") + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: + try: + self.close() + except Exception: + _LOG.exception("An exception occured during Butler.close()") + return False + @staticmethod def makeRepo( root: ResourcePathExpression, @@ -506,9 +516,10 @@ def makeRepo( # Create Registry and populate tables registryConfig = RegistryConfig(config.get("registry")) dimensionConfig = DimensionConfig(dimensionConfig) - _RegistryFactory(registryConfig).create_from_config( + registry = _RegistryFactory(registryConfig).create_from_config( dimensionConfig=dimensionConfig, butlerRoot=root_uri ) + registry.close() _LOG.verbose("Wrote new Butler configuration file to %s", configURI) @@ -2222,3 +2233,18 @@ def clone( Metrics object to record butler statistics. """ raise NotImplementedError() + + @abstractmethod + def close(self) -> None: + """Release all resources associated with this Butler instance. The + instance may no longer be used after this is called. + + Notes + ----- + Instead of calling ``close()``directly, you can use the Butler object + as a context manager. For example:: + with Butler(...) as butler: + butler.get(...) + # butler is closed after exiting the block. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/_labeled_butler_factory.py b/python/lsst/daf/butler/_labeled_butler_factory.py index e96df9d5c4..f4fe46a6bf 100644 --- a/python/lsst/daf/butler/_labeled_butler_factory.py +++ b/python/lsst/daf/butler/_labeled_butler_factory.py @@ -25,9 +25,11 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + __all__ = ("LabeledButlerFactory", "LabeledButlerFactoryProtocol") -from collections.abc import Callable, Mapping +from collections.abc import Mapping from typing import Protocol from lsst.resources import ResourcePathExpression @@ -38,10 +40,6 @@ from ._utilities.named_locks import NamedLocks from ._utilities.thread_safe_cache import ThreadSafeCache -_FactoryFunction = Callable[[str | None], Butler] -"""Function that takes an access token string or `None`, and returns a Butler -instance.""" - class LabeledButlerFactoryProtocol(Protocol): """Callable to retrieve a butler from a label.""" @@ -84,7 +82,7 @@ def __init__(self, repositories: Mapping[str, str] | None = None) -> None: else: self._repositories = dict(repositories) - self._factories = ThreadSafeCache[str, _FactoryFunction]() + self._factories = ThreadSafeCache[str, _ButlerFactory]() self._initialization_locks = NamedLocks() # This may be overridden by unit tests. @@ -138,10 +136,18 @@ def create_butler(self, *, label: str, access_token: str | None) -> Butler: based on the end user instead of the service. See https://gafaelfawr.lsst.io/user-guide/gafaelfawringress.html#requesting-delegated-tokens """ - factory = self._get_or_create_butler_factory_function(label) - return factory(access_token) + factory = self._get_or_create_butler_factory(label) + return factory.create_butler(access_token) + + def close(self) -> None: + """Reset the factory cache, and release any resources associated with + the cached instances. + """ + factories = self._factories.clear() + for factory in factories.values(): + factory.close() - def _get_or_create_butler_factory_function(self, label: str) -> _FactoryFunction: + def _get_or_create_butler_factory(self, label: str) -> _ButlerFactory: # We maintain a separate lock per label. We only want to instantiate # one factory function per label, because creating the factory sets up # shared state that should only exist once per repository. However, we @@ -154,16 +160,16 @@ def _get_or_create_butler_factory_function(self, label: str) -> _FactoryFunction factory = self._create_butler_factory_function(label) return self._factories.set_or_get(label, factory) - def _create_butler_factory_function(self, label: str) -> _FactoryFunction: + def _create_butler_factory_function(self, label: str) -> _ButlerFactory: config_uri = self._get_config_uri(label) config = ButlerConfig(config_uri) butler_type = config.get_butler_type() match butler_type: case ButlerType.DIRECT: - return _create_direct_butler_factory(config, self._preload_unsafe_direct_butler_caches) + return _DirectButlerFactory(config, self._preload_unsafe_direct_butler_caches) case ButlerType.REMOTE: - return _create_remote_butler_factory(config) + return _RemoteButlerFactory(config) case _: raise TypeError(f"Unknown butler type '{butler_type}' for label '{label}'") @@ -177,34 +183,45 @@ def _get_config_uri(self, label: str) -> ResourcePathExpression: return config_uri -def _create_direct_butler_factory(config: ButlerConfig, preload_unsafe_caches: bool) -> _FactoryFunction: - import lsst.daf.butler.direct_butler +class _ButlerFactory(Protocol): + def create_butler(self, access_token: str | None) -> Butler: ... + def close(self) -> None: ... + + +class _DirectButlerFactory(_ButlerFactory): + def __init__(self, config: ButlerConfig, preload_unsafe_caches: bool) -> None: + import lsst.daf.butler.direct_butler - # Create a 'template' Butler that will be cloned when callers request an - # instance. - butler = Butler.from_config(config) - assert isinstance(butler, lsst.daf.butler.direct_butler.DirectButler) + # Create a 'template' Butler that will be cloned when callers request + # an instance. + self._butler = Butler.from_config(config) + assert isinstance(self._butler, lsst.daf.butler.direct_butler.DirectButler) - # Load caches so that data is available in cloned instances without - # needing to refetch it from the database for every instance. - butler._preload_cache(load_dimension_record_cache=preload_unsafe_caches) + # Load caches so that data is available in cloned instances without + # needing to refetch it from the database for every instance. + self._butler._preload_cache(load_dimension_record_cache=preload_unsafe_caches) - def create_butler(access_token: str | None) -> Butler: + def create_butler(self, access_token: str | None) -> Butler: # Access token is ignored because DirectButler does not use Gafaelfawr # authentication. - return butler.clone() + return self._butler.clone() - return create_butler + def close(self) -> None: + self._butler.close() -def _create_remote_butler_factory(config: ButlerConfig) -> _FactoryFunction: - import lsst.daf.butler.remote_butler._factory +class _RemoteButlerFactory(_ButlerFactory): + def __init__(self, config: ButlerConfig) -> None: + import lsst.daf.butler.remote_butler._factory - factory = lsst.daf.butler.remote_butler._factory.RemoteButlerFactory.create_factory_from_config(config) + self._factory = lsst.daf.butler.remote_butler._factory.RemoteButlerFactory.create_factory_from_config( + config + ) - def create_butler(access_token: str | None) -> Butler: + def create_butler(self, access_token: str | None) -> Butler: if access_token is None: raise ValueError("Access token is required to connect to a Butler server") - return factory.create_butler_for_access_token(access_token) + return self._factory.create_butler_for_access_token(access_token) - return create_butler + def close(self) -> None: + pass diff --git a/python/lsst/daf/butler/_quantum_backed.py b/python/lsst/daf/butler/_quantum_backed.py index c626a59351..41bfb42f25 100644 --- a/python/lsst/daf/butler/_quantum_backed.py +++ b/python/lsst/daf/butler/_quantum_backed.py @@ -55,7 +55,7 @@ from .datastore.record_data import DatastoreRecordData, SerializedDatastoreRecordData from .datastores.file_datastore.retrieve_artifacts import retrieve_and_zip from .dimensions import DimensionUniverse -from .registry.interfaces import DatastoreRegistryBridgeManager, OpaqueTableStorageManager +from .registry.interfaces import Database, DatastoreRegistryBridgeManager, OpaqueTableStorageManager if TYPE_CHECKING: from ._butler import Butler @@ -83,6 +83,9 @@ class QuantumBackedButler(LimitedButler): The registry dataset type definitions, indexed by name. metrics : `lsst.daf.butler.ButlerMetrics` or `None`, optional Metrics object for tracking butler statistics. + database : `Database`, optional + Database instance used by datastore. Not required -- only provided + to allow database connections to be closed during cleanup. Notes ----- @@ -130,6 +133,7 @@ def __init__( storageClasses: StorageClassFactory, dataset_types: Mapping[str, DatasetType] | None = None, metrics: ButlerMetrics | None = None, + database: Database | None = None, ): self._dimensions = dimensions self._predicted_inputs = set(predicted_inputs) @@ -142,6 +146,7 @@ def __init__( self.storageClasses = storageClasses self._dataset_types: Mapping[str, DatasetType] = {} self._metrics = metrics if metrics is not None else ButlerMetrics() + self._database = database if dataset_types is not None: self._dataset_types = dataset_types self._datastore.set_retrieve_dataset_type_method(self._retrieve_dataset_type) @@ -321,7 +326,7 @@ def _initialize( Metrics object for gathering butler statistics. """ butler_config = ButlerConfig(config, searchPaths=search_paths) - datastore, _ = instantiate_standalone_datastore( + datastore, database = instantiate_standalone_datastore( butler_config, dimensions, filename, OpaqueManagerClass, BridgeManagerClass ) @@ -342,8 +347,13 @@ def _initialize( storageClasses=storageClasses, dataset_types=dataset_types, metrics=metrics, + database=database, ) + def close(self) -> None: + if self._database is not None: + self._database.dispose() + def _retrieve_dataset_type(self, name: str) -> DatasetType | None: """Return DatasetType defined in registry given dataset type name.""" return self._dataset_types.get(name) diff --git a/python/lsst/daf/butler/_utilities/thread_safe_cache.py b/python/lsst/daf/butler/_utilities/thread_safe_cache.py index c8b06de70b..ac850893d0 100644 --- a/python/lsst/daf/butler/_utilities/thread_safe_cache.py +++ b/python/lsst/daf/butler/_utilities/thread_safe_cache.py @@ -76,3 +76,16 @@ def set_or_get(self, key: TKey, value: TValue) -> TValue: """ with self._mutex: return self._cache.setdefault(key, value) + + def clear(self) -> dict[TKey, TValue]: + """Clear the cache. + + Returns + ------- + old_cache : `dict` + The values that were contained in the cache prior to clearing it. + """ + with self._mutex: + old = self._cache + self._cache = {} + return old diff --git a/python/lsst/daf/butler/direct_butler/_direct_butler.py b/python/lsst/daf/butler/direct_butler/_direct_butler.py index 7a4724f2fc..d931b4533c 100644 --- a/python/lsst/daf/butler/direct_butler/_direct_butler.py +++ b/python/lsst/daf/butler/direct_butler/_direct_butler.py @@ -45,6 +45,7 @@ import warnings from collections import Counter, defaultdict from collections.abc import Collection, Iterable, Iterator, MutableMapping, Sequence +from functools import partial from types import EllipsisType from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, TextIO, cast @@ -160,9 +161,9 @@ def __new__( # dependency-inversion trick. This is not used by regular butler, # but we do not have a way to distinguish regular butler from execution # butler. - self._datastore.set_retrieve_dataset_type_method(self._retrieve_dataset_type) + self._datastore.set_retrieve_dataset_type_method(partial(_retrieve_dataset_type, registry)) - self._registry_shim = RegistryShim(self) + self._closed = False return self @@ -250,6 +251,18 @@ def clone( metrics=metrics, ) + def close(self) -> None: + if not self._closed: + self._closed = True + self._registry.close() + # Cause exceptions to be raised if a user attempts to use the + # instance after closing it. Without this, Butler would still + # work after being closed because of implementation details + # of SqlAlchemy, but this may not continue to be the case in the + # future and we don't want users to get in the habit of doing this. + self._registry = _BUTLER_CLOSED_INSTANCE + self._datastore = _BUTLER_CLOSED_INSTANCE + GENERATION: ClassVar[int] = 3 """This is a Generation 3 Butler. @@ -258,13 +271,6 @@ def clone( code. """ - def _retrieve_dataset_type(self, name: str) -> DatasetType | None: - """Return DatasetType defined in registry given dataset type name.""" - try: - return self.get_dataset_type(name) - except MissingDatasetTypeError: - return None - @classmethod def _unpickle( cls, @@ -2530,7 +2536,7 @@ def registry(self) -> Registry: are accessible only via `Registry` methods. Eventually these methods will be replaced by equivalent `Butler` methods. """ - return self._registry_shim + return RegistryShim(self) @property def dimensions(self) -> DimensionUniverse: @@ -2574,22 +2580,14 @@ def _preload_cache(self, *, load_dimension_record_cache: bool = True) -> None: accessible only via `SqlRegistry` methods. """ - datastore: Datastore - """The object that manages actual dataset storage (`Datastore`). - - Direct user access to the datastore should rarely be necessary; the primary - exception is the case where a `Datastore` implementation provides extra - functionality beyond what the base class defines. - """ - storageClasses: StorageClassFactory """An object that maps known storage class names to objects that fully describe them (`StorageClassFactory`). """ - _registry_shim: RegistryShim - """Shim object to provide a legacy public interface for querying via the - the ``registry`` property. + _closed: bool + """`True` if close() has already been called on this instance; `False` + otherwise. """ @@ -2614,3 +2612,19 @@ def _to_uuid(id: DatasetId | str) -> uuid.UUID: return id else: return uuid.UUID(id) + + +class _ButlerClosed: + def __getattr__(self, name: str) -> Any: + raise RuntimeError("Attempted to use a Butler instance which has been closed.") + + +_BUTLER_CLOSED_INSTANCE: Any = _ButlerClosed() + + +def _retrieve_dataset_type(registry: SqlRegistry, name: str) -> DatasetType | None: + """Return DatasetType defined in registry given dataset type name.""" + try: + return registry.getDatasetType(name) + except MissingDatasetTypeError: + return None diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 1d105866ab..4e3123f526 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -175,6 +175,7 @@ def createFromConfig( elif not isinstance(dimensionConfig, DimensionConfig): raise TypeError(f"Incompatible Dimension configuration type: {type(dimensionConfig)}") + managerTypes = RegistryManagerTypes.fromConfig(config) DatabaseClass = config.getDatabaseClass() database = DatabaseClass.fromUri( config.connectionString, @@ -182,9 +183,13 @@ def createFromConfig( namespace=config.get("namespace"), allow_temporary_tables=config.areTemporaryTablesAllowed, ) - managerTypes = RegistryManagerTypes.fromConfig(config) - managers = managerTypes.makeRepo(database, dimensionConfig) - return cls(database, RegistryDefaults(), managers) + + try: + managers = managerTypes.makeRepo(database, dimensionConfig) + return cls(database, RegistryDefaults(), managers) + except Exception: + database.dispose() + raise @classmethod def fromConfig( @@ -217,6 +222,8 @@ def fromConfig( """ config = cls.forceRegistryConfig(config) config.replaceRoot(butlerRoot) + if defaults is None: + defaults = RegistryDefaults() DatabaseClass = config.getDatabaseClass() database = DatabaseClass.fromUri( config.connectionString, @@ -225,13 +232,15 @@ def fromConfig( writeable=writeable, allow_temporary_tables=config.areTemporaryTablesAllowed, ) - managerTypes = RegistryManagerTypes.fromConfig(config) - with database.session(): - managers = managerTypes.loadRepo(database) - if defaults is None: - defaults = RegistryDefaults() + try: + managerTypes = RegistryManagerTypes.fromConfig(config) + with database.session(): + managers = managerTypes.loadRepo(database) - return cls(database, defaults, managers) + return cls(database, defaults, managers) + except Exception: + database.dispose() + raise def __init__( self, @@ -258,6 +267,17 @@ def __init__( # eventually we'll need to do it during construction. # The mapping is indexed by the opaque table name. self._datastore_record_classes: Mapping[str, type[StoredDatastoreItemInfo]] = {} + self._is_clone = False + + def close(self) -> None: + # Connection pool is shared between cloned instances, so only the root + # instance should close it. + # Note: The underlying SQLAlchemy call will create a fresh connection + # pool, so nothing breaks if the root instance is accidentally closed + # before the clones are finished -- we just have a small performance + # hit from re-creating the connections. + if not self._is_clone: + self._db.dispose() def __str__(self) -> str: return str(self._db) @@ -296,6 +316,7 @@ def copy(self, defaults: RegistryDefaults | None = None) -> SqlRegistry: result = SqlRegistry(db, defaults, self._managers.clone(db)) result._datastore_record_classes = dict(self._datastore_record_classes) result.dimension_record_cache.load_from(self.dimension_record_cache) + result._is_clone = True return result @property diff --git a/python/lsst/daf/butler/remote_butler/_remote_butler.py b/python/lsst/daf/butler/remote_butler/_remote_butler.py index a9e55ffa5c..f88c940877 100644 --- a/python/lsst/daf/butler/remote_butler/_remote_butler.py +++ b/python/lsst/daf/butler/remote_butler/_remote_butler.py @@ -735,6 +735,9 @@ def clone( connection=self._connection, cache=self._cache, defaults=defaults, metrics=metrics ) + def close(self) -> None: + pass + @property def _file_transfer_source(self) -> RemoteFileTransferSource: return RemoteFileTransferSource(self._connection) diff --git a/python/lsst/daf/butler/script/_associate.py b/python/lsst/daf/butler/script/_associate.py index 90e5849cb2..75160f0370 100644 --- a/python/lsst/daf/butler/script/_associate.py +++ b/python/lsst/daf/butler/script/_associate.py @@ -64,20 +64,18 @@ def associate( unlimited. A negative value is used to specify a cap where a warning is issued if that cap is hit. """ - butler = Butler.from_config(repo, writeable=True, without_datastore=True) + with Butler.from_config(repo, writeable=True, without_datastore=True) as butler: + butler.collections.register(collection, CollectionType.TAGGED) - butler.collections.register(collection, CollectionType.TAGGED) + results = QueryDatasets( + butler=butler, + glob=dataset_type, + collections=collections, + where=where, + find_first=find_first, + limit=limit, + order_by=(), + show_uri=False, + ) - results = QueryDatasets( - butler=butler, - glob=dataset_type, - collections=collections, - where=where, - find_first=find_first, - limit=limit, - order_by=(), - show_uri=False, - repo=None, - ) - - butler.registry.associate(collection, itertools.chain(*results.getDatasets())) + butler.registry.associate(collection, itertools.chain(*results.getDatasets())) diff --git a/python/lsst/daf/butler/script/_pruneDatasets.py b/python/lsst/daf/butler/script/_pruneDatasets.py index d2c535d86b..04841c2fd8 100644 --- a/python/lsst/daf/butler/script/_pruneDatasets.py +++ b/python/lsst/daf/butler/script/_pruneDatasets.py @@ -222,28 +222,29 @@ def pruneDatasets( return PruneDatasetsResult(state=PruneDatasetsResult.State.ERR_NO_COLLECTION_RESTRICTION) # If purging, verify that the collection to purge is RUN type collection. - if purge_run: - butler = Butler.from_config(repo, without_datastore=True) - collection_info = butler.collections.get_info(purge_run) - if collection_info.type is not CollectionType.RUN: - return PruneDatasetsResult( - state=PruneDatasetsResult.State.ERR_PRUNE_ON_NOT_RUN, errDict=dict(collection=purge_run) - ) + with Butler.from_config(repo, without_datastore=True) as butler: + if purge_run: + collection_info = butler.collections.get_info(purge_run) + if collection_info.type is not CollectionType.RUN: + return PruneDatasetsResult( + state=PruneDatasetsResult.State.ERR_PRUNE_ON_NOT_RUN, errDict=dict(collection=purge_run) + ) + + datasets_found = QueryDatasets( + butler=butler, + glob=datasets, + collections=collections, + where=where, + # By default we want find_first to be True if collections are + # provided + # (else False) (find_first requires collections to be provided). + # But the user may specify that they want to find all (thus forcing + # find_first to be False) + find_first=not find_all, + show_uri=False, + ) - datasets_found = QueryDatasets( - repo=repo, - glob=datasets, - collections=collections, - where=where, - # By default we want find_first to be True if collections are provided - # (else False) (find_first requires collections to be provided). - # But the user may specify that they want to find all (thus forcing - # find_first to be False) - find_first=not find_all, - show_uri=False, - ) - - result = PruneDatasetsResult(list(datasets_found.getTables())) + result = PruneDatasetsResult(list(datasets_found.getTables())) disassociate = bool(disassociate_tags) or bool(purge_run) purge = bool(purge_run) @@ -255,16 +256,16 @@ def pruneDatasets( return result def doPruneDatasets() -> PruneDatasetsResult: - butler = Butler.from_config(repo, writeable=True) - butler.pruneDatasets( - refs=list(itertools.chain(*datasets_found.getDatasets())), - disassociate=disassociate, - tags=disassociate_tags or (), - purge=purge, - unstore=unstore, - ) - result.state = PruneDatasetsResult.State.FINISHED - return result + with Butler.from_config(repo, writeable=True) as butler: + butler.pruneDatasets( + refs=list(itertools.chain(*datasets_found.getDatasets())), + disassociate=disassociate, + tags=disassociate_tags or (), + purge=purge, + unstore=unstore, + ) + result.state = PruneDatasetsResult.State.FINISHED + return result if confirm: result.state = PruneDatasetsResult.State.AWAITING_CONFIRMATION diff --git a/python/lsst/daf/butler/script/butlerImport.py b/python/lsst/daf/butler/script/butlerImport.py index ccf36a627c..cb1d7c6034 100644 --- a/python/lsst/daf/butler/script/butlerImport.py +++ b/python/lsst/daf/butler/script/butlerImport.py @@ -64,16 +64,15 @@ def butlerImport( be tracked by the datastore. Whether this parameter is honored depends on the specific datastore implementation. """ - butler = Butler.from_config(repo, writeable=True) + with Butler.from_config(repo, writeable=True) as butler: + if skip_dimensions is not None: + skip_dimensions = set(skip_dimensions) - if skip_dimensions is not None: - skip_dimensions = set(skip_dimensions) - - butler.import_( - directory=directory, - filename=export_file, - transfer=transfer, - format="yaml", - skip_dimensions=skip_dimensions, - record_validation_info=track_file_attrs, - ) + butler.import_( + directory=directory, + filename=export_file, + transfer=transfer, + format="yaml", + skip_dimensions=skip_dimensions, + record_validation_info=track_file_attrs, + ) diff --git a/python/lsst/daf/butler/script/certifyCalibrations.py b/python/lsst/daf/butler/script/certifyCalibrations.py index 0285bbb5b4..0cc60abda0 100644 --- a/python/lsst/daf/butler/script/certifyCalibrations.py +++ b/python/lsst/daf/butler/script/certifyCalibrations.py @@ -69,23 +69,23 @@ def certifyCalibrations( Search all children of the inputCollection if it is a CHAINED collection, instead of just the most recent one. """ - butler = Butler.from_config(repo, writeable=True, without_datastore=True) - timespan = Timespan( - begin=astropy.time.Time(begin_date, scale="tai") if begin_date is not None else None, - end=astropy.time.Time(end_date, scale="tai") if end_date is not None else None, - ) - if not search_all_inputs: - collection_info = butler.collections.get_info(input_collection) - if collection_info.type is CollectionType.CHAINED: - input_collection = collection_info.children[0] + with Butler.from_config(repo, writeable=True, without_datastore=True) as butler: + timespan = Timespan( + begin=astropy.time.Time(begin_date, scale="tai") if begin_date is not None else None, + end=astropy.time.Time(end_date, scale="tai") if end_date is not None else None, + ) + if not search_all_inputs: + collection_info = butler.collections.get_info(input_collection) + if collection_info.type is CollectionType.CHAINED: + input_collection = collection_info.children[0] - with butler.query() as query: - results = query.datasets(dataset_type_name, collections=input_collection) - refs = set(results) - if not refs: - explanation = "\n".join(results.explain_no_results()) - raise RuntimeError( - f"No inputs found for dataset {dataset_type_name} in {input_collection}. {explanation}" - ) - butler.collections.register(output_collection, type=CollectionType.CALIBRATION) - butler.registry.certify(output_collection, refs, timespan) + with butler.query() as query: + results = query.datasets(dataset_type_name, collections=input_collection) + refs = set(results) + if not refs: + explanation = "\n".join(results.explain_no_results()) + raise RuntimeError( + f"No inputs found for dataset {dataset_type_name} in {input_collection}. {explanation}" + ) + butler.collections.register(output_collection, type=CollectionType.CALIBRATION) + butler.registry.certify(output_collection, refs, timespan) diff --git a/python/lsst/daf/butler/script/collectionChain.py b/python/lsst/daf/butler/script/collectionChain.py index 3f6f946efd..1e49cbb5dd 100644 --- a/python/lsst/daf/butler/script/collectionChain.py +++ b/python/lsst/daf/butler/script/collectionChain.py @@ -72,34 +72,33 @@ def collectionChain( chain : `tuple` of `str` The collections in the chain following this command. """ - butler = Butler.from_config(repo, writeable=True, without_datastore=True) - - # Every mode needs children except pop. - if not children and mode != "pop": - raise RuntimeError(f"Must provide children when defining a collection chain in mode {mode}.") - - try: - butler.collections.get_info(parent) - except MissingCollectionError: - # Create it -- but only if mode can work with empty chain. - if mode in ("redefine", "extend", "prepend"): - if not doc: - doc = None - butler.collections.register(parent, CollectionType.CHAINED, doc) - else: - raise RuntimeError( - f"Mode '{mode}' requires that the collection exists " - f"but collection '{parent}' is not known to this registry" - ) from None - - if flatten: - if mode not in ("redefine", "prepend", "extend"): - raise RuntimeError(f"'flatten' flag is not allowed for {mode}") - children = butler.collections.query(children, flatten_chains=True) - - _modify_collection_chain(butler, mode, parent, children) - - return butler.collections.get_info(parent).children + with Butler.from_config(repo, writeable=True, without_datastore=True) as butler: + # Every mode needs children except pop. + if not children and mode != "pop": + raise RuntimeError(f"Must provide children when defining a collection chain in mode {mode}.") + + try: + butler.collections.get_info(parent) + except MissingCollectionError: + # Create it -- but only if mode can work with empty chain. + if mode in ("redefine", "extend", "prepend"): + if not doc: + doc = None + butler.collections.register(parent, CollectionType.CHAINED, doc) + else: + raise RuntimeError( + f"Mode '{mode}' requires that the collection exists " + f"but collection '{parent}' is not known to this registry" + ) from None + + if flatten: + if mode not in ("redefine", "prepend", "extend"): + raise RuntimeError(f"'flatten' flag is not allowed for {mode}") + children = butler.collections.query(children, flatten_chains=True) + + _modify_collection_chain(butler, mode, parent, children) + + return butler.collections.get_info(parent).children def _modify_collection_chain(butler: Butler, mode: str, parent: str, children: Iterable[str]) -> None: diff --git a/python/lsst/daf/butler/script/configValidate.py b/python/lsst/daf/butler/script/configValidate.py index ff50aa8ba6..6715ca65ff 100644 --- a/python/lsst/daf/butler/script/configValidate.py +++ b/python/lsst/daf/butler/script/configValidate.py @@ -52,12 +52,14 @@ def configValidate(repo: str, quiet: bool, dataset_type: list[str], ignore: list error. """ logFailures = not quiet - butler = Butler.from_config(config=repo) - is_good = True - try: - butler.validateConfiguration(logFailures=logFailures, datasetTypeNames=dataset_type, ignore=ignore) - except ValidationError: - is_good = False - else: - print("No problems encountered with configuration.") - return is_good + with Butler.from_config(config=repo) as butler: + is_good = True + try: + butler.validateConfiguration( + logFailures=logFailures, datasetTypeNames=dataset_type, ignore=ignore + ) + except ValidationError: + is_good = False + else: + print("No problems encountered with configuration.") + return is_good diff --git a/python/lsst/daf/butler/script/exportCalibs.py b/python/lsst/daf/butler/script/exportCalibs.py index 07be56c19b..bb2355ef2b 100644 --- a/python/lsst/daf/butler/script/exportCalibs.py +++ b/python/lsst/daf/butler/script/exportCalibs.py @@ -120,65 +120,64 @@ def exportCalibs( RuntimeError Raised if the output directory already exists. """ - butler = Butler.from_config(repo, writeable=False) - - dataset_type_query = dataset_type or ... - collections_query = collections or "*" - - calibTypes = [ - datasetType - for datasetType in butler.registry.queryDatasetTypes(dataset_type_query) - if datasetType.isCalibration() - ] - - collectionsToExport = [] - datasetsToExport = [] - - for collection in butler.collections.query_info( - collections_query, - flatten_chains=True, - include_chains=True, - include_doc=True, - collection_types={CollectionType.CALIBRATION, CollectionType.CHAINED}, - ): - log.info("Checking collection: %s", collection.name) - - # Get collection information. - collectionsToExport.append(collection.name) - if collection.type == CollectionType.CALIBRATION: - exportDatasets = find_calibration_datasets(butler, collection, calibTypes) - datasetsToExport.extend(exportDatasets) - - if os.path.exists(directory): - raise RuntimeError(f"Export directory exists: {directory}") - os.makedirs(directory) - with butler.export(directory=directory, format="yaml", transfer=transfer) as export: - collectionsToExport = list(set(collectionsToExport)) - datasetsToExport = list(set(datasetsToExport)) - - for exportable in collectionsToExport: - try: - export.saveCollection(exportable) - except Exception as e: - log.warning("Did not save collection %s due to %s.", exportable, e) - - log.info("Saving %d dataset(s)", len(datasetsToExport)) - export.saveDatasets(datasetsToExport) - - sortedDatasets = sorted(datasetsToExport, key=attrgetter("datasetType.name", "dataId")) - - requiredDimensions: set[str] = set() - for ref in sortedDatasets: - requiredDimensions.update(ref.dimensions.names) - dimensionColumns = { - dimensionName: [ref.dataId.get(dimensionName, "") for ref in sortedDatasets] - for dimensionName in sorted(requiredDimensions) - } - - return Table( - { - "calibrationType": [ref.datasetType.name for ref in sortedDatasets], - "run": [ref.run for ref in sortedDatasets], - **dimensionColumns, + with Butler.from_config(repo, writeable=False) as butler: + dataset_type_query = dataset_type or ... + collections_query = collections or "*" + + calibTypes = [ + datasetType + for datasetType in butler.registry.queryDatasetTypes(dataset_type_query) + if datasetType.isCalibration() + ] + + collectionsToExport = [] + datasetsToExport = [] + + for collection in butler.collections.query_info( + collections_query, + flatten_chains=True, + include_chains=True, + include_doc=True, + collection_types={CollectionType.CALIBRATION, CollectionType.CHAINED}, + ): + log.info("Checking collection: %s", collection.name) + + # Get collection information. + collectionsToExport.append(collection.name) + if collection.type == CollectionType.CALIBRATION: + exportDatasets = find_calibration_datasets(butler, collection, calibTypes) + datasetsToExport.extend(exportDatasets) + + if os.path.exists(directory): + raise RuntimeError(f"Export directory exists: {directory}") + os.makedirs(directory) + with butler.export(directory=directory, format="yaml", transfer=transfer) as export: + collectionsToExport = list(set(collectionsToExport)) + datasetsToExport = list(set(datasetsToExport)) + + for exportable in collectionsToExport: + try: + export.saveCollection(exportable) + except Exception as e: + log.warning("Did not save collection %s due to %s.", exportable, e) + + log.info("Saving %d dataset(s)", len(datasetsToExport)) + export.saveDatasets(datasetsToExport) + + sortedDatasets = sorted(datasetsToExport, key=attrgetter("datasetType.name", "dataId")) + + requiredDimensions: set[str] = set() + for ref in sortedDatasets: + requiredDimensions.update(ref.dimensions.names) + dimensionColumns = { + dimensionName: [ref.dataId.get(dimensionName, "") for ref in sortedDatasets] + for dimensionName in sorted(requiredDimensions) } - ) + + return Table( + { + "calibrationType": [ref.datasetType.name for ref in sortedDatasets], + "run": [ref.run for ref in sortedDatasets], + **dimensionColumns, + } + ) diff --git a/python/lsst/daf/butler/script/ingest_files.py b/python/lsst/daf/butler/script/ingest_files.py index 31142e5273..21ce98f9b3 100644 --- a/python/lsst/daf/butler/script/ingest_files.py +++ b/python/lsst/daf/butler/script/ingest_files.py @@ -111,24 +111,23 @@ def ingest_files( id_gen_mode = DatasetIdGenEnum.__members__[id_generation_mode] # Create the butler with the relevant run attached. - butler = Butler.from_config(repo, run=run) + with Butler.from_config(repo, run=run) as butler: + datasetType = butler.get_dataset_type(dataset_type) - datasetType = butler.get_dataset_type(dataset_type) + # Convert the k=v strings into a dataId dict. + universe = butler.dimensions + common_data_id = parse_data_id_tuple(data_id, universe) - # Convert the k=v strings into a dataId dict. - universe = butler.dimensions - common_data_id = parse_data_id_tuple(data_id, universe) + # Read the table assuming that Astropy can work out the format. + uri = ResourcePath(table_file, forceAbsolute=False) + with uri.as_local() as local_file: + table = Table.read(local_file.ospath) - # Read the table assuming that Astropy can work out the format. - uri = ResourcePath(table_file, forceAbsolute=False) - with uri.as_local() as local_file: - table = Table.read(local_file.ospath) - - datasets = extract_datasets_from_table( - table, common_data_id, datasetType, run, formatter, prefix, id_gen_mode - ) + datasets = extract_datasets_from_table( + table, common_data_id, datasetType, run, formatter, prefix, id_gen_mode + ) - butler.ingest(*datasets, transfer=transfer, record_validation_info=track_file_attrs) + butler.ingest(*datasets, transfer=transfer, record_validation_info=track_file_attrs) def extract_datasets_from_table( diff --git a/python/lsst/daf/butler/script/ingest_zip.py b/python/lsst/daf/butler/script/ingest_zip.py index 014df1a565..5e702f18cc 100644 --- a/python/lsst/daf/butler/script/ingest_zip.py +++ b/python/lsst/daf/butler/script/ingest_zip.py @@ -59,6 +59,5 @@ def ingest_zip( If `True` no transfers are done but the number of transfers that would be done is reported. """ - butler = Butler.from_config(repo, writeable=True) - - butler.ingest_zip(zip, transfer=transfer, transfer_dimensions=transfer_dimensions, dry_run=dry_run) + with Butler.from_config(repo, writeable=True) as butler: + butler.ingest_zip(zip, transfer=transfer, transfer_dimensions=transfer_dimensions, dry_run=dry_run) diff --git a/python/lsst/daf/butler/script/queryCollections.py b/python/lsst/daf/butler/script/queryCollections.py index 496b0628ff..e1dfd7bc83 100644 --- a/python/lsst/daf/butler/script/queryCollections.py +++ b/python/lsst/daf/butler/script/queryCollections.py @@ -99,79 +99,80 @@ def _getTable( ) if show_dataset_types: table.add_column(Column(name="Dataset Types", dtype=str)) - butler = Butler.from_config(repo) - - def addDatasetTypes(collection_table: Table, collection: str, dataset_types: list[str]) -> Table: - if dataset_types[0] == "": - cinfo = butler.collections.get_info(collection, include_summary=True) - dataset_types = _parseDatasetTypes(cinfo.dataset_types) - if exclude_dataset_types: - dataset_types = [ - dt - for dt in dataset_types - if not any(fnmatch(dt, pattern) for pattern in exclude_dataset_types) - ] - dataset_types = _parseDatasetTypes(dataset_types) - types_table = Table({"Dataset Types": sorted(dataset_types)}, dtype=(str,)) - collection_table = hstack([collection_table, types_table]).filled("") - return collection_table - - def addCollection(info: CollectionInfo, relation: str) -> None: - try: - info_relatives = getattr(info, relation) - except AttributeError: - info_relatives = [] - # Parent results can be returned in a non-deterministic order, so sort - # them to make the output deterministic. - if relation == "parents": - info_relatives = sorted(info_relatives) - if info_relatives: - collection_table = Table([[info.name], [info.type.name]], names=("Name", typeCol)) - description_table = Table(names=(descriptionCol,), dtype=(str,)) - for info_relative in info_relatives: - relative_table = Table([[info_relative]], names=(descriptionCol,)) + + with Butler.from_config(repo) as butler: + + def addDatasetTypes(collection_table: Table, collection: str, dataset_types: list[str]) -> Table: + if dataset_types[0] == "": + cinfo = butler.collections.get_info(collection, include_summary=True) + dataset_types = _parseDatasetTypes(cinfo.dataset_types) + if exclude_dataset_types: + dataset_types = [ + dt + for dt in dataset_types + if not any(fnmatch(dt, pattern) for pattern in exclude_dataset_types) + ] + dataset_types = _parseDatasetTypes(dataset_types) + types_table = Table({"Dataset Types": sorted(dataset_types)}, dtype=(str,)) + collection_table = hstack([collection_table, types_table]).filled("") + return collection_table + + def addCollection(info: CollectionInfo, relation: str) -> None: + try: + info_relatives = getattr(info, relation) + except AttributeError: + info_relatives = [] + # Parent results can be returned in a non-deterministic order, so + # sort them to make the output deterministic. + if relation == "parents": + info_relatives = sorted(info_relatives) + if info_relatives: + collection_table = Table([[info.name], [info.type.name]], names=("Name", typeCol)) + description_table = Table(names=(descriptionCol,), dtype=(str,)) + for info_relative in info_relatives: + relative_table = Table([[info_relative]], names=(descriptionCol,)) + if show_dataset_types: + relative_table = addDatasetTypes(relative_table, info_relative, [""]) + description_table = vstack([description_table, relative_table]) + collection_table = hstack([collection_table, description_table]).filled("") + for row in collection_table: + table.add_row(row) + else: + collection_table = Table( + [[info.name], [info.type.name], [""]], names=("Name", typeCol, descriptionCol) + ) if show_dataset_types: - relative_table = addDatasetTypes(relative_table, info_relative, [""]) - description_table = vstack([description_table, relative_table]) - collection_table = hstack([collection_table, description_table]).filled("") - for row in collection_table: - table.add_row(row) - else: - collection_table = Table( - [[info.name], [info.type.name], [""]], names=("Name", typeCol, descriptionCol) + collection_table = addDatasetTypes(collection_table, info.name, [""]) + for row in collection_table: + table.add_row(row) + + collections = sorted( + butler.collections.query_info( + glob or "*", + collection_types=frozenset(collection_type), + include_parents=inverse, + include_summary=show_dataset_types, ) - if show_dataset_types: - collection_table = addDatasetTypes(collection_table, info.name, [""]) - for row in collection_table: - table.add_row(row) - - collections = sorted( - butler.collections.query_info( - glob or "*", - collection_types=frozenset(collection_type), - include_parents=inverse, - include_summary=show_dataset_types, ) - ) - if inverse: - for info in collections: - addCollection(info, "parents") - # If none of the datasets has a parent dataset then remove the - # description column. - if not any(c for c in table[descriptionCol]): - del table[descriptionCol] - else: - for info in collections: - if info.type == CollectionType.CHAINED: - addCollection(info, "children") - else: - addCollection(info, "self") - # If there aren't any CHAINED datasets in the results then remove the - # description column. - if not any(columnVal == CollectionType.CHAINED.name for columnVal in table[typeCol]): - del table[descriptionCol] + if inverse: + for info in collections: + addCollection(info, "parents") + # If none of the datasets has a parent dataset then remove the + # description column. + if not any(c for c in table[descriptionCol]): + del table[descriptionCol] + else: + for info in collections: + if info.type == CollectionType.CHAINED: + addCollection(info, "children") + else: + addCollection(info, "self") + # If there aren't any CHAINED datasets in the results then remove + # the description column. + if not any(columnVal == CollectionType.CHAINED.name for columnVal in table[typeCol]): + del table[descriptionCol] - return table + return table def _getTree( @@ -216,51 +217,52 @@ def _getTree( ) if show_dataset_types: table.add_column(Column(name="Dataset Types", dtype=str)) - butler = Butler.from_config(repo, without_datastore=True) - - def addCollection(info: CollectionInfo, level: int = 0) -> None: - collection_table = Table([[" " * level + info.name], [info.type.name]], names=["Name", "Type"]) - if show_dataset_types: - if info.type == CollectionType.CHAINED: - collection_table = hstack( - [collection_table, Table([[""] * len(collection_table)], names=["Dataset Types"])] - ) + + with Butler.from_config(repo, without_datastore=True) as butler: + + def addCollection(info: CollectionInfo, level: int = 0) -> None: + collection_table = Table([[" " * level + info.name], [info.type.name]], names=["Name", "Type"]) + if show_dataset_types: + if info.type == CollectionType.CHAINED: + collection_table = hstack( + [collection_table, Table([[""] * len(collection_table)], names=["Dataset Types"])] + ) + else: + dataset_types = _parseDatasetTypes(info.dataset_types) + if exclude_dataset_types: + dataset_types = [ + dt + for dt in dataset_types + if not any(fnmatch(dt, pattern) for pattern in exclude_dataset_types) + ] + dataset_types = _parseDatasetTypes(dataset_types) + dataset_types_table = Table({"Dataset Types": sorted(dataset_types)}, dtype=(str,)) + collection_table = hstack([collection_table, dataset_types_table]).filled("") + for row in collection_table: + table.add_row(row) + + if inverse: + assert info.parents is not None # For mypy. + for pname in sorted(info.parents): + pinfo = butler.collections.get_info( + pname, include_parents=inverse, include_summary=show_dataset_types + ) + addCollection(pinfo, level + 1) else: - dataset_types = _parseDatasetTypes(info.dataset_types) - if exclude_dataset_types: - dataset_types = [ - dt - for dt in dataset_types - if not any(fnmatch(dt, pattern) for pattern in exclude_dataset_types) - ] - dataset_types = _parseDatasetTypes(dataset_types) - dataset_types_table = Table({"Dataset Types": sorted(dataset_types)}, dtype=(str,)) - collection_table = hstack([collection_table, dataset_types_table]).filled("") - for row in collection_table: - table.add_row(row) + if info.type == CollectionType.CHAINED: + for name in info.children: + cinfo = butler.collections.get_info(name, include_summary=show_dataset_types) + addCollection(cinfo, level + 1) - if inverse: - assert info.parents is not None # For mypy. - for pname in sorted(info.parents): - pinfo = butler.collections.get_info( - pname, include_parents=inverse, include_summary=show_dataset_types - ) - addCollection(pinfo, level + 1) - else: - if info.type == CollectionType.CHAINED: - for name in info.children: - cinfo = butler.collections.get_info(name, include_summary=show_dataset_types) - addCollection(cinfo, level + 1) - - collections = butler.collections.query_info( - glob or "*", - collection_types=frozenset(collection_type), - include_parents=inverse, - include_summary=show_dataset_types, - ) - for collection in sorted(collections): - addCollection(collection) - return table + collections = butler.collections.query_info( + glob or "*", + collection_types=frozenset(collection_type), + include_parents=inverse, + include_summary=show_dataset_types, + ) + for collection in sorted(collections): + addCollection(collection) + return table def _getList( @@ -301,36 +303,37 @@ def _getList( ) if show_dataset_types: table.add_column(Column(name="Dataset Types", dtype=str)) - butler = Butler.from_config(repo) - def addCollection(info: CollectionInfo) -> None: - collection_table = Table([[info.name], [info.type.name]], names=["Name", "Type"]) - if show_dataset_types: - dataset_types = _parseDatasetTypes(info.dataset_types) - if exclude_dataset_types: - dataset_types = [ - dt - for dt in dataset_types - if not any(fnmatch(dt, pattern) for pattern in exclude_dataset_types) - ] - dataset_types = _parseDatasetTypes(dataset_types) - dataset_types_table = Table({"Dataset Types": sorted(dataset_types)}, dtype=(str,)) - collection_table = hstack([collection_table, dataset_types_table]).filled("") - for row in collection_table: - table.add_row(row) + with Butler.from_config(repo) as butler: - collections = list( - butler.collections.query_info( - glob or "*", - collection_types=frozenset(collection_type), - flatten_chains=flatten_chains, - include_summary=show_dataset_types, + def addCollection(info: CollectionInfo) -> None: + collection_table = Table([[info.name], [info.type.name]], names=["Name", "Type"]) + if show_dataset_types: + dataset_types = _parseDatasetTypes(info.dataset_types) + if exclude_dataset_types: + dataset_types = [ + dt + for dt in dataset_types + if not any(fnmatch(dt, pattern) for pattern in exclude_dataset_types) + ] + dataset_types = _parseDatasetTypes(dataset_types) + dataset_types_table = Table({"Dataset Types": sorted(dataset_types)}, dtype=(str,)) + collection_table = hstack([collection_table, dataset_types_table]).filled("") + for row in collection_table: + table.add_row(row) + + collections = list( + butler.collections.query_info( + glob or "*", + collection_types=frozenset(collection_type), + flatten_chains=flatten_chains, + include_summary=show_dataset_types, + ) ) - ) - for collection in collections: - addCollection(collection) + for collection in collections: + addCollection(collection) - return table + return table def queryCollections( diff --git a/python/lsst/daf/butler/script/queryDataIds.py b/python/lsst/daf/butler/script/queryDataIds.py index b6226fb88d..1250dfb5c1 100644 --- a/python/lsst/daf/butler/script/queryDataIds.py +++ b/python/lsst/daf/butler/script/queryDataIds.py @@ -141,84 +141,87 @@ def queryDataIds( if offset: raise NotImplementedError("--offset is no longer supported. It will be removed after v28.") - butler = Butler.from_config(repo, without_datastore=True) - - dataset_types = [] - if datasets: - dataset_types = list(butler.registry.queryDatasetTypes(datasets)) - - if datasets and collections and not dimensions: - # Determine the dimensions relevant to all given dataset types. - # Since we are going to AND together all dimensions, we can not - # seed the result with an empty set. - dataset_type_dimensions: DimensionGroup | None = None - for dataset_type in dataset_types: - if dataset_type_dimensions is None: - # Seed with dimensions of first dataset type. - dataset_type_dimensions = dataset_type.dimensions - else: - # Only retain dimensions that are in the current - # set AND the set from this dataset type. - dataset_type_dimensions = dataset_type_dimensions.intersection(dataset_type.dimensions) - _LOG.debug("Dimensions now %s from %s", set(dataset_type_dimensions.names), dataset_type.name) + with Butler.from_config(repo, without_datastore=True) as butler: + dataset_types = [] + if datasets: + dataset_types = list(butler.registry.queryDatasetTypes(datasets)) + + if datasets and collections and not dimensions: + # Determine the dimensions relevant to all given dataset types. + # Since we are going to AND together all dimensions, we can not + # seed the result with an empty set. + dataset_type_dimensions: DimensionGroup | None = None + for dataset_type in dataset_types: + if dataset_type_dimensions is None: + # Seed with dimensions of first dataset type. + dataset_type_dimensions = dataset_type.dimensions + else: + # Only retain dimensions that are in the current + # set AND the set from this dataset type. + dataset_type_dimensions = dataset_type_dimensions.intersection(dataset_type.dimensions) + _LOG.debug("Dimensions now %s from %s", set(dataset_type_dimensions.names), dataset_type.name) + + # Break out of the loop early. No additional dimensions + # can be added to an empty set when using AND. + if not dataset_type_dimensions: + break - # Break out of the loop early. No additional dimensions - # can be added to an empty set when using AND. if not dataset_type_dimensions: - break - - if not dataset_type_dimensions: - names = [d.name for d in dataset_types] - return None, f"No dimensions in common for specified dataset types ({names})" - dimensions = set(dataset_type_dimensions.names) - _LOG.info("Determined dimensions %s from datasets option %s", dimensions, datasets) - - with butler.query() as query: - if datasets: - # Need to constrain results based on dataset type and collection. - query_collections = collections or "*" - collections_info = butler.collections.query_info( - query_collections, include_summary=True, summary_datasets=dataset_types - ) - expanded_collections = [info.name for info in collections_info] - dataset_type_collections = butler.collections._group_by_dataset_type( - {dt.name for dt in dataset_types}, collections_info - ) - if not dataset_type_collections: - return ( - None, - f"No datasets of type {datasets!r} existed in the specified " - f"collections {','.join(expanded_collections)}.", + names = [d.name for d in dataset_types] + return None, f"No dimensions in common for specified dataset types ({names})" + dimensions = set(dataset_type_dimensions.names) + _LOG.info("Determined dimensions %s from datasets option %s", dimensions, datasets) + + with butler.query() as query: + if datasets: + # Need to constrain results based on dataset type and + # collection. + query_collections = collections or "*" + collections_info = butler.collections.query_info( + query_collections, include_summary=True, summary_datasets=dataset_types ) - - for dt, dt_collections in dataset_type_collections.items(): - query = query.join_dataset_search(dt, collections=dt_collections) - - results = query.data_ids(dimensions) - - if where: - results = results.where(where) - if order_by: - results = results.order_by(*order_by) - query_limit = abs(limit) - warn_limit = False - if limit != 0: - if limit < 0: - query_limit += 1 - warn_limit = True - - results = results.limit(query_limit) - - if results.any(exact=False): - if results.dimensions: - table = _Table(results) - if warn_limit and len(table) == query_limit: - table.pop_last() - _LOG.warning("More data IDs are available than the request limit of %d", abs(limit)) - if not table.dataIds: - return None, "Post-query region filtering removed all rows, since nothing overlapped." - return table.getAstropyTable(not order_by), None + expanded_collections = [info.name for info in collections_info] + dataset_type_collections = butler.collections._group_by_dataset_type( + {dt.name for dt in dataset_types}, collections_info + ) + if not dataset_type_collections: + return ( + None, + f"No datasets of type {datasets!r} existed in the specified " + f"collections {','.join(expanded_collections)}.", + ) + + for dt, dt_collections in dataset_type_collections.items(): + query = query.join_dataset_search(dt, collections=dt_collections) + + results = query.data_ids(dimensions) + + if where: + results = results.where(where) + if order_by: + results = results.order_by(*order_by) + query_limit = abs(limit) + warn_limit = False + if limit != 0: + if limit < 0: + query_limit += 1 + warn_limit = True + + results = results.limit(query_limit) + + if results.any(exact=False): + if results.dimensions: + table = _Table(results) + if warn_limit and len(table) == query_limit: + table.pop_last() + _LOG.warning("More data IDs are available than the request limit of %d", abs(limit)) + if not table.dataIds: + return None, "Post-query region filtering removed all rows, since nothing overlapped." + return table.getAstropyTable(not order_by), None + else: + return ( + None, + "Result has one logical row but no columns because no dimensions were requested.", + ) else: - return None, "Result has one logical row but no columns because no dimensions were requested." - else: - return None, "\n".join(results.explain_no_results()) + return None, "\n".join(results.explain_no_results()) diff --git a/python/lsst/daf/butler/script/queryDatasetTypes.py b/python/lsst/daf/butler/script/queryDatasetTypes.py index a4ecd0682c..26553614a8 100644 --- a/python/lsst/daf/butler/script/queryDatasetTypes.py +++ b/python/lsst/daf/butler/script/queryDatasetTypes.py @@ -60,26 +60,29 @@ def queryDatasetTypes( A dict whose key is "datasetTypes" and whose value is a list of collection names. """ - butler = Butler.from_config(repo, without_datastore=True) - expression = glob or ... - datasetTypes = butler.registry.queryDatasetTypes(expression=expression) + with Butler.from_config(repo, without_datastore=True) as butler: + expression = glob or ... + datasetTypes = butler.registry.queryDatasetTypes(expression=expression) - if collections: - collections_info = butler.collections.query_info(collections, include_summary=True) - filtered_dataset_types = set( - butler.collections._filter_dataset_types([d.name for d in datasetTypes], collections_info) - ) - datasetTypes = [d for d in datasetTypes if d.name in filtered_dataset_types] + if collections: + collections_info = butler.collections.query_info(collections, include_summary=True) + filtered_dataset_types = set( + butler.collections._filter_dataset_types([d.name for d in datasetTypes], collections_info) + ) + datasetTypes = [d for d in datasetTypes if d.name in filtered_dataset_types] - if verbose: - table = Table( - array( - [(d.name, str(list(d.dimensions.names)) or "None", d.storageClass_name) for d in datasetTypes] - ), - names=("name", "dimensions", "storage class"), - ) - else: - rows = ([d.name for d in datasetTypes],) - table = Table(rows, names=("name",)) - table.sort("name") - return table + if verbose: + table = Table( + array( + [ + (d.name, str(list(d.dimensions.names)) or "None", d.storageClass_name) + for d in datasetTypes + ] + ), + names=("name", "dimensions", "storage class"), + ) + else: + rows = ([d.name for d in datasetTypes],) + table = Table(rows, names=("name",)) + table.sort("name") + return table diff --git a/python/lsst/daf/butler/script/queryDatasets.py b/python/lsst/daf/butler/script/queryDatasets.py index 0cd0a2bcff..dc22073827 100644 --- a/python/lsst/daf/butler/script/queryDatasets.py +++ b/python/lsst/daf/butler/script/queryDatasets.py @@ -158,6 +158,8 @@ class QueryDatasets: wildcards. show_uri : `bool` If True, include the dataset URI in the output. + butler : `lsst.daf.butler.Butler` + The butler to use to query. limit : `int`, optional Limit the number of results to be returned. A value of 0 means unlimited. A negative value is used to specify a cap where a warning @@ -167,13 +169,6 @@ class QueryDatasets: results of ``limit`` are undefined and default sorting of the resulting datasets will be applied. It is an error if the requested ordering is inconsistent with the dimensions of the dataset type being queried. - repo : `str` or `None` - URI to the location of the repo or URI to a config file describing the - repo and its location. One of `repo` and `butler` must be `None` and - the other must not be `None`. - butler : `lsst.daf.butler.Butler` or `None` - The butler to use to query. One of `repo` and `butler` must be `None` - and the other must not be `None`. with_dimension_records : `bool`, optional If `True` (default is `False`) then returned data IDs will have dimension records. @@ -186,14 +181,11 @@ def __init__( where: str, find_first: bool, show_uri: bool, + butler: Butler, limit: int = 0, order_by: tuple[str, ...] = (), - repo: str | None = None, - butler: Butler | None = None, with_dimension_records: bool = False, ): - if (repo and butler) or (not repo and not butler): - raise RuntimeError("One of repo and butler must be provided and the other must be None.") collections = list(collections) if not collections: warnings.warn( @@ -215,9 +207,7 @@ def __init__( if order_by and searches_multiple_dataset_types: raise NotImplementedError("--order-by is only supported for queries with a single dataset type.") - # show_uri requires a datastore. - without_datastore = not show_uri - self.butler = butler or Butler.from_config(repo, without_datastore=without_datastore) + self.butler = butler self.showUri = show_uri self._dataset_type_glob = glob self._collections_wildcard = collections diff --git a/python/lsst/daf/butler/script/queryDimensionRecords.py b/python/lsst/daf/butler/script/queryDimensionRecords.py index b2575a9a0f..e7487f320c 100644 --- a/python/lsst/daf/butler/script/queryDimensionRecords.py +++ b/python/lsst/daf/butler/script/queryDimensionRecords.py @@ -81,65 +81,66 @@ def queryDimensionRecords( if offset: raise NotImplementedError("--offset is no longer supported. It will be removed after v28.") - butler = Butler.from_config(repo, without_datastore=True) - - with butler.query() as query: - if datasets: - query_collections = collections or "*" - dataset_types = butler.registry.queryDatasetTypes(datasets) - collections_info = butler.collections.query_info( - query_collections, include_summary=True, summary_datasets=dataset_types - ) - dataset_type_collections = butler.collections._group_by_dataset_type( - {dt.name for dt in dataset_types}, collections_info - ) - - if not dataset_type_collections: + with Butler.from_config(repo, without_datastore=True) as butler: + with butler.query() as query: + if datasets: + query_collections = collections or "*" + dataset_types = butler.registry.queryDatasetTypes(datasets) + collections_info = butler.collections.query_info( + query_collections, include_summary=True, summary_datasets=dataset_types + ) + dataset_type_collections = butler.collections._group_by_dataset_type( + {dt.name for dt in dataset_types}, collections_info + ) + + if not dataset_type_collections: + return None + + for dt, dt_collections in dataset_type_collections.items(): + query = query.join_dataset_search(dt, collections=dt_collections) + + query_results = query.dimension_records(element) + + if where: + query_results = query_results.where(where) + if order_by: + query_results = query_results.order_by(*order_by) + query_limit = abs(limit) + warn_limit = False + if limit != 0: + if limit < 0: + query_limit += 1 + warn_limit = True + + query_results = query_results.limit(query_limit) + + records = list(query_results) + if warn_limit and len(records) == query_limit: + records.pop(-1) + _LOG.warning("More data IDs are available than the request limit of %d", abs(limit)) + + if not records: return None - for dt, dt_collections in dataset_type_collections.items(): - query = query.join_dataset_search(dt, collections=dt_collections) - - query_results = query.dimension_records(element) - - if where: - query_results = query_results.where(where) - if order_by: - query_results = query_results.order_by(*order_by) - query_limit = abs(limit) - warn_limit = False - if limit != 0: - if limit < 0: - query_limit += 1 - warn_limit = True - - query_results = query_results.limit(query_limit) - - records = list(query_results) - if warn_limit and len(records) == query_limit: - records.pop(-1) - _LOG.warning("More data IDs are available than the request limit of %d", abs(limit)) - - if not records: - return None - - if not order_by: - # use the dataId to sort the rows if not ordered already - records.sort(key=attrgetter("dataId")) - - # order the columns the same as the record's `field.names`, and add units - # to timespans - keys = records[0].fields.names - headers = ["timespan (TAI)" if name == "timespan" else name for name in records[0].fields.names] - - def conform(v: Any) -> Any: - match v: - case Timespan(): - v = str(v) - case bytes(): - v = "0x" + v.hex() - case Region(): - v = "(elided)" - return v - - return Table([[conform(getattr(record, key, None)) for record in records] for key in keys], names=headers) + if not order_by: + # use the dataId to sort the rows if not ordered already + records.sort(key=attrgetter("dataId")) + + # order the columns the same as the record's `field.names`, and add + # units to timespans + keys = records[0].fields.names + headers = ["timespan (TAI)" if name == "timespan" else name for name in records[0].fields.names] + + def conform(v: Any) -> Any: + match v: + case Timespan(): + v = str(v) + case bytes(): + v = "0x" + v.hex() + case Region(): + v = "(elided)" + return v + + return Table( + [[conform(getattr(record, key, None)) for record in records] for key in keys], names=headers + ) diff --git a/python/lsst/daf/butler/script/register_dataset_type.py b/python/lsst/daf/butler/script/register_dataset_type.py index 41ab12d4fb..3f1ac32b5c 100644 --- a/python/lsst/daf/butler/script/register_dataset_type.py +++ b/python/lsst/daf/butler/script/register_dataset_type.py @@ -69,19 +69,20 @@ def register_dataset_type( be created by this command. They are always derived from the composite dataset type. """ - butler = Butler.from_config(repo, writeable=True, without_datastore=True) + with Butler.from_config(repo, writeable=True, without_datastore=True) as butler: + _, component = DatasetType.splitDatasetTypeName(dataset_type) + if component: + raise ValueError( + "Component dataset types are created automatically when the composite is created." + ) - _, component = DatasetType.splitDatasetTypeName(dataset_type) - if component: - raise ValueError("Component dataset types are created automatically when the composite is created.") + datasetType = DatasetType( + dataset_type, + butler.dimensions.conform(dimensions), + storage_class, + parentStorageClass=None, + isCalibration=is_calibration, + universe=butler.dimensions, + ) - datasetType = DatasetType( - dataset_type, - butler.dimensions.conform(dimensions), - storage_class, - parentStorageClass=None, - isCalibration=is_calibration, - universe=butler.dimensions, - ) - - return butler.registry.registerDatasetType(datasetType) + return butler.registry.registerDatasetType(datasetType) diff --git a/python/lsst/daf/butler/script/removeCollections.py b/python/lsst/daf/butler/script/removeCollections.py index c6ea6049c0..8b06392d90 100644 --- a/python/lsst/daf/butler/script/removeCollections.py +++ b/python/lsst/daf/butler/script/removeCollections.py @@ -90,26 +90,32 @@ def _getCollectionInfo(repo: str, collection: str, include_parents: bool) -> Col collectionInfo : `CollectionInfo` Contains tables with run and non-run collection info. """ - butler = Butler.from_config(repo, without_datastore=True) - try: - collections_info = sorted( - butler.collections.query_info(collection, include_chains=True, include_parents=include_parents) - ) - except MissingCollectionError: - # Hide the error and act like no collections should be removed. - collections_info = [] - collections = Table(names=("Collection", "Collection Type"), dtype=(str, str)) - runCollections = Table(names=("Collection",), dtype=(str,)) - parents: dict[str, tuple[str, ...]] = {} - for collection_info in collections_info: - if collection_info.type == CollectionType.RUN: - runCollections.add_row((collection_info.name,)) - else: - collections.add_row((collection_info.name, collection_info.type.name)) - if include_parents and collection_info.parents is not None and len(collection_info.parents) > 0: - parents[collection_info.name] = tuple(collection_info.parents) - - return CollectionInfo(collections, runCollections, parents) + with Butler.from_config(repo, without_datastore=True) as butler: + try: + collections_info = sorted( + butler.collections.query_info( + collection, include_chains=True, include_parents=include_parents + ) + ) + except MissingCollectionError: + # Hide the error and act like no collections should be removed. + collections_info = [] + collections = Table(names=("Collection", "Collection Type"), dtype=(str, str)) + runCollections = Table(names=("Collection",), dtype=(str,)) + parents: dict[str, tuple[str, ...]] = {} + for collection_info in collections_info: + if collection_info.type == CollectionType.RUN: + runCollections.add_row((collection_info.name,)) + else: + collections.add_row((collection_info.name, collection_info.type.name)) + if ( + include_parents + and collection_info.parents is not None + and len(collection_info.parents) > 0 + ): + parents[collection_info.name] = tuple(collection_info.parents) + + return CollectionInfo(collections, runCollections, parents) def removeCollections(repo: str, collection: str, remove_from_parents: bool) -> RemoveCollectionResult: @@ -136,20 +142,19 @@ def removeCollections(repo: str, collection: str, remove_from_parents: bool) -> def _doRemove(collections: Table) -> None: """Perform the prune collection step.""" - butler = Butler.from_config(repo, writeable=True, without_datastore=True) - - for name in collections["Collection"]: - with butler.transaction(): - for parent in collectionInfo.parentCollections.get(name, []): - butler.collections.remove_from_chain(parent, name) - try: - butler.collections.x_remove(name) - except OrphanedRecordError as e: - e.add_note( - "Add the --remove-from-parents flag to this command" - " if you are sure this collection is no longer needed." - ) - raise + with Butler.from_config(repo, writeable=True, without_datastore=True) as butler: + for name in collections["Collection"]: + with butler.transaction(): + for parent in collectionInfo.parentCollections.get(name, []): + butler.collections.remove_from_chain(parent, name) + try: + butler.collections.x_remove(name) + except OrphanedRecordError as e: + e.add_note( + "Add the --remove-from-parents flag to this command" + " if you are sure this collection is no longer needed." + ) + raise remove_chains_table = Table(names=("Child Collection", "Parent Collection"), dtype=(str, str)) for child in sorted(collectionInfo.parentCollections.keys()): diff --git a/python/lsst/daf/butler/script/removeDatasetType.py b/python/lsst/daf/butler/script/removeDatasetType.py index 3e4bf6d8b6..906a9018d7 100644 --- a/python/lsst/daf/butler/script/removeDatasetType.py +++ b/python/lsst/daf/butler/script/removeDatasetType.py @@ -43,5 +43,5 @@ def removeDatasetType(repo: str, dataset_type_name: tuple[str, ...]) -> None: dataset_type_name : `str` The name of the dataset type to be removed. """ - butler = Butler.from_config(repo, writeable=True, without_datastore=True) - butler.registry.removeDatasetType(dataset_type_name) + with Butler.from_config(repo, writeable=True, without_datastore=True) as butler: + butler.registry.removeDatasetType(dataset_type_name) diff --git a/python/lsst/daf/butler/script/removeRuns.py b/python/lsst/daf/butler/script/removeRuns.py index 680a399f9c..cdc2e40ea1 100644 --- a/python/lsst/daf/butler/script/removeRuns.py +++ b/python/lsst/daf/butler/script/removeRuns.py @@ -85,8 +85,7 @@ def _getCollectionInfo( datasets : `dict` [`str`, `int`] The dataset types and and how many will be removed. """ - butler = Butler.from_config(repo) - with butler.registry.caching_context(): + with Butler.from_config(repo) as butler, butler.registry.caching_context(): try: collections = butler.collections.query_info( collection, @@ -138,9 +137,8 @@ def removeRuns( def _doRemove(runs: Sequence[RemoveRun]) -> None: """Perform the remove step.""" - butler = Butler.from_config(repo, writeable=True) - - butler.removeRuns([r.name for r in runs], unlink_from_chains=True) + with Butler.from_config(repo, writeable=True) as butler: + butler.removeRuns([r.name for r in runs], unlink_from_chains=True) result = RemoveRunsResult( onConfirmation=partial(_doRemove, runs), diff --git a/python/lsst/daf/butler/script/retrieveArtifacts.py b/python/lsst/daf/butler/script/retrieveArtifacts.py index 5a85eae573..67ceaada0d 100644 --- a/python/lsst/daf/butler/script/retrieveArtifacts.py +++ b/python/lsst/daf/butler/script/retrieveArtifacts.py @@ -102,32 +102,35 @@ def retrieveArtifacts( query_types = dataset_type or "*" query_collections: tuple[str, ...] = collections or ("*",) - butler = Butler.from_config(repo, writeable=False) - - # Need to store in set so we can count the number to give some feedback - # to caller. - query = QueryDatasets( - butler=butler, - glob=query_types, - collections=query_collections, - where=where, - find_first=find_first, - limit=limit, - order_by=order_by, - show_uri=False, - with_dimension_records=True, - ) - refs = set(itertools.chain(*query.getDatasets())) - log.info("Number of datasets matching query: %d", len(refs)) - if not refs: - return [] - - if not zip: - transferred = butler.retrieveArtifacts( - refs, destination=destination, transfer=transfer, preserve_path=preserve_path, overwrite=clobber + with Butler.from_config(repo, writeable=False) as butler: + # Need to store in set so we can count the number to give some feedback + # to caller. + query = QueryDatasets( + butler=butler, + glob=query_types, + collections=query_collections, + where=where, + find_first=find_first, + limit=limit, + order_by=order_by, + show_uri=False, + with_dimension_records=True, ) - else: - zip_file = butler.retrieve_artifacts_zip(refs, destination=destination, overwrite=clobber) - transferred = [zip_file] + refs = set(itertools.chain(*query.getDatasets())) + log.info("Number of datasets matching query: %d", len(refs)) + if not refs: + return [] + + if not zip: + transferred = butler.retrieveArtifacts( + refs, + destination=destination, + transfer=transfer, + preserve_path=preserve_path, + overwrite=clobber, + ) + else: + zip_file = butler.retrieve_artifacts_zip(refs, destination=destination, overwrite=clobber) + transferred = [zip_file] - return transferred + return transferred diff --git a/python/lsst/daf/butler/script/transferDatasets.py b/python/lsst/daf/butler/script/transferDatasets.py index d52c2c42ad..ed0573781f 100644 --- a/python/lsst/daf/butler/script/transferDatasets.py +++ b/python/lsst/daf/butler/script/transferDatasets.py @@ -89,33 +89,34 @@ def transferDatasets( If `True` no transfers are done but the number of transfers that would be done is reported. """ - source_butler = Butler.from_config(source, writeable=False) - dest_butler = Butler.from_config(dest, writeable=True) + with ( + Butler.from_config(source, writeable=False) as source_butler, + Butler.from_config(dest, writeable=True) as dest_butler, + ): + dataset_type_expr = dataset_type or "*" + collections_expr: tuple[str, ...] = collections or ("*",) - dataset_type_expr = dataset_type or "*" - collections_expr: tuple[str, ...] = collections or ("*",) + query = QueryDatasets( + butler=source_butler, + glob=dataset_type_expr, + collections=collections_expr, + where=where, + find_first=find_first, + limit=limit, + order_by=order_by, + show_uri=False, + with_dimension_records=True, + ) + # Place results in a set to remove duplicates (which should not exist + # in new query system) + source_refs_set = set(itertools.chain(*query.getDatasets())) - query = QueryDatasets( - butler=source_butler, - glob=dataset_type_expr, - collections=collections_expr, - where=where, - find_first=find_first, - limit=limit, - order_by=order_by, - show_uri=False, - with_dimension_records=True, - ) - # Place results in a set to remove duplicates (which should not exist - # in new query system) - source_refs_set = set(itertools.chain(*query.getDatasets())) - - transferred = dest_butler.transfer_from( - source_butler, - source_refs_set, - transfer=transfer, - register_dataset_types=register_dataset_types, - transfer_dimensions=transfer_dimensions, - dry_run=dry_run, - ) - return len(transferred) + transferred = dest_butler.transfer_from( + source_butler, + source_refs_set, + transfer=transfer, + register_dataset_types=register_dataset_types, + transfer_dimensions=transfer_dimensions, + dry_run=dry_run, + ) + return len(transferred) diff --git a/python/lsst/daf/butler/tests/hybrid_butler.py b/python/lsst/daf/butler/tests/hybrid_butler.py index df1d26ef2c..2a8dcca5d4 100644 --- a/python/lsst/daf/butler/tests/hybrid_butler.py +++ b/python/lsst/daf/butler/tests/hybrid_butler.py @@ -441,3 +441,7 @@ def _query_all_datasets_by_page( @property def _file_transfer_source(self) -> FileTransferSource: return self._remote_butler._file_transfer_source + + def close(self) -> None: + self._direct_butler.close() + self._remote_butler.close() diff --git a/python/lsst/daf/butler/tests/server.py b/python/lsst/daf/butler/tests/server.py index 884d585607..622235a11b 100644 --- a/python/lsst/daf/butler/tests/server.py +++ b/python/lsst/daf/butler/tests/server.py @@ -1,7 +1,7 @@ import json import os from collections.abc import Iterator -from contextlib import contextmanager +from contextlib import closing, contextmanager from dataclasses import dataclass from tempfile import TemporaryDirectory @@ -151,14 +151,13 @@ def create_test_server( app.dependency_overrides[user_name_dependency] = lambda: "mock-username" app.dependency_overrides[auth_delegated_token_dependency] = lambda: "mock-delegated-token" + direct_butler = Butler.from_config(config_file_path, writeable=True) + assert isinstance(direct_butler, DirectButler) # Using TestClient in a context manager ensures that it uses # the same async event loop for all requests -- otherwise it # starts a new one on each request. - with TestClient(app) as client: + with TestClient(app) as client, direct_butler, closing(server_butler_factory): remote_butler = _make_remote_butler(client) - - direct_butler = Butler.from_config(config_file_path, writeable=True) - assert isinstance(direct_butler, DirectButler) hybrid_butler = HybridButler(remote_butler, direct_butler) client_without_error_propagation = TestClient(app, raise_server_exceptions=False) diff --git a/tests/test_astropyTableFormatter.py b/tests/test_astropyTableFormatter.py index c93669a651..195555fa93 100644 --- a/tests/test_astropyTableFormatter.py +++ b/tests/test_astropyTableFormatter.py @@ -56,6 +56,7 @@ def tearDown(self): def testAstropyTableFormatter(self): butler = Butler(self.root, run="testrun") + self.enterContext(butler) datasetType = DatasetType("table", [], "AstropyTable", universe=butler.dimensions) butler.registry.registerDatasetType(datasetType) ref = butler.put(self.table, datasetType) diff --git a/tests/test_butler.py b/tests/test_butler.py index 0a16eafa7d..6186b9d740 100644 --- a/tests/test_butler.py +++ b/tests/test_butler.py @@ -43,6 +43,8 @@ import unittest import unittest.mock import uuid +import warnings +import weakref from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, cast @@ -225,12 +227,18 @@ def tearDown(self) -> None: removeTestTempDir(self.root) def create_empty_butler( - self, run: str | None = None, writeable: bool | None = None, metrics: ButlerMetrics | None = None + self, + run: str | None = None, + writeable: bool | None = None, + metrics: ButlerMetrics | None = None, + cleanup: bool = True, ): """Create a Butler for the test repository, without inserting test data. """ butler = Butler.from_config(self.tmpConfigFile, run=run, writeable=writeable, metrics=metrics) + if cleanup: + self.enterContext(butler) assert isinstance(butler, DirectButler), "Expect DirectButler in configuration" return butler @@ -653,16 +661,19 @@ def are_uris_equivalent(self, uri1: ResourcePath, uri2: ResourcePath) -> bool: def testConstructor(self) -> None: """Independent test of constructor.""" butler = Butler.from_config(self.tmpConfigFile, run=self.default_run) + self.enterContext(butler) self.assertIsInstance(butler, Butler) # Check that butler.yaml is added automatically. if self.tmpConfigFile.endswith(end := "/butler.yaml"): config_dir = self.tmpConfigFile[: -len(end)] butler = Butler.from_config(config_dir, run=self.default_run) + self.enterContext(butler) self.assertIsInstance(butler, Butler) # Even with a ResourcePath. butler = Butler.from_config(ResourcePath(config_dir, forceDirectory=True), run=self.default_run) + self.enterContext(butler) self.assertIsInstance(butler, Butler) collections = set(butler.collections.query("*")) @@ -671,10 +682,12 @@ def testConstructor(self) -> None: # Check that some special characters can be included in run name. special_run = "u@b.c-A" butler_special = Butler.from_config(butler=butler, run=special_run) + self.enterContext(butler_special) collections = set(butler_special.registry.queryCollections("*@*")) self.assertEqual(collections, {special_run}) butler2 = Butler.from_config(butler=butler, collections=["other"]) + self.enterContext(butler2) self.assertEqual(butler2.collections.defaults, ("other",)) self.assertIsNone(butler2.run) self.assertEqual(type(butler._datastore), type(butler2._datastore)) @@ -698,8 +711,10 @@ def testConstructor(self) -> None: uri = Butler.get_repo_uri("label") butler = Butler.from_config(uri, writeable=False) self.assertIsInstance(butler, Butler) + butler.close() butler = Butler.from_config("label", writeable=False) self.assertIsInstance(butler, Butler) + butler.close() with self.assertRaisesRegex(FileNotFoundError, "aliases:.*bad_label"): Butler.from_config("not_there", writeable=False) with self.assertRaisesRegex(FileNotFoundError, "resolved from alias 'bad_label'"): @@ -739,6 +754,7 @@ def testConstructor(self) -> None: # Check that we can create Butler when the alias file is not found. butler = Butler.from_config(self.tmpConfigFile, writeable=False) + self.enterContext(butler) self.assertIsInstance(butler, Butler) with self.assertRaises(RuntimeError) as cm: # No environment variable set. @@ -750,6 +766,56 @@ def testConstructor(self) -> None: Butler.from_config("not_there") self.assertEqual(Butler.get_known_repos(), set()) + def testClose(self): + butler = self.create_empty_butler(cleanup=False) + is_direct_butler = isinstance(butler, DirectButler) + if is_direct_butler: + self.assertFalse(butler._closed) + + with butler as butler_from_context_manager: + self.assertIs(butler, butler_from_context_manager) + if is_direct_butler: + self.assertTrue(butler._closed) + with self.assertRaisesRegex(RuntimeError, "has been closed"): + butler.get_dataset_type("raw") + + # Close may be called multiple times. + butler.close() + if is_direct_butler: + self.assertTrue(butler._closed) + + def testGarbageCollection(self): + """Test that Butler does not have any circular references that prevent + it from being garbage collected immediately when it goes out of scope. + """ + butler = self.create_empty_butler(cleanup=False) + is_direct_butler = isinstance(butler, DirectButler) + butler_ref = weakref.ref(butler) + if is_direct_butler: + registry_ref = weakref.ref(butler._registry) + managers_ref = weakref.ref(butler._registry._managers) + datastore_ref = weakref.ref(butler._datastore) + db_ref = weakref.ref(butler._registry._db) + engine_ref = weakref.ref(butler._registry._db._engine) + + with warnings.catch_warnings(): + # Hide warnings from unclosed database handles. + warnings.simplefilter("ignore", ResourceWarning) + del butler + self.assertIsNone(butler_ref(), "Butler should have been garbage collected") + if is_direct_butler: + self.assertIsNone(registry_ref(), "SqlRegistry should have been garbage collected") + self.assertIsNone(managers_ref(), "Registry managers should have been garbage collected") + self.assertIsNone(datastore_ref(), "Datastore should have been garbage collected") + self.assertIsNone(db_ref(), "Database should have been garbage collected") + # SQLAlchemy has internal reference cycles, so the Engine instance + # is not cleaned up promptly even if we release our reference to + # it. Explicitly clean it up here to avoid file handles leaking. + if is_direct_butler: + engine = engine_ref() + if engine is not None: + engine.dispose() + def testDafButlerRepositories(self): with unittest.mock.patch.dict( os.environ, @@ -973,6 +1039,7 @@ def test_ingest_zip(self) -> None: # Create an entirely new local file butler in this temp directory. new_butler_cfg = Butler.makeRepo(tmpdir) new_butler = Butler.from_config(new_butler_cfg, writeable=True) + self.enterContext(new_butler) # This will fail since dimensions records are missing. with self.assertRaises(ConflictingDefinitionError): @@ -1185,6 +1252,7 @@ def testPickle(self) -> None: butler = self.create_empty_butler(run=self.default_run) assert isinstance(butler, DirectButler), "Expect DirectButler in configuration" butlerOut = pickle.loads(pickle.dumps(butler)) + self.enterContext(butlerOut) self.assertIsInstance(butlerOut, Butler) self.assertEqual(butlerOut._config, butler._config) self.assertEqual(list(butlerOut.collections.defaults), list(butler.collections.defaults)) @@ -1352,10 +1420,12 @@ def testMakeRepo(self) -> None: butlerConfig = Butler.makeRepo(root1, config=Config(self.configFile)) limited = Config(self.configFile) butler1 = Butler.from_config(butlerConfig) + self.enterContext(butler1) assert isinstance(butler1, DirectButler), "Expect DirectButler in configuration" butlerConfig = Butler.makeRepo(root2, standalone=True, config=Config(self.configFile)) full = Config(self.tmpConfigFile) butler2 = Butler.from_config(butlerConfig) + self.enterContext(butler2) assert isinstance(butler2, DirectButler), "Expect DirectButler in configuration" # Butlers should have the same configuration regardless of whether # defaults were expanded. @@ -1383,6 +1453,7 @@ def testMakeRepo(self) -> None: def testStringification(self) -> None: butler = Butler.from_config(self.tmpConfigFile, run=self.default_run) + self.enterContext(butler) butlerStr = str(butler) if self.datastoreStr is not None: @@ -1794,6 +1865,7 @@ def runImportExportTest(self, storageClass: StorageClass) -> None: skip_dimensions=None, ) importButler = Butler.from_config(importDir, run=self.default_run) + self.enterContext(importButler) for ref in datasets: with self.subTest(ref=repr(ref)): # Test for existence by passing in the DatasetType and @@ -2117,11 +2189,13 @@ class PosixDatastoreButlerTestCase(FileDatastoreButlerTests, unittest.TestCase): def testPathConstructor(self) -> None: """Independent test of constructor using PathLike.""" butler = Butler.from_config(self.tmpConfigFile, run=self.default_run) + self.enterContext(butler) self.assertIsInstance(butler, Butler) # And again with a Path object with the butler yaml path = pathlib.Path(self.tmpConfigFile) butler = Butler.from_config(path, writeable=False) + self.enterContext(butler) self.assertIsInstance(butler, Butler) # And again with a Path object without the butler yaml @@ -2130,6 +2204,7 @@ def testPathConstructor(self) -> None: if self.tmpConfigFile.endswith("butler.yaml"): path = pathlib.Path(os.path.dirname(self.tmpConfigFile)) butler = Butler.from_config(path, writeable=False) + self.enterContext(butler) self.assertIsInstance(butler, Butler) def testExportTransferCopy(self) -> None: @@ -2301,6 +2376,7 @@ def test_specialized_file_datasets_functions(self): # Make sure the target Butler can ingest the datasets. target_butler = Butler(target_repo_config, writeable=True) + self.enterContext(target_butler) target_butler.transfer_dimension_records_from(source_butler, refs) target_butler.ingest(*datasets, transfer=None) self.assertIsNotNone(target_butler.get(repo.ref1)) @@ -2358,6 +2434,7 @@ def test_specialized_file_datasets_functions(self): # Make sure the target Butler can ingest the datasets. target_butler = Butler(target_repo_config, writeable=True) + self.enterContext(target_butler) target_butler.transfer_dimension_records_from(source_butler, refs) target_butler.ingest(*datasets, transfer=None) self.assertIsNotNone(target_butler.get(repo.ref1)) @@ -2693,9 +2770,11 @@ def create_butler(self, manager: str | None, label: str, config_file: str | None ) config = Config(config_file if config_file is not None else self.configFile) config["registry", "managers", "datasets"] = manager - return Butler.from_config( + butler = Butler.from_config( Butler.makeRepo(f"{self.root}/butler{label}", config=config), writeable=True ) + self.enterContext(butler) + return butler def assertButlerTransfers( self, @@ -2768,6 +2847,7 @@ def assertButlerTransfers( # Will not be relevant for UUID. run = "distraction" butler = Butler.from_config(butler=self.source_butler, run=run) + self.enterContext(butler) butler.put( makeExampleMetrics(), datasetTypeName, @@ -2778,6 +2858,7 @@ def assertButlerTransfers( # Write some example metrics to the source butler = Butler.from_config(butler=self.source_butler) + self.enterContext(butler) # Set of DatasetRefs that should be in the list of refs to transfer # but which will not be transferred. @@ -3209,6 +3290,7 @@ def test_fallback(self) -> None: Butler.from_config(self.root, without_datastore=False) butler = Butler.from_config(self.root, writeable=True, without_datastore=True) + self.enterContext(butler) self.assertIsInstance(butler._datastore, NullDatastore) # Check that registry is working. @@ -3253,7 +3335,11 @@ def are_uris_equivalent(self, uri1: ResourcePath, uri2: ResourcePath) -> bool: return uri1.scheme == uri2.scheme and uri1.netloc == uri2.netloc and uri1.path == uri2.path def create_empty_butler( - self, run: str | None = None, writeable: bool | None = None, metrics: ButlerMetrics | None = None + self, + run: str | None = None, + writeable: bool | None = None, + metrics: ButlerMetrics | None = None, + cleanup: bool = True, ) -> Butler: return self.server_instance.hybrid_butler.clone(run=run, metrics=metrics) diff --git a/tests/test_butler_factory.py b/tests/test_butler_factory.py index 279643f05f..c199448f4d 100644 --- a/tests/test_butler_factory.py +++ b/tests/test_butler_factory.py @@ -44,7 +44,8 @@ class ButlerFactoryTestCase(unittest.TestCase): @classmethod def setUpClass(cls): repo_dir = cls.enterClassContext(tempfile.TemporaryDirectory()) - makeTestRepo(repo_dir) + butler = makeTestRepo(repo_dir) + butler.close() cls.config_file_uri = f"file://{repo_dir}" def test_factory_via_global_repository_index(self): @@ -53,10 +54,12 @@ def test_factory_via_global_repository_index(self): index_file.flush() with mock_env({"DAF_BUTLER_REPOSITORY_INDEX": index_file.name}): factory = LabeledButlerFactory() + self.addCleanup(factory.close) self._test_factory(factory) def test_factory_via_custom_index(self): factory = LabeledButlerFactory({"test_repo": self.config_file_uri}) + self.addCleanup(factory.close) self._test_factory(factory) def _test_factory(self, factory: LabeledButlerFactory) -> None: diff --git a/tests/test_cliCmdIngestFiles.py b/tests/test_cliCmdIngestFiles.py index bfd4672b17..0e85a7d171 100644 --- a/tests/test_cliCmdIngestFiles.py +++ b/tests/test_cliCmdIngestFiles.py @@ -52,6 +52,7 @@ def setUp(self): self.addCleanup(removeTestTempDir, self.root) self.testRepo = MetricTestRepo(self.root, configFile=self.configFile) + self.enterContext(self.testRepo.butler) self.root2 = makeTestTempDir(TESTDIR) self.addCleanup(removeTestTempDir, self.root2) @@ -105,6 +106,7 @@ def assertIngest(self, table, options): self.assertEqual(result.exit_code, 0, clickResultMsg(result)) butler = Butler.from_config(self.root) + self.enterContext(butler) refs = list(butler.registry.queryDatasets("test_metric_comp", collections=run)) self.assertEqual(len(refs), 2) diff --git a/tests/test_cliCmdPruneDatasets.py b/tests/test_cliCmdPruneDatasets.py index 269e3f1724..24b8416e40 100644 --- a/tests/test_cliCmdPruneDatasets.py +++ b/tests/test_cliCmdPruneDatasets.py @@ -29,7 +29,7 @@ import unittest from itertools import chain -from unittest.mock import patch +from unittest.mock import ANY, patch from astropy.table import Table @@ -99,9 +99,9 @@ def setUp(self): self.repo = "here" @staticmethod - def makeQueryDatasetsArgs(*, repo, **kwargs): + def makeQueryDatasetsArgs(**kwargs): expectedArgs = dict( - repo=repo, collections=("*",), where="", find_first=True, show_uri=False, glob=tuple() + butler=ANY, collections=("*",), where="", find_first=True, show_uri=False, glob=tuple() ) expectedArgs.update(kwargs) return expectedArgs @@ -233,7 +233,7 @@ def test_defaults_doContinue(self): self.run_test( cliArgs=["myCollection", "--unstore"], exPruneDatasetsCallArgs=self.makePruneDatasetsArgs(refs=getRefs(), unstore=True), - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(repo=self.repo, collections=("myCollection",)), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=("myCollection",)), exGetTablesCalled=True, exMsgs=( pruneDatasets_willRemoveMsg, @@ -253,7 +253,7 @@ def test_defaults_doNotContinue(self): self.run_test( cliArgs=["myCollection", "--unstore"], exPruneDatasetsCallArgs=None, - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(repo=self.repo, collections=("myCollection",)), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=("myCollection",)), exGetTablesCalled=True, exMsgs=( pruneDatasets_willRemoveMsg, @@ -272,7 +272,7 @@ def test_dryRun_unstore(self): self.run_test( cliArgs=["myCollection", "--dry-run", "--unstore"], exPruneDatasetsCallArgs=None, - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(repo=self.repo, collections=("myCollection",)), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=("myCollection",)), exGetTablesCalled=True, exMsgs=(pruneDatasets_wouldRemoveMsg, astropyTablesToStr(getTables())), ) @@ -287,7 +287,7 @@ def test_dryRun_disassociate(self): self.run_test( cliArgs=[collection, "--dry-run", "--disassociate", "tag1"], exPruneDatasetsCallArgs=None, - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(repo=self.repo, collections=(collection,)), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=(collection,)), exGetTablesCalled=True, exMsgs=( pruneDatasets_wouldDisassociateMsg.format(collections=(collection,)), @@ -305,7 +305,7 @@ def test_dryRun_unstoreAndDisassociate(self): self.run_test( cliArgs=[collection, "--dry-run", "--unstore", "--disassociate", "tag1"], exPruneDatasetsCallArgs=None, - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(repo=self.repo, collections=(collection,)), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=(collection,)), exGetTablesCalled=True, exMsgs=( pruneDatasets_wouldDisassociateAndRemoveMsg.format(collections=(collection,)), @@ -323,7 +323,7 @@ def test_noConfirm(self): self.run_test( cliArgs=["myCollection", "--no-confirm", "--unstore"], exPruneDatasetsCallArgs=self.makePruneDatasetsArgs(refs=getRefs(), unstore=True), - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(repo=self.repo, collections=("myCollection",)), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=("myCollection",)), exGetTablesCalled=True, exMsgs=(pruneDatasets_didRemoveMsg, astropyTablesToStr(getTables())), ) @@ -337,7 +337,7 @@ def test_quiet(self): self.run_test( cliArgs=["myCollection", "--quiet", "--unstore"], exPruneDatasetsCallArgs=self.makePruneDatasetsArgs(refs=getRefs(), unstore=True), - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(repo=self.repo, collections=("myCollection",)), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=("myCollection",)), exGetTablesCalled=True, exMsgs=None, ) @@ -373,9 +373,7 @@ def test_noDatasets(self): self.run_test( cliArgs=["myCollection", "--unstore"], exPruneDatasetsCallArgs=None, - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs( - repo=self.repo, collections=("myCollection",) - ), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=("myCollection",)), exGetTablesCalled=True, exMsgs=(pruneDatasets_noDatasetsFound,), ) @@ -428,9 +426,7 @@ def test_purgeImpliedArgs(self, mockGetCollectionType): exPruneDatasetsCallArgs=self.makePruneDatasetsArgs( purge=True, refs=getRefs(), disassociate=True, unstore=True ), - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs( - repo=self.repo, collections=("run",), find_first=True - ), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=("run",), find_first=True), exGetTablesCalled=True, exMsgs=( pruneDatasets_willRemoveMsg, @@ -454,7 +450,7 @@ def test_purgeImpliedArgsWithCollections(self, mockGetCollectionType): purge=True, disassociate=True, unstore=True, refs=getRefs() ), exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs( - repo=self.repo, collections=("myCollection",), find_first=True + collections=("myCollection",), find_first=True ), exGetTablesCalled=True, exMsgs=( @@ -502,9 +498,7 @@ def test_disassociateImpliedArgs(self): exPruneDatasetsCallArgs=self.makePruneDatasetsArgs( tags=("tag1", "tag2"), disassociate=True, refs=getRefs() ), - exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs( - repo=self.repo, collections=("tag1", "tag2"), find_first=True - ), + exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs(collections=("tag1", "tag2"), find_first=True), exGetTablesCalled=True, exMsgs=(pruneDatasets_didRemoveMsg, astropyTablesToStr(getTables())), ) @@ -519,7 +513,7 @@ def test_disassociateImpliedArgsWithCollections(self): tags=("tag1", "tag2"), disassociate=True, refs=getRefs() ), exQueryDatasetsCallArgs=self.makeQueryDatasetsArgs( - repo=self.repo, collections=("myCollection",), find_first=True + collections=("myCollection",), find_first=True ), exGetTablesCalled=True, exMsgs=(pruneDatasets_didRemoveMsg, astropyTablesToStr(getTables())), diff --git a/tests/test_cliCmdQueryCollections.py b/tests/test_cliCmdQueryCollections.py index db794c6ce7..644c211376 100644 --- a/tests/test_cliCmdQueryCollections.py +++ b/tests/test_cliCmdQueryCollections.py @@ -154,6 +154,7 @@ def testGetCollections(self): butlerCfg = Butler.makeRepo("here") # the purpose of this call is to create some collections butler = Butler.from_config(butlerCfg, run=run, collections=[tag], writeable=True) + self.enterContext(butler) butler.registry.registerCollection(tag, CollectionType.TAGGED) # Verify collections that were created are found by @@ -196,6 +197,7 @@ def testChained(self): butlerCfg = Butler.makeRepo("here") butler1 = Butler.from_config(butlerCfg, writeable=True) + self.enterContext(butler1) # Replace datastore functions with mocks: DatastoreMock.apply(butler1) diff --git a/tests/test_cliCmdQueryDataIds.py b/tests/test_cliCmdQueryDataIds.py index fd4f300f99..5b94635b75 100644 --- a/tests/test_cliCmdQueryDataIds.py +++ b/tests/test_cliCmdQueryDataIds.py @@ -71,6 +71,7 @@ def loadData(self, *filenames: str) -> Butler: which should be a YAML import/export file. """ butler = Butler.from_config(self.repo, writeable=True) + self.enterContext(butler) assert isinstance(butler, DirectButler), "Test expects DirectButler" for filename in filenames: butler.import_( diff --git a/tests/test_cliCmdQueryDatasets.py b/tests/test_cliCmdQueryDatasets.py index 48b0f7ef70..920e5688d2 100644 --- a/tests/test_cliCmdQueryDatasets.py +++ b/tests/test_cliCmdQueryDatasets.py @@ -173,6 +173,7 @@ def testChained(self): testRepo = MetricTestRepo( self.repoDir, configFile=os.path.join(TESTDIR, "config/basic/butler-chained.yaml") ) + self.enterContext(testRepo.butler) tables = self._queryDatasets(repo=testRepo.butler, show_uri=True, collections="*", glob="*") @@ -191,6 +192,7 @@ def testChained(self): def testShowURI(self): """Test for expected output with show_uri=True.""" testRepo = MetricTestRepo(self.repoDir, configFile=self.configFile) + self.enterContext(testRepo.butler) tables = self._queryDatasets(repo=testRepo.butler, show_uri=True, collections="*", glob="*") @@ -208,6 +210,7 @@ def testShowUriNoDisassembly(self): configFile=self.configFile, storageClassName="StructuredCompositeReadCompNoDisassembly", ) + self.enterContext(testRepo.butler) tables = self._queryDatasets(repo=testRepo.butler, show_uri=True, collections="*", glob="*") @@ -251,6 +254,7 @@ def testShowUriNoDisassembly(self): def testNoShowURI(self): """Test for expected output without show_uri (default is False).""" testRepo = MetricTestRepo(self.repoDir, configFile=self.configFile) + self.enterContext(testRepo.butler) tables = self._queryDatasets(repo=testRepo.butler, collections="*", glob="*") @@ -273,6 +277,7 @@ def testWhere(self): queryDatasets. """ testRepo = MetricTestRepo(self.repoDir, configFile=self.configFile) + self.enterContext(testRepo.butler) for glob in (("*",), ("test_metric_comp",)): with self.subTest(glob=glob): @@ -299,6 +304,7 @@ def testGlobDatasetType(self): """Test specifying dataset type.""" # Create and register an additional DatasetType testRepo = MetricTestRepo(self.repoDir, configFile=self.configFile) + self.enterContext(testRepo.butler) testRepo.butler.registry.insertDimensionData( "visit", @@ -346,6 +352,7 @@ def test_limit_order(self): """Test limit and ordering.""" # Create and register an additional DatasetType testRepo = MetricTestRepo(self.repoDir, configFile=self.configFile) + self.enterContext(testRepo.butler) with self.assertLogs("lsst.daf.butler.script.queryDatasets", level="WARNING") as cm: tables = self._queryDatasets( @@ -417,6 +424,7 @@ def testFindFirstAndCollections(self): is required for find-first. """ testRepo = MetricTestRepo(self.repoDir, configFile=self.configFile) + self.enterContext(testRepo.butler) # Add a new run, and add a dataset to shadow an existing dataset. testRepo.addDataset(run="foo", dataId={"instrument": "DummyCamComp", "visit": 424}) diff --git a/tests/test_cliCmdQueryDimensionRecords.py b/tests/test_cliCmdQueryDimensionRecords.py index e90250e693..bac2510b33 100644 --- a/tests/test_cliCmdQueryDimensionRecords.py +++ b/tests/test_cliCmdQueryDimensionRecords.py @@ -77,6 +77,7 @@ def setUp(self): self.testRepo = MetricTestRepo( self.root, configFile=os.path.join(TESTDIR, "config/basic/butler.yaml") ) + self.enterContext(self.testRepo.butler) self.runner = LogCliRunner() def tearDown(self): @@ -162,6 +163,7 @@ def testWhere(self): def testCollection(self): butler = Butler.from_config(self.root, run="foo") + self.enterContext(butler) # try replacing the testRepo's butler with the one with the "foo" run. self.testRepo.butler = butler @@ -270,6 +272,7 @@ def testCollection(self): def testSkymap(self): butler = Butler.from_config(self.root, run="foo") + self.enterContext(butler) # try replacing the testRepo's butler with the one with the "foo" run. self.testRepo.butler = butler diff --git a/tests/test_cliCmdRemoveCollections.py b/tests/test_cliCmdRemoveCollections.py index cf5dfecd1a..4998ee5636 100644 --- a/tests/test_cliCmdRemoveCollections.py +++ b/tests/test_cliCmdRemoveCollections.py @@ -71,6 +71,7 @@ def setUp(self): self.testRepo = MetricTestRepo( self.root, configFile=os.path.join(TESTDIR, "config/basic/butler.yaml") ) + self.enterContext(self.testRepo.butler) def tearDown(self): removeTestTempDir(self.root) @@ -219,6 +220,7 @@ def testRemoveCmd(self): # verify chained-run-1 was removed: butler = Butler.from_config(self.root) + self.enterContext(butler) collections = butler.registry.queryCollections( collectionTypes=frozenset( ( @@ -275,6 +277,7 @@ def testRemoveCmd(self): def testRemoveFromParents(self) -> None: butler = Butler(self.root, writeable=True) + self.enterContext(butler) butler.collections.register("tag1", CollectionType.TAGGED) butler.collections.register("tag2", CollectionType.TAGGED) butler.collections.register("chain1", CollectionType.CHAINED) diff --git a/tests/test_cliCmdRemoveRuns.py b/tests/test_cliCmdRemoveRuns.py index 0c0b82390e..63cc7c3fcb 100644 --- a/tests/test_cliCmdRemoveRuns.py +++ b/tests/test_cliCmdRemoveRuns.py @@ -59,6 +59,7 @@ def test_removeRuns(self): with self.runner.isolated_filesystem(): root = "repo" repo = MetricTestRepo(root, configFile=os.path.join(TESTDIR, "config/basic/butler.yaml")) + self.enterContext(repo.butler) # Add a dataset type that will have no datasets to make sure it # isn't printed. repo.butler.registry.registerDatasetType( @@ -120,7 +121,8 @@ def test_removeRuns(self): # Remake the repo and check --no-confirm option. root = "repo1" - MetricTestRepo(root, configFile=os.path.join(TESTDIR, "config/basic/butler.yaml")) + repo1 = MetricTestRepo(root, configFile=os.path.join(TESTDIR, "config/basic/butler.yaml")) + repo1.butler.close() # Add the run to a CHAINED collection. parentCollection = "parent" diff --git a/tests/test_cliCmdRetrieveArtifacts.py b/tests/test_cliCmdRetrieveArtifacts.py index 16d1542bc4..30c256dedf 100644 --- a/tests/test_cliCmdRetrieveArtifacts.py +++ b/tests/test_cliCmdRetrieveArtifacts.py @@ -48,6 +48,7 @@ class CliRetrieveArtifactsTest(unittest.TestCase, ButlerTestHelper): def setUp(self): self.root = makeTestTempDir(TESTDIR) self.testRepo = MetricTestRepo(self.root, configFile=self.configFile) + self.enterContext(self.testRepo.butler) def tearDown(self): removeTestTempDir(self.root) diff --git a/tests/test_dimension_record_containers.py b/tests/test_dimension_record_containers.py index 85168015bd..157316d61f 100644 --- a/tests/test_dimension_record_containers.py +++ b/tests/test_dimension_record_containers.py @@ -60,6 +60,7 @@ def setUpClass(cls): # Create an in-memory SQLite database and Registry just to import the # YAML data. cls.butler = create_populated_sqlite_registry(*DIMENSION_DATA_FILES) + cls.enterClassContext(cls.butler) cls.records = { element: tuple(list(cls.butler.registry.queryDimensionRecords(element))) for element in ("visit", "skymap", "patch") diff --git a/tests/test_dimensions.py b/tests/test_dimensions.py index c77d33f019..95a210136b 100644 --- a/tests/test_dimensions.py +++ b/tests/test_dimensions.py @@ -73,9 +73,9 @@ def loadDimensionData() -> DataCoordinateSequence: """ # Create an in-memory SQLite database and Registry just to import the YAML # data and retrieve it as a set of DataCoordinate objects. - butler = create_populated_sqlite_registry(DIMENSION_DATA_FILE) - dimensions = butler.registry.dimensions.conform(["visit", "detector", "tract", "patch"]) - return butler.registry.queryDataIds(dimensions).expanded().toSequence() + with create_populated_sqlite_registry(DIMENSION_DATA_FILE) as butler: + dimensions = butler.registry.dimensions.conform(["visit", "detector", "tract", "patch"]) + return butler.registry.queryDataIds(dimensions).expanded().toSequence() class ConcreteTestDimensionPacker(DimensionPacker): diff --git a/tests/test_logFormatter.py b/tests/test_logFormatter.py index ae3fd36d88..2d87e2ca8c 100644 --- a/tests/test_logFormatter.py +++ b/tests/test_logFormatter.py @@ -56,6 +56,7 @@ def setUp(self): self.run = "testrun" self.butler = Butler.from_config(self.root, run=self.run) + self.enterContext(self.butler) self.datasetType = DatasetType("test_logs", [], "ButlerLogRecords", universe=self.butler.dimensions) self.butler.registry.registerDatasetType(self.datasetType) diff --git a/tests/test_matplotlibFormatter.py b/tests/test_matplotlibFormatter.py index d54683117d..ca60ee6a10 100644 --- a/tests/test_matplotlibFormatter.py +++ b/tests/test_matplotlibFormatter.py @@ -65,6 +65,7 @@ def tearDown(self): def testMatplotlibFormatter(self): butler = Butler.from_config(self.root, run="testrun") + self.enterContext(butler) datasetType = DatasetType("test_plot", [], "Plot", universe=butler.dimensions) butler.registry.registerDatasetType(datasetType) # Does not have to be a random image diff --git a/tests/test_obscore.py b/tests/test_obscore.py index 6c2ce39877..7553720e20 100644 --- a/tests/test_obscore.py +++ b/tests/test_obscore.py @@ -70,6 +70,7 @@ def make_registry( """Create new empty Registry.""" config = self.make_registry_config(collections, collection_type) registry = _RegistryFactory(config).create_from_config(butlerRoot=self.root) + self.addCleanup(registry.close) self.initialize_registry(registry) return registry @@ -526,7 +527,9 @@ def make_registry( ) -> SqlRegistry: """Create new empty Registry.""" original = super().make_registry(collections, collection_type) - return original.copy() + copy = original.copy() + self.addCleanup(copy.close) + return copy class PostgresObsCoreTest(ObsCoreTests, unittest.TestCase): diff --git a/tests/test_packages.py b/tests/test_packages.py index 506acdd6f7..1bfd53a6f0 100644 --- a/tests/test_packages.py +++ b/tests/test_packages.py @@ -45,6 +45,7 @@ def setUp(self): self.root = makeTestTempDir(TESTDIR) Butler.makeRepo(self.root) self.butler = Butler.from_config(self.root, run="test_run") + self.enterContext(self.butler) # No dimensions in dataset type so we don't have to worry about # inserting dimension data or defining data IDs. self.datasetType = DatasetType( diff --git a/tests/test_parquet.py b/tests/test_parquet.py index 80b7fc937f..7704d33c56 100644 --- a/tests/test_parquet.py +++ b/tests/test_parquet.py @@ -363,6 +363,7 @@ def setUp(self): self.butler = Butler.from_config( Butler.makeRepo(self.root, config=config), writeable=True, run=self.run ) + self.enterContext(self.butler) # No dimensions in dataset type so we don't have to worry about # inserting dimension data or defining data IDs. self.datasetType = DatasetType( @@ -957,6 +958,7 @@ def setUp(self): self.butler = Butler.from_config( Butler.makeRepo(self.root, config=config), writeable=True, run=self.run ) + self.enterContext(self.butler) # No dimensions in dataset type so we don't have to worry about # inserting dimension data or defining data IDs. self.datasetType = DatasetType( @@ -1352,6 +1354,7 @@ def setUp(self): self.butler = Butler.from_config( Butler.makeRepo(self.root, config=config), writeable=True, run="test_run" ) + self.enterContext(self.butler) # No dimensions in dataset type so we don't have to worry about # inserting dimension data or defining data IDs. self.datasetType = DatasetType( @@ -1651,6 +1654,7 @@ def setUp(self): self.butler = Butler.from_config( Butler.makeRepo(self.root, config=config), writeable=True, run="test_run" ) + self.enterContext(self.butler) # No dimensions in dataset type so we don't have to worry about # inserting dimension data or defining data IDs. self.datasetType = DatasetType( @@ -1964,6 +1968,7 @@ def setUp(self): self.butler = Butler.from_config( Butler.makeRepo(self.root, config=config), writeable=True, run="test_run" ) + self.enterContext(self.butler) # No dimensions in dataset type so we don't have to worry about # inserting dimension data or defining data IDs. self.datasetType = DatasetType( @@ -2125,6 +2130,7 @@ def setUp(self): self.butler = Butler.from_config( Butler.makeRepo(self.root, config=config), writeable=True, run="test_run" ) + self.enterContext(self.butler) # No dimensions in dataset type so we don't have to worry about # inserting dimension data or defining data IDs. self.datasetType = DatasetType( @@ -2334,6 +2340,7 @@ def setUp(self): self.tmpConfigFile = posixpath.join(rooturi, "butler.yaml") self.butler = Butler(self.tmpConfigFile, writeable=True, run="test_run") + self.enterContext(self.butler) # No dimensions in dataset type so we don't have to worry about # inserting dimension data or defining data IDs. diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index e7c2667532..540d73bfcd 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -32,6 +32,7 @@ import unittest import warnings from contextlib import contextmanager +from typing import cast import astropy.time import sqlalchemy @@ -218,12 +219,15 @@ def make_butler(self, config: RegistryConfig | None = None) -> Butler: self.postgres.patch_registry_config(config) registry = _RegistryFactory(config).create_from_config() - return DirectButler( + butler = DirectButler( config=ButlerConfig(), registry=registry, datastore=NullDatastore(None, None), storageClasses=StorageClassFactory(), ) + cast(unittest.TestCase, self).enterContext(butler) + + return butler def testSkipCalibs(self): if self.postgres.server_major_version() < 16: diff --git a/tests/test_quantumBackedButler.py b/tests/test_quantumBackedButler.py index 99acacdaf7..dd6063802e 100644 --- a/tests/test_quantumBackedButler.py +++ b/tests/test_quantumBackedButler.py @@ -67,8 +67,10 @@ def setUp(self) -> None: # Make a butler and import dimension definitions. registryConfig = RegistryConfig(self.config.get("registry")) - _RegistryFactory(registryConfig).create_from_config(butlerRoot=self.root) + registry = _RegistryFactory(registryConfig).create_from_config(butlerRoot=self.root) + registry.close() butler = Butler.from_config(self.config, writeable=True, run="RUN", metrics=self.metrics) + self.enterContext(butler) assert isinstance(butler, DirectButler) self.butler = butler self.butler.import_(filename="resource://lsst.daf.butler/tests/registry_data/base.yaml") @@ -149,6 +151,7 @@ def test_initialize(self) -> None: dataset_types=self.dataset_types, metrics=self.metrics, ) + self.addCleanup(qbb.close) self._test_factory(qbb) def test_initialize_repo_index(self) -> None: @@ -170,6 +173,7 @@ def test_initialize_repo_index(self) -> None: dataset_types=self.dataset_types, metrics=self.metrics, ) + self.addCleanup(qbb.close) self._test_factory(qbb) def test_from_predicted(self) -> None: @@ -183,6 +187,7 @@ def test_from_predicted(self) -> None: datastore_records=datastore_records, dataset_types=self.dataset_types, ) + self.addCleanup(qbb.close) self._test_factory(qbb) def _test_factory(self, qbb: QuantumBackedButler) -> None: @@ -205,6 +210,7 @@ def test_getput(self) -> None: dataset_types=self.dataset_types, metrics=self.metrics, ) + self.addCleanup(qbb.close) # Verify all input data are readable. for ref in self.input_refs: @@ -261,6 +267,7 @@ def test_getDeferred(self) -> None: qbb = QuantumBackedButler.initialize( config=self.config, quantum=quantum, dimensions=self.universe, dataset_types=self.dataset_types ) + self.addCleanup(qbb.close) # get some input data input_refs = self.input_refs[:2] @@ -286,6 +293,7 @@ def test_stored(self) -> None: qbb = QuantumBackedButler.initialize( config=self.config, quantum=quantum, dimensions=self.universe, dataset_types=self.dataset_types ) + self.addCleanup(qbb.close) # get some input data input_refs = self.input_refs[:2] @@ -320,6 +328,7 @@ def test_markInputUnused(self) -> None: qbb = QuantumBackedButler.initialize( config=self.config, quantum=quantum, dimensions=self.universe, dataset_types=self.dataset_types ) + self.addCleanup(qbb.close) # get some input data for ref in self.input_refs: @@ -341,6 +350,7 @@ def test_pruneDatasets(self) -> None: qbb = QuantumBackedButler.initialize( config=self.config, quantum=quantum, dimensions=self.universe, dataset_types=self.dataset_types ) + self.addCleanup(qbb.close) # Write all expected outputs. for ref in self.output_refs: @@ -385,6 +395,7 @@ def test_extract_provenance_data(self) -> None: qbb = QuantumBackedButler.initialize( config=self.config, quantum=quantum, dimensions=self.universe, dataset_types=self.dataset_types ) + self.addCleanup(qbb.close) # read/store everything for ref in self.input_refs: @@ -422,6 +433,7 @@ def test_export_predicted_datastore_records(self) -> None: qbb = QuantumBackedButler.initialize( config=self.config, quantum=quantum, dimensions=self.universe, dataset_types=self.dataset_types ) + self.addCleanup(qbb.close) records = qbb.export_predicted_datastore_records(self.output_refs) self.assertEqual(len(records["FileDatastore@/datastore"].records), len(self.output_refs)) @@ -435,11 +447,13 @@ def test_collect_and_transfer(self) -> None: qbb1 = QuantumBackedButler.initialize( config=self.config, quantum=quantum1, dimensions=self.universe, dataset_types=self.dataset_types ) + self.addCleanup(qbb1.close) quantum2 = self.make_quantum(2) qbb2 = QuantumBackedButler.initialize( config=self.config, quantum=quantum2, dimensions=self.universe, dataset_types=self.dataset_types ) + self.addCleanup(qbb2.close) # read/store everything for ref in self.input_refs: diff --git a/tests/test_query_direct_postgresql.py b/tests/test_query_direct_postgresql.py index 8678efcaca..e1c5eac5f2 100644 --- a/tests/test_query_direct_postgresql.py +++ b/tests/test_query_direct_postgresql.py @@ -61,6 +61,7 @@ def make_butler(self, *args: str) -> Butler: datastore=NullDatastore(None, None), storageClasses=StorageClassFactory(), ) + self.enterContext(butler) for arg in args: self.load_data(butler, arg) return butler diff --git a/tests/test_query_direct_sqlite.py b/tests/test_query_direct_sqlite.py index a2194c8545..adfabe0498 100644 --- a/tests/test_query_direct_sqlite.py +++ b/tests/test_query_direct_sqlite.py @@ -49,6 +49,7 @@ def make_butler(self, *args: str) -> Butler: # dir but create_populated_sqlite_registry does not and for consistency # call load_data to match the other usages. butler = create_populated_sqlite_registry() + self.enterContext(butler) for arg in args: self.load_data(butler, arg) return butler diff --git a/tests/test_server.py b/tests/test_server.py index 694abc4c2f..945c84da31 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -247,6 +247,7 @@ def override_read(http_resource_path): collections=["collection1", "collection2"], run="collection2", ) + self.enterContext(butler) self.assertIsInstance(butler, RemoteButler) self.assertEqual(butler._connection.server_url, server_url) self.assertEqual(butler.collections.defaults, ("collection1", "collection2")) diff --git a/tests/test_simpleButler.py b/tests/test_simpleButler.py index 63668b6bd7..ef717a583b 100644 --- a/tests/test_simpleButler.py +++ b/tests/test_simpleButler.py @@ -599,12 +599,14 @@ def testRegistryDefaults(self): # This should not have a default instrument, because there are two. # Pass run instead of collections; this should set both. butler2 = Butler.from_config(butler=butler, run="imported_g") + self.enterContext(butler2) self.assertEqual(list(butler2.registry.defaults.collections), ["imported_g"]) self.assertEqual(butler2.registry.defaults.run, "imported_g") self.assertFalse(butler2.registry.defaults.dataId) # Initialize a new butler with an instrument default explicitly given. # Set collections instead of run, which should then be None. butler3 = Butler.from_config(butler=butler, collections=["imported_g"], instrument="Cam2") + self.enterContext(butler3) self.assertEqual(list(butler3.registry.defaults.collections), ["imported_g"]) self.assertIsNone(butler3.registry.defaults.run, None) self.assertEqual(butler3.registry.defaults.dataId.required, {"instrument": "Cam2"}) @@ -900,11 +902,13 @@ def makeButler(self, writeable: bool = False) -> Butler: # have to make a registry first registryConfig = RegistryConfig(config.get("registry")) - _RegistryFactory(registryConfig).create_from_config() + registry = _RegistryFactory(registryConfig).create_from_config() + registry.close() # Write the YAML file so that some tests can recreate butler from it. config.dumpToUri(os.path.join(self.root, "butler.yaml")) butler = Butler.from_config(config, writeable=writeable) + self.enterContext(butler) DatastoreMock.apply(butler) return butler @@ -929,6 +933,7 @@ def test_dataset_uris(self): index_file.flush() with mock_env({"DAF_BUTLER_REPOSITORY_INDEX": index_file.name}): butler_factory = LabeledButlerFactory() + self.addCleanup(butler_factory.close) factory = butler_factory.bind(access_token=None) for dataset_uri in ( @@ -938,6 +943,7 @@ def test_dataset_uris(self): f"ivo://org.rubinobs/usdac/lsst-dp1?repo={label}&id={ref.id}", ): result = Butler.get_dataset_from_uri(dataset_uri) + self.enterContext(result.butler) self.assertEqual(result.dataset, ref) # The returned butler needs to have the datastore mocked. DatastoreMock.apply(result.butler) @@ -945,6 +951,7 @@ def test_dataset_uris(self): self.assertEqual(dataset_id, ref.id) factory_result = Butler.get_dataset_from_uri(dataset_uri, factory=factory) + self.enterContext(factory_result.butler) self.assertEqual(factory_result.dataset, ref) # The returned butler needs to have the datastore mocked. DatastoreMock.apply(factory_result.butler) @@ -954,6 +961,7 @@ def test_dataset_uris(self): # Non existent dataset. missing_id = str(ref.id).replace("2", "3") result = Butler.get_dataset_from_uri(f"butler://{label}/{missing_id}") + self.enterContext(result.butler) self.assertIsNone(result.dataset) # Test some failure modes. diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index ed05175db7..06a8da05b4 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -31,6 +31,7 @@ import tempfile import unittest from contextlib import contextmanager +from typing import cast import sqlalchemy @@ -86,11 +87,15 @@ def tearDown(self): def makeEmptyDatabase(self, origin: int = 0) -> SqliteDatabase: _, filename = tempfile.mkstemp(dir=self.root, suffix=".sqlite3") engine = SqliteDatabase.makeEngine(filename=filename) - return SqliteDatabase.fromEngine(engine=engine, origin=origin) + db = SqliteDatabase.fromEngine(engine=engine, origin=origin) + self.addCleanup(db.dispose) + return db def getNewConnection(self, database: SqliteDatabase, *, writeable: bool) -> SqliteDatabase: engine = SqliteDatabase.makeEngine(filename=database.filename, writeable=writeable) - return SqliteDatabase.fromEngine(origin=database.origin, engine=engine, writeable=writeable) + db = SqliteDatabase.fromEngine(origin=database.origin, engine=engine, writeable=writeable) + self.addCleanup(db.dispose) + return db @contextmanager def asReadOnly(self, database: SqliteDatabase) -> SqliteDatabase: @@ -104,12 +109,14 @@ def testConnection(self): _, filename = tempfile.mkstemp(dir=self.root, suffix=".sqlite3") # Create a read-write database by passing in the filename. rwFromFilename = SqliteDatabase.fromEngine(SqliteDatabase.makeEngine(filename=filename), origin=0) + self.addCleanup(rwFromFilename.dispose) self.assertEqual(os.path.realpath(rwFromFilename.filename), os.path.realpath(filename)) self.assertEqual(rwFromFilename.origin, 0) self.assertTrue(rwFromFilename.isWriteable()) self.assertTrue(isEmptyDatabaseActuallyWriteable(rwFromFilename)) # Create a read-write database via a URI. rwFromUri = SqliteDatabase.fromUri(f"sqlite:///{filename}", origin=0) + self.addCleanup(rwFromUri.dispose) self.assertEqual(os.path.realpath(rwFromUri.filename), os.path.realpath(filename)) self.assertEqual(rwFromUri.origin, 0) self.assertTrue(rwFromUri.isWriteable()) @@ -124,12 +131,14 @@ def testConnection(self): roFromFilename = SqliteDatabase.fromEngine( SqliteDatabase.makeEngine(filename=filename), origin=0, writeable=False ) + self.addCleanup(roFromFilename.dispose) self.assertEqual(os.path.realpath(roFromFilename.filename), os.path.realpath(filename)) self.assertEqual(roFromFilename.origin, 0) self.assertFalse(roFromFilename.isWriteable()) self.assertFalse(isEmptyDatabaseActuallyWriteable(roFromFilename)) # Create a read-write database via a URI. roFromUri = SqliteDatabase.fromUri(f"sqlite:///{filename}", origin=0, writeable=False) + self.addCleanup(roFromUri.dispose) self.assertEqual(os.path.realpath(roFromUri.filename), os.path.realpath(filename)) self.assertEqual(roFromUri.origin, 0) self.assertFalse(roFromUri.isWriteable()) @@ -146,10 +155,14 @@ class SqliteMemoryDatabaseTestCase(unittest.TestCase, DatabaseTests): def makeEmptyDatabase(self, origin: int = 0) -> SqliteDatabase: engine = SqliteDatabase.makeEngine(filename=None) - return SqliteDatabase.fromEngine(engine=engine, origin=origin) + db = SqliteDatabase.fromEngine(engine=engine, origin=origin) + self.addCleanup(db.dispose) + return db def getNewConnection(self, database: SqliteDatabase, *, writeable: bool) -> SqliteDatabase: - return SqliteDatabase.fromEngine(origin=database.origin, engine=database._engine, writeable=writeable) + db = SqliteDatabase.fromEngine(origin=database.origin, engine=database._engine, writeable=writeable) + self.addCleanup(db.dispose) + return db @contextmanager def asReadOnly(self, database: SqliteDatabase) -> SqliteDatabase: @@ -161,12 +174,14 @@ def testConnection(self): """ # Create an in-memory database by passing filename=None. memFromFilename = SqliteDatabase.fromEngine(SqliteDatabase.makeEngine(filename=None), origin=0) + self.addCleanup(memFromFilename.dispose) self.assertIsNone(memFromFilename.filename) self.assertEqual(memFromFilename.origin, 0) self.assertTrue(memFromFilename.isWriteable()) self.assertTrue(isEmptyDatabaseActuallyWriteable(memFromFilename)) # Create an in-memory database via a URI. memFromUri = SqliteDatabase.fromUri("sqlite://", origin=0) + self.addCleanup(memFromUri.dispose) self.assertIsNone(memFromUri.filename) self.assertEqual(memFromUri.origin, 0) self.assertTrue(memFromUri.isWriteable()) @@ -209,7 +224,9 @@ def make_butler(self, registry_config: RegistryConfig | None = None) -> Butler: if registry_config is None: registry_config = self.makeRegistryConfig() config["registry"] = registry_config - return makeTestRepo(self.root, config=config) + butler = makeTestRepo(self.root, config=config) + cast(unittest.TestCase, self).enterContext(butler) + return butler class SqliteFileRegistryNameKeyCollMgrUUIDTestCase(SqliteFileRegistryTests, unittest.TestCase): @@ -260,7 +277,9 @@ def make_butler(self, registry_config: RegistryConfig | None = None) -> Butler: # with default managers. if registry_config is None: registry_config = self.makeRegistryConfig() - return create_populated_sqlite_registry(registry_config=registry_config) + butler = create_populated_sqlite_registry(registry_config=registry_config) + cast(unittest.TestCase, self).enterContext(butler) + return butler def testMissingAttributes(self): """Test for instantiating a registry against outdated schema which diff --git a/tests/test_testRepo.py b/tests/test_testRepo.py index 22c983bb62..6133779a3a 100644 --- a/tests/test_testRepo.py +++ b/tests/test_testRepo.py @@ -66,6 +66,7 @@ def testMakeTestRepo(self): } butler = makeTestRepo(self.root, dataIds) + self.enterContext(butler) records = list(butler.registry.queryDimensionRecords("visit")) self.assertEqual(len(records), 3) @@ -81,6 +82,7 @@ def setUpClass(cls): cls.root = makeTestTempDir(TESTDIR) cls.creatorButler = makeTestRepo(cls.root) + cls.enterClassContext(cls.creatorButler) addDataIdValue(cls.creatorButler, "instrument", "notACam") addDataIdValue(cls.creatorButler, "instrument", "dummyCam") addDataIdValue(cls.creatorButler, "physical_filter", "k2020", band="k", instrument="notACam") @@ -112,7 +114,8 @@ def testButlerKwargs(self): # outfile has the most obvious effects of any Butler.makeRepo keyword with safeTestTempDir(TESTDIR) as temp: path = os.path.join(temp, "oddConfig.json") - makeTestRepo(temp, {}, outfile=path) + butler = makeTestRepo(temp, {}, outfile=path) + self.enterContext(butler) self.assertTrue(os.path.isfile(path)) def _checkButlerDimension(self, dimensions, query, expected): @@ -212,6 +215,7 @@ def testRegisterMetricsExampleChained(self): repo = lsst.daf.butler.Butler.makeRepo(temp, config=config) butler = lsst.daf.butler.Butler.from_config(repo, run="chainedExample") + self.enterContext(butler) registerMetricsExample(butler) addDatasetType(butler, "DummyType", {}, "StructuredDataNoComponents") diff --git a/tests/test_versioning.py b/tests/test_versioning.py index ea28c4053c..ac42212f8c 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -102,7 +102,9 @@ def tearDown(self): def makeEmptyDatabase(self, origin: int = 0) -> Database: _, filename = tempfile.mkstemp(dir=self.root, suffix=".sqlite3") engine = SqliteDatabase.makeEngine(filename=filename) - return SqliteDatabase.fromEngine(engine=engine, origin=origin) + db = SqliteDatabase.fromEngine(engine=engine, origin=origin) + self.addCleanup(db.dispose) + return db def test_new_schema(self) -> None: """Test for creating new database schema."""