Skip to content

Commit 7fa0fd4

Browse files
remtavadamjstewart
andauthored
GridGeoSampler: change stride of last patch to sample entire ROI (torchgeo#630)
* Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning * style and mypy fixes * black test fix * Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning * style and mypy fixes * black test fix * single.py: adapt gridgeosampler to sample beyond limit of ROI for a partial patch (to be padded) test_single.py: add tests for multiple limit cases (see issue torchgeo#448) * format for black and flake8 * format for black and flake8 * once again, format for black and flake8 * Revert "Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning" This reverts commit cb554c6 * adapt unit tests, remove warnings * flake8: remove warnings import * Address some comments * Simplify computation of # rows/cols * Document this new feature * Fix size of ceiling symbol * Simplify tests Co-authored-by: Adam J. Stewart <[email protected]>
1 parent f41619a commit 7fa0fd4

File tree

2 files changed

+70
-13
lines changed

2 files changed

+70
-13
lines changed

tests/samplers/test_single.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,9 @@ def test_iter(self, sampler: GridGeoSampler) -> None:
182182
)
183183

184184
def test_len(self, sampler: GridGeoSampler) -> None:
185-
rows = ((100 - sampler.size[0]) // sampler.stride[0]) + 1
186-
cols = ((100 - sampler.size[1]) // sampler.stride[1]) + 1
187-
length = rows * cols * 2
185+
rows = math.ceil((100 - sampler.size[0]) / sampler.stride[0]) + 1
186+
cols = math.ceil((100 - sampler.size[1]) / sampler.stride[1]) + 1
187+
length = rows * cols * 2 # two items in dataset
188188
assert len(sampler) == length
189189

190190
def test_roi(self, dataset: CustomGeoDataset) -> None:
@@ -194,12 +194,35 @@ def test_roi(self, dataset: CustomGeoDataset) -> None:
194194
assert query in roi
195195

196196
def test_small_area(self) -> None:
197+
ds = CustomGeoDataset()
198+
ds.index.insert(0, (0, 1, 0, 1, 0, 1))
199+
sampler = GridGeoSampler(ds, 2, 10)
200+
assert len(sampler) == 0
201+
202+
def test_tiles_side_by_side(self) -> None:
197203
ds = CustomGeoDataset()
198204
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
199-
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
205+
ds.index.insert(0, (0, 10, 10, 20, 0, 10))
200206
sampler = GridGeoSampler(ds, 2, 10)
201-
for _ in sampler:
202-
continue
207+
for bbox in sampler:
208+
assert bbox.area > 0
209+
210+
def test_integer_multiple(self) -> None:
211+
ds = CustomGeoDataset()
212+
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
213+
sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS)
214+
iterator = iter(sampler)
215+
assert len(sampler) == 1
216+
assert next(iterator) == BoundingBox(0, 10, 0, 10, 0, 10)
217+
218+
def test_float_multiple(self) -> None:
219+
ds = CustomGeoDataset()
220+
ds.index.insert(0, (0, 6, 0, 5, 0, 10))
221+
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
222+
iterator = iter(sampler)
223+
assert len(sampler) == 2
224+
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
225+
assert next(iterator) == BoundingBox(1, 6, 0, 5, 0, 10)
203226

204227
@pytest.mark.slow
205228
@pytest.mark.parametrize("num_workers", [0, 1, 2])

torchgeo/samplers/single.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""TorchGeo samplers."""
55

66
import abc
7+
import math
78
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union
89

910
import torch
@@ -146,7 +147,7 @@ def __len__(self) -> int:
146147

147148

148149
class GridGeoSampler(GeoSampler):
149-
"""Samples elements in a grid-like fashion.
150+
r"""Samples elements in a grid-like fashion.
150151
151152
This is particularly useful during evaluation when you want to make predictions for
152153
an entire region of interest. You want to minimize the amount of redundant
@@ -158,6 +159,21 @@ class GridGeoSampler(GeoSampler):
158159
The overlap between each chip (``chip_size - stride``) should be approximately equal
159160
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
160161
the CNN.
162+
163+
Note that the stride of the final set of chips in each row/column may be adjusted so
164+
that the entire :term:`tile` is sampled without exceeding the bounds of the dataset.
165+
166+
Let :math:`i` be the size of the input tile. Let :math:`k` be the requested size of
167+
the output patch. Let :math:`s` be the requested stride. Let :math:`o` be the number
168+
of output rows/columns sampled from each tile. :math:`o` can then be computed as:
169+
170+
.. math::
171+
172+
o = \left\lceil \frac{i - k}{s} \right\rceil + 1
173+
174+
This is almost identical to relationship 5 in
175+
https://doi.org/10.48550/arXiv.1603.07285. However, we use ceiling instead of floor
176+
because we want to include the final remaining chip.
161177
"""
162178

163179
def __init__(
@@ -200,17 +216,23 @@ def __init__(
200216
for hit in self.index.intersection(tuple(self.roi), objects=True):
201217
bounds = BoundingBox(*hit.bounds)
202218
if (
203-
bounds.maxx - bounds.minx > self.size[1]
204-
and bounds.maxy - bounds.miny > self.size[0]
219+
bounds.maxx - bounds.minx >= self.size[1]
220+
and bounds.maxy - bounds.miny >= self.size[0]
205221
):
206222
self.hits.append(hit)
207223

208224
self.length = 0
209225
for hit in self.hits:
210226
bounds = BoundingBox(*hit.bounds)
211227

212-
rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
213-
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
228+
rows = (
229+
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
230+
+ 1
231+
)
232+
cols = (
233+
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
234+
+ 1
235+
)
214236
self.length += rows * cols
215237

216238
def __iter__(self) -> Iterator[BoundingBox]:
@@ -223,8 +245,14 @@ def __iter__(self) -> Iterator[BoundingBox]:
223245
for hit in self.hits:
224246
bounds = BoundingBox(*hit.bounds)
225247

226-
rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
227-
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
248+
rows = (
249+
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
250+
+ 1
251+
)
252+
cols = (
253+
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
254+
+ 1
255+
)
228256

229257
mint = bounds.mint
230258
maxt = bounds.maxt
@@ -233,11 +261,17 @@ def __iter__(self) -> Iterator[BoundingBox]:
233261
for i in range(rows):
234262
miny = bounds.miny + i * self.stride[0]
235263
maxy = miny + self.size[0]
264+
if maxy > bounds.maxy:
265+
maxy = bounds.maxy
266+
miny = bounds.maxy - self.size[0]
236267

237268
# For each column...
238269
for j in range(cols):
239270
minx = bounds.minx + j * self.stride[1]
240271
maxx = minx + self.size[1]
272+
if maxx > bounds.maxx:
273+
maxx = bounds.maxx
274+
minx = bounds.maxx - self.size[1]
241275

242276
yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)
243277

0 commit comments

Comments
 (0)