diff --git a/README.md b/README.md index d6834656..6f015807 100644 --- a/README.md +++ b/README.md @@ -159,9 +159,9 @@ write_geotiff(data, 'out.tif', gpu=True) # force GPU compress write_vrt('mosaic.vrt', ['tile1.tif', 'tile2.tif']) # generate VRT ``` -**Compression codecs:** Deflate, LZW (Numba JIT), ZSTD, PackBits, JPEG (Pillow), uncompressed +**Compression codecs:** Deflate, LZW (Numba JIT), ZSTD, PackBits, JPEG (Pillow), JPEG 2000 (glymur), uncompressed -**GPU codecs:** Deflate and ZSTD via nvCOMP batch API; LZW via Numba CUDA kernels +**GPU codecs:** Deflate and ZSTD via nvCOMP batch API; JPEG 2000 via nvJPEG2000; LZW via Numba CUDA kernels **Features:** - Tiled, stripped, BigTIFF, multi-band (RGB/RGBA), sub-byte (1/2/4/12-bit) @@ -540,7 +540,7 @@ Check out the user guide [here](/examples/user_guide/). - **Zero GDAL installation hassle.** `pip install xarray-spatial` gets you everything needed to read and write GeoTIFFs, COGs, and VRT files. - **Pure Python, fully extensible.** All codec, header parsing, and metadata code is readable Python/Numba, not wrapped C/C++. -- **GPU-accelerated reads.** With optional nvCOMP, compressed tiles decompress directly on the GPU via CUDA -- something GDAL cannot do. +- **GPU-accelerated reads.** With optional nvCOMP and nvJPEG2000, compressed tiles decompress directly on the GPU via CUDA -- something GDAL cannot do. The native reader is pixel-exact against rasterio/GDAL across Landsat 8, Copernicus DEM, USGS 1-arc-second, and USGS 1-meter DEMs. For uncompressed files it reads 5-7x faster than rioxarray; for compressed COGs it is comparable or faster with GPU acceleration. diff --git a/examples/user_guide/35_JPEG2000_Compression.ipynb b/examples/user_guide/35_JPEG2000_Compression.ipynb new file mode 100644 index 00000000..0e1e4207 --- /dev/null +++ b/examples/user_guide/35_JPEG2000_Compression.ipynb @@ -0,0 +1,101 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "yox7s6qx13e", + "source": "# JPEG 2000 compression for GeoTIFFs\n\nThe geotiff package supports JPEG 2000 (J2K) as a compression codec for both reading and writing. This is useful for satellite imagery workflows where J2K is common (Sentinel-2, Landsat, etc.).\n\nTwo acceleration tiers are available:\n- **CPU** via `glymur` (pip install glymur) -- works anywhere OpenJPEG is installed\n- **GPU** via NVIDIA's nvJPEG2000 library -- same optional pattern as nvCOMP for deflate/ZSTD\n\nThis notebook demonstrates write/read roundtrips with JPEG 2000 compression.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "kamu534xsm", + "source": "import numpy as np\nimport xarray as xr\nimport matplotlib.pyplot as plt\nimport tempfile\nimport os\n\nfrom xrspatial.geotiff import read_geotiff, write_geotiff", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "w7tlml1cyqj", + "source": "## Generate synthetic elevation data\n\nWe'll create a small terrain-like raster to use as test data.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "9fzhnpcn4xq", + "source": "# Create a 256x256 synthetic terrain (uint16, typical for satellite imagery)\nrng = np.random.RandomState(42)\nyy, xx = np.meshgrid(np.linspace(-2, 2, 256), np.linspace(-2, 2, 256), indexing='ij')\nterrain = np.exp(-(xx**2 + yy**2)) * 10000 + rng.normal(0, 100, (256, 256))\nterrain = np.clip(terrain, 0, 65535).astype(np.uint16)\n\nda = xr.DataArray(\n terrain,\n dims=['y', 'x'],\n coords={\n 'y': np.linspace(45.0, 44.0, 256),\n 'x': np.linspace(-120.0, -119.0, 256),\n },\n attrs={'crs': 4326},\n name='elevation',\n)\n\nfig, ax = plt.subplots(figsize=(6, 5))\nda.plot(ax=ax, cmap='terrain')\nax.set_title('Synthetic elevation (uint16)')\nplt.tight_layout()\nplt.show()", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "8tsuyr3jbay", + "source": "## Write with JPEG 2000 (lossless)\n\nPass `compression='jpeg2000'` to `write_geotiff`. The default is lossless encoding.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "ystjp6v30d", + "source": "tmpdir = tempfile.mkdtemp(prefix='j2k_demo_')\n\n# Write with JPEG 2000 compression\nj2k_path = os.path.join(tmpdir, 'elevation_j2k.tif')\nwrite_geotiff(da, j2k_path, compression='jpeg2000')\n\n# Compare file sizes with deflate\ndeflate_path = os.path.join(tmpdir, 'elevation_deflate.tif')\nwrite_geotiff(da, deflate_path, compression='deflate')\n\nnone_path = os.path.join(tmpdir, 'elevation_none.tif')\nwrite_geotiff(da, none_path, compression='none')\n\nj2k_size = os.path.getsize(j2k_path)\ndeflate_size = os.path.getsize(deflate_path)\nnone_size = os.path.getsize(none_path)\n\nprint(f\"Uncompressed: {none_size:>8,} bytes\")\nprint(f\"Deflate: {deflate_size:>8,} bytes ({deflate_size/none_size:.1%} of original)\")\nprint(f\"JPEG 2000: {j2k_size:>8,} bytes ({j2k_size/none_size:.1%} of original)\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "89y9zun97nb", + "source": "## Read it back and verify lossless roundtrip\n\n`read_geotiff` auto-detects the compression from the TIFF header. No special arguments needed.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "8vf9ljxkx03", + "source": "# Read back and check lossless roundtrip\nda_read = read_geotiff(j2k_path)\n\nprint(f\"Shape: {da_read.shape}\")\nprint(f\"Dtype: {da_read.dtype}\")\nprint(f\"CRS: {da_read.attrs.get('crs')}\")\nprint(f\"Exact match: {np.array_equal(da_read.values, terrain)}\")\n\nfig, axes = plt.subplots(1, 2, figsize=(12, 5))\nda.plot(ax=axes[0], cmap='terrain')\naxes[0].set_title('Original')\nda_read.plot(ax=axes[1], cmap='terrain')\naxes[1].set_title('After JPEG 2000 roundtrip')\nplt.tight_layout()\nplt.show()", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "gcj96utnd3u", + "source": "## Multi-band example (RGB)\n\nJPEG 2000 also handles multi-band imagery, which is the common case for satellite data.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "mgv9xhsrcen", + "source": "# Create a 3-band uint8 image\nrgb = np.zeros((128, 128, 3), dtype=np.uint8)\nrgb[:, :, 0] = np.linspace(0, 255, 128).astype(np.uint8)[None, :] # red gradient\nrgb[:, :, 1] = np.linspace(0, 255, 128).astype(np.uint8)[:, None] # green gradient\nrgb[:, :, 2] = 128 # constant blue\n\nda_rgb = xr.DataArray(\n rgb, dims=['y', 'x', 'band'],\n coords={'y': np.arange(128), 'x': np.arange(128), 'band': [0, 1, 2]},\n)\n\nrgb_path = os.path.join(tmpdir, 'rgb_j2k.tif')\nwrite_geotiff(da_rgb, rgb_path, compression='jpeg2000')\n\nda_rgb_read = read_geotiff(rgb_path)\nprint(f\"RGB shape: {da_rgb_read.shape}, dtype: {da_rgb_read.dtype}\")\nprint(f\"Exact match: {np.array_equal(da_rgb_read.values, rgb)}\")\n\nfig, axes = plt.subplots(1, 2, figsize=(10, 4))\naxes[0].imshow(rgb)\naxes[0].set_title('Original RGB')\naxes[1].imshow(da_rgb_read.values)\naxes[1].set_title('After J2K roundtrip')\nplt.tight_layout()\nplt.show()", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "zzga5hc3a99", + "source": "## GPU acceleration\n\nOn systems with nvJPEG2000 installed (CUDA toolkit, RAPIDS environments), pass `gpu=True` to use GPU-accelerated J2K encode/decode. The API is the same -- it falls back to CPU automatically if the library isn't found.\n\n```python\n# GPU write (nvJPEG2000 if available, else CPU fallback)\nwrite_geotiff(cupy_data, \"output.tif\", compression=\"jpeg2000\", gpu=True)\n\n# GPU read (nvJPEG2000 decode if available)\nda = read_geotiff(\"satellite.tif\", gpu=True)\n```", + "metadata": {} + }, + { + "cell_type": "code", + "id": "x74nrht8kx", + "source": "# Cleanup temp files\nimport shutil\nshutil.rmtree(tmpdir, ignore_errors=True)", + "metadata": {}, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/xrspatial/geotiff/_compression.py b/xrspatial/geotiff/_compression.py index c78c6ebc..e7213c9a 100644 --- a/xrspatial/geotiff/_compression.py +++ b/xrspatial/geotiff/_compression.py @@ -727,6 +727,69 @@ def zstd_compress(data: bytes, level: int = 3) -> bytes: return _zstd.ZstdCompressor(level=level).compress(data) +# -- JPEG 2000 codec (via glymur) -------------------------------------------- + +JPEG2000_AVAILABLE = False +try: + import glymur as _glymur + JPEG2000_AVAILABLE = True +except ImportError: + _glymur = None + + +def jpeg2000_decompress(data: bytes, width: int = 0, height: int = 0, + samples: int = 1) -> bytes: + """Decompress a JPEG 2000 codestream. Requires ``glymur``.""" + if not JPEG2000_AVAILABLE: + raise ImportError( + "glymur is required to read JPEG 2000-compressed TIFFs. " + "Install it with: pip install glymur") + import tempfile + import os + # glymur reads from files, so write the codestream to a temp file + fd, tmp = tempfile.mkstemp(suffix='.j2k') + try: + os.write(fd, data) + os.close(fd) + jp2 = _glymur.Jp2k(tmp) + arr = jp2[:] + return arr.tobytes() + finally: + os.unlink(tmp) + + +def jpeg2000_compress(data: bytes, width: int, height: int, + samples: int = 1, dtype: np.dtype = np.dtype('uint8'), + lossless: bool = True) -> bytes: + """Compress raw pixel data as JPEG 2000 codestream. Requires ``glymur``.""" + if not JPEG2000_AVAILABLE: + raise ImportError( + "glymur is required to write JPEG 2000-compressed TIFFs. " + "Install it with: pip install glymur") + import math + import tempfile + import os + if samples == 1: + arr = np.frombuffer(data, dtype=dtype).reshape(height, width) + else: + arr = np.frombuffer(data, dtype=dtype).reshape(height, width, samples) + fd, tmp = tempfile.mkstemp(suffix='.j2k') + os.close(fd) + os.unlink(tmp) # glymur needs the file to not exist + try: + cratios = [1] if lossless else [20] + # numres must be <= log2(min_dim) + 1 to avoid OpenJPEG errors + min_dim = max(min(width, height), 1) + numres = min(6, int(math.log2(min_dim)) + 1) + numres = max(numres, 1) + _glymur.Jp2k(tmp, data=arr, cratios=cratios, numres=numres) + with open(tmp, 'rb') as f: + return f.read() + finally: + if os.path.exists(tmp): + os.unlink(tmp) + + # -- Dispatch helpers --------------------------------------------------------- # TIFF compression tag values @@ -734,6 +797,7 @@ def zstd_compress(data: bytes, level: int = 3) -> bytes: COMPRESSION_LZW = 5 COMPRESSION_JPEG = 7 COMPRESSION_DEFLATE = 8 +COMPRESSION_JPEG2000 = 34712 COMPRESSION_ZSTD = 50000 COMPRESSION_PACKBITS = 32773 COMPRESSION_ADOBE_DEFLATE = 32946 @@ -771,6 +835,9 @@ def decompress(data, compression: int, expected_size: int = 0, dtype=np.uint8) elif compression == COMPRESSION_ZSTD: return np.frombuffer(zstd_decompress(data), dtype=np.uint8) + elif compression == COMPRESSION_JPEG2000: + return np.frombuffer( + jpeg2000_decompress(data, width, height, samples), dtype=np.uint8) else: raise ValueError(f"Unsupported compression type: {compression}") @@ -803,5 +870,7 @@ def compress(data: bytes, compression: int, level: int = 6) -> bytes: return zstd_compress(data, level) elif compression == COMPRESSION_JPEG: raise ValueError("Use jpeg_compress() directly with width/height/samples") + elif compression == COMPRESSION_JPEG2000: + raise ValueError("Use jpeg2000_compress() directly with width/height/samples/dtype") else: raise ValueError(f"Unsupported compression type: {compression}") diff --git a/xrspatial/geotiff/_gpu_decode.py b/xrspatial/geotiff/_gpu_decode.py index 93f2ae1a..649b425b 100644 --- a/xrspatial/geotiff/_gpu_decode.py +++ b/xrspatial/geotiff/_gpu_decode.py @@ -1182,6 +1182,29 @@ def gpu_decode_tiles( ) cuda.synchronize() + elif compression == 34712: # JPEG 2000 + nvj2k_result = _try_nvjpeg2k_batch_decode( + compressed_tiles, tile_width, tile_height, dtype, samples) + if nvj2k_result is not None: + d_decomp = nvj2k_result + decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_decomp_offsets = cupy.asarray(decomp_offsets) + else: + # CPU fallback for JPEG 2000 + from ._compression import jpeg2000_decompress + raw_host = np.empty(n_tiles * tile_bytes, dtype=np.uint8) + for i, tile in enumerate(compressed_tiles): + start = i * tile_bytes + chunk = np.frombuffer( + jpeg2000_decompress(tile, tile_width, tile_height, samples), + dtype=np.uint8) + raw_host[start:start + min(len(chunk), tile_bytes)] = \ + chunk[:tile_bytes] if len(chunk) >= tile_bytes else \ + np.pad(chunk, (0, tile_bytes - len(chunk))) + d_decomp = cupy.asarray(raw_host) + decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_decomp_offsets = cupy.asarray(decomp_offsets) + elif compression == 1: # Uncompressed raw_host = np.empty(n_tiles * tile_bytes, dtype=np.uint8) for i, tile in enumerate(compressed_tiles): @@ -1476,6 +1499,340 @@ class _DeflateCompOpts(ctypes.Structure): return None +# --------------------------------------------------------------------------- +# nvJPEG2000 batch decode/encode (optional, GPU-accelerated JPEG 2000) +# --------------------------------------------------------------------------- + +_nvjpeg2k_lib = None +_nvjpeg2k_checked = False + + +def _find_nvjpeg2k_lib(): + """Find and load libnvjpeg2k.so. Returns ctypes.CDLL or None.""" + import ctypes + import os + + search_paths = [ + 'libnvjpeg2k.so', # system LD_LIBRARY_PATH + ] + + conda_prefix = os.environ.get('CONDA_PREFIX', '') + if conda_prefix: + search_paths.append(os.path.join(conda_prefix, 'lib', 'libnvjpeg2k.so')) + + conda_base = os.path.dirname(conda_prefix) if conda_prefix else '' + if conda_base: + for env in ['rapids', 'test-again', 'rtxpy-fire']: + p = os.path.join(conda_base, env, 'lib', 'libnvjpeg2k.so') + if os.path.exists(p): + search_paths.append(p) + + for path in search_paths: + try: + return ctypes.CDLL(path) + except OSError: + continue + return None + + +def _get_nvjpeg2k(): + """Get the nvJPEG2000 library handle (cached). Returns CDLL or None.""" + global _nvjpeg2k_lib, _nvjpeg2k_checked + if not _nvjpeg2k_checked: + _nvjpeg2k_checked = True + _nvjpeg2k_lib = _find_nvjpeg2k_lib() + return _nvjpeg2k_lib + + +def _try_nvjpeg2k_batch_decode(compressed_tiles, tile_width, tile_height, + dtype, samples): + """Try decoding JPEG 2000 tiles via nvJPEG2000. Returns list of CuPy arrays or None. + + Each tile is decoded independently. The decoded pixels are returned as a + flat CuPy uint8 buffer (all tiles concatenated), matching the layout + expected by _apply_predictor_and_assemble / the assembly kernel. + """ + lib = _get_nvjpeg2k() + if lib is None: + return None + + import ctypes + import cupy + + n_tiles = len(compressed_tiles) + bytes_per_pixel = dtype.itemsize * samples + tile_bytes = tile_width * tile_height * bytes_per_pixel + + try: + # Create nvjpeg2k handle + handle = ctypes.c_void_p() + s = lib.nvjpeg2kCreateSimple(ctypes.byref(handle)) + if s != 0: + return None + + # Create decode state and params + state = ctypes.c_void_p() + s = lib.nvjpeg2kDecodeStateCreate(handle, ctypes.byref(state)) + if s != 0: + lib.nvjpeg2kDestroy(handle) + return None + + stream = ctypes.c_void_p() + s = lib.nvjpeg2kStreamCreate(ctypes.byref(stream)) + if s != 0: + lib.nvjpeg2kDecodeStateDestroy(state) + lib.nvjpeg2kDestroy(handle) + return None + + params = ctypes.c_void_p() + s = lib.nvjpeg2kDecodeParamsCreate(ctypes.byref(params)) + if s != 0: + lib.nvjpeg2kStreamDestroy(stream) + lib.nvjpeg2kDecodeStateDestroy(state) + lib.nvjpeg2kDestroy(handle) + return None + + # nvjpeg2kImage_t: array of pointers (pixel_data) + array of pitches + MAX_COMPONENTS = 4 + + class _NvJpeg2kImage(ctypes.Structure): + _fields_ = [ + ('pixel_data', ctypes.c_void_p * MAX_COMPONENTS), + ('pitch_in_bytes', ctypes.c_size_t * MAX_COMPONENTS), + ('num_components', ctypes.c_uint32), + ('pixel_type', ctypes.c_int), # NVJPEG2K_UINT8=0, UINT16=1, INT16=2 + ] + + # Map numpy dtype to nvjpeg2k pixel type + if dtype == np.uint8: + pixel_type = 0 # NVJPEG2K_UINT8 + elif dtype == np.uint16: + pixel_type = 1 # NVJPEG2K_UINT16 + elif dtype == np.int16: + pixel_type = 2 # NVJPEG2K_INT16 + else: + # Unsupported dtype for nvJPEG2000 -- fall back + lib.nvjpeg2kDecodeParamsDestroy(params) + lib.nvjpeg2kStreamDestroy(stream) + lib.nvjpeg2kDecodeStateDestroy(state) + lib.nvjpeg2kDestroy(handle) + return None + + # Decode each tile + d_all_tiles = cupy.empty(n_tiles * tile_bytes, dtype=cupy.uint8) + + for i, tile_data in enumerate(compressed_tiles): + # Parse the J2K codestream + src = np.frombuffer(tile_data, dtype=np.uint8) + s = lib.nvjpeg2kStreamParse( + handle, + ctypes.c_void_p(src.ctypes.data), + ctypes.c_size_t(len(src)), + ctypes.c_int(0), # save_metadata + ctypes.c_int(0), # save_stream + stream, + ) + if s != 0: + continue + + # Allocate per-component output buffers on GPU + comp_bufs = [] + pitch = tile_width * dtype.itemsize + for c in range(samples): + buf = cupy.empty(tile_height * pitch, dtype=cupy.uint8) + comp_bufs.append(buf) + + # Build nvjpeg2kImage_t + img = _NvJpeg2kImage() + img.num_components = samples + img.pixel_type = pixel_type + for c in range(samples): + img.pixel_data[c] = comp_bufs[c].data.ptr + img.pitch_in_bytes[c] = pitch + + # Decode + s = lib.nvjpeg2kDecode( + handle, state, stream, params, + ctypes.byref(img), + ctypes.c_void_p(0), # default CUDA stream + ) + cupy.cuda.Device().synchronize() + + if s != 0: + continue + + # Interleave components into pixel order (comp0,comp1,...) per pixel + tile_offset = i * tile_bytes + if samples == 1: + d_all_tiles[tile_offset:tile_offset + tile_bytes] = comp_bufs[0][:tile_bytes] + else: + # Interleave: separate planes -> pixel-interleaved + comp_arrays = [ + comp_bufs[c][:tile_height * pitch].view( + dtype=cupy.dtype(dtype)).reshape(tile_height, tile_width) + for c in range(samples) + ] + interleaved = cupy.stack(comp_arrays, axis=-1) + d_all_tiles[tile_offset:tile_offset + tile_bytes] = \ + interleaved.view(cupy.uint8).ravel() + + # Cleanup + lib.nvjpeg2kDecodeParamsDestroy(params) + lib.nvjpeg2kStreamDestroy(stream) + lib.nvjpeg2kDecodeStateDestroy(state) + lib.nvjpeg2kDestroy(handle) + + return d_all_tiles + + except Exception: + return None + + +def _nvjpeg2k_batch_encode(d_tile_bufs, tile_width, tile_height, + dtype, samples, n_tiles, lossless=True): + """Encode tiles as JPEG 2000 via nvJPEG2000. Returns list of bytes or None.""" + lib = _get_nvjpeg2k() + if lib is None: + return None + + import ctypes + import cupy + + try: + bytes_per_pixel = dtype.itemsize * samples + tile_bytes = tile_width * tile_height * bytes_per_pixel + + # Create encoder + encoder = ctypes.c_void_p() + s = lib.nvjpeg2kEncoderCreateSimple(ctypes.byref(encoder)) + if s != 0: + return None + + enc_state = ctypes.c_void_p() + s = lib.nvjpeg2kEncodeStateCreate(encoder, ctypes.byref(enc_state)) + if s != 0: + lib.nvjpeg2kEncoderDestroy(encoder) + return None + + enc_params = ctypes.c_void_p() + s = lib.nvjpeg2kEncodeParamsCreate(ctypes.byref(enc_params)) + if s != 0: + lib.nvjpeg2kEncodeStateDestroy(enc_state) + lib.nvjpeg2kEncoderDestroy(encoder) + return None + + # Set encoding parameters + if lossless: + lib.nvjpeg2kEncodeParamsSetQuality(enc_params, ctypes.c_int(1)) + + MAX_COMPONENTS = 4 + + class _NvJpeg2kImage(ctypes.Structure): + _fields_ = [ + ('pixel_data', ctypes.c_void_p * MAX_COMPONENTS), + ('pitch_in_bytes', ctypes.c_size_t * MAX_COMPONENTS), + ('num_components', ctypes.c_uint32), + ('pixel_type', ctypes.c_int), + ] + + if dtype == np.uint8: + pixel_type = 0 + elif dtype == np.uint16: + pixel_type = 1 + elif dtype == np.int16: + pixel_type = 2 + else: + lib.nvjpeg2kEncodeParamsDestroy(enc_params) + lib.nvjpeg2kEncodeStateDestroy(enc_state) + lib.nvjpeg2kEncoderDestroy(encoder) + return None + + pitch = tile_width * dtype.itemsize + result = [] + + for i in range(n_tiles): + tile_data = d_tile_bufs[i * tile_bytes:(i + 1) * tile_bytes] + + # De-interleave into per-component planes for the encoder + if samples == 1: + comp_bufs = [tile_data] + else: + tile_arr = tile_data.view(dtype=cupy.dtype(dtype)).reshape( + tile_height, tile_width, samples) + comp_bufs = [ + cupy.ascontiguousarray(tile_arr[:, :, c]).view(cupy.uint8).ravel() + for c in range(samples) + ] + + img = _NvJpeg2kImage() + img.num_components = samples + img.pixel_type = pixel_type + for c in range(samples): + img.pixel_data[c] = comp_bufs[c].data.ptr + img.pitch_in_bytes[c] = pitch + + # Set image info on params + class _CompInfo(ctypes.Structure): + _fields_ = [ + ('component_width', ctypes.c_uint32), + ('component_height', ctypes.c_uint32), + ('precision', ctypes.c_uint8), + ('sgn', ctypes.c_uint8), + ] + + precision = dtype.itemsize * 8 + sgn = 1 if dtype.kind == 'i' else 0 + + comp_info = (_CompInfo * samples)() + for c in range(samples): + comp_info[c].component_width = tile_width + comp_info[c].component_height = tile_height + comp_info[c].precision = precision + comp_info[c].sgn = sgn + + # Encode + s = lib.nvjpeg2kEncode( + encoder, enc_state, enc_params, + ctypes.byref(img), + ctypes.c_void_p(0), # default CUDA stream + ) + cupy.cuda.Device().synchronize() + if s != 0: + lib.nvjpeg2kEncodeParamsDestroy(enc_params) + lib.nvjpeg2kEncodeStateDestroy(enc_state) + lib.nvjpeg2kEncoderDestroy(encoder) + return None + + # Retrieve bitstream size + bs_size = ctypes.c_size_t(0) + lib.nvjpeg2kEncoderRetrieveBitstream( + encoder, enc_state, + ctypes.c_void_p(0), + ctypes.byref(bs_size), + ctypes.c_void_p(0), + ) + + # Retrieve bitstream data + bs_buf = np.empty(bs_size.value, dtype=np.uint8) + lib.nvjpeg2kEncoderRetrieveBitstream( + encoder, enc_state, + ctypes.c_void_p(bs_buf.ctypes.data), + ctypes.byref(bs_size), + ctypes.c_void_p(0), + ) + + result.append(bs_buf[:bs_size.value].tobytes()) + + lib.nvjpeg2kEncodeParamsDestroy(enc_params) + lib.nvjpeg2kEncodeStateDestroy(enc_state) + lib.nvjpeg2kEncoderDestroy(encoder) + + return result + + except Exception: + return None + + # --------------------------------------------------------------------------- # High-level GPU write pipeline # --------------------------------------------------------------------------- @@ -1550,6 +1907,24 @@ def gpu_compress_tiles(d_image, tile_width, tile_height, d_tile_buf, d_tmp, tile_width * samples, total_rows, dtype.itemsize) cuda.synchronize() + # JPEG 2000: use nvJPEG2000 (image codec, not byte-stream codec) + if compression == 34712: + result = _nvjpeg2k_batch_encode( + d_tile_buf, tile_width, tile_height, dtype, samples, n_tiles) + if result is not None: + return result + # CPU fallback for JPEG 2000 + from ._compression import jpeg2000_compress + cpu_buf = d_tile_buf.get() + result = [] + for i in range(n_tiles): + start = i * tile_bytes + tile_data = bytes(cpu_buf[start:start + tile_bytes]) + result.append(jpeg2000_compress( + tile_data, tile_width, tile_height, + samples=samples, dtype=dtype)) + return result + # Split into per-tile buffers for nvCOMP d_tiles = [d_tile_buf[i * tile_bytes:(i + 1) * tile_bytes] for i in range(n_tiles)] diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index ae7658ab..063b0c7f 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -8,6 +8,7 @@ from ._compression import ( COMPRESSION_DEFLATE, + COMPRESSION_JPEG2000, COMPRESSION_LZW, COMPRESSION_NONE, COMPRESSION_PACKBITS, @@ -67,6 +68,8 @@ def _compression_tag(compression_name: str) -> int: 'lzw': COMPRESSION_LZW, 'packbits': COMPRESSION_PACKBITS, 'zstd': COMPRESSION_ZSTD, + 'jpeg2000': COMPRESSION_JPEG2000, + 'j2k': COMPRESSION_JPEG2000, } name = compression_name.lower() if name not in _map: @@ -318,7 +321,12 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: bool, else: strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() - compressed = compress(strip_data, compression) + if compression == COMPRESSION_JPEG2000: + from ._compression import jpeg2000_compress + compressed = jpeg2000_compress( + strip_data, width, strip_rows, samples=samples, dtype=dtype) + else: + compressed = compress(strip_data, compression) rel_offsets.append(current_offset) byte_counts.append(len(compressed)) @@ -391,7 +399,12 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: bool, else: tile_data = tile_arr.tobytes() - compressed = compress(tile_data, compression) + if compression == COMPRESSION_JPEG2000: + from ._compression import jpeg2000_compress + compressed = jpeg2000_compress( + tile_data, tw, th, samples=samples, dtype=dtype) + else: + compressed = compress(tile_data, compression) rel_offsets.append(current_offset) byte_counts.append(len(compressed)) diff --git a/xrspatial/geotiff/tests/test_edge_cases.py b/xrspatial/geotiff/tests/test_edge_cases.py index 1a8a8680..0c2fd2a8 100644 --- a/xrspatial/geotiff/tests/test_edge_cases.py +++ b/xrspatial/geotiff/tests/test_edge_cases.py @@ -50,7 +50,7 @@ def test_0d_scalar(self, tmp_path): def test_unsupported_compression(self, tmp_path): arr = np.zeros((4, 4), dtype=np.float32) with pytest.raises(ValueError, match="Unsupported compression"): - write_geotiff(arr, str(tmp_path / 'bad.tif'), compression='jpeg2000') + write_geotiff(arr, str(tmp_path / 'bad.tif'), compression='webp') def test_complex_dtype(self, tmp_path): arr = np.zeros((4, 4), dtype=np.complex64) diff --git a/xrspatial/geotiff/tests/test_jpeg2000.py b/xrspatial/geotiff/tests/test_jpeg2000.py new file mode 100644 index 00000000..2ada8aac --- /dev/null +++ b/xrspatial/geotiff/tests/test_jpeg2000.py @@ -0,0 +1,186 @@ +"""Tests for JPEG 2000 compression codec (#1048).""" +from __future__ import annotations + +import numpy as np +import pytest + +from xrspatial.geotiff._compression import ( + COMPRESSION_JPEG2000, + JPEG2000_AVAILABLE, + jpeg2000_compress, + jpeg2000_decompress, + decompress, +) + +pytestmark = pytest.mark.skipif( + not JPEG2000_AVAILABLE, + reason="glymur not installed", +) + + +class TestJPEG2000Codec: + """CPU JPEG 2000 codec roundtrip via glymur.""" + + def test_roundtrip_uint8(self): + arr = np.arange(64, dtype=np.uint8).reshape(8, 8) + compressed = jpeg2000_compress( + arr.tobytes(), 8, 8, samples=1, dtype=np.dtype('uint8'), + lossless=True) + assert isinstance(compressed, bytes) + assert len(compressed) > 0 + + decompressed = jpeg2000_decompress(compressed, 8, 8, 1) + result = np.frombuffer(decompressed, dtype=np.uint8).reshape(8, 8) + np.testing.assert_array_equal(result, arr) + + def test_roundtrip_uint16(self): + arr = np.arange(64, dtype=np.uint16).reshape(8, 8) + compressed = jpeg2000_compress( + arr.tobytes(), 8, 8, samples=1, dtype=np.dtype('uint16'), + lossless=True) + decompressed = jpeg2000_decompress(compressed, 8, 8, 1) + result = np.frombuffer(decompressed, dtype=np.uint16).reshape(8, 8) + np.testing.assert_array_equal(result, arr) + + def test_roundtrip_multiband(self): + arr = np.arange(192, dtype=np.uint8).reshape(8, 8, 3) + compressed = jpeg2000_compress( + arr.tobytes(), 8, 8, samples=3, dtype=np.dtype('uint8'), + lossless=True) + decompressed = jpeg2000_decompress(compressed, 8, 8, 3) + result = np.frombuffer(decompressed, dtype=np.uint8).reshape(8, 8, 3) + np.testing.assert_array_equal(result, arr) + + def test_single_pixel(self): + arr = np.array([[42]], dtype=np.uint8) + compressed = jpeg2000_compress( + arr.tobytes(), 1, 1, samples=1, dtype=np.dtype('uint8'), + lossless=True) + decompressed = jpeg2000_decompress(compressed, 1, 1, 1) + result = np.frombuffer(decompressed, dtype=np.uint8) + assert result[0] == 42 + + def test_lossy_produces_smaller_output(self): + rng = np.random.RandomState(1048) + arr = rng.randint(0, 256, size=(64, 64), dtype=np.uint8) + lossless = jpeg2000_compress( + arr.tobytes(), 64, 64, samples=1, dtype=np.dtype('uint8'), + lossless=True) + lossy = jpeg2000_compress( + arr.tobytes(), 64, 64, samples=1, dtype=np.dtype('uint8'), + lossless=False) + # Lossy should generally be smaller + assert len(lossy) <= len(lossless) + + def test_dispatch_decompress(self): + arr = np.arange(16, dtype=np.uint8).reshape(4, 4) + compressed = jpeg2000_compress( + arr.tobytes(), 4, 4, samples=1, dtype=np.dtype('uint8'), + lossless=True) + result = decompress(compressed, COMPRESSION_JPEG2000, + width=4, height=4, samples=1) + np.testing.assert_array_equal( + result.reshape(4, 4), + arr, + ) + + +class TestJPEG2000WriteRoundTrip: + """Write-read roundtrip using the TIFF writer with JPEG 2000 compression.""" + + def test_tiled_uint8(self, tmp_path): + from xrspatial.geotiff._writer import write + from xrspatial.geotiff._reader import read_to_array + + expected = np.arange(64, dtype=np.uint8).reshape(8, 8) + path = str(tmp_path / 'j2k_1048_tiled_uint8.tif') + write(expected, path, compression='jpeg2000', tiled=True, tile_size=8) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_tiled_uint16(self, tmp_path): + from xrspatial.geotiff._writer import write + from xrspatial.geotiff._reader import read_to_array + + expected = np.arange(64, dtype=np.uint16).reshape(8, 8) + path = str(tmp_path / 'j2k_1048_tiled_uint16.tif') + write(expected, path, compression='jpeg2000', tiled=True, tile_size=8) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_stripped_uint8(self, tmp_path): + from xrspatial.geotiff._writer import write + from xrspatial.geotiff._reader import read_to_array + + expected = np.arange(64, dtype=np.uint8).reshape(8, 8) + path = str(tmp_path / 'j2k_1048_stripped.tif') + write(expected, path, compression='jpeg2000', tiled=False) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_with_geo_info(self, tmp_path): + from xrspatial.geotiff._writer import write + from xrspatial.geotiff._reader import read_to_array + from xrspatial.geotiff._geotags import GeoTransform + + expected = np.ones((8, 8), dtype=np.uint8) * 100 + gt = GeoTransform(-120.0, 45.0, 0.001, -0.001) + path = str(tmp_path / 'j2k_1048_geo.tif') + write(expected, path, compression='jpeg2000', tiled=True, tile_size=8, + geo_transform=gt, crs_epsg=4326, nodata=0) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + assert geo.crs_epsg == 4326 + + def test_public_api_roundtrip(self, tmp_path): + """Test via read_geotiff / write_geotiff public API.""" + import xarray as xr + from xrspatial.geotiff import read_geotiff, write_geotiff + + data = np.arange(64, dtype=np.uint8).reshape(8, 8) + da = xr.DataArray(data, dims=['y', 'x'], + coords={'y': np.arange(8), 'x': np.arange(8)}, + attrs={'crs': 4326}) + path = str(tmp_path / 'j2k_1048_api.tif') + write_geotiff(da, path, compression='jpeg2000') + + result = read_geotiff(path) + np.testing.assert_array_equal(result.values, data) + + +class TestJPEG2000Availability: + """Test the availability flag and error handling. + + These don't need glymur, so they always run. + """ + + # Override the module-level skip for this class + pytestmark = [] + + def test_compression_constant(self): + assert COMPRESSION_JPEG2000 == 34712 + + def test_compression_tag_mapping(self): + from xrspatial.geotiff._writer import _compression_tag + assert _compression_tag('jpeg2000') == 34712 + assert _compression_tag('j2k') == 34712 + + def test_unavailable_raises_import_error(self): + """If glymur is missing, codec functions raise ImportError.""" + import unittest.mock + import importlib + import xrspatial.geotiff._compression as comp_mod + # Temporarily pretend glymur is unavailable + orig = comp_mod.JPEG2000_AVAILABLE + comp_mod.JPEG2000_AVAILABLE = False + try: + with pytest.raises(ImportError, match="glymur"): + comp_mod.jpeg2000_decompress(b'\x00', 1, 1, 1) + with pytest.raises(ImportError, match="glymur"): + comp_mod.jpeg2000_compress(b'\x00', 1, 1, dtype=np.dtype('uint8')) + finally: + comp_mod.JPEG2000_AVAILABLE = orig