Skip to content

Commit

Permalink
refactor of CacheDirectory and CachePath classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Karl5766 committed Sep 19, 2024
1 parent 95329b0 commit d8e86d0
Show file tree
Hide file tree
Showing 12 changed files with 449 additions and 259 deletions.
2 changes: 2 additions & 0 deletions docs/API/imfs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ View source at `fs.py <https://github.com/khanlab/cvpl_tools/blob/main/src/cvpl_
:members:
.. autoclass:: cvpl_tools.im.fs.CacheDirectory
:members:
.. autoclass:: cvpl_tools.im.fs.CacheRootDirectory
:members:
24 changes: 12 additions & 12 deletions docs/GettingStarted/segmentation_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,22 @@ and two optional parameters: cid and viewer_args.

- cid specifies the subdirectory under the cache directory (set by the :code:`set_tmpdir` method of the base
class) to save intermediate files. If not provided (:code:`cid=None`),
then the cache will be saved in a temporary directory that will be removed when the CacheDirectory is
then the cache will be saved in a temporary directory that will be removed when the CacheRootDirectory is
closed. If provided, this cache file will persist. Within the :code:`forward()` method, you should use
:code:`self.tmpdir.cache()` and :code:`self.tmpdir.cache_im()` to create cache files:

.. code-block:: Python
class ExampleSegProcess(SegProcess):
def forward(self, im, cid: str = None, viewer: napari.Viewer = None):
cache_exists, cache_path = self.tmpdir.cache(is_dir=True, cid=cid)
def forward(self, im, cptr: CachePointer, viewer: napari.Viewer = None):
cache_path = cptr.subpath()
# in the case cache does not exists, cache_path.path is an empty path we can create a folder in:
if not cache_exists:
os.makedirs(cache_path.path)
# in the case cache does not exists, cache_path.abs_path is an empty path we can create a folder in:
if not cache_path.exists:
os.makedirs(cache_path.abs_path)
result = compute_result(im)
save(cache_path.path, result)
result = load(cache_path.path)
save(cache_path.abs_path, result)
result = load(cache_path.abs_path)
return result
- The :code:`viewer_args` parameter specifies the napari viewer to display the intermediate results. If not provided
Expand All @@ -139,11 +139,11 @@ and two optional parameters: cid and viewer_args.
.. code-block:: Python
class ExampleSegProcess(SegProcess):
def forward(self, im, cid: str = None, viewer_args: dict = None):
def forward(self, im, cptr: CachePointer, viewer_args: dict = None):
if viewer_args is None:
viewer_args = {}
result = compute_result(im)
result = self.tmpdir.cache_im(lambda: result, cid=cid, viewer_args=viewer_args)
result = cptr.im(lambda: result, viewer_args=viewer_args) # caching result at location pointed by cptr
return result
# ...
viewer = napari.Viewer(ndisplay=2)
Expand All @@ -154,7 +154,7 @@ and two optional parameters: cid and viewer_args.
multiscale=4 if viewer else 0, # maximum downsampling level of OME ZARR files, necessary for very large images
)
process = ExampleSegProcess()
process.forward(im, cid=cid, viewer_args=viewer_args)
process.forward(im, cptr = root_dir.cache(cid='compute'), viewer_args=viewer_args)
:code:`viewer_args` is a parameter that allows us to visualize the saved results as part of the caching
function. The reason we need this is that displaying the saved result often requires a different (flatter)
Expand All @@ -179,7 +179,7 @@ to segment an input dataset. Note we need a dask cluster and a temporary directo
from dask.distributed import Client
import napari
with (dask.config.set({'temporary_directory': TMP_PATH}),
imfs.CacheDirectory(
imfs.CacheRootDirectory(
f'{TMP_PATH}/CacheDirectory',
remove_when_done=False,
read_if_exists=True) as temp_directory):
Expand Down
35 changes: 18 additions & 17 deletions docs/GettingStarted/setting_up_the_script.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ accidentally close the command window. Below is an example:
logfile_stdout = open('log_stdout.txt', mode='w')
logfile_stderr = open('log_stderr.txt', mode='w')
sys.stdout = fs.MultiOutputStream(sys.stdout, logfile_stdout)
sys.stderr = fs.MultiOutputStream(sys.stderr, logfile_stderr)
sys.stdout = imfs.MultiOutputStream(sys.stdout, logfile_stdout)
sys.stderr = imfs.MultiOutputStream(sys.stderr, logfile_stderr)
import dask
import dask.config
Expand Down Expand Up @@ -130,7 +130,8 @@ may discard once computed, and for the others (like the final output) we may wan
for access later without having to redo the computation. In order to cache the result, we need a fixed path
that do not change across program executions. The :code:`CacheDirectory` class is one that manages and
assigns paths for these intermediate results, based on their cache ID (cid) and the parent CacheDirectory
they belongs to.
they belongs to. :code:`CacheRootDirectory` is a subclass of :code:`CacheDirectory` that acts as the root
of the cache directory structure.

In cvpl_tool's model of caching, there is a root cache directory that is created or loaded when the program
starts to run, and every cache directory may contain many sub-cache-directory or data directories in
Expand All @@ -140,33 +141,33 @@ which there are intermediate files. To create a cache directory, we write
if __name__ == '__main__':
import cvpl_tools.im.fs as imfs
with imfs.CacheDirectory(
with imfs.CacheRootDirectory(
f'{TMP_PATH}/CacheDirectory',
remove_when_done=False,
read_if_exists=True) as temp_directory):
# Use case #1. Create a data directory for caching computation results
cache_exists, cache_path = temp_directory.cache(is_dir=False, cid='some_cache_path')
if not cache_exists:
os.makedirs(cache_path.path, exists_ok=True)
# PUT CODE HERE: Now write your data into cache_path.path and load it back later
cache_path = temp_directory.cache_subpath(cid='some_cache_path')
if not cache_path.exists():
os.makedirs(cache_path.abs_path, exists_ok=True)
# PUT CODE HERE: Now write your data into cache_path.abs_path and load it back later
# Use case #2. Create a sub-directory and pass it to other processes for caching
def multi_step_computation(cache_at: imfs.CacheDirectory):
cache_exists, cache_path = cache_at.cache(is_dir=False, cid='A')
if not cache_exists:
cache_path = cache_at.cache_subpath(cid='A')
if not cache_path.exists:
A = computeA()
save(cache_path.path, A)
A = load(cache_path.path)
save(cache_path.abs_path, A)
A = load(cache_path.abs_path)
cache_exists_B, cache_path_B = cache_at.cache(is_dir=False, cid='B')
if not cache_exists_B:
cache_path_B = cache_at.cache_subpath(cid='B')
if not cache_path_B.exists:
B = computeBFromA()
save(cache_path_B.path, B)
B = load(cache_path_B.path)
save(cache_path_B.abs_path, B)
B = load(cache_path_B.abs_path)
return B
sub_temp_directory = temp_directory.cache(is_dir=True, cid='mult_step_cache')
sub_temp_directory = temp_directory.cache_subdir(cid='mult_step_cache')
result = multi_step_computation(cache_at=sub_temp_directory)
After running the above code once, caching files will be created. The second time the code is run, the computation
Expand Down
36 changes: 23 additions & 13 deletions src/cvpl_tools/im/dask_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy.typing as npt
from cvpl_tools.im.ndblock import NDBlock
from cvpl_tools.im.partd_server import SQLiteKVStore, SQLitePartd, SqliteServer
from cvpl_tools.im.fs import CacheDirectory
from cvpl_tools.im.fs import CacheDirectory, CachePointer
import os
from scipy.ndimage import label as scipy_label
import partd
Expand Down Expand Up @@ -71,11 +71,15 @@ def dfs(node, component):
return components


def label(im: npt.NDArray | da.Array | NDBlock, cache_dir: CacheDirectory, output_dtype: np.dtype = None,
def label(im: npt.NDArray | da.Array | NDBlock,
cptr: CachePointer,
output_dtype: np.dtype = None,
viewer_args: dict = None
) -> npt.NDArray | da.Array | NDBlock:
"""Return (lbl_im, nlbl) where lbl_im is a globally labeled image of the same type/chunk size as the input"""

cdir = cptr.subdir()

ndim = im.ndim
if viewer_args is None:
viewer_args = {}
Expand All @@ -86,7 +90,7 @@ def label(im: npt.NDArray | da.Array | NDBlock, cache_dir: CacheDirectory, outpu
is_dask = isinstance(im, da.Array)
if not is_dask:
assert isinstance(im, NDBlock)
im = im.as_dask_array(tmp_dirpath=cache_dir.path)
im = im.as_dask_array(tmp_dirpath=cdir.abs_path)

def map_block(block: npt.NDArray, block_info: dict):
lbl_im = scipy_label(block, output=output_dtype)[0]
Expand All @@ -98,7 +102,7 @@ def to_max(block: npt.NDArray, block_info: dict):
# compute locally labelled chunks and save their bordering slices
if is_logging:
print('Locally label the image')
locally_labeled = cache_dir.cache_im(
locally_labeled = cdir.cache_im(
lambda: im.map_blocks(map_block, meta=np.zeros(tuple(), dtype=output_dtype)),
cid='locally_labeled_without_cumsum'
)
Expand All @@ -116,15 +120,15 @@ def compute_nlbl_np_arr():
nlbl_np_arr = nlbl_ndblock_arr.as_numpy()
return nlbl_np_arr

nlbl_np_arr = cache_dir.cache_im(fn=compute_nlbl_np_arr, cid='nlbl_np_arr')
nlbl_np_arr = cdir.cache_im(fn=compute_nlbl_np_arr, cid='nlbl_np_arr')

def compute_cumsum_np_arr():
if is_logging:
print('Compute prefix sum and reshape back')
cumsum_np_arr = np.cumsum(nlbl_np_arr)
return cumsum_np_arr

cumsum_np_arr = cache_dir.cache_im(fn=compute_cumsum_np_arr, cid='cumsum_np_arr')
cumsum_np_arr = cdir.cache_im(fn=compute_cumsum_np_arr, cid='cumsum_np_arr')
assert cumsum_np_arr.ndim == 1
total_nlbl = cumsum_np_arr[-1].item()
cumsum_np_arr[1:] = cumsum_np_arr[:-1]
Expand All @@ -137,21 +141,22 @@ def compute_cumsum_np_arr():
# Prepare cache file to be used
if is_logging:
print('Setting up cache sqlite database')
_, cache_file = cache_dir.cache(cid='border_slices')
os.makedirs(cache_file.path, exist_ok=True)
db_path = f'{cache_file.path}/border_slices.db'
cache_file = cdir.cache_subpath(cid='border_slices')
slices_abs_path = cache_file.abs_path
os.makedirs(slices_abs_path, exist_ok=True)
db_path = f'{slices_abs_path}/border_slices.db'

def create_kv_store():
kv_store = PairKVStore(db_path)
return kv_store

def get_sqlite_partd():
partd = SQLitePartd(cache_file.path, create_kv_store=create_kv_store)
partd = SQLitePartd(slices_abs_path, create_kv_store=create_kv_store)
return partd

if is_logging:
print('Setting up partd server')
server = SqliteServer(cache_file.path, get_sqlite_partd=get_sqlite_partd)
server = SqliteServer(slices_abs_path, get_sqlite_partd=get_sqlite_partd)
server_address = server.address

# compute edge slices
Expand All @@ -177,7 +182,7 @@ def compute_slices(block: npt.NDArray, block2: npt.NDArray, block_info: dict = N
client.close()
return block

locally_labeled = cache_dir.cache_im(
locally_labeled = cdir.cache_im(
lambda: da.map_blocks(compute_slices, locally_labeled, cumsum_da_arr,
meta=np.zeros(tuple(), dtype=output_dtype)),
cid='locally_labeled_with_cumsum'
Expand Down Expand Up @@ -233,7 +238,7 @@ def compute_slices(block: npt.NDArray, block2: npt.NDArray, block_info: dict = N
def local_to_global(block, block_info, ind_map_scatter):
return ind_map_scatter[block]

globally_labeled = cache_dir.cache_im(
globally_labeled = cdir.cache_im(
lambda: locally_labeled.map_blocks(func=local_to_global, meta=np.zeros(tuple(), dtype=output_dtype),
ind_map_scatter=ind_map_scatter),
cid='globally_labeled'
Expand All @@ -246,6 +251,11 @@ def local_to_global(block, block_info, ind_map_scatter):
if is_logging:
print('Function ends')

im = cdir.cache_im(lambda: result_arr,
cid='global_os',
cache_level=1,
viewer_args=viewer_args | dict(is_label=True))

server.close() # TODO: find way to move this to where it should be

return result_arr, comp_i
Loading

0 comments on commit d8e86d0

Please sign in to comment.