Skip to content

Commit

Permalink
0.5.7
Browse files Browse the repository at this point in the history
  • Loading branch information
Karl5766 committed Sep 18, 2024
1 parent b5c0797 commit 2dbb676
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 69 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cvpl-tools"
version = "0.5.6"
version = "0.5.7"
description = "A Python package for utilities and classes related to the file I/O, dataset record keeping and visualization for image processing and computer vision."
authors = ["Karl5766 <[email protected]>"]
license = "MIT"
Expand Down
142 changes: 93 additions & 49 deletions src/cvpl_tools/im/dask_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
version of the label() function of scipy.ndimage
"""


import dask.array as da
import numpy as np
import numpy.typing as npt
Expand All @@ -16,6 +15,7 @@
import pickle
import cvpl_tools.im.algorithms as cvpl_algorithms
from dask.distributed import print as dprint
from collections import defaultdict


class PairKVStore(SQLiteKVStore):
Expand Down Expand Up @@ -45,6 +45,32 @@ def read_all(self):
yield self.cursor.fetchone()


def find_connected_components(edges: set[tuple[int, int]]) -> list[set[int, ...], ...]:
graph = defaultdict(set)

for u, v in edges:
graph[u].add(v)
graph[v].add(u)

visited = set()
components = []

def dfs(node, component):
visited.add(node)
component.add(node)
for neighbor in graph[node]:
if neighbor not in visited:
dfs(neighbor, component)

for node in graph:
if node not in visited:
component = set()
dfs(node, component)
components.append(component)

return components


def label(im: npt.NDArray | da.Array | NDBlock, cache_dir: CacheDirectory, output_dtype: np.dtype = None,
viewer_args: dict = None
) -> npt.NDArray | da.Array | NDBlock:
Expand All @@ -58,8 +84,9 @@ def label(im: npt.NDArray | da.Array | NDBlock, cache_dir: CacheDirectory, outpu
if isinstance(im, np.ndarray):
return scipy_label(im, output=output_dtype)
is_dask = isinstance(im, da.Array)
if is_dask:
im = NDBlock(im)
if not is_dask:
assert isinstance(im, NDBlock)
im = im.as_dask_array(tmp_dirpath=cache_dir.path)

def map_block(block: npt.NDArray, block_info: dict):
lbl_im = scipy_label(block, output=output_dtype)[0]
Expand All @@ -72,35 +99,45 @@ def to_max(block: npt.NDArray, block_info: dict):
if is_logging:
print('Locally label the image')
locally_labeled = cache_dir.cache_im(
lambda: NDBlock.map_ndblocks([im], fn=map_block, out_dtype=output_dtype)
lambda: im.map_blocks(map_block, meta=np.zeros(tuple(), dtype=output_dtype)),
cid='locally_labeled_without_cumsum'
)
if is_logging:
print('Taking the max of each chunk to obtain number of labels')
new_slices = list(tuple(slice(0, 1) for _ in range(ndim)) for _ in locally_labeled.get_slices_list())
nlbl_ndblock_arr = NDBlock.map_ndblocks([locally_labeled], fn=to_max, out_dtype=output_dtype,
new_slices=new_slices)
if is_logging:
print('Convert number of labels of chunks to numpy array')
print(nlbl_ndblock_arr.get_repr_format())
print(nlbl_ndblock_arr.is_numpy())
print(type(nlbl_ndblock_arr.arr))
nlbl_np_arr = nlbl_ndblock_arr.as_numpy()
if is_logging:
print('Compute prefix sum and reshape back')
cumsum_np_arr = np.cumsum(nlbl_np_arr)

def compute_nlbl_np_arr():
if is_logging:
print('Taking the max of each chunk to obtain number of labels')
locally_labeled_ndblock = NDBlock(locally_labeled)
new_slices = list(tuple(slice(0, 1) for _ in range(ndim))
for _ in NDBlock(locally_labeled_ndblock).get_slices_list())
nlbl_ndblock_arr = NDBlock.map_ndblocks([locally_labeled_ndblock], fn=to_max, out_dtype=output_dtype,
new_slices=new_slices)
if is_logging:
print('Convert number of labels of chunks to numpy array')
nlbl_np_arr = nlbl_ndblock_arr.as_numpy()
return nlbl_np_arr

nlbl_np_arr = cache_dir.cache_im(fn=compute_nlbl_np_arr, cid='nlbl_np_arr')

def compute_cumsum_np_arr():
if is_logging:
print('Compute prefix sum and reshape back')
cumsum_np_arr = np.cumsum(nlbl_np_arr)
return cumsum_np_arr

cumsum_np_arr = cache_dir.cache_im(fn=compute_cumsum_np_arr, cid='cumsum_np_arr')
assert cumsum_np_arr.ndim == 1
total_nlbl = cumsum_np_arr[-1].item()
cumsum_np_arr[1:] = cumsum_np_arr[:-1]
cumsum_np_arr[0] = 0
cumsum_np_arr = cumsum_np_arr.reshape(nlbl_np_arr.shape)
if is_logging:
print(f'total_nlbl={total_nlbl}, Convert prefix sum to a dask array then to NDBlock')
cumsum_ndblock_arr = NDBlock(da.from_array(cumsum_np_arr, chunks=(1,) * cumsum_np_arr.ndim))
cumsum_da_arr = da.from_array(cumsum_np_arr, chunks=(1,) * cumsum_np_arr.ndim)

# Prepare cache file to be used
if is_logging:
print('Setting up cache sqlite database')
_, cache_file = cache_dir.cache()
_, cache_file = cache_dir.cache(cid='border_slices')
os.makedirs(cache_file.path, exist_ok=True)
db_path = f'{cache_file.path}/border_slices.db'

Expand All @@ -121,31 +158,30 @@ def get_sqlite_partd():
if is_logging:
print('Computing edge slices, writing to database')

def compute_slices(block1: npt.NDArray, block2: npt.NDArray, block_info: dict = None):
# block1 is the local label, block2 is the single element prefix summed number of labels
def compute_slices(block: npt.NDArray, block2: npt.NDArray, block_info: dict = None):
# block is the local label, block2 is the single element prefix summed number of labels

client = partd.Client(server_address)
block_index = list(block_info[0]['chunk-location'])
block1 = block1 + (block1 != 0).astype(block1.dtype) * block2
for ax in range(block1.ndim):
block = block + (block != 0).astype(block.dtype) * block2
for ax in range(block.ndim):
for face in range(2):
block_index[ax] += face
indstr = '_'.join(str(index) for index in block_index) + f'_{ax}'
sli_idx = face * (block1.shape[ax] - 1)
sli = np.take(block1, indices=sli_idx, axis=ax)
sli_idx = face * (block.shape[ax] - 1)
sli = np.take(block, indices=sli_idx, axis=ax)
client.append({
indstr: pickle.dumps(sli)
})
block_index[ax] -= face
client.close()
return block1
return block

locally_labeled = cache_dir.cache_im(
lambda: NDBlock.map_ndblocks([locally_labeled, cumsum_ndblock_arr],
compute_slices,
out_dtype=output_dtype,
use_input_index_as_arrloc=0)
lambda: da.map_blocks(compute_slices, locally_labeled, cumsum_da_arr,
meta=np.zeros(tuple(), dtype=output_dtype)),
cid='locally_labeled_with_cumsum'
)
server.close()

if is_logging:
print('Process locally to obtain a lower triangular adjacency matrix')
Expand All @@ -169,39 +205,47 @@ def compute_slices(block1: npt.NDArray, block2: npt.NDArray, block_info: dict =
tup = (i2, i1)
if tup not in lower_adj:
lower_adj.add(tup)
ind_map = {i: i for i in range(1, total_nlbl + 1)}
for i2, i1 in lower_adj:
if ind_map[i2] > i1:
ind_map[i2] = i1
connected_components = find_connected_components(lower_adj)
if is_logging:
print('Compute final indices remap array')
ind_map_np = np.zeros((total_nlbl + 1,), dtype=output_dtype)
for i in range(1, total_nlbl + 1):
direct_connected = ind_map[i]
ind_map_np[i] = direct_connected
ind_map_np[i] = ind_map_np[direct_connected]
ind_map_np = ind_map_np[ind_map_np]
ind_map_np = np.arange(total_nlbl + 1, dtype=output_dtype)
assigned_mask = np.zeros((total_nlbl + 1), dtype=np.uint8)
assigned_mask[0] = 1 # we don't touch background class
comp_i = 0
while comp_i < len(connected_components):
comp = connected_components[comp_i]
comp_i += 1
for j in comp:
ind_map_np[j] = comp_i
assigned_mask[j] = 1
for i in range(assigned_mask.shape[0]):
if assigned_mask[i] == 0:
comp_i += 1
ind_map_np[i] = comp_i

read_kv_store.close()

if is_logging:
print('Remapping the indices array to be globally consistent')
print(f'comp_i={comp_i}, Remapping the indices array to be globally consistent')
client = viewer_args['client']
ind_map_scatter = client.scatter(ind_map_np, broadcast=True)

def local_to_global(block, block_info, ind_map_scatter):
return ind_map_scatter[block]

globally_labeled = cache_dir.cache_im(
fn=lambda: NDBlock.map_ndblocks([locally_labeled],
fn=local_to_global,
out_dtype=output_dtype,
fn_args=dict(ind_map_scatter=ind_map_scatter)),
lambda: locally_labeled.map_blocks(func=local_to_global, meta=np.zeros(tuple(), dtype=output_dtype),
ind_map_scatter=ind_map_scatter),
cid='globally_labeled'
)
result_arr = globally_labeled.as_dask_array(tmp_dirpath=f'{cache_file.path}/to_dask_array')
result_arr = globally_labeled
if not is_dask:
if is_logging:
print('converting the result to NDBlock')
result_arr = NDBlock(result_arr)
if is_logging:
print('Function ends')
return result_arr, total_nlbl

server.close() # TODO: find way to move this to where it should be

return result_arr, comp_i
3 changes: 0 additions & 3 deletions src/cvpl_tools/im/ndblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,12 @@ def save(file: str, ndblock: NDBlock, downsample_level: int = 0):
@dask.delayed
def save_block(block_index, block):
block_id = ('_'.join(str(i) for i in block_index)).encode('utf-8')
dprint(f'saving_____{block_id}')
store = partd.Client(server_address)
store.append({
block_id: pickle.dumps(block)
})
store.close()

for block_index, (block, _) in ndblock.arr.items():
print('TO BE SAVED', block_index)
tasks = [save_block(block_index, block) for block_index, (block, _) in ndblock.arr.items()]
dask.compute(*tasks)
import time
Expand Down
9 changes: 0 additions & 9 deletions src/cvpl_tools/im/partd_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

class SQLiteKVStore:
def __init__(self, db_path: str):
dprint(f'OPENING A CONNECTION AT PATH {db_path}')
self.is_exists = os.path.exists(db_path) # may not be accurate
self.conn = sqlite3.connect(db_path)
self.cursor = self.conn.cursor()
Expand All @@ -42,33 +41,25 @@ def init_db(self):
"""

def write_rows(self, tups):
dprint(f'WRITING TUPLES AT {tuple(tup[0] for tup in tups)}')
self.cursor.executemany(self.write_row_stmt, tups)

def write_row(self, tup):
dprint(f'WRITING TUPLE AT {tup[0]}')
self.cursor.execute(self.write_row_stmt, tup)

def read_rows(self, ids):
dprint(f'READING ROWS AT {ids}')
return [self.read_row(id) for id in ids]

def read_row(self, id):
dprint(f'READING ROW AT {id}')
self.cursor.execute(self.read_row_stmt, (id,))
result = self.cursor.fetchone()

if result is None:
dprint('READING ROW FAILED AT', id, self.read_row_stmt)
result = result[0]
return result

def commit(self):
dprint('COMMITING')
self.conn.commit()

def close(self):
dprint('CLOSING')
self.conn.close()


Expand Down
32 changes: 25 additions & 7 deletions src/cvpl_tools/ome_zarr/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy.typing as npt
from ome_zarr.io import parse_url
import urllib.parse
import numcodecs


# ----------------------------------Part 1: utilities---------------------------------------
Expand Down Expand Up @@ -100,7 +101,9 @@ def write_ome_zarr_image_direct(zarr_group: zarr.Group,
da_arr: da.Array | None = None,
lbl_arr: da.Array | None = None,
lbl_name: str | None = None,
MAX_LAYER: int = 3):
MAX_LAYER: int = 3,
storage_options: dict = None,
lbl_storage_options: dict = None):
"""Direct write of dask array to target ome zarr group (can not be a zip)
Args:
Expand All @@ -109,7 +112,15 @@ def write_ome_zarr_image_direct(zarr_group: zarr.Group,
lbl_arr: If provided, this is the array to write at zarr_group['labels'][lbl_name]
lbl_name: name of the label array subgroup
MAX_LAYER: The maximum layer of down sampling; starting at layer=0
storage_options: options for storing the image
lbl_storage_options: options for storing the labels
"""
if storage_options is None:
compressor = numcodecs.Blosc(cname='lz4', clevel=9, shuffle=numcodecs.Blosc.BITSHUFFLE)
storage_options = dict(
dimension_separator='/',
compressor=compressor
)
if da_arr is not None:
# assert the group is empty, since we are writing a new group
for mem in zarr_group:
Expand All @@ -122,7 +133,7 @@ def write_ome_zarr_image_direct(zarr_group: zarr.Group,
group=zarr_group,
scaler=scaler,
coordinate_transformations=_get_coord_transform_yx_for_write(da_arr.ndim, MAX_LAYER),
storage_options={'dimension_separator': '/'},
storage_options=storage_options,
axes=_get_axes_for_write(da_arr.ndim))

if lbl_arr is not None:
Expand All @@ -133,14 +144,15 @@ def write_ome_zarr_image_direct(zarr_group: zarr.Group,
# type axes. So we need to fall back to manual definition, avoid 'c' which defaults to a channel type
lbl_axes = [{'name': ch, 'type': 'space'} for ch in _get_axes_for_write(lbl_arr.ndim)]

import numcodecs
compressor = numcodecs.Blosc(cname='lz4', clevel=9, shuffle=numcodecs.Blosc.BITSHUFFLE)
if lbl_storage_options is None:
compressor = numcodecs.Blosc(cname='lz4', clevel=9, shuffle=numcodecs.Blosc.BITSHUFFLE)
lbl_storage_options = dict(compressor=compressor)
writer.write_labels(labels=lbl_arr,
group=zarr_group,
scaler=scaler,
name=lbl_name,
coordinate_transformations=_get_coord_transform_yx_for_write(lbl_arr.ndim, MAX_LAYER),
storage_options=dict(compressor=compressor),
storage_options=lbl_storage_options,
axes=lbl_axes)
# ome_zarr.writer.write_label_metadata(group=g,
# name=f'/labels/{lbl_name}',
Expand All @@ -154,7 +166,9 @@ def write_ome_zarr_image(ome_zarr_path: str,
lbl_name: str | None = None,
make_zip: bool | None = None,
MAX_LAYER: int = 0,
logging=False):
logging=False,
storage_options: dict = None,
lbl_storage_options: dict = None):
"""Write dask array as an ome zarr
For writing to zip file: due to dask does not directly support write to zip file, we instead create a temp ome zarr
Expand All @@ -169,6 +183,8 @@ def write_ome_zarr_image(ome_zarr_path: str,
make_zip: bool, if True the output is a zip; if False a folder; if None, then determine based on file suffix
MAX_LAYER: The maximum layer of down sampling; starting at layer=0
logging: If true, print message when job starts and ends
storage_options: options for storing the image
lbl_storage_options: options for storing the labels
"""
if tmp_path is not None:
os.makedirs(tmp_path, exist_ok=True)
Expand All @@ -187,7 +203,9 @@ def write_ome_zarr_image(ome_zarr_path: str,

store = ome_zarr.io.parse_url(folder_ome_zarr_path, mode='w').store
g = zarr.group(store)
write_ome_zarr_image_direct(g, da_arr, lbl_arr, lbl_name, MAX_LAYER=MAX_LAYER)
write_ome_zarr_image_direct(g, da_arr, lbl_arr, lbl_name, MAX_LAYER=MAX_LAYER,
storage_options=storage_options,
lbl_storage_options=lbl_storage_options)
if logging:
print('Folder is written.')
store.close()
Expand Down
Loading

1 comment on commit 2dbb676

@Karl5766
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of AD-HOC solutions. Some of them are temporary patches that needs to be fixed later.

Please sign in to comment.