diff --git a/src/ome2024_ngff_challenge/resave.py b/src/ome2024_ngff_challenge/resave.py index 4ef116f..e4a8fe1 100755 --- a/src/ome2024_ngff_challenge/resave.py +++ b/src/ome2024_ngff_challenge/resave.py @@ -2,6 +2,7 @@ from __future__ import annotations import argparse +import itertools import json import logging import math @@ -29,6 +30,39 @@ # +class Batched: + """ + implementation of itertools.batched for pre-3.12 Python versions + from https://mathspp.com/blog/itertools-batched + """ + + def __init__(self, iterable, n: int): + if n < 1: + msg = f"n must be at least one ({n})" + raise ValueError(msg) + self.iter = iter(iterable) + self.n = n + + def __iter__(self): + return self + + def __next__(self): + batch = tuple(itertools.islice(self.iter, self.n)) + if not batch: + raise StopIteration() + return batch + + +class SafeEncoder(json.JSONEncoder): + # Handle any TypeErrors so we are safe to use this for logging + # E.g. dtype obj is not JSON serializable + def default(self, o): + try: + return super().default(o) + except TypeError: + return str(o) + + def guess_shards(shape: list, chunks: list): """ Method to calculate best shard sizes. These values can be written to @@ -36,10 +70,30 @@ def guess_shards(shape: list, chunks: list): ./resave.py input.zarr output.json --output-write-details """ - # TODO: hard-coded to return the full size unless too large - if math.prod(shape) < 100_000_000: - return shape - raise ValueError(f"no shard guess: shape={shape}, chunks={chunks}") + # TODO: hard-coded to return the full size + assert chunks is not None # fixes unused parameter + return shape + + +def chunk_iter(shape: list, chunks: list): + """ + Returns a series of tuples, each containing chunk slice + E.g. for 2D shape/chunks: ((slice(0, 512, 1), slice(0, 512, 1)), (slice(0, 512, 1), slice(512, 1024, 1))...) + Thanks to Davis Bennett. + """ + assert len(shape) == len(chunks) + chunk_iters = [] + for chunk_size, dim_size in zip(chunks, shape): + chunk_tuple = tuple( + slice( + c_index * chunk_size, + min(dim_size, c_index * chunk_size + chunk_size), + 1, + ) + for c_index in range(-(-dim_size // chunk_size)) + ) + chunk_iters.append(chunk_tuple) + return tuple(itertools.product(*chunk_iters)) def csv_int(vstr, sep=",") -> list: @@ -53,7 +107,7 @@ def csv_int(vstr, sep=",") -> list: values.append(v) except ValueError as ve: raise argparse.ArgumentError( - message="Invalid value %s, values must be a number" % v0 + message=f"Invalid value {v0}, values must be a number" ) from ve return values @@ -237,7 +291,9 @@ def check_or_delete_path(self): else: shutil.rmtree(self.path) else: - raise Exception(f"{self.path} exists. Exiting") + raise Exception( + f"{self.path} exists. Use --output-overwrite to overwrite" + ) def open_group(self): # Needs zarr_format=2 or we get ValueError("store mode does not support writing") @@ -291,6 +347,7 @@ def convert_array( dimension_names: list, chunks: list, shards: list, + threads: int, ): read = input_config.ts_read() @@ -340,13 +397,44 @@ def convert_array( write_config["create"] = True write_config["delete_existing"] = output_config.overwrite + LOGGER.log( + 5, + f"""input_config: +{json.dumps(input_config.ts_config, indent=4)} + """, + ) + LOGGER.log( + 5, + f"""write_config: +{json.dumps(write_config, indent=4, cls=SafeEncoder)} + """, + ) + verify_config = base_config.copy() write = ts.open(write_config).result() before = TSMetrics(input_config.ts_config, write_config) - future = write.write(read) - future.result() + + # read & write a chunk (or shard) at a time: + blocks = shards if shards is not None else chunks + for idx, batch in enumerate(Batched(chunk_iter(read.shape, blocks), threads)): + start = time.time() + with ts.Transaction() as txn: + LOGGER.log(5, f"batch {idx:03d}: scheduling transaction size={len(batch)}") + for slice_tuple in batch: + write.with_transaction(txn)[slice_tuple] = read[slice_tuple] + LOGGER.log( + 5, f"batch {idx:03d}: {slice_tuple} scheduled in transaction" + ) + LOGGER.log(5, f"batch {idx:03d}: waiting on transaction size={len(batch)}") + stop = time.time() + elapsed = stop - start + avg = float(elapsed) / len(batch) + LOGGER.debug( + f"batch {idx:03d}: completed transaction size={len(batch)} in {stop-start:0.2f}s (avg={avg:0.2f})" + ) + after = TSMetrics(input_config.ts_config, write_config, before) LOGGER.info(f"""Re-encode (tensorstore) {input_config} to {output_config} @@ -374,6 +462,7 @@ def convert_image( output_read_details: str | None, output_write_details: bool, output_script: bool, + threads: int, ): dimension_names = None # top-level version... @@ -417,13 +506,21 @@ def convert_image( with output_config.path.open(mode="w") as o: json.dump(details, o) else: - if output_chunks: - ds_chunks = output_chunks - ds_shards = output_shards - elif output_read_details: + if output_read_details: # read row by row and overwrite ds_chunks = details[idx]["chunks"] ds_shards = details[idx]["shards"] + else: + if output_chunks: + ds_chunks = output_chunks + if output_shards: + ds_shards = output_shards + elif not output_script and math.prod(ds_shards) > 100_000_000: + # if we're going to convert, and we guessed the shards, + # let's validate the guess... + raise ValueError( + f"no shard guess: shape={ds_shape}, chunks={ds_chunks}" + ) if output_script: chunk_txt = ",".join(map(str, ds_chunks)) @@ -440,6 +537,7 @@ def convert_image( dimension_names, ds_chunks, ds_shards, + threads, ) @@ -549,6 +647,7 @@ def main(ns: argparse.Namespace, rocrate: ROCrateWriter | None = None) -> int: ns.output_read_details, ns.output_write_details, ns.output_script, + ns.output_threads, ) converted += 1 @@ -602,6 +701,7 @@ def main(ns: argparse.Namespace, rocrate: ROCrateWriter | None = None) -> int: ns.output_read_details, ns.output_write_details, ns.output_script, + ns.output_threads, ) converted += 1 # Note: plates can *also* contain this metadata @@ -644,6 +744,7 @@ def main(ns: argparse.Namespace, rocrate: ROCrateWriter | None = None) -> int: ns.output_read_details, ns.output_write_details, ns.output_script, + ns.output_threads, ) converted += 1 else: @@ -669,12 +770,21 @@ def cli(args=sys.argv[1:]): parser.add_argument("--output-region", default="us-east-1") parser.add_argument("--output-overwrite", action="store_true") parser.add_argument("--output-script", action="store_true") + parser.add_argument( + "--output-threads", + type=int, + default=16, + help="number of simultaneous write threads", + ) parser.add_argument("--rocrate-name", type=str) parser.add_argument("--rocrate-description", type=str) parser.add_argument("--rocrate-license", type=str) parser.add_argument("--rocrate-organism", type=str) parser.add_argument("--rocrate-modality", type=str) parser.add_argument("--rocrate-skip", action="store_true") + parser.add_argument( + "--log", default="warn", help="'error', 'warn', 'info', 'debug' or 'trace'" + ) group_ex = parser.add_mutually_exclusive_group() group_ex.add_argument( "--output-write-details", @@ -698,7 +808,18 @@ def cli(args=sys.argv[1:]): parser.add_argument("output_path", type=Path) ns = parser.parse_args(args) - logging.basicConfig() + # configure logging + if ns.log.upper() == "TRACE": + numeric_level = 5 + else: + numeric_level = getattr(logging, ns.log.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f"Invalid log level: {ns.log}. Use 'info' or 'debug'") + logging.basicConfig( + level=numeric_level, + format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) rocrate = None if not ns.rocrate_skip: