Skip to content

Commit

Permalink
Merge pull request #23 from will-moore/shard_guess_validation
Browse files Browse the repository at this point in the history
Don't check shard guess if output_script
  • Loading branch information
joshmoore authored Aug 20, 2024
2 parents e472445 + 3e175be commit d066ef1
Showing 1 changed file with 134 additions and 13 deletions.
147 changes: 134 additions & 13 deletions src/ome2024_ngff_challenge/resave.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import argparse
import itertools
import json
import logging
import math
Expand Down Expand Up @@ -29,17 +30,70 @@
#


class Batched:
"""
implementation of itertools.batched for pre-3.12 Python versions
from https://mathspp.com/blog/itertools-batched
"""

def __init__(self, iterable, n: int):
if n < 1:
msg = f"n must be at least one ({n})"
raise ValueError(msg)
self.iter = iter(iterable)
self.n = n

def __iter__(self):
return self

def __next__(self):
batch = tuple(itertools.islice(self.iter, self.n))
if not batch:
raise StopIteration()
return batch


class SafeEncoder(json.JSONEncoder):
# Handle any TypeErrors so we are safe to use this for logging
# E.g. dtype obj is not JSON serializable
def default(self, o):
try:
return super().default(o)
except TypeError:
return str(o)


def guess_shards(shape: list, chunks: list):
"""
Method to calculate best shard sizes. These values can be written to
a file for the current dataset by using:
./resave.py input.zarr output.json --output-write-details
"""
# TODO: hard-coded to return the full size unless too large
if math.prod(shape) < 100_000_000:
return shape
raise ValueError(f"no shard guess: shape={shape}, chunks={chunks}")
# TODO: hard-coded to return the full size
assert chunks is not None # fixes unused parameter
return shape


def chunk_iter(shape: list, chunks: list):
"""
Returns a series of tuples, each containing chunk slice
E.g. for 2D shape/chunks: ((slice(0, 512, 1), slice(0, 512, 1)), (slice(0, 512, 1), slice(512, 1024, 1))...)
Thanks to Davis Bennett.
"""
assert len(shape) == len(chunks)
chunk_iters = []
for chunk_size, dim_size in zip(chunks, shape):
chunk_tuple = tuple(
slice(
c_index * chunk_size,
min(dim_size, c_index * chunk_size + chunk_size),
1,
)
for c_index in range(-(-dim_size // chunk_size))
)
chunk_iters.append(chunk_tuple)
return tuple(itertools.product(*chunk_iters))


def csv_int(vstr, sep=",") -> list:
Expand All @@ -53,7 +107,7 @@ def csv_int(vstr, sep=",") -> list:
values.append(v)
except ValueError as ve:
raise argparse.ArgumentError(
message="Invalid value %s, values must be a number" % v0
message=f"Invalid value {v0}, values must be a number"
) from ve
return values

Expand Down Expand Up @@ -237,7 +291,9 @@ def check_or_delete_path(self):
else:
shutil.rmtree(self.path)
else:
raise Exception(f"{self.path} exists. Exiting")
raise Exception(
f"{self.path} exists. Use --output-overwrite to overwrite"
)

def open_group(self):
# Needs zarr_format=2 or we get ValueError("store mode does not support writing")
Expand Down Expand Up @@ -291,6 +347,7 @@ def convert_array(
dimension_names: list,
chunks: list,
shards: list,
threads: int,
):
read = input_config.ts_read()

Expand Down Expand Up @@ -340,13 +397,44 @@ def convert_array(
write_config["create"] = True
write_config["delete_existing"] = output_config.overwrite

LOGGER.log(
5,
f"""input_config:
{json.dumps(input_config.ts_config, indent=4)}
""",
)
LOGGER.log(
5,
f"""write_config:
{json.dumps(write_config, indent=4, cls=SafeEncoder)}
""",
)

verify_config = base_config.copy()

write = ts.open(write_config).result()

before = TSMetrics(input_config.ts_config, write_config)
future = write.write(read)
future.result()

# read & write a chunk (or shard) at a time:
blocks = shards if shards is not None else chunks
for idx, batch in enumerate(Batched(chunk_iter(read.shape, blocks), threads)):
start = time.time()
with ts.Transaction() as txn:
LOGGER.log(5, f"batch {idx:03d}: scheduling transaction size={len(batch)}")
for slice_tuple in batch:
write.with_transaction(txn)[slice_tuple] = read[slice_tuple]
LOGGER.log(
5, f"batch {idx:03d}: {slice_tuple} scheduled in transaction"
)
LOGGER.log(5, f"batch {idx:03d}: waiting on transaction size={len(batch)}")
stop = time.time()
elapsed = stop - start
avg = float(elapsed) / len(batch)
LOGGER.debug(
f"batch {idx:03d}: completed transaction size={len(batch)} in {stop-start:0.2f}s (avg={avg:0.2f})"
)

after = TSMetrics(input_config.ts_config, write_config, before)

LOGGER.info(f"""Re-encode (tensorstore) {input_config} to {output_config}
Expand Down Expand Up @@ -374,6 +462,7 @@ def convert_image(
output_read_details: str | None,
output_write_details: bool,
output_script: bool,
threads: int,
):
dimension_names = None
# top-level version...
Expand Down Expand Up @@ -417,13 +506,21 @@ def convert_image(
with output_config.path.open(mode="w") as o:
json.dump(details, o)
else:
if output_chunks:
ds_chunks = output_chunks
ds_shards = output_shards
elif output_read_details:
if output_read_details:
# read row by row and overwrite
ds_chunks = details[idx]["chunks"]
ds_shards = details[idx]["shards"]
else:
if output_chunks:
ds_chunks = output_chunks
if output_shards:
ds_shards = output_shards
elif not output_script and math.prod(ds_shards) > 100_000_000:
# if we're going to convert, and we guessed the shards,
# let's validate the guess...
raise ValueError(
f"no shard guess: shape={ds_shape}, chunks={ds_chunks}"
)

if output_script:
chunk_txt = ",".join(map(str, ds_chunks))
Expand All @@ -440,6 +537,7 @@ def convert_image(
dimension_names,
ds_chunks,
ds_shards,
threads,
)


Expand Down Expand Up @@ -549,6 +647,7 @@ def main(ns: argparse.Namespace, rocrate: ROCrateWriter | None = None) -> int:
ns.output_read_details,
ns.output_write_details,
ns.output_script,
ns.output_threads,
)
converted += 1

Expand Down Expand Up @@ -602,6 +701,7 @@ def main(ns: argparse.Namespace, rocrate: ROCrateWriter | None = None) -> int:
ns.output_read_details,
ns.output_write_details,
ns.output_script,
ns.output_threads,
)
converted += 1
# Note: plates can *also* contain this metadata
Expand Down Expand Up @@ -644,6 +744,7 @@ def main(ns: argparse.Namespace, rocrate: ROCrateWriter | None = None) -> int:
ns.output_read_details,
ns.output_write_details,
ns.output_script,
ns.output_threads,
)
converted += 1
else:
Expand All @@ -669,12 +770,21 @@ def cli(args=sys.argv[1:]):
parser.add_argument("--output-region", default="us-east-1")
parser.add_argument("--output-overwrite", action="store_true")
parser.add_argument("--output-script", action="store_true")
parser.add_argument(
"--output-threads",
type=int,
default=16,
help="number of simultaneous write threads",
)
parser.add_argument("--rocrate-name", type=str)
parser.add_argument("--rocrate-description", type=str)
parser.add_argument("--rocrate-license", type=str)
parser.add_argument("--rocrate-organism", type=str)
parser.add_argument("--rocrate-modality", type=str)
parser.add_argument("--rocrate-skip", action="store_true")
parser.add_argument(
"--log", default="warn", help="'error', 'warn', 'info', 'debug' or 'trace'"
)
group_ex = parser.add_mutually_exclusive_group()
group_ex.add_argument(
"--output-write-details",
Expand All @@ -698,7 +808,18 @@ def cli(args=sys.argv[1:]):
parser.add_argument("output_path", type=Path)
ns = parser.parse_args(args)

logging.basicConfig()
# configure logging
if ns.log.upper() == "TRACE":
numeric_level = 5
else:
numeric_level = getattr(logging, ns.log.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {ns.log}. Use 'info' or 'debug'")
logging.basicConfig(
level=numeric_level,
format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

rocrate = None
if not ns.rocrate_skip:
Expand Down

0 comments on commit d066ef1

Please sign in to comment.