|
22 | 22 | package org.locationtech.rasterframes.expressions.localops
|
23 | 23 |
|
24 | 24 | import geotrellis.raster.Tile
|
25 |
| -import geotrellis.raster.resample.NearestNeighbor |
| 25 | +import geotrellis.raster.resample._ |
| 26 | +import geotrellis.raster.resample.{ResampleMethod ⇒ GTResampleMethod, Max ⇒ RMax, Min ⇒ RMin} |
26 | 27 | import org.apache.spark.sql.Column
|
27 | 28 | import org.apache.spark.sql.catalyst.InternalRow
|
| 29 | +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult |
| 30 | +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} |
28 | 31 | import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
29 |
| -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription} |
| 32 | +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression} |
30 | 33 | import org.apache.spark.sql.functions.lit
|
31 |
| -import org.locationtech.rasterframes.expressions.BinaryLocalRasterOp |
32 |
| -import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor |
| 34 | +import org.apache.spark.sql.rf.TileUDT |
| 35 | +import org.apache.spark.sql.types.{DataType, StringType} |
| 36 | +import org.apache.spark.unsafe.types.UTF8String |
| 37 | +import org.locationtech.rasterframes.util.ResampleMethod |
| 38 | +import org.locationtech.rasterframes.encoders.CatalystSerializer._ |
| 39 | +import org.locationtech.rasterframes.expressions.{fpTile, row} |
| 40 | +import org.locationtech.rasterframes.expressions.DynamicExtractors._ |
| 41 | + |
| 42 | + |
| 43 | +abstract class ResampleBase(left: Expression, right: Expression, method: Expression) |
| 44 | + extends TernaryExpression |
| 45 | + with CodegenFallback with Serializable { |
33 | 46 |
|
34 |
| -@ExpressionDescription( |
35 |
| - usage = "_FUNC_(tile, factor) - Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses nearest-neighbor value.", |
36 |
| - arguments = """ |
37 |
| - Arguments: |
38 |
| - * tile - tile |
39 |
| - * rhs - scalar or tile to match dimension""", |
40 |
| - examples = """ |
41 |
| - Examples: |
42 |
| - > SELECT _FUNC_(tile, 2.0); |
43 |
| - ... |
44 |
| - > SELECT _FUNC_(tile1, tile2); |
45 |
| - ...""" |
46 |
| -) |
47 |
| -case class Resample(left: Expression, right: Expression) extends BinaryLocalRasterOp |
48 |
| - with CodegenFallback { |
49 | 47 | override val nodeName: String = "rf_resample"
|
50 |
| - override protected def op(left: Tile, right: Tile): Tile = left.resample(right.cols, right.rows, NearestNeighbor) |
51 |
| - override protected def op(left: Tile, right: Double): Tile = left.resample((left.cols * right).toInt, |
52 |
| - (left.rows * right).toInt, NearestNeighbor) |
53 |
| - override protected def op(left: Tile, right: Int): Tile = op(left, right.toDouble) |
| 48 | + override def dataType: DataType = left.dataType |
| 49 | + override def children: Seq[Expression] = Seq(left, right, method) |
| 50 | + |
| 51 | + def targetFloatIfNeeded(t: Tile, method: GTResampleMethod): Tile = |
| 52 | + method match { |
| 53 | + case NearestNeighbor | Mode | RMax | RMin | Sum ⇒ t |
| 54 | + case _ ⇒ fpTile(t) |
| 55 | + } |
| 56 | + |
| 57 | + // These methods define the core algorithms to be used. |
| 58 | + def op(left: Tile, right: Tile, method: GTResampleMethod): Tile = |
| 59 | + op(left, right.cols, right.rows, method) |
| 60 | + |
| 61 | + def op(left: Tile, right: Double, method: GTResampleMethod): Tile = |
| 62 | + op(left, (left.cols * right).toInt, (left.rows * right).toInt, method) |
| 63 | + |
| 64 | + def op(tile: Tile, newCols: Int, newRows: Int, method: GTResampleMethod): Tile = |
| 65 | + targetFloatIfNeeded(tile, method).resample(newCols, newRows, method) |
| 66 | + |
| 67 | + override def checkInputDataTypes(): TypeCheckResult = { |
| 68 | + // copypasta from BinaryLocalRasterOp |
| 69 | + if (!tileExtractor.isDefinedAt(left.dataType)) { |
| 70 | + TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.") |
| 71 | + } |
| 72 | + else if (!tileOrNumberExtractor.isDefinedAt(right.dataType)) { |
| 73 | + TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a compatible type.") |
| 74 | + } else method.dataType match { |
| 75 | + case StringType ⇒ TypeCheckSuccess |
| 76 | + case _ ⇒ TypeCheckFailure(s"Cannot interpret value of type `${method.dataType.simpleString}` for resampling method; please provide a String method name.") |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { |
| 81 | + // more copypasta from BinaryLocalRasterOp |
| 82 | + implicit val tileSer = TileUDT.tileSerializer |
| 83 | + |
| 84 | + val (leftTile, leftCtx) = tileExtractor(left.dataType)(row(input1)) |
| 85 | + val methodString = input3.asInstanceOf[UTF8String].toString |
| 86 | + |
| 87 | + val resamplingMethod = methodString match { |
| 88 | + case ResampleMethod(mm) => mm |
| 89 | + case _ => throw new IllegalArgumentException("Unrecognized resampling method specified") |
| 90 | + } |
| 91 | + |
| 92 | + val result: Tile = tileOrNumberExtractor(right.dataType)(input2) match { |
| 93 | + // in this case we expect the left and right contexts to vary. no warnings raised. |
| 94 | + case TileArg(rightTile, _) ⇒ op(leftTile, rightTile, resamplingMethod) |
| 95 | + case DoubleArg(d) ⇒ op(leftTile, d, resamplingMethod) |
| 96 | + case IntegerArg(i) ⇒ op(leftTile, i.toDouble, resamplingMethod) |
| 97 | + } |
| 98 | + |
| 99 | + // reassemble the leftTile with its context. Note that this operation does not change Extent and CRS |
| 100 | + leftCtx match { |
| 101 | + case Some(ctx) ⇒ ctx.toProjectRasterTile(result).toInternalRow |
| 102 | + case None ⇒ result.toInternalRow |
| 103 | + } |
| 104 | + } |
54 | 105 |
|
55 | 106 | override def eval(input: InternalRow): Any = {
|
56 | 107 | if(input == null) null
|
57 | 108 | else {
|
58 | 109 | val l = left.eval(input)
|
59 | 110 | val r = right.eval(input)
|
60 |
| - if (l == null && r == null) null |
61 |
| - else if (l == null) r |
62 |
| - else if (r == null && tileExtractor.isDefinedAt(right.dataType)) l |
63 |
| - else if (r == null) null |
64 |
| - else nullSafeEval(l, r) |
| 111 | + val m = method.eval(input) |
| 112 | + if (m == null) null // no method, return null |
| 113 | + else if (l == null) null // no l tile, return null |
| 114 | + else if (r == null) l // no target tile or factor, return l without changin it |
| 115 | + else nullSafeEval(l, r, m) |
65 | 116 | }
|
66 | 117 | }
|
| 118 | + |
67 | 119 | }
|
68 |
| -object Resample{ |
69 |
| - def apply(left: Column, right: Column): Column = |
70 |
| - new Column(Resample(left.expr, right.expr)) |
| 120 | + |
| 121 | +@ExpressionDescription( |
| 122 | + usage = "_FUNC_(tile, factor, method_name) - Resample tile to different dimension based on scalar `factor` or a tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses resampling method named in the `method_name`." + |
| 123 | + "Methods average, mode, median, max, min, and sum aggregate over cells when downsampling", |
| 124 | + arguments = """ |
| 125 | +Arguments: |
| 126 | + * tile - tile |
| 127 | + * factor - scalar or tile to match dimension |
| 128 | + * method_name - one the following options: nearest_neighbor, bilinear, cubic_convolution, cubic_spline, lanczos, average, mode, median, max, min, sum |
| 129 | + This option can be CamelCase as well |
| 130 | +""", |
| 131 | + examples = """ |
| 132 | +Examples: |
| 133 | + > SELECT _FUNC_(tile, 0.2, median); |
| 134 | + ... |
| 135 | + > SELECT _FUNC_(tile1, tile2, lit("cubic_spline")); |
| 136 | + ...""" |
| 137 | +) |
| 138 | +case class Resample(left: Expression, factor: Expression, method: Expression) |
| 139 | + extends ResampleBase(left, factor, method) |
| 140 | + |
| 141 | +object Resample { |
| 142 | + def apply(left: Column, right: Column, methodName: String): Column = |
| 143 | + new Column(Resample(left.expr, right.expr, lit(methodName).expr)) |
| 144 | + |
| 145 | + def apply(left: Column, right: Column, method: Column): Column = |
| 146 | + new Column(Resample(left.expr, right.expr, method.expr)) |
| 147 | + |
| 148 | + def apply[N: Numeric](left: Column, right: N, method: String): Column = new Column(Resample(left.expr, lit(right).expr, lit(method).expr)) |
| 149 | + |
| 150 | + def apply[N: Numeric](left: Column, right: N, method: Column): Column = new Column(Resample(left.expr, lit(right).expr, method.expr)) |
| 151 | + |
| 152 | +} |
| 153 | + |
| 154 | +@ExpressionDescription( |
| 155 | + usage = "_FUNC_(tile, factor) - Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses nearest-neighbor value.", |
| 156 | + arguments = """ |
| 157 | + Arguments: |
| 158 | + * tile - tile |
| 159 | + * rhs - scalar or tile to match dimension""", |
| 160 | + examples = """ |
| 161 | + Examples: |
| 162 | + > SELECT _FUNC_(tile, 2.0); |
| 163 | + ... |
| 164 | + > SELECT _FUNC_(tile1, tile2); |
| 165 | + ...""") |
| 166 | +case class ResampleNearest(tile: Expression, target: Expression) |
| 167 | + extends ResampleBase(tile, target, Literal("nearest")) { |
| 168 | + override val nodeName: String = "rf_resample_nearest" |
| 169 | +} |
| 170 | +object ResampleNearest { |
| 171 | + def apply(tile: Column, target: Column): Column = |
| 172 | + new Column(ResampleNearest(tile.expr, target.expr)) |
71 | 173 |
|
72 | 174 | def apply[N: Numeric](tile: Column, value: N): Column =
|
73 |
| - new Column(Resample(tile.expr, lit(value).expr)) |
| 175 | + new Column(ResampleNearest(tile.expr, lit(value).expr)) |
74 | 176 | }
|
| 177 | + |
| 178 | + |
0 commit comments