From 788284a5e6feaad9cb95ed12eef3f72fca4e3672 Mon Sep 17 00:00:00 2001 From: Zhongsheng Ji <9573586@qq.com> Date: Tue, 19 Dec 2023 14:32:21 +0800 Subject: [PATCH] Add subdir for NDArrayLoader to prevent collision of cache files (#78) * Add subdir for NDArrayLoader to prevent collision of cache files * Release model when cleanup --- .../components/optimize/ndarray_loader.py | 23 ++++++++++++++++--- sdgx/synthesizer.py | 5 ++++ tests/conftest.py | 2 ++ tests/optmize/test_ndarry_loader.py | 2 +- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sdgx/models/components/optimize/ndarray_loader.py b/sdgx/models/components/optimize/ndarray_loader.py index 16639796..12afcca2 100644 --- a/sdgx/models/components/optimize/ndarray_loader.py +++ b/sdgx/models/components/optimize/ndarray_loader.py @@ -1,13 +1,17 @@ from __future__ import annotations +import os import shutil from functools import cached_property from pathlib import Path from typing import Generator +from uuid import uuid4 import numpy as np from numpy import ndarray +DEFAULT_CACHE_ROOT = os.getenv("SDG_NDARRAY_CACHE_ROOT", "./.ndarry_cache") + class NDArrayLoader: """ @@ -16,10 +20,23 @@ class NDArrayLoader: Support for storing two-dimensional data by columns. """ - def __init__(self, cache_dir: str | Path = "./.ndarry_cache") -> None: + def __init__(self, cache_root: str | Path = DEFAULT_CACHE_ROOT) -> None: self.store_index = 0 - self.cache_dir = Path(cache_dir).expanduser().resolve() - self.cache_dir.mkdir(exist_ok=True, parents=True) + self.cache_root = Path(cache_root).expanduser().resolve() + self.cache_root.mkdir(exist_ok=True, parents=True) + + @cached_property + def subdir(self) -> str: + """ + Prevent collision of cache files. + """ + return uuid4().hex + + @cached_property + def cache_dir(self) -> Path: + """Cache directory for storing ndarray.""" + + return self.cache_root / self.subdir def _get_cache_filename(self, index: int) -> Path: return self.cache_dir / f"{index}.npy" diff --git a/sdgx/synthesizer.py b/sdgx/synthesizer.py index 9d9d3d23..d301441b 100644 --- a/sdgx/synthesizer.py +++ b/sdgx/synthesizer.py @@ -234,3 +234,8 @@ def generator_sample_caller(): def cleanup(self): if self.dataloader: self.dataloader.finalize(clear_cache=True) + # Release resources + del self.model + + def __del__(self): + self.cleanup() diff --git a/tests/conftest.py b/tests/conftest.py index 33d345a9..3b863a74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ import os + +os.environ["SDG_NDARRAY_CACHE_ROOT"] = "/tmp/sdgx/ndarray_cache" import shutil import pytest diff --git a/tests/optmize/test_ndarry_loader.py b/tests/optmize/test_ndarry_loader.py index b2315642..2bd0ccfc 100644 --- a/tests/optmize/test_ndarry_loader.py +++ b/tests/optmize/test_ndarry_loader.py @@ -7,7 +7,7 @@ @pytest.fixture def ndarray_loader(tmp_path, ndarray_list): cache_dir = tmp_path / "ndarrycache" - loader = NDArrayLoader(cache_dir=cache_dir) + loader = NDArrayLoader(cache_root=cache_dir) for ndarray in ndarray_list: loader.store(ndarray) yield loader