Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions examples/user_guide/35_Focal_Variety.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Focal Variety\n",
"\n",
"Focal variety counts the number of distinct values in a sliding\n",
"neighbourhood window. It is most useful for categorical rasters\n",
"(land-cover, soil type, geology codes) where you want to map\n",
"boundary complexity or patch fragmentation.\n",
"\n",
"This notebook shows how to compute focal variety with\n",
"`xrspatial.focal.focal_stats` across different kernel shapes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import xarray as xr\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from xrspatial.convolution import circle_kernel, custom_kernel\n",
"from xrspatial.focal import focal_stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create a synthetic land-cover raster\n",
"\n",
"We build a 60x60 grid with four land-cover classes arranged in\n",
"quadrants, plus a few scattered patches to make things interesting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rng = np.random.default_rng(42)\n",
"rows, cols = 60, 60\n",
"\n",
"# Four quadrants: classes 1-4\n",
"lc = np.ones((rows, cols), dtype=np.float64)\n",
"lc[:rows//2, cols//2:] = 2\n",
"lc[rows//2:, :cols//2] = 3\n",
"lc[rows//2:, cols//2:] = 4\n",
"\n",
"# Scatter some class-5 patches\n",
"for _ in range(30):\n",
" r, c = rng.integers(0, rows), rng.integers(0, cols)\n",
" lc[r:r+3, c:c+3] = 5\n",
"\n",
"land_cover = xr.DataArray(lc, dims=['y', 'x'], name='land_cover')\n",
"\n",
"fig, ax = plt.subplots(figsize=(5, 5))\n",
"land_cover.plot(ax=ax, cmap='Set2', add_colorbar=True)\n",
"ax.set_title('Synthetic land-cover raster')\n",
"ax.set_aspect('equal')\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute focal variety with a 3x3 box kernel\n",
"\n",
"A 3x3 box kernel counts how many distinct classes appear in the\n",
"immediate 8-connected neighbourhood of each pixel."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"kernel_box = np.ones((3, 3))\n",
"result_box = focal_stats(land_cover, kernel_box, stats_funcs=['variety'])\n",
"variety_box = result_box.sel(stats='variety')\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
"land_cover.plot(ax=axes[0], cmap='Set2', add_colorbar=True)\n",
"axes[0].set_title('Land cover')\n",
"axes[0].set_aspect('equal')\n",
"\n",
"variety_box.plot(ax=axes[1], cmap='YlOrRd', add_colorbar=True)\n",
"axes[1].set_title('Focal variety (3x3 box)')\n",
"axes[1].set_aspect('equal')\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pixels deep inside a uniform quadrant show variety = 1. Pixels on\n",
"boundaries between classes show variety = 2, 3, or 4 depending on\n",
"how many classes meet at that point. The scattered class-5 patches\n",
"create small pockets of higher variety."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Larger kernel: 5x5 circle\n",
"\n",
"Increasing the kernel radius captures more of the surrounding\n",
"landscape, so variety values near boundaries will be higher."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"kernel_circle = circle_kernel(2, 2, 2)\n",
"result_circle = focal_stats(land_cover, kernel_circle, stats_funcs=['variety'])\n",
"variety_circle = result_circle.sel(stats='variety')\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
"variety_box.plot(ax=axes[0], cmap='YlOrRd', add_colorbar=True)\n",
"axes[0].set_title('Variety (3x3 box)')\n",
"axes[0].set_aspect('equal')\n",
"\n",
"variety_circle.plot(ax=axes[1], cmap='YlOrRd', add_colorbar=True)\n",
"axes[1].set_title('Variety (5x5 circle)')\n",
"axes[1].set_aspect('equal')\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Combining variety with other focal stats\n",
"\n",
"You can request variety alongside other statistics in one call.\n",
"Here we grab both range and variety to compare continuous and\n",
"categorical measures of local heterogeneity."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result_combo = focal_stats(land_cover, kernel_box,\n",
" stats_funcs=['range', 'variety'])\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
"result_combo.sel(stats='range').plot(ax=axes[0], cmap='viridis',\n",
" add_colorbar=True)\n",
"axes[0].set_title('Focal range')\n",
"axes[0].set_aspect('equal')\n",
"\n",
"result_combo.sel(stats='variety').plot(ax=axes[1], cmap='YlOrRd',\n",
" add_colorbar=True)\n",
"axes[1].set_title('Focal variety')\n",
"axes[1].set_aspect('equal')\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Range measures the numeric spread (max minus min) while variety\n",
"counts distinct classes. For categorical data, variety is usually\n",
"the more meaningful measure since the numeric distance between\n",
"class codes is arbitrary."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
82 changes: 78 additions & 4 deletions xrspatial/focal.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,28 @@ def _calc_var(array):
return np.nanvar(array)


@ngjit
def _calc_variety(array):
"""Count distinct non-NaN values in the flat kernel neighbourhood."""
count = 0
uvals = np.empty(array.size, dtype=array.dtype)
for i in range(array.size):
v = array.flat[i]
if np.isnan(v):
continue
found = False
for j in range(count):
if uvals[j] == v:
found = True
break
if not found:
uvals[count] = v
count += 1
if count == 0:
return np.nan
return np.float64(count)


@ngjit
def _apply_numpy(data, kernel, func):
data = data.astype(np.float32)
Expand Down Expand Up @@ -762,6 +784,52 @@ def _focal_var_cuda(data, kernel, out):
out[i, j] = 0.0


@cuda.jit
def _focal_variety_cuda(data, kernel, out):
i, j = cuda.grid(2)

rows, cols = data.shape
if i >= rows or j >= cols:
return

dr = kernel.shape[0] // 2
dc = kernel.shape[1] // 2

# Local buffer for up to 25 unique values (covers kernels up to 5x5).
# For larger kernels the buffer simply fills and stops counting,
# which is an acceptable trade-off for GPU register pressure.
MAX_UNIQ = 25
buf = cuda.local.array(MAX_UNIQ, nb.float32)
count = 0

for k in range(kernel.shape[0]):
for h in range(kernel.shape[1]):
if kernel[k, h] == 0:
continue

ii = i + k - dr
jj = j + h - dc

if 0 <= ii < rows and 0 <= jj < cols:
v = data[ii, jj]
if v != v: # NaN check (NaN != NaN)
continue
# check if already in buffer
found = False
for u in range(count):
if buf[u] == v:
found = True
break
if not found and count < MAX_UNIQ:
buf[count] = v
count += 1

if count == 0:
out[i, j] = math.nan
else:
out[i, j] = float(count)


def _focal_mean_cupy(data, kernel):
out = convolve_2d(data, kernel / kernel.sum())
return out
Expand Down Expand Up @@ -852,6 +920,7 @@ def _focal_stats_cupy(agg, kernel, stats_funcs):
min=lambda *args: _focal_stats_func_cupy(*args, func=_focal_min_cuda),
std=lambda *args: _focal_stats_func_cupy(*args, func=_focal_std_cuda),
var=lambda *args: _focal_stats_func_cupy(*args, func=_focal_var_cuda),
variety=lambda *args: _focal_stats_func_cupy(*args, func=_focal_variety_cuda),
)
stats_aggs = []
for stats in stats_funcs:
Expand All @@ -873,6 +942,7 @@ def _focal_stats_dask_cupy(agg, kernel, stats_funcs, boundary='nan'):
mean=_focal_mean_cuda, sum=_focal_sum_cuda,
range=_focal_range_cuda, max=_focal_max_cuda,
min=_focal_min_cuda, std=_focal_std_cuda, var=_focal_var_cuda,
variety=_focal_variety_cuda,
)
pad_h = kernel.shape[0] // 2
pad_w = kernel.shape[1] // 2
Expand Down Expand Up @@ -902,7 +972,8 @@ def _focal_stats_cpu(agg, kernel, stats_funcs, boundary='nan'):
'range': _calc_range,
'std': _calc_std,
'var': _calc_var,
'sum': _calc_sum
'sum': _calc_sum,
'variety': _calc_variety,
}
stats_aggs = []
for stats in stats_funcs:
Expand All @@ -916,13 +987,14 @@ def _focal_stats_cpu(agg, kernel, stats_funcs, boundary='nan'):
def focal_stats(agg,
kernel,
stats_funcs=[
'mean', 'max', 'min', 'range', 'std', 'var', 'sum'
'mean', 'max', 'min', 'range', 'std', 'var',
'sum', 'variety'
],
boundary='nan'):
"""
Calculates statistics of the values within a specified focal neighborhood
for each pixel in an input raster. The statistics types are Mean, Maximum,
Minimum, Range, Standard deviation, Variation and Sum.
Minimum, Range, Standard deviation, Variation, Sum, and Variety.

Parameters
----------
Expand All @@ -934,7 +1006,9 @@ def focal_stats(agg,
2D array where values of 1 indicate the kernel.
stats_funcs: list of string
List of statistics types to be calculated.
Default set to ['mean', 'max', 'min', 'range', 'std', 'var', 'sum'].
Default set to ['mean', 'max', 'min', 'range', 'std', 'var',
'sum', 'variety']. ``'variety'`` counts the number of distinct
non-NaN values in the neighbourhood (useful for categorical rasters).
boundary : str, default='nan'
How to handle edges where the kernel extends beyond the raster.
``'nan'`` -- fill missing neighbours with NaN (default).
Expand Down
Loading
Loading