@@ -8,7 +8,8 @@ import geotrellis.spark.partition.{PartitionerIndex, SpacePartitioner}
8
8
import geotrellis .spark .{MultibandTileLayerRDD , _ }
9
9
import geotrellis .util .GetComponent
10
10
import geotrellis .vector .{Extent , MultiPolygon , ProjectedExtent }
11
- import org .apache .spark .rdd .RDD
11
+ import org .apache .spark .Partitioner
12
+ import org .apache .spark .rdd .{CoGroupedRDD , RDD }
12
13
import org .slf4j .LoggerFactory
13
14
14
15
import java .time .ZonedDateTime
@@ -194,7 +195,25 @@ object DatacubeSupport {
194
195
ignoreKeysWithoutMask : Boolean = false ,
195
196
): RDD [(K , MultibandTile )] with Metadata [M ] = {
196
197
val joined = if (ignoreKeysWithoutMask) {
197
- val tmpRdd = SpatialJoin .join(datacube, mask).mapValues(v => (v._1, Option (v._2)))
198
+ // inner join, try to preserve partitioner
199
+ val tmpRdd : RDD [(K , (MultibandTile , Option [MultibandTile ]))] =
200
+ if (datacube.partitioner.isDefined && datacube.partitioner.get.isInstanceOf [SpacePartitioner [K ]]){
201
+ val part = datacube.partitioner.get.asInstanceOf [SpacePartitioner [K ]]
202
+ new CoGroupedRDD [K ](List (datacube, part(mask)), part)
203
+ .flatMapValues { case Array (l, r) =>
204
+ if (l.isEmpty) {
205
+ Seq .empty[(MultibandTile , Option [MultibandTile ])]
206
+ }
207
+ else if (r.isEmpty)
208
+ Seq .empty[(MultibandTile , Option [MultibandTile ])]
209
+ else
210
+ for (v <- l.iterator; w <- r.iterator) yield (v, Some (w))
211
+ }.asInstanceOf [RDD [(K , (MultibandTile , Option [MultibandTile ]))]]
212
+ }else {
213
+ SpatialJoin .join(datacube, mask).mapValues(v => (v._1, Option (v._2)))
214
+ }
215
+
216
+
198
217
ContextRDD (tmpRdd, datacube.metadata)
199
218
} else {
200
219
SpatialJoin .leftOuterJoin(datacube, mask)
@@ -210,7 +229,8 @@ object DatacubeSupport {
210
229
if (dataTile.bandCount == maskTile.bandCount) {
211
230
maskIndex = index
212
231
}
213
- tile.dualCombine(maskTile.band(maskIndex))((v1, v2) => if (v2 != 0 && isData(v1)) replacementInt else v1)((v1, v2) => if (v2 != 0.0 && isData(v1)) replacementDouble else v1)
232
+ // tile has to be 'mutable', for instant ConstantTile implements dualCombine, but not correctly converting celltype!!
233
+ tile.mutable.dualCombine(maskTile.band(maskIndex))((v1, v2) => if (v2 != 0 && isData(v1)) replacementInt else v1)((v1, v2) => if (v2 != 0.0 && isData(v1)) replacementDouble else v1)
214
234
})
215
235
216
236
} else {
@@ -236,16 +256,8 @@ object DatacubeSupport {
236
256
logger.debug(s " Spacetime mask is used to reduce input. " )
237
257
}
238
258
239
- val alignedMask : MultibandTileLayerRDD [SpaceTimeKey ] =
240
- if (spacetimeMask.metadata.crs.equals(metadata.crs) && spacetimeMask.metadata.layout.equals(metadata.layout)) {
241
- spacetimeMask
242
- }else {
243
- logger.debug(s " mask: automatically resampling mask to match datacube: ${spacetimeMask.metadata}" )
244
- spacetimeMask.reproject(metadata.crs,metadata.layout,16 ,rdd.partitioner)._2
245
- }
246
-
247
- // retain only tiles where there is at least one valid pixel (mask value == 0), others will be fully removed
248
- val filtered = alignedMask.withContext{_.filter(_._2.band(0 ).toArray().exists(pixel => pixel == 0 ))}
259
+ val partitioner = rdd.partitioner
260
+ val filtered = prepareMask(spacetimeMask, metadata, partitioner)
249
261
250
262
if (pixelwiseMasking) {
251
263
val spacetimeDataContextRDD = ContextRDD (rdd, metadata)
@@ -263,4 +275,23 @@ object DatacubeSupport {
263
275
rdd
264
276
}
265
277
}
278
+
279
+ def prepareMask (spacetimeMask : MultibandTileLayerRDD [SpaceTimeKey ], metadata : TileLayerMetadata [SpaceTimeKey ], partitioner : Option [Partitioner ]): ContextRDD [SpaceTimeKey , MultibandTile , TileLayerMetadata [SpaceTimeKey ]] = {
280
+ val alignedMask : MultibandTileLayerRDD [SpaceTimeKey ] =
281
+ if (spacetimeMask.metadata.crs.equals(metadata.crs) && spacetimeMask.metadata.layout.equals(metadata.layout)) {
282
+ spacetimeMask
283
+ } else {
284
+ logger.debug(s " mask: automatically resampling mask to match datacube: ${spacetimeMask.metadata}" )
285
+ spacetimeMask.reproject(metadata.crs, metadata.layout, 16 , partitioner)._2
286
+ }
287
+
288
+ val keyBounds = metadata.bounds.get
289
+ // retain only tiles where there is at least one valid pixel (mask value == 0), others will be fully removed
290
+ val filtered = alignedMask.withContext {
291
+ _.filter(t => {
292
+ keyBounds.includes(t._1) && t._2.band(0 ).toArray().exists(pixel => pixel == 0 )
293
+ })
294
+ }
295
+ filtered
296
+ }
266
297
}
0 commit comments