Skip to content

Commit

Permalink
fix per_band tests to use larger source images and not use tracemallo…
Browse files Browse the repository at this point in the history
…c.reset_peak()
  • Loading branch information
dugalh committed Jun 25, 2024
1 parent e83958f commit 1c06df6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 29 deletions.
53 changes: 39 additions & 14 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,34 +533,59 @@ def test_frame_dem_interp_error(frame_legacy_ngi_cli_str: str, tmp_path: Path, r
assert '--dem-interp' in result.stdout and 'invalid' in result.stdout.lower()


def test_frame_per_band(frame_legacy_ngi_cli_str: str, tmp_path: Path, runner: CliRunner):
def test_frame_per_band(
ngi_image_file: Path,
ngi_dem_file: Path,
ngi_legacy_config_file: Path,
ngi_legacy_csv_file: Path,
tmp_path: Path,
runner: CliRunner,
):
"""Test ``oty frame --per-band`` by comparing memory usage between ``--per-band`` and ``--no-
per-band``.
"""
# make a temporary 4 band float64 source image from ngi_image_file (for --per-band to make
# a measurable memory difference, the source image needs to be relatively large, have many
# bands and/or a 'big' dtype.)
src_file = tmp_path.joinpath(ngi_image_file.name)
with rio.open(ngi_image_file, 'r') as ngi_im:
array = ngi_im.read(out_dtype='float32')
array = np.stack((*array, array[0]), axis=0)
profile = ngi_im.profile
profile.update(
count=array.shape[0], dtype=array.dtype, compress='deflate', photometric='minisblack'
)
with rio.open(src_file, 'w', **profile) as src_im:
src_im.write(array)

# compare memory usage between --no-per-band and --per-band
cli_str = (
f'frame --dem {ngi_dem_file} --int-param {ngi_legacy_config_file} '
f'--ext-param {ngi_legacy_csv_file} --res 30 --compress deflate {src_file}'
)
mem_peaks = []
tracemalloc.start()
try:
for per_band in ['per-band', 'no-per-band']:
# create ortho
tracemalloc.start()
for per_band in ['no-per-band', 'per-band']:
out_dir = tmp_path.joinpath(per_band)
out_dir.mkdir()
cli_str = (
frame_legacy_ngi_cli_str
+ f' --out-dir {out_dir} --res 30 --compress deflate --{per_band}'
)
cli_str = cli_str + f' --out-dir {out_dir} --{per_band}'

# find peak memory used by the command
mem_start = tracemalloc.get_traced_memory()
result = runner.invoke(cli, cli_str.split())
mem_end = tracemalloc.get_traced_memory()
mem_peaks.append(mem_end[1] - mem_start[0])
tracemalloc.clear_traces() # clears the peak

assert result.exit_code == 0, result.stdout
ortho_files = [*out_dir.glob('*_ORTHO.tif')]
assert len(ortho_files) == 1

_, mem_peak = tracemalloc.get_traced_memory()
tracemalloc.reset_peak()
mem_peaks.append(mem_peak)

assert mem_peaks[1] > mem_peaks[0]
finally:
tracemalloc.stop()

assert mem_peaks[1] < mem_peaks[0]


def test_frame_full_remap(
odm_image_file: Path,
Expand Down
45 changes: 30 additions & 15 deletions tests/test_ortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,39 +730,54 @@ def test_process_dem_interp(rgb_pinhole_utm34n_ortho: Ortho, dem_interp: Interp,
assert cc[0, 1] != 1.0


def test_process_per_band(rgb_pinhole_utm34n_ortho: Ortho, tmp_path: Path):
def test_process_per_band(
ms_float_src_file: Path,
float_utm34n_dem_file: Path,
frame_args: dict,
utm34n_crs: str,
tmp_path: Path,
):
"""Test ortho equivalence for ``per_band=True/False`` and that ``per_band=True`` uses less
memory than ``per_band=False``."""
resolution = (5, 5)
# Note: Allocated memory depends on thread timing and is noisy. For per_band to make
# measurable memory differences, the source image needs to be relatively large, have many bands
# and/or a 'big' dtype. Also, bear in mind that tracemalloc does not track GDAL allocations.

# create a camera for ms_float_src_file
cam_args = dict(**frame_args)
with rio.open(ms_float_src_file) as src_im:
cam_args.update(im_size=src_im.shape[::-1])
camera = PinholeCamera(**cam_args)

# create orthos with per_band=True/False, tracking memory usage
resolution = (5, 5)
ortho_files = [tmp_path.joinpath('ref_ortho.tif'), tmp_path.joinpath('test_ortho.tif')]
per_bands = [True, False]
mem_peaks = []

tracemalloc.start()
peak_mems = []
try:
tracemalloc.start()
for ortho_file, per_band in zip(ortho_files, per_bands):
rgb_pinhole_utm34n_ortho.process(
ortho_file, resolution, per_band=per_band, compress=Compress.deflate
)
_, mem_peak = tracemalloc.get_traced_memory()
tracemalloc.reset_peak()
mem_peaks.append(mem_peak)
start_mem = tracemalloc.get_traced_memory()
ortho = Ortho(ms_float_src_file, float_utm34n_dem_file, camera, utm34n_crs)
ortho.process(ortho_file, resolution, per_band=per_band, compress=Compress.deflate)
end_mem = tracemalloc.get_traced_memory()
peak_mems.append(end_mem[1] - start_mem[0])
tracemalloc.clear_traces() # clears the peak
del ortho
finally:
tracemalloc.stop()

# compare memory usage
assert mem_peaks[1] > mem_peaks[0]
assert peak_mems[1] > peak_mems[0]

# compare ref and test orthos
# compare pre_band=True/False orthos
assert ortho_files[0].exists() and ortho_files[1].exists()
with rio.open(ortho_files[0], 'r') as ref_im, rio.open(ortho_files[1], 'r') as test_im:
ref_array = ref_im.read()
test_array = test_im.read()

assert test_array.shape[0] == 3
assert test_array.shape == ref_array.shape
assert np.all(test_array == ref_array)
assert np.all(nan_equals(test_array, ref_array))


@pytest.mark.parametrize(
Expand Down

0 comments on commit 1c06df6

Please sign in to comment.