From a5c651c7cc6c87923a0ce1e1b0edd8ce53936cb3 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Wed, 6 Nov 2019 13:41:56 -0500 Subject: [PATCH 01/24] Initial implementation of extracting bit values, e.g. for a quality band Signed-off-by: Jason T. Brown --- .../rasterframes/RasterFunctions.scala | 9 ++ .../rasterframes/expressions/package.scala | 2 + .../transformers/ExtractBits.scala | 94 +++++++++++++++++++ .../rasterframes/RasterFunctionsSpec.scala | 57 +++++++++++ 4 files changed, 162 insertions(+) create mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala diff --git a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala index 4301b188f..4d90203c0 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala @@ -353,6 +353,15 @@ trait RasterFunctions { def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] = Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue)) + def rf_local_extract_bits(tile: Column, startBit: Column, numBits: Column): Column = + ExtractBits(tile, startBit, numBits) + def rf_local_extract_bits(tile: Column, startBit: Column): Column = + rf_local_extract_bits(tile, startBit, lit(1)) + def rf_local_extract_bits(tile: Column, startBit: Int, numBits: Int): Column = + rf_local_extract_bits(tile, lit(startBit), lit(numBits)) + def rf_local_extract_bits(tile: Column, startBit: Int): Column = + rf_local_extract_bits(tile, lit(startBit)) + /** Create a tile where cells in the grid defined by cols, rows, and bounds are filled with the given value. */ def rf_rasterize(geometry: Column, bounds: Column, value: Column, cols: Int, rows: Int): TypedColumn[Any, Tile] = withTypedAlias("rf_rasterize", geometry)( diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index d289242bc..873fea004 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -137,5 +137,7 @@ package object expressions { registry.registerExpression[XZ2Indexer]("rf_spatial_index") registry.registerExpression[transformers.ReprojectGeometry]("st_reproject") + + registry.registerExpression[ExtractBits]("rf_local_extract_bits") } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala new file mode 100644 index 000000000..45e2db88c --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala @@ -0,0 +1,94 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.transformers + +import geotrellis.raster.Tile +import org.apache.spark.sql.{Column, TypedColumn} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression} +import org.apache.spark.sql.rf.TileUDT +import org.apache.spark.sql.types.DataType +import org.locationtech.rasterframes.encoders.CatalystSerializer._ +import org.locationtech.rasterframes.expressions.DynamicExtractors._ +import org.locationtech.rasterframes.expressions._ + +@ExpressionDescription( + usage = "_FUNC_(tile, start_bit, num_bits) - In each cell of `tile`, extract `num_bits` from the cell value, starting at `start_bit` from the left.", + arguments = """ + Arguments: + * tile - tile column to extract values + * start_bit - + * num_bits - + """, + examples = """ + Examples: + > SELECT _FUNC_(tile, lit(4), lit(2)) + ...""" +) +case class ExtractBits(child1: Expression, child2: Expression, child3: Expression) extends TernaryExpression with CodegenFallback with Serializable { + override val nodeName: String = "rf_local_extract_bits" + + override def children: Seq[Expression] = Seq(child1, child2, child3) + + override def dataType: DataType = child1.dataType + + override def checkInputDataTypes(): TypeCheckResult = + if(!tileExtractor.isDefinedAt(child1.dataType)) { + TypeCheckFailure(s"Input type '${child1.dataType}' does not conform to a raster type.") + } else if (!intArgExtractor.isDefinedAt(child2.dataType)) { + TypeCheckFailure(s"Input type '${child2.dataType}' isn't an integral type.") + } else if (!intArgExtractor.isDefinedAt(child3.dataType)) { + TypeCheckFailure(s"Input type '${child3.dataType}' isn't an integral type.") + } else TypeCheckSuccess + + + override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { + implicit val tileSer = TileUDT.tileSerializer + val (childTile, childCtx) = tileExtractor(child1.dataType)(row(input1)) + + val startBits = intArgExtractor(child2.dataType)(input2).value.toShort + + val numBits = intArgExtractor(child2.dataType)(input3).value.toShort + + childCtx match { + case Some(ctx) => ctx.toProjectRasterTile(op(childTile, startBits, numBits)).toInternalRow + case None => op(childTile, startBits, numBits).toInternalRow + } + } + + protected def op(tile: Tile, startBit: Short, numBits: Short): Tile = { + // this is the last `numBits` positions of "111111111111111" + val widthMask = Short.MaxValue >> (15 - numBits) + // map preserving the nodata structure + tile.mapIfSet(x ⇒ x >> startBit & widthMask) + } + +} + +object ExtractBits{ + def apply(tile: Column, startBit: Column, numBits: Column): Column = + new Column(ExtractBits(tile.expr, startBit.expr, numBits.expr)) + +} + diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index dd33d80f6..bfa01b9c4 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -1020,4 +1020,61 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { // lazy val invalid = df.select(rf_local_is_in($"t", lit("foobar"))).as[Tile].first() } + + it("should unpack QA bits"){ + // Sample of https://www.usgs.gov/media/images/landsat-8-quality-assessment-band-pixel-value-interpretations + val fill = 1 + val clear = 2720 + val cirrus = 6816 + val med_cloud = 2756 // with 1-2 bands saturated + val tiles = Seq(fill, clear, cirrus, med_cloud).map(v ⇒ + TestData.projectedRasterTile(3, 3, v, TestData.extent, TestData.crs, UShortCellType)) + + val df = tiles.toDF("tile") + .withColumn("val", rf_tile_min($"tile")) + + + val result = df + .withColumn("qa_fill", rf_local_extract_bits($"tile", lit(0))) + .withColumn("qa_sat", rf_local_extract_bits($"tile", lit(2), lit(2))) + .withColumn("qa_cloud", rf_local_extract_bits($"tile", lit(4))) + .withColumn("qa_cconf", rf_local_extract_bits($"tile", 5, 2)) + .withColumn("qa_snow", rf_local_extract_bits($"tile", lit(9), lit(2))) + .withColumn("qa_circonf", rf_local_extract_bits($"tile", 11, 2)) + + + def checker(colName: String, valFilter: Int, assertValue: Int): Unit = { + // print this so we can see what's happening if something wrong + println(s"${colName} should be ${assertValue} for qa val ${valFilter}") + result.filter($"val" === lit(valFilter)) + .select(col(colName)) + .as[ProjectedRasterTile] + .first() + .get(0, 0) should be (assertValue) + } + + checker("qa_fill", fill, 1) + checker("qa_cloud", fill, 0) + checker("qa_cconf", fill, 0) + checker("qa_sat", fill, 0) + checker("qa_snow", fill, 0) + checker("qa_circonf", fill, 0) + + checker("qa_fill", clear, 0) + checker("qa_cloud", clear, 0) + checker("qa_cconf", clear, 1) + + checker("qa_fill", med_cloud, 0) + checker("qa_cloud", med_cloud, 0) + checker("qa_cconf", med_cloud, 2) // L8 only tags hi conf in the cloud assessment + checker("qa_sat", med_cloud, 1) + + checker("qa_fill", cirrus, 0) + checker("qa_sat", cirrus, 0) + checker("qa_cloud", cirrus, 0) //low cloud conf + checker("qa_cconf", cirrus, 1) //low cloud conf + checker("qa_circonf", cirrus, 3) //high cirrus conf + + + } } From bebf00fef7e7fddeba769431867372f28308d020 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Mon, 11 Nov 2019 10:34:42 -0500 Subject: [PATCH 02/24] WIP: mask bits Signed-off-by: Jason T. Brown --- .../rasterframes/RasterFunctions.scala | 34 ++++++++++++++++--- .../rasterframes/RasterFunctionsSpec.scala | 9 +++-- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala index 4d90203c0..4034e9b58 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala @@ -353,14 +353,40 @@ trait RasterFunctions { def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] = Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue)) + def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): Column = + rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(bitPosition)) + + def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): Column = ??? + + def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): Column = ??? + + def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): Column = { + import org.apache.spark.sql.functions.array + val values = array(valuesToMask.map(lit):_*) + rf_mask_by_bits(dataTile, maskTile, lit(startBit), lit(numBits), values) + } + + /** Extract value from specified bits of the cells' underlying binary data. + * `startBit` is the first bit to consider, working from the right. It is zero indexed. + * `numBits` is the number of bits to take moving further to the left. */ def rf_local_extract_bits(tile: Column, startBit: Column, numBits: Column): Column = ExtractBits(tile, startBit, numBits) - def rf_local_extract_bits(tile: Column, startBit: Column): Column = - rf_local_extract_bits(tile, startBit, lit(1)) + + /** Extract value from specified bits of the cells' underlying binary data. + * `bitPosition` is bit to consider, working from the right. It is zero indexed. */ + def rf_local_extract_bits(tile: Column, bitPosition: Column): Column = + rf_local_extract_bits(tile, bitPosition, lit(1)) + + /** Extract value from specified bits of the cells' underlying binary data. + * `startBit` is the first bit to consider, working from the right. It is zero indexed. + * `numBits` is the number of bits to take, moving further to the left. */ def rf_local_extract_bits(tile: Column, startBit: Int, numBits: Int): Column = rf_local_extract_bits(tile, lit(startBit), lit(numBits)) - def rf_local_extract_bits(tile: Column, startBit: Int): Column = - rf_local_extract_bits(tile, lit(startBit)) + + /** Extract value from specified bits of the cells' underlying binary data. + * `bitPosition` is bit to consider, working from the right. It is zero indexed. */ + def rf_local_extract_bits(tile: Column, bitPosition: Int): Column = + rf_local_extract_bits(tile, lit(bitPosition)) /** Create a tile where cells in the grid defined by cols, rows, and bounds are filled with the given value. */ def rf_rasterize(geometry: Column, bounds: Column, value: Column, cols: Int, rows: Int): TypedColumn[Any, Tile] = diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index bfa01b9c4..4c1134828 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -1022,6 +1022,8 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { } it("should unpack QA bits"){ + checkDocs("rf_local_extract_bits") + // Sample of https://www.usgs.gov/media/images/landsat-8-quality-assessment-band-pixel-value-interpretations val fill = 1 val clear = 2720 @@ -1033,7 +1035,6 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { val df = tiles.toDF("tile") .withColumn("val", rf_tile_min($"tile")) - val result = df .withColumn("qa_fill", rf_local_extract_bits($"tile", lit(0))) .withColumn("qa_sat", rf_local_extract_bits($"tile", lit(2), lit(2))) @@ -1042,7 +1043,6 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { .withColumn("qa_snow", rf_local_extract_bits($"tile", lit(9), lit(2))) .withColumn("qa_circonf", rf_local_extract_bits($"tile", 11, 2)) - def checker(colName: String, valFilter: Int, assertValue: Int): Unit = { // print this so we can see what's happening if something wrong println(s"${colName} should be ${assertValue} for qa val ${valFilter}") @@ -1060,6 +1060,11 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { checker("qa_snow", fill, 0) checker("qa_circonf", fill, 0) + // trivial bits selection (numBits=0) and SQL + df.filter($"val" === lit(fill)) + .selectExpr("rf_local_extract_bits(tile, 0, 0) AS t") + .select(rf_exists($"t")).as[Boolean].first() should be (false) + checker("qa_fill", clear, 0) checker("qa_cloud", clear, 0) checker("qa_cconf", clear, 1) From b45543b58858cd67150a36e6d32b7d68650a1021 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Mon, 11 Nov 2019 16:57:03 -0500 Subject: [PATCH 03/24] Update docs and scala function api Signed-off-by: Jason T. Brown --- .../rasterframes/RasterFunctions.scala | 15 +++++-- .../transformers/ExtractBits.scala | 19 ++++----- docs/src/main/paradox/reference.md | 39 +++++++++++++++++-- docs/src/main/paradox/release-notes.md | 4 ++ 4 files changed, 61 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala index f8ed154d5..7142e8b6c 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala @@ -337,13 +337,22 @@ trait RasterFunctions { def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] = Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue)) + /** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */ def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): Column = - rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(bitPosition)) + rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(valueToMask)) - def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): Column = ??? + /** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */ + def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): Column = + rf_mask_by_bits(dataTile, maskTile, bitPosition, lit(1), valueToMask) + + /** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */ + def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): Column = { + val bitMask = rf_local_extract_bits(maskTile, startBit, numBits) + rf_mask_by_values(dataTile, bitMask, valuesToMask) + } - def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): Column = ??? + /** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */ def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): Column = { import org.apache.spark.sql.functions.array val values = array(valuesToMask.map(lit):_*) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala index 45e2db88c..3219fc9b3 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala @@ -67,9 +67,9 @@ case class ExtractBits(child1: Expression, child2: Expression, child3: Expressio implicit val tileSer = TileUDT.tileSerializer val (childTile, childCtx) = tileExtractor(child1.dataType)(row(input1)) - val startBits = intArgExtractor(child2.dataType)(input2).value.toShort + val startBits = intArgExtractor(child2.dataType)(input2).value - val numBits = intArgExtractor(child2.dataType)(input3).value.toShort + val numBits = intArgExtractor(child2.dataType)(input3).value childCtx match { case Some(ctx) => ctx.toProjectRasterTile(op(childTile, startBits, numBits)).toInternalRow @@ -77,12 +77,7 @@ case class ExtractBits(child1: Expression, child2: Expression, child3: Expressio } } - protected def op(tile: Tile, startBit: Short, numBits: Short): Tile = { - // this is the last `numBits` positions of "111111111111111" - val widthMask = Short.MaxValue >> (15 - numBits) - // map preserving the nodata structure - tile.mapIfSet(x ⇒ x >> startBit & widthMask) - } + protected def op(tile: Tile, startBit: Int, numBits: Int): Tile = ExtractBits(tile, startBit, numBits) } @@ -90,5 +85,11 @@ object ExtractBits{ def apply(tile: Column, startBit: Column, numBits: Column): Column = new Column(ExtractBits(tile.expr, startBit.expr, numBits.expr)) -} + def apply(tile: Tile, startBit: Int, numBits: Int): Tile = { + // this is the last `numBits` positions of "111111111111111" + val widthMask = Int.MaxValue >> (63 - numBits) + // map preserving the nodata structure + tile.mapIfSet(x ⇒ x >> startBit & widthMask) + } +} diff --git a/docs/src/main/paradox/reference.md b/docs/src/main/paradox/reference.md index 1121bbd36..aef211035 100644 --- a/docs/src/main/paradox/reference.md +++ b/docs/src/main/paradox/reference.md @@ -167,7 +167,6 @@ Tile rf_make_ones_tile(Int tile_columns, Int tile_rows, [CellType cell_type]) Tile rf_make_ones_tile(Int tile_columns, Int tile_rows, [String cell_type_name]) ``` - Create a `tile` of shape `tile_columns` by `tile_rows` full of ones, with the optional cell type; default is float64. See @ref:[this discussion](nodata-handling.md#cell-types) on cell types for info on the `cell_type` argument. All arguments are literal values and not column expressions. ### rf_make_constant_tile @@ -175,7 +174,6 @@ Create a `tile` of shape `tile_columns` by `tile_rows` full of ones, with the op Tile rf_make_constant_tile(Numeric constant, Int tile_columns, Int tile_rows, [CellType cell_type]) Tile rf_make_constant_tile(Numeric constant, Int tile_columns, Int tile_rows, [String cell_type_name]) - Create a `tile` of shape `tile_columns` by `tile_rows` full of `constant`, with the optional cell type; default is float64. See @ref:[this discussion](nodata-handling.md#cell-types) on cell types for info on the `cell_type` argument. All arguments are literal values and not column expressions. @@ -183,7 +181,6 @@ Create a `tile` of shape `tile_columns` by `tile_rows` full of `constant`, with Tile rf_rasterize(Geometry geom, Geometry tile_bounds, Int value, Int tile_columns, Int tile_rows) - Convert a vector Geometry `geom` into a Tile representation. The `value` will be "burned-in" to the returned `tile` where the `geom` intersects the `tile_bounds`. Returned `tile` will have shape `tile_columns` by `tile_rows`. Values outside the `geom` will be assigned a NoData value. Returned `tile` has cell type `int32`, note that `value` is of type Int. Parameters `tile_columns` and `tile_rows` are literals, not column expressions. The others are column expressions. @@ -236,10 +233,35 @@ Generate a `tile` with the values from `data_tile`, with NoData in cells where t ### rf_mask_by_values Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, Array mask_values) - Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, seq mask_values) + Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, list mask_values) Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is in the `mask_values` Array or list. `mask_values` can be a [`pyspark.sql.ArrayType`][Array] or a `list`. +### rf_mask_by_bit + + Tile rf_mask_by_bits(Tile tile, Tile mask_tile, Int bit_position, Bool mask_value) + +Applies a mask using bit values in the `mask_tile`. Working from the right, the bit at `bit_position` is @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are equal to the `mask_value`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. + +This is a single-bit version of @ref:[`rf_mask_by_bits`](reference.md#rf-mask-by-bits). + +### rf_mask_by_bits + + Tile rf_mask_by_bits(Tile tile, Tile mask_tile, Int start_bit, Int num_bits, Array mask_values) + Tile rf_mask_by_bits(Tile tile, Tile mask_tile, Int start_bit, Int num_bits, list mask_values) + +Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. + +This function is not available in the SQL API. The below is equivalent: + +```sql +SELECT rf_mask_by_values( + tile, + rf_extract_bits(mask_tile, start_bit, num_bits), + mask_values + ), +``` + ### rf_inverse_mask Tile rf_inverse_mask(Tile tile, Tile mask) @@ -406,6 +428,15 @@ Returns a `tile` column containing the element-wise inequality of `tile1` and `r Returns a `tile` column with cell values of 1 where the `tile` cell value is in the provided array or list. The `array` is a Spark SQL [Array][Array]. A python `list` of numeric values can also be passed. +### rf_local_extract_bits + + Tile rf_local_extract_bits(Tile tile, Int start_bit, Int num_bits) + Tile rf_local_extract_bits(Tile tile, Int start_bit) + +Extract value from specified bits of the cells' underlying binary data. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are extracted from cell values of the `tile`. The `start_bit` is zero indexed. If `num_bits` is not provided, a single bit is extracted. + +A common use case for this function is covered by @ref:[`rf_mask_by_bits`](reference.md#rf-mask-by-bits). + ### rf_round Tile rf_round(Tile tile) diff --git a/docs/src/main/paradox/release-notes.md b/docs/src/main/paradox/release-notes.md index c181d55da..f8007ab18 100644 --- a/docs/src/main/paradox/release-notes.md +++ b/docs/src/main/paradox/release-notes.md @@ -2,6 +2,10 @@ ## 0.8.x +### 0.8.5 + +* Add `rf_mask_by_bit`, `rf_mask_by_bits` and `rf_local_extract_bits` to deal with bit packed quality masks. Updated the masking documentation to demonstrate the use of these functions. + ### 0.8.4 * Upgraded to Spark 2.4.4 From d150dbc8f021b83f3d1dfdd2b7c62cf88877fec8 Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Wed, 13 Nov 2019 10:46:59 -0500 Subject: [PATCH 04/24] Added Z2Indexer updating API to differentiate it with XZ2Indexer. --- .../rasterframes/RasterFunctions.scala | 32 ++- .../expressions/DynamicExtractors.scala | 47 +++- .../rasterframes/expressions/package.scala | 3 +- .../expressions/transformers/XZ2Indexer.scala | 42 +--- .../expressions/transformers/Z2Indexer.scala | 95 ++++++++ .../jts/ReprojectionTransformer.scala | 7 +- .../rasterframes/RasterFunctionsSpec.scala | 4 +- .../expressions/SFCIndexerSpec.scala | 219 ++++++++++++++++++ .../expressions/XZ2IndexerSpec.scala | 124 ---------- .../raster/RasterSourceDataSource.scala | 20 +- .../raster/RasterSourceRelation.scala | 29 ++- .../datasource/raster/package.scala | 4 +- .../raster/RasterSourceDataSourceSpec.scala | 23 +- docs/src/main/paradox/reference.md | 8 +- docs/src/main/paradox/release-notes.md | 8 +- .../python/pyrasterframes/rasterfunctions.py | 16 +- .../src/main/python/tests/VectorTypesTests.py | 17 +- 17 files changed, 495 insertions(+), 203 deletions(-) create mode 100644 core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala create mode 100644 core/src/test/scala/org/locationtech/rasterframes/expressions/SFCIndexerSpec.scala delete mode 100644 core/src/test/scala/org/locationtech/rasterframes/expressions/XZ2IndexerSpec.scala diff --git a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala index 94dcef333..1f1c51de9 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala @@ -53,6 +53,9 @@ trait RasterFunctions { /** Query the number of (cols, rows) in a Tile. */ def rf_dimensions(col: Column): TypedColumn[Any, TileDimensions] = GetDimensions(col) + /** Extracts the CRS from a RasterSource or ProjectedRasterTile */ + def rf_crs(col: Column): TypedColumn[Any, CRS] = GetCRS(col) + /** Extracts the bounding box of a geometry as an Extent */ def st_extent(col: Column): TypedColumn[Any, Extent] = GeometryToExtent(col) @@ -61,22 +64,39 @@ trait RasterFunctions { /** Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ - def rf_spatial_index(targetExtent: Column, targetCRS: Column, indexResolution: Short) = XZ2Indexer(targetExtent, targetCRS, indexResolution) + def rf_xz2_index(targetExtent: Column, targetCRS: Column, indexResolution: Short) = XZ2Indexer(targetExtent, targetCRS, indexResolution) /** Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ - def rf_spatial_index(targetExtent: Column, targetCRS: Column) = XZ2Indexer(targetExtent, targetCRS, 18: Short) + def rf_xz2_index(targetExtent: Column, targetCRS: Column) = XZ2Indexer(targetExtent, targetCRS, 18: Short) /** Constructs a XZ2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ - def rf_spatial_index(targetExtent: Column, indexResolution: Short) = XZ2Indexer(targetExtent, indexResolution) + def rf_xz2_index(targetExtent: Column, indexResolution: Short) = XZ2Indexer(targetExtent, indexResolution) /** Constructs a XZ2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ - def rf_spatial_index(targetExtent: Column) = XZ2Indexer(targetExtent, 18: Short) + def rf_xz2_index(targetExtent: Column) = XZ2Indexer(targetExtent, 18: Short) - /** Extracts the CRS from a RasterSource or ProjectedRasterTile */ - def rf_crs(col: Column): TypedColumn[Any, CRS] = GetCRS(col) + /** Constructs a Z2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS + * First the native extent is extracted or computed, and then center is used as the indexing location. + * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ + def rf_z2_index(targetExtent: Column, targetCRS: Column, indexResolution: Short) = Z2Indexer(targetExtent, targetCRS, indexResolution) + + /** Constructs a Z2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS + * First the native extent is extracted or computed, and then center is used as the indexing location. + * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ + def rf_z2_index(targetExtent: Column, targetCRS: Column) = Z2Indexer(targetExtent, targetCRS, 31: Short) + + /** Constructs a Z2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource + * First the native extent is extracted or computed, and then center is used as the indexing location. + * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ + def rf_z2_index(targetExtent: Column, indexResolution: Short) = Z2Indexer(targetExtent, indexResolution) + + /** Constructs a Z2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource + * First the native extent is extracted or computed, and then center is used as the indexing location. + * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ + def rf_z2_index(targetExtent: Column) = Z2Indexer(targetExtent, 31: Short) /** Extracts the tile from a ProjectedRasterTile, or passes through a Tile. */ def rf_tile(col: Column): TypedColumn[Any, Tile] = RealizeTile(col) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala index e72f158aa..4bb78faea 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.jts.JTSTypes import org.apache.spark.sql.rf.{RasterSourceUDT, TileUDT} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.locationtech.jts.geom.Envelope +import org.locationtech.jts.geom.{Envelope, Point} import org.locationtech.rasterframes.encoders.CatalystSerializer._ import org.locationtech.rasterframes.model.{LazyCRS, TileContext} import org.locationtech.rasterframes.ref.{ProjectedRasterLike, RasterRef, RasterSource} @@ -69,13 +69,13 @@ object DynamicExtractors { } /** Partial function for pulling a ProjectedRasterLike an input row. */ - lazy val projectedRasterLikeExtractor: PartialFunction[DataType, InternalRow ⇒ ProjectedRasterLike] = { + lazy val projectedRasterLikeExtractor: PartialFunction[DataType, Any ⇒ ProjectedRasterLike] = { case _: RasterSourceUDT ⇒ - (row: InternalRow) => row.to[RasterSource](RasterSourceUDT.rasterSourceSerializer) + (input: Any) => input.asInstanceOf[InternalRow].to[RasterSource](RasterSourceUDT.rasterSourceSerializer) case t if t.conformsTo[ProjectedRasterTile] => - (row: InternalRow) => row.to[ProjectedRasterTile] + (input: Any) => input.asInstanceOf[InternalRow].to[ProjectedRasterTile] case t if t.conformsTo[RasterRef] => - (row: InternalRow) => row.to[RasterRef] + (input: Any) => input.asInstanceOf[InternalRow].to[RasterRef] } /** Partial function for pulling a CellGrid from an input row. */ @@ -97,13 +97,36 @@ object DynamicExtractors { (v: Any) => v.asInstanceOf[InternalRow].to[CRS] } - lazy val extentLikeExtractor: PartialFunction[DataType, Any ⇒ Extent] = { - case t if org.apache.spark.sql.rf.WithTypeConformity(t).conformsTo(JTSTypes.GeometryTypeInstance) => - (input: Any) => JTSTypes.GeometryTypeInstance.deserialize(input).getEnvelopeInternal - case t if t.conformsTo[Extent] => - (input: Any) => input.asInstanceOf[InternalRow].to[Extent] - case t if t.conformsTo[Envelope] => - (input: Any) => Extent(input.asInstanceOf[InternalRow].to[Envelope]) + lazy val extentExtractor: PartialFunction[DataType, Any ⇒ Extent] = { + val base: PartialFunction[DataType, Any ⇒ Extent]= { + case t if org.apache.spark.sql.rf.WithTypeConformity(t).conformsTo(JTSTypes.GeometryTypeInstance) => + (input: Any) => Extent(JTSTypes.GeometryTypeInstance.deserialize(input).getEnvelopeInternal) + case t if t.conformsTo[Extent] => + (input: Any) => input.asInstanceOf[InternalRow].to[Extent] + case t if t.conformsTo[Envelope] => + (input: Any) => Extent(input.asInstanceOf[InternalRow].to[Envelope]) + } + + val fromPRL = projectedRasterLikeExtractor.andThen(_.andThen(_.extent)) + fromPRL orElse base + } + + lazy val envelopeExtractor: PartialFunction[DataType, Any => Envelope] = { + val base = PartialFunction[DataType, Any => Envelope] { + case t if org.apache.spark.sql.rf.WithTypeConformity(t).conformsTo(JTSTypes.GeometryTypeInstance) => + (input: Any) => JTSTypes.GeometryTypeInstance.deserialize(input).getEnvelopeInternal + case t if t.conformsTo[Extent] => + (input: Any) => input.asInstanceOf[InternalRow].to[Extent].jtsEnvelope + case t if t.conformsTo[Envelope] => + (input: Any) => input.asInstanceOf[InternalRow].to[Envelope] + } + + val fromPRL = projectedRasterLikeExtractor.andThen(_.andThen(_.extent.jtsEnvelope)) + fromPRL orElse base + } + + lazy val centroidExtractor: PartialFunction[DataType, Any ⇒ Point] = { + extentExtractor.andThen(_.andThen(_.center.jtsGeom)) } sealed trait TileOrNumberArg diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index d2163f72b..26bef04e3 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -135,7 +135,8 @@ package object expressions { registry.registerExpression[RenderPNG.RenderCompositePNG]("rf_render_png") registry.registerExpression[RGBComposite]("rf_rgb_composite") - registry.registerExpression[XZ2Indexer]("rf_spatial_index") + registry.registerExpression[XZ2Indexer]("rf_xz2_index") + registry.registerExpression[Z2Indexer]("rf_z2_index") registry.registerExpression[transformers.ReprojectGeometry]("st_reproject") } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala index 7acbb3277..9d4ea57da 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala @@ -22,24 +22,16 @@ package org.locationtech.rasterframes.expressions.transformers import geotrellis.proj4.LatLng -import geotrellis.vector.Extent import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription} -import org.apache.spark.sql.jts.JTSTypes -import org.apache.spark.sql.rf.RasterSourceUDT import org.apache.spark.sql.types.{DataType, LongType} -import org.apache.spark.sql.{Column, TypedColumn, rf} +import org.apache.spark.sql.{Column, TypedColumn} import org.locationtech.geomesa.curve.XZ2SFC -import org.locationtech.jts.geom.{Envelope, Geometry} -import org.locationtech.rasterframes.encoders.CatalystSerializer._ import org.locationtech.rasterframes.expressions.DynamicExtractors._ import org.locationtech.rasterframes.expressions.accessors.GetCRS -import org.locationtech.rasterframes.expressions.row import org.locationtech.rasterframes.jts.ReprojectionTransformer -import org.locationtech.rasterframes.ref.{RasterRef, RasterSource} -import org.locationtech.rasterframes.tiles.ProjectedRasterTile /** * Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource @@ -63,12 +55,12 @@ import org.locationtech.rasterframes.tiles.ProjectedRasterTile case class XZ2Indexer(left: Expression, right: Expression, indexResolution: Short) extends BinaryExpression with CodegenFallback { - override def nodeName: String = "rf_spatial_index" + override def nodeName: String = "rf_xz2_index" override def dataType: DataType = LongType override def checkInputDataTypes(): TypeCheckResult = { - if (!extentLikeExtractor.orElse(projectedRasterLikeExtractor).isDefinedAt(left.dataType)) + if (!envelopeExtractor.isDefinedAt(left.dataType)) TypeCheckFailure(s"Input type '${left.dataType}' does not look like something with an Extent or something with one.") else if(!crsExtractor.isDefinedAt(right.dataType)) TypeCheckFailure(s"Input type '${right.dataType}' does not look like something with a CRS.") @@ -79,36 +71,14 @@ case class XZ2Indexer(left: Expression, right: Expression, indexResolution: Shor override protected def nullSafeEval(leftInput: Any, rightInput: Any): Any = { val crs = crsExtractor(right.dataType)(rightInput) - - val coords = left.dataType match { - case t if rf.WithTypeConformity(t).conformsTo(JTSTypes.GeometryTypeInstance) => - JTSTypes.GeometryTypeInstance.deserialize(leftInput) - case t if t.conformsTo[Extent] => - row(leftInput).to[Extent] - case t if t.conformsTo[Envelope] => - row(leftInput).to[Envelope] - case _: RasterSourceUDT ⇒ - row(leftInput).to[RasterSource](RasterSourceUDT.rasterSourceSerializer).extent - case t if t.conformsTo[ProjectedRasterTile] => - row(leftInput).to[ProjectedRasterTile].extent - case t if t.conformsTo[RasterRef] => - row(leftInput).to[RasterRef].extent - } + val coords = envelopeExtractor(left.dataType)(leftInput) // If no transformation is needed then just normalize to an Envelope - val env = if(crs == LatLng) coords match { - case e: Extent => e.jtsEnvelope - case g: Geometry => g.getEnvelopeInternal - case e: Envelope => e - } + val env = if(crs == LatLng) coords // Otherwise convert to geometry, transform, and get envelope else { val trans = new ReprojectionTransformer(crs, LatLng) - coords match { - case e: Extent => trans(e).getEnvelopeInternal - case g: Geometry => trans(g).getEnvelopeInternal - case e: Envelope => trans(e).getEnvelopeInternal - } + trans(coords).getEnvelopeInternal } val index = indexer.index( diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala new file mode 100644 index 000000000..42d2a120d --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala @@ -0,0 +1,95 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.transformers + +import geotrellis.proj4.LatLng +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription} +import org.apache.spark.sql.types.{DataType, LongType} +import org.apache.spark.sql.{Column, TypedColumn} +import org.locationtech.geomesa.curve.Z2SFC +import org.locationtech.rasterframes.expressions.DynamicExtractors._ +import org.locationtech.rasterframes.expressions.accessors.GetCRS +import org.locationtech.rasterframes.jts.ReprojectionTransformer + +/** + * Constructs a Z2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource. First the + * native extent is extracted or computed, and then center is used as the indexing location. + * This function is useful for [range partitioning](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=registerjava#pyspark.sql.DataFrame.repartitionByRange). + * Also see: https://www.geomesa.org/documentation/user/datastores/index_overview.html + * + * @param left geometry-like column + * @param right CRS column + * @param indexResolution resolution level of the space filling curve - + * i.e. how many times the space will be recursively quartered + * 1-31 is typical. + */ +@ExpressionDescription( + usage = "_FUNC_(geom, crs) - Constructs a Z2 index in WGS84/EPSG:4326", + arguments = """ + Arguments: + * geom - Geometry or item with Geometry: Extent, ProjectedRasterTile, or RasterSource + * crs - the native CRS of the `geom` column +""" +) +case class Z2Indexer(left: Expression, right: Expression, indexResolution: Short) + extends BinaryExpression with CodegenFallback { + + override def nodeName: String = "rf_z2_index" + + override def dataType: DataType = LongType + + override def checkInputDataTypes(): TypeCheckResult = { + if (!centroidExtractor.isDefinedAt(left.dataType)) + TypeCheckFailure(s"Input type '${left.dataType}' does not look like something with a point.") + else if(!crsExtractor.isDefinedAt(right.dataType)) + TypeCheckFailure(s"Input type '${right.dataType}' does not look like something with a CRS.") + else TypeCheckSuccess + } + + private lazy val indexer = new Z2SFC(indexResolution) + + override protected def nullSafeEval(leftInput: Any, rightInput: Any): Any = { + val crs = crsExtractor(right.dataType)(rightInput) + val coord = centroidExtractor(left.dataType)(leftInput) + + val pt = if(crs == LatLng) coord + else { + val trans = new ReprojectionTransformer(crs, LatLng) + trans(coord) + } + + indexer.index(pt.getX, pt.getY, lenient = false).z + } +} + +object Z2Indexer { + import org.locationtech.rasterframes.encoders.SparkBasicEncoders.longEnc + def apply(targetExtent: Column, targetCRS: Column, indexResolution: Short): TypedColumn[Any, Long] = + new Column(new Z2Indexer(targetExtent.expr, targetCRS.expr, indexResolution)).as[Long] + def apply(targetExtent: Column, targetCRS: Column): TypedColumn[Any, Long] = + new Column(new Z2Indexer(targetExtent.expr, targetCRS.expr, 31)).as[Long] + def apply(targetExtent: Column, indexResolution: Short = 31): TypedColumn[Any, Long] = + new Column(new Z2Indexer(targetExtent.expr, GetCRS(targetExtent.expr), indexResolution)).as[Long] +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/jts/ReprojectionTransformer.scala b/core/src/main/scala/org/locationtech/rasterframes/jts/ReprojectionTransformer.scala index 54b45c034..0ceacc965 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/jts/ReprojectionTransformer.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/jts/ReprojectionTransformer.scala @@ -21,10 +21,11 @@ package org.locationtech.rasterframes.jts -import org.locationtech.jts.geom.{CoordinateSequence, Envelope, Geometry, GeometryFactory} +import org.locationtech.jts.geom.{CoordinateSequence, Envelope, Geometry, GeometryFactory, Point} import org.locationtech.jts.geom.util.GeometryTransformer import geotrellis.proj4.CRS import geotrellis.vector.Extent +import org.locationtech.jts.algorithm.Centroid /** * JTS Geometry reprojection transformation routine. @@ -38,6 +39,10 @@ class ReprojectionTransformer(src: CRS, dst: CRS) extends GeometryTransformer { def apply(geometry: Geometry): Geometry = transform(geometry) def apply(extent: Extent): Geometry = transform(extent.jtsGeom) def apply(env: Envelope): Geometry = transform(gf.toGeometry(env)) + def apply(pt: Point): Point = { + val t = transform(pt) + gf.createPoint(Centroid.getCentroid(t)) + } override def transformCoordinates(coords: CoordinateSequence, parent: Geometry): CoordinateSequence = { val fact = parent.getFactory diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index 2ef126790..cf1e52cb9 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -28,7 +28,7 @@ import geotrellis.raster._ import geotrellis.raster.render.ColorRamps import geotrellis.raster.testkit.RasterMatchers import javax.imageio.ImageIO -import org.apache.spark.sql.{Column, Encoders, TypedColumn} +import org.apache.spark.sql.Encoders import org.apache.spark.sql.functions._ import org.locationtech.rasterframes.expressions.accessors.ExtractTile import org.locationtech.rasterframes.model.TileDimensions @@ -701,7 +701,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { val withMasked = withMask.withColumn("masked", rf_inverse_mask_by_value($"tile", $"mask", mask_value)) .withColumn("masked2", rf_mask_by_value($"tile", $"mask", lit(mask_value), true)) - withMasked.explain(true) + val result = withMasked.agg(rf_agg_no_data_cells($"tile") < rf_agg_no_data_cells($"masked")).as[Boolean] result.first() should be(true) diff --git a/core/src/test/scala/org/locationtech/rasterframes/expressions/SFCIndexerSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/expressions/SFCIndexerSpec.scala new file mode 100644 index 000000000..13ea363e3 --- /dev/null +++ b/core/src/test/scala/org/locationtech/rasterframes/expressions/SFCIndexerSpec.scala @@ -0,0 +1,219 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions +import geotrellis.proj4.{CRS, LatLng, WebMercator} +import geotrellis.raster.CellType +import geotrellis.vector.Extent +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.jts.JTSTypes +import org.locationtech.geomesa.curve.{XZ2SFC, Z2SFC} +import org.locationtech.rasterframes.{TestEnvironment, _} +import org.locationtech.rasterframes.encoders.serialized_literal +import org.locationtech.rasterframes.ref.{InMemoryRasterSource, RasterSource} +import org.locationtech.rasterframes.tiles.ProjectedRasterTile +import org.scalatest.Inspectors + +class SFCIndexerSpec extends TestEnvironment with Inspectors { + val testExtents = Seq( + Extent(10, 10, 12, 12), + Extent(9.0, 9.0, 13.0, 13.0), + Extent(-180.0, -90.0, 180.0, 90.0), + Extent(0.0, 0.0, 180.0, 90.0), + Extent(0.0, 0.0, 20.0, 20.0), + Extent(11.0, 11.0, 13.0, 13.0), + Extent(9.0, 9.0, 11.0, 11.0), + Extent(10.5, 10.5, 11.5, 11.5), + Extent(11.0, 11.0, 11.0, 11.0), + Extent(-180.0, -90.0, 8.0, 8.0), + Extent(0.0, 0.0, 8.0, 8.0), + Extent(9.0, 9.0, 9.5, 9.5), + Extent(20.0, 20.0, 180.0, 90.0) + ) + def reproject(dst: CRS)(e: Extent): Extent = e.reproject(LatLng, dst) + + val xzsfc = XZ2SFC(18) + val zsfc = new Z2SFC(31) + val xzExpected = testExtents.map(e => xzsfc.index(e.xmin, e.ymin, e.xmax, e.ymax)) + val zExpected = (crs: CRS) => testExtents.map(reproject(crs)).map(e => { + val p = e.center.reproject(crs, LatLng) + zsfc.index(p.x, p.y).z + }) + + describe("Centroid extraction") { + import org.locationtech.rasterframes.encoders.CatalystSerializer._ + val expected = testExtents.map(_.center.jtsGeom) + it("should extract from Extent") { + val dt = schemaOf[Extent] + val extractor = DynamicExtractors.centroidExtractor(dt) + val inputs = testExtents.map(_.toInternalRow).map(extractor) + forEvery(inputs.zip(expected)) { case (i, e) => + i should be(e) + } + } + it("should extract from Geometry") { + val dt = JTSTypes.GeometryTypeInstance + val extractor = DynamicExtractors.centroidExtractor(dt) + val inputs = testExtents.map(_.jtsGeom).map(dt.serialize).map(extractor) + forEvery(inputs.zip(expected)) { case (i, e) => + i should be(e) + } + } + it("should extract from ProjectedRasterTile") { + val crs: CRS = WebMercator + val tile = TestData.randomTile(2, 2, CellType.fromName("uint8")) + val dt = schemaOf[ProjectedRasterTile] + val extractor = DynamicExtractors.centroidExtractor(dt) + val inputs = testExtents.map(ProjectedRasterTile(tile, _, crs)) + .map(_.toInternalRow).map(extractor) + forEvery(inputs.zip(expected)) { case (i, e) => + i should be(e) + } + } + it("should extract from RasterSource") { + val crs: CRS = WebMercator + val tile = TestData.randomTile(2, 2, CellType.fromName("uint8")) + val dt = RasterSourceType + val extractor = DynamicExtractors.centroidExtractor(dt) + val inputs = testExtents.map(InMemoryRasterSource(tile, _, crs): RasterSource) + .map(RasterSourceType.serialize).map(extractor) + forEvery(inputs.zip(expected)) { case (i, e) => + i should be(e) + } + } + } + + describe("Spatial index generation") { + import spark.implicits._ + it("should be SQL registered with docs") { + checkDocs("rf_xz2_index") + checkDocs("rf_z2_index") + } + it("should create index from Extent") { + val crs: CRS = WebMercator + val df = testExtents.map(reproject(crs)).map(Tuple1.apply).toDF("extent") + + withClue("XZ2") { + val indexes = df.select(rf_xz2_index($"extent", serialized_literal(crs))).collect() + forEvery(indexes.zip(xzExpected)) { case (i, e) => + i should be(e) + } + } + withClue("Z2") { + val indexes = df.select(rf_z2_index($"extent", serialized_literal(crs))).collect() + forEvery(indexes.zip(zExpected(crs))) { case (i, e) => + i should be(e) + } + indexes.distinct.length should be (indexes.length) + } + } + it("should create index from Geometry") { + val crs: CRS = LatLng + val df = testExtents.map(_.jtsGeom).map(Tuple1.apply).toDF("extent") + withClue("XZ2") { + val indexes = df.select(rf_xz2_index($"extent", serialized_literal(crs))).collect() + forEvery(indexes.zip(xzExpected)) { case (i, e) => + i should be(e) + } + } + withClue("Z2") { + val indexes = df.select(rf_z2_index($"extent", serialized_literal(crs))).collect() + forEvery(indexes.zip(zExpected(crs))) { case (i, e) => + i should be(e) + } + } + } + it("should create index from ProjectedRasterTile") { + val crs: CRS = WebMercator + val tile = TestData.randomTile(2, 2, CellType.fromName("uint8")) + val prts = testExtents.map(reproject(crs)).map(ProjectedRasterTile(tile, _, crs)) + + implicit val enc = Encoders.tuple(ProjectedRasterTile.prtEncoder, Encoders.scalaInt) + // The `id` here is to deal with Spark auto projecting single columns dataframes and needing to provide an encoder + val df = prts.zipWithIndex.toDF("proj_raster", "id") + withClue("XZ2") { + val indexes = df.select(rf_xz2_index($"proj_raster")).collect() + forEvery(indexes.zip(xzExpected)) { case (i, e) => + i should be(e) + } + } + withClue("Z2") { + val indexes = df.select(rf_z2_index($"proj_raster")).collect() + forEvery(indexes.zip(zExpected(crs))) { case (i, e) => + i should be(e) + } + } + } + it("should create index from RasterSource") { + val crs: CRS = WebMercator + val tile = TestData.randomTile(2, 2, CellType.fromName("uint8")) + val srcs = testExtents.map(reproject(crs)).map(InMemoryRasterSource(tile, _, crs): RasterSource).toDF("src") + withClue("XZ2") { + val indexes = srcs.select(rf_xz2_index($"src")).collect() + forEvery(indexes.zip(xzExpected)) { case (i, e) => + i should be(e) + } + } + withClue("Z2") { + val indexes = srcs.select(rf_z2_index($"src")).collect() + forEvery(indexes.zip(zExpected(crs))) { case (i, e) => + i should be(e) + } + } + } + it("should work when CRS is LatLng") { + val df = testExtents.map(Tuple1.apply).toDF("extent") + val crs: CRS = LatLng + withClue("XZ2") { + val indexes = df.select(rf_xz2_index($"extent", serialized_literal(crs))).collect() + forEvery(indexes.zip(xzExpected)) { case (i, e) => + i should be(e) + } + } + withClue("Z2") { + val indexes = df.select(rf_z2_index($"extent", serialized_literal(crs))).collect() + forEvery(indexes.zip(zExpected(crs))) { case (i, e) => + i should be(e) + } + } + } + it("should support custom resolution") { + val df = testExtents.map(Tuple1.apply).toDF("extent") + val crs: CRS = LatLng + withClue("XZ2") { + val sfc = XZ2SFC(3) + val expected = testExtents.map(e => sfc.index(e.xmin, e.ymin, e.xmax, e.ymax)) + val indexes = df.select(rf_xz2_index($"extent", serialized_literal(crs), 3)).collect() + forEvery(indexes.zip(expected)) { case (i, e) => + i should be(e) + } + } + withClue("Z2") { + val sfc = new Z2SFC(3) + val expected = testExtents.map(e => sfc.index(e.center.x, e.center.y).z) + val indexes = df.select(rf_z2_index($"extent", serialized_literal(crs), 3)).collect() + forEvery(indexes.zip(expected)) { case (i, e) => + i should be(e) + } + } + } + } +} diff --git a/core/src/test/scala/org/locationtech/rasterframes/expressions/XZ2IndexerSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/expressions/XZ2IndexerSpec.scala deleted file mode 100644 index fc62949dd..000000000 --- a/core/src/test/scala/org/locationtech/rasterframes/expressions/XZ2IndexerSpec.scala +++ /dev/null @@ -1,124 +0,0 @@ -/* - * This software is licensed under the Apache 2 license, quoted below. - * - * Copyright 2019 Astraea, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * [http://www.apache.org/licenses/LICENSE-2.0] - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - * - * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.locationtech.rasterframes.expressions -import geotrellis.proj4.{CRS, LatLng, WebMercator} -import geotrellis.raster.CellType -import geotrellis.vector.Extent -import org.apache.spark.sql.Encoders -import org.locationtech.geomesa.curve.XZ2SFC -import org.locationtech.rasterframes.{TestEnvironment, _} -import org.locationtech.rasterframes.encoders.serialized_literal -import org.locationtech.rasterframes.ref.{InMemoryRasterSource, RasterSource} -import org.locationtech.rasterframes.tiles.ProjectedRasterTile -import org.scalatest.Inspectors - -class XZ2IndexerSpec extends TestEnvironment with Inspectors { - val testExtents = Seq( - Extent(10, 10, 12, 12), - Extent(9.0, 9.0, 13.0, 13.0), - Extent(-180.0, -90.0, 180.0, 90.0), - Extent(0.0, 0.0, 180.0, 90.0), - Extent(0.0, 0.0, 20.0, 20.0), - Extent(11.0, 11.0, 13.0, 13.0), - Extent(9.0, 9.0, 11.0, 11.0), - Extent(10.5, 10.5, 11.5, 11.5), - Extent(11.0, 11.0, 11.0, 11.0), - Extent(-180.0, -90.0, 8.0, 8.0), - Extent(0.0, 0.0, 8.0, 8.0), - Extent(9.0, 9.0, 9.5, 9.5), - Extent(20.0, 20.0, 180.0, 90.0) - ) - val sfc = XZ2SFC(18) - val expected = testExtents.map(e => sfc.index(e.xmin, e.ymin, e.xmax, e.ymax)) - - def reproject(dst: CRS)(e: Extent): Extent = e.reproject(LatLng, dst) - - describe("Spatial index generation") { - import spark.implicits._ - it("should be SQL registered with docs") { - checkDocs("rf_spatial_index") - } - it("should create index from Extent") { - val crs: CRS = WebMercator - val df = testExtents.map(reproject(crs)).map(Tuple1.apply).toDF("extent") - val indexes = df.select(rf_spatial_index($"extent", serialized_literal(crs))).collect() - - forEvery(indexes.zip(expected)) { case (i, e) => - i should be (e) - } - } - it("should create index from Geometry") { - val crs: CRS = LatLng - val df = testExtents.map(_.jtsGeom).map(Tuple1.apply).toDF("extent") - val indexes = df.select(rf_spatial_index($"extent", serialized_literal(crs))).collect() - - forEvery(indexes.zip(expected)) { case (i, e) => - i should be (e) - } - } - it("should create index from ProjectedRasterTile") { - val crs: CRS = WebMercator - val tile = TestData.randomTile(2, 2, CellType.fromName("uint8")) - val prts = testExtents.map(reproject(crs)).map(ProjectedRasterTile(tile, _, crs)) - - implicit val enc = Encoders.tuple(ProjectedRasterTile.prtEncoder, Encoders.scalaInt) - // The `id` here is to deal with Spark auto projecting single columns dataframes and needing to provide an encoder - val df = prts.zipWithIndex.toDF("proj_raster", "id") - val indexes = df.select(rf_spatial_index($"proj_raster")).collect() - - forEvery(indexes.zip(expected)) { case (i, e) => - i should be (e) - } - } - it("should create index from RasterSource") { - val crs: CRS = WebMercator - val tile = TestData.randomTile(2, 2, CellType.fromName("uint8")) - val srcs = testExtents.map(reproject(crs)).map(InMemoryRasterSource(tile, _, crs): RasterSource).toDF("src") - val indexes = srcs.select(rf_spatial_index($"src")).collect() - - forEvery(indexes.zip(expected)) { case (i, e) => - i should be (e) - } - - } - it("should work when CRS is LatLng") { - val df = testExtents.map(Tuple1.apply).toDF("extent") - val crs: CRS = LatLng - val indexes = df.select(rf_spatial_index($"extent", serialized_literal(crs))).collect() - - forEvery(indexes.zip(expected)) { case (i, e) => - i should be (e) - } - } - it("should support custom resolution") { - val sfc = XZ2SFC(3) - val expected = testExtents.map(e => sfc.index(e.xmin, e.ymin, e.xmax, e.ymax)) - val df = testExtents.map(Tuple1.apply).toDF("extent") - val crs: CRS = LatLng - val indexes = df.select(rf_spatial_index($"extent", serialized_literal(crs), 3)).collect() - - forEvery(indexes.zip(expected)) { case (i, e) => - i should be (e) - } - } - } -} diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala index 03b2fd0da..be8d11b31 100644 --- a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala @@ -32,6 +32,8 @@ import org.locationtech.rasterframes.model.TileDimensions import shapeless.tag import shapeless.tag.@@ +import scala.util.Try + class RasterSourceDataSource extends DataSourceRegister with RelationProvider { import RasterSourceDataSource._ override def shortName(): String = SHORT_NAME @@ -39,9 +41,10 @@ class RasterSourceDataSource extends DataSourceRegister with RelationProvider { val bands = parameters.bandIndexes val tiling = parameters.tileDims.orElse(Some(NOMINAL_TILE_DIMS)) val lazyTiles = parameters.lazyTiles + val spatialIndex = parameters.spatialIndex val spec = parameters.pathSpec val catRef = spec.fold(_.registerAsTable(sqlContext), identity) - RasterSourceRelation(sqlContext, catRef, bands, tiling, lazyTiles) + RasterSourceRelation(sqlContext, catRef, bands, tiling, lazyTiles, spatialIndex) } } @@ -55,6 +58,7 @@ object RasterSourceDataSource { final val CATALOG_TABLE_COLS_PARAM = "catalog_col_names" final val CATALOG_CSV_PARAM = "catalog_csv" final val LAZY_TILES_PARAM = "lazy_tiles" + final val SPATIAL_INDEX_PARTITIONS_PARAM = "spatial_index_partitions" final val DEFAULT_COLUMN_NAME = PROJECTED_RASTER_COLUMN.columnName @@ -115,10 +119,12 @@ object RasterSourceDataSource { .map(tokenize(_).map(_.toInt)) .getOrElse(Seq(0)) - def lazyTiles: Boolean = parameters .get(LAZY_TILES_PARAM).forall(_.toBoolean) + def spatialIndex: Option[Int] = parameters + .get(SPATIAL_INDEX_PARTITIONS_PARAM).flatMap(p => Try(p.toInt).toOption) + def catalog: Option[RasterSourceCatalog] = { val paths = ( parameters @@ -158,6 +164,16 @@ object RasterSourceDataSource { } } + /** Mixin for adding extension methods on DataFrameReader for RasterSourceDataSource-like readers. */ + trait SpatialIndexOptionsSupport[ReaderTag] { + type _TaggedReader = DataFrameReader @@ ReaderTag + val reader: _TaggedReader + def withSpatialIndex(numPartitions: Int = -1): _TaggedReader = + tag[ReaderTag][DataFrameReader]( + reader.option(RasterSourceDataSource.SPATIAL_INDEX_PARTITIONS_PARAM, numPartitions) + ) + } + /** Mixin for adding extension methods on DataFrameReader for RasterSourceDataSource-like readers. */ trait CatalogReaderOptionsSupport[ReaderTag] { type TaggedReader = DataFrameReader @@ ReaderTag diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceRelation.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceRelation.scala index 9b381d3a6..706bb9680 100644 --- a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceRelation.scala +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceRelation.scala @@ -24,13 +24,14 @@ package org.locationtech.rasterframes.datasource.raster import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{BaseRelation, TableScan} -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.locationtech.rasterframes.datasource.raster.RasterSourceDataSource.RasterSourceCatalogRef import org.locationtech.rasterframes.encoders.CatalystSerializer._ +import org.locationtech.rasterframes.expressions.accessors.{GetCRS, GetExtent} import org.locationtech.rasterframes.expressions.generators.{RasterSourceToRasterRefs, RasterSourceToTiles} import org.locationtech.rasterframes.expressions.generators.RasterSourceToRasterRefs.bandNames -import org.locationtech.rasterframes.expressions.transformers.{RasterRefToTile, URIToRasterSource} +import org.locationtech.rasterframes.expressions.transformers.{RasterRefToTile, URIToRasterSource, XZ2Indexer} import org.locationtech.rasterframes.model.TileDimensions import org.locationtech.rasterframes.tiles.ProjectedRasterTile @@ -40,13 +41,16 @@ import org.locationtech.rasterframes.tiles.ProjectedRasterTile * @param catalogTable Specification of raster path sources * @param bandIndexes band indexes to fetch * @param subtileDims how big to tile/subdivide rasters info + * @param lazyTiles if true, creates a lazy representation of tile instead of fetching contents. + * @param spatialIndexPartitions if not None, adds a spatial index. If Option value < 1, uses the value of `numShufflePartitions` in SparkContext. */ case class RasterSourceRelation( sqlContext: SQLContext, catalogTable: RasterSourceCatalogRef, bandIndexes: Seq[Int], subtileDims: Option[TileDimensions], - lazyTiles: Boolean + lazyTiles: Boolean, + spatialIndexPartitions: Option[Int] ) extends BaseRelation with TableScan { lazy val inputColNames = catalogTable.bandColumnNames @@ -69,6 +73,9 @@ case class RasterSourceRelation( catalog.schema.fields.filter(f => !catalogTable.bandColumnNames.contains(f.name)) } + lazy val indexCols: Seq[StructField] = + if (spatialIndexPartitions.isDefined) Seq(StructField("spatial_index", LongType, false)) else Seq.empty + protected def defaultNumPartitions: Int = sqlContext.sparkSession.sessionState.conf.numShufflePartitions @@ -81,17 +88,17 @@ case class RasterSourceRelation( tileColName <- tileColNames } yield StructField(tileColName, tileSchema, true) - StructType(paths ++ tiles ++ extraCols) + StructType(paths ++ tiles ++ extraCols ++ indexCols) } override def buildScan(): RDD[Row] = { import sqlContext.implicits._ - + val numParts = spatialIndexPartitions.filter(_ > 0).getOrElse(defaultNumPartitions) // The general transformation is: // input -> path -> src -> ref -> tile // Each step is broken down for readability val inputs: DataFrame = sqlContext.table(catalogTable.tableName) - .repartition(defaultNumPartitions) + .repartition(numParts) // Basically renames the input columns to have the '_path' suffix val pathsAliasing = for { @@ -135,6 +142,14 @@ case class RasterSourceRelation( withPaths .select((paths :+ tiles) ++ extras: _*) } - df.rdd + + if (spatialIndexPartitions.isDefined) { + val sample = col(tileColNames.head) + val indexed = df + .withColumn("spatial_index", XZ2Indexer(GetExtent(sample), GetCRS(sample))) + .repartitionByRange(numParts,$"spatial_index") + indexed.rdd + } + else df.rdd } } diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/package.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/package.scala index 48c0e9642..b70d782af 100644 --- a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/package.scala +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/package.scala @@ -22,6 +22,7 @@ package org.locationtech.rasterframes.datasource import org.apache.spark.sql.DataFrameReader +import org.locationtech.rasterframes.datasource.raster.RasterSourceDataSource._ import shapeless.tag import shapeless.tag.@@ package object raster { @@ -38,5 +39,6 @@ package object raster { /** Adds option methods relevant to RasterSourceDataSource. */ implicit class RasterSourceDataFrameReaderHasOptions(val reader: RasterSourceDataFrameReader) - extends RasterSourceDataSource.CatalogReaderOptionsSupport[RasterSourceDataFrameReaderTag] + extends CatalogReaderOptionsSupport[RasterSourceDataFrameReaderTag] with + SpatialIndexOptionsSupport[RasterSourceDataFrameReaderTag] } diff --git a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala index 7bd46ce37..345579567 100644 --- a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala +++ b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala @@ -21,7 +21,8 @@ package org.locationtech.rasterframes.datasource.raster import geotrellis.raster.Tile -import org.apache.spark.sql.functions.{lit, udf, round} +import org.apache.spark.sql.functions.{lit, round, udf} +import org.apache.spark.sql.types.LongType import org.locationtech.rasterframes.{TestEnvironment, _} import org.locationtech.rasterframes.datasource.raster.RasterSourceDataSource.{RasterSourceCatalog, _} import org.locationtech.rasterframes.model.TileDimensions @@ -78,6 +79,12 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData { val p = Map(CATALOG_CSV_PARAM -> csv) p.pathSpec should be (Left(RasterSourceCatalog(csv))) } + + it("should parse spatial index state") { + Map(SPATIAL_INDEX_PARTITIONS_PARAM -> "12").spatialIndex should be (Some(12)) + Map(SPATIAL_INDEX_PARTITIONS_PARAM -> "-1").spatialIndex should be (Some(-1)) + Map("foo"-> "bar").spatialIndex should be (None) + } } describe("RasterSource as relation reading") { @@ -326,4 +333,18 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData { res.length should be (1) } } + + describe("attaching a spatial index") { + val l8_df = spark.read.raster + .withSpatialIndex() + .load(remoteL8.toASCIIString) + .cache() + + it("should add index") { + l8_df.columns should contain("spatial_index") + l8_df.schema("spatial_index").dataType should be(LongType) + l8_df.show(false) + l8_df.select($"spatial_index").distinct().count() should be(l8_df.count()) + } + } } diff --git a/docs/src/main/paradox/reference.md b/docs/src/main/paradox/reference.md index 1121bbd36..717336336 100644 --- a/docs/src/main/paradox/reference.md +++ b/docs/src/main/paradox/reference.md @@ -67,11 +67,11 @@ See also GeoMesa [st_envelope](https://www.geomesa.org/documentation/user/spark/ Convert an extent to a Geometry. The extent likely comes from @ref:[`st_extent`](reference.md#st-extent) or @ref:[`rf_extent`](reference.md#rf-extent). -### rf_spatial_index +### rf_xz2_index - Long rf_spatial_index(Geometry geom, CRS crs) - Long rf_spatial_index(Extent extent, CRS crs) - Long rf_spatial_index(ProjectedRasterTile proj_raster, CRS crs) + Long rf_xz2_index(Geometry geom, CRS crs) + Long rf_xz2_index(Extent extent, CRS crs) + Long rf_xz2_index(ProjectedRasterTile proj_raster, CRS crs) Constructs a XZ2 index in WGS84/EPSG:4326 from either a Geometry, Extent, ProjectedRasterTile and its CRS. This function is useful for [range partitioning](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=registerjava#pyspark.sql.DataFrame.repartitionByRange). diff --git a/docs/src/main/paradox/release-notes.md b/docs/src/main/paradox/release-notes.md index c181d55da..7a6152873 100644 --- a/docs/src/main/paradox/release-notes.md +++ b/docs/src/main/paradox/release-notes.md @@ -2,6 +2,12 @@ ## 0.8.x +### 0.8.5 + +* _Breaking_: `rf_spatial_index` renamed `rf_xz2_index` to differentiate between XZ2 and Z2 variants. +* Added `rf_z2_index` for constructing a Z2 index on types with bounds. +* Added `withSpatialIndex` to RasterSourceDataSource to pre-partition tiles based on tile extents mapped to a Z2 space-filling curve + ### 0.8.4 * Upgraded to Spark 2.4.4 @@ -17,7 +23,7 @@ * Updated to GeoTrellis 2.3.3 and Proj4j 1.1.0. * Fixed issues with `LazyLogger` and shading assemblies ([#293](https://github.com/locationtech/rasterframes/issues/293)) * Updated `rf_crs` to accept string columns containing CRS specifications. ([#366](https://github.com/locationtech/rasterframes/issues/366)) -* Added `rf_spatial_index` function. ([#368](https://github.com/locationtech/rasterframes/issues/368)) +* Added `rf_xz2_index` function. ([#368](https://github.com/locationtech/rasterframes/issues/368)) * _Breaking_ (potentially): removed `pyrasterframes.create_spark_session` in lieu of `pyrasterframes.utils.create_rf_spark_session` ### 0.8.2 diff --git a/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py b/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py index ae5977f51..2cf25cbd7 100644 --- a/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py +++ b/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py @@ -622,11 +622,23 @@ def rf_geometry(proj_raster_col): return _apply_column_function('rf_geometry', proj_raster_col) -def rf_spatial_index(geom_col, crs_col=None, index_resolution = 18): +def rf_xz2_index(geom_col, crs_col=None, index_resolution = 18): """Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS. For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html """ - jfcn = RFContext.active().lookup('rf_spatial_index') + jfcn = RFContext.active().lookup('rf_xz2_index') + + if crs_col is not None: + return Column(jfcn(_to_java_column(geom_col), _to_java_column(crs_col), index_resolution)) + else: + return Column(jfcn(_to_java_column(geom_col), index_resolution)) + +def rf_z2_index(geom_col, crs_col=None, index_resolution = 18): + """Constructs a Z2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS. + First the native extent is extracted or computed, and then center is used as the indexing location. + For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html """ + + jfcn = RFContext.active().lookup('rf_z2_index') if crs_col is not None: return Column(jfcn(_to_java_column(geom_col), _to_java_column(crs_col), index_resolution)) diff --git a/pyrasterframes/src/main/python/tests/VectorTypesTests.py b/pyrasterframes/src/main/python/tests/VectorTypesTests.py index 70e94af72..c23c07552 100644 --- a/pyrasterframes/src/main/python/tests/VectorTypesTests.py +++ b/pyrasterframes/src/main/python/tests/VectorTypesTests.py @@ -157,15 +157,26 @@ def test_geojson(self): geo.show() self.assertEqual(geo.select('geometry').count(), 8) - def test_spatial_index(self): - df = self.df.select(rf_spatial_index(self.df.poly_geom, rf_crs(lit("EPSG:4326"))).alias('index')) + def test_xz2_index(self): + df = self.df.select(rf_xz2_index(self.df.poly_geom, rf_crs(lit("EPSG:4326"))).alias('index')) expected = {22858201775, 38132946267, 38166922588, 38180072113} indexes = {x[0] for x in df.collect()} self.assertSetEqual(indexes, expected) # Custom resolution - df = self.df.select(rf_spatial_index(self.df.poly_geom, rf_crs(lit("EPSG:4326")), 3).alias('index')) + df = self.df.select(rf_xz2_index(self.df.poly_geom, rf_crs(lit("EPSG:4326")), 3).alias('index')) expected = {21, 36} indexes = {x[0] for x in df.collect()} self.assertSetEqual(indexes, expected) + def test_z2_index(self): + df = self.df.select(rf_z2_index(self.df.poly_geom, rf_crs(lit("EPSG:4326"))).alias('index')) + expected = {22858201775, 38132946267, 38166922588, 38180072113} + indexes = {x[0] for x in df.collect()} + self.assertSetEqual(indexes, expected) + + # Custom resolution + df = self.df.select(rf_z2_index(self.df.poly_geom, rf_crs(lit("EPSG:4326")), 3).alias('index')) + expected = {21, 36} + indexes = {x[0] for x in df.collect()} + self.assertSetEqual(indexes, expected) \ No newline at end of file From f9084f88b270ad51f8109090c2359cfb011b2b5b Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Wed, 13 Nov 2019 11:08:12 -0500 Subject: [PATCH 05/24] Updated python tests against spatial index functions. --- build.sbt | 3 +-- .../rasterframes/RasterFunctions.scala | 6 ++--- docs/src/main/paradox/reference.md | 10 ++++++++- docs/src/main/paradox/release-notes.md | 2 +- .../main/python/tests/RasterFunctionsTests.py | 12 ---------- .../src/main/python/tests/VectorTypesTests.py | 22 ++++++++++++++----- 6 files changed, 31 insertions(+), 24 deletions(-) diff --git a/build.sbt b/build.sbt index f941ea060..4fcb29806 100644 --- a/build.sbt +++ b/build.sbt @@ -34,7 +34,7 @@ lazy val root = project .enablePlugins(RFReleasePlugin) .settings( publish / skip := true, - clean := clean.dependsOn(`rf-notebook`/clean).value + clean := clean.dependsOn(`rf-notebook`/clean, docs/clean).value ) lazy val `rf-notebook` = project @@ -76,7 +76,6 @@ lazy val core = project buildInfoObject := "RFBuildInfo", buildInfoOptions := Seq( BuildInfoOption.ToMap, - BuildInfoOption.BuildTime, BuildInfoOption.ToJson ) ) diff --git a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala index 1f1c51de9..3f238ea3e 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala @@ -79,7 +79,7 @@ trait RasterFunctions { def rf_xz2_index(targetExtent: Column) = XZ2Indexer(targetExtent, 18: Short) /** Constructs a Z2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS - * First the native extent is extracted or computed, and then center is used as the indexing location. + * First the native extent is extracted or computed, and then center is used as the indexing location. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_z2_index(targetExtent: Column, targetCRS: Column, indexResolution: Short) = Z2Indexer(targetExtent, targetCRS, indexResolution) @@ -89,12 +89,12 @@ trait RasterFunctions { def rf_z2_index(targetExtent: Column, targetCRS: Column) = Z2Indexer(targetExtent, targetCRS, 31: Short) /** Constructs a Z2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource - * First the native extent is extracted or computed, and then center is used as the indexing location. + * First the native extent is extracted or computed, and then center is used as the indexing location. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_z2_index(targetExtent: Column, indexResolution: Short) = Z2Indexer(targetExtent, indexResolution) /** Constructs a Z2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource - * First the native extent is extracted or computed, and then center is used as the indexing location. + * First the native extent is extracted or computed, and then center is used as the indexing location. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_z2_index(targetExtent: Column) = Z2Indexer(targetExtent, 31: Short) diff --git a/docs/src/main/paradox/reference.md b/docs/src/main/paradox/reference.md index 717336336..11a02d3ce 100644 --- a/docs/src/main/paradox/reference.md +++ b/docs/src/main/paradox/reference.md @@ -71,10 +71,18 @@ Convert an extent to a Geometry. The extent likely comes from @ref:[`st_extent`] Long rf_xz2_index(Geometry geom, CRS crs) Long rf_xz2_index(Extent extent, CRS crs) - Long rf_xz2_index(ProjectedRasterTile proj_raster, CRS crs) + Long rf_xz2_index(ProjectedRasterTile proj_raster) Constructs a XZ2 index in WGS84/EPSG:4326 from either a Geometry, Extent, ProjectedRasterTile and its CRS. This function is useful for [range partitioning](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=registerjava#pyspark.sql.DataFrame.repartitionByRange). +### rf_z2_index + + Long rf_z2_index(Geometry geom, CRS crs) + Long rf_z2_index(Extent extent, CRS crs) + Long rf_z2_index(ProjectedRasterTile proj_raster) + +Constructs a Z2 index in WGS84/EPSG:4326 from either a Geometry, Extent, ProjectedRasterTile and its CRS. First the native extent is extracted or computed, and then center is used as the indexing location. This function is useful for [range partitioning](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=registerjava#pyspark.sql.DataFrame.repartitionByRange). + ## Tile Metadata and Mutation Functions to access and change the particulars of a `tile`: its shape and the data type of its cells. See section on @ref:["NoData" handling](nodata-handling.md) for additional discussion of cell types. diff --git a/docs/src/main/paradox/release-notes.md b/docs/src/main/paradox/release-notes.md index 7a6152873..a907c868a 100644 --- a/docs/src/main/paradox/release-notes.md +++ b/docs/src/main/paradox/release-notes.md @@ -4,8 +4,8 @@ ### 0.8.5 -* _Breaking_: `rf_spatial_index` renamed `rf_xz2_index` to differentiate between XZ2 and Z2 variants. * Added `rf_z2_index` for constructing a Z2 index on types with bounds. +* _Breaking_: `rf_spatial_index` renamed `rf_xz2_index` to differentiate between XZ2 and Z2 variants. * Added `withSpatialIndex` to RasterSourceDataSource to pre-partition tiles based on tile extents mapped to a Z2 space-filling curve ### 0.8.4 diff --git a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py index 6c82f867c..a8e4aa3a7 100644 --- a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py +++ b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py @@ -409,15 +409,3 @@ def test_rf_local_is_in(self): self.assertEqual(result['in_list'].cells.sum(), 2, "Tile value {} should contain two 1s as: [[1, 0, 1],[0, 0, 0]]" .format(result['in_list'].cells)) - - def test_rf_spatial_index(self): - from pyspark.sql.functions import min as F_min - result_one_arg = self.df.select(rf_spatial_index('tile').alias('ix')) \ - .agg(F_min('ix')).first()[0] - print(result_one_arg) - - result_two_arg = self.df.select(rf_spatial_index(rf_extent('tile'), rf_crs('tile')).alias('ix')) \ - .agg(F_min('ix')).first()[0] - - self.assertEqual(result_two_arg, result_one_arg) - self.assertEqual(result_one_arg, 55179438768) # this is a bit more fragile but less important diff --git a/pyrasterframes/src/main/python/tests/VectorTypesTests.py b/pyrasterframes/src/main/python/tests/VectorTypesTests.py index c23c07552..7c85d3119 100644 --- a/pyrasterframes/src/main/python/tests/VectorTypesTests.py +++ b/pyrasterframes/src/main/python/tests/VectorTypesTests.py @@ -132,7 +132,7 @@ def buffer(g, d): cols = 194 # from dims of tile rows = 250 # from dims of tile with_raster = with_poly.withColumn('rasterized', - rf_rasterize('poly', 'geometry', lit(16), lit(cols), lit(rows))) + rf_rasterize('poly', 'geometry', lit(16), lit(cols), lit(rows))) result = with_raster.select(rf_tile_sum(rf_local_equal_int(with_raster.rasterized, 16)), rf_tile_sum(with_raster.rasterized)) # @@ -158,11 +158,22 @@ def test_geojson(self): self.assertEqual(geo.select('geometry').count(), 8) def test_xz2_index(self): + from pyspark.sql.functions import min as F_min df = self.df.select(rf_xz2_index(self.df.poly_geom, rf_crs(lit("EPSG:4326"))).alias('index')) expected = {22858201775, 38132946267, 38166922588, 38180072113} indexes = {x[0] for x in df.collect()} self.assertSetEqual(indexes, expected) + # Test against proj_raster (has CRS and Extent embedded). + result_one_arg = self.df.select(rf_xz2_index('tile').alias('ix')) \ + .agg(F_min('ix')).first()[0] + + result_two_arg = self.df.select(rf_xz2_index(rf_extent('tile'), rf_crs('tile')).alias('ix')) \ + .agg(F_min('ix')).first()[0] + + self.assertEqual(result_two_arg, result_one_arg) + self.assertEqual(result_one_arg, 55179438768) # this is a bit more fragile but less important + # Custom resolution df = self.df.select(rf_xz2_index(self.df.poly_geom, rf_crs(lit("EPSG:4326")), 3).alias('index')) expected = {21, 36} @@ -171,12 +182,13 @@ def test_xz2_index(self): def test_z2_index(self): df = self.df.select(rf_z2_index(self.df.poly_geom, rf_crs(lit("EPSG:4326"))).alias('index')) - expected = {22858201775, 38132946267, 38166922588, 38180072113} + + expected = {28596898472, 28625192874, 28635062506, 28599712232} indexes = {x[0] for x in df.collect()} self.assertSetEqual(indexes, expected) # Custom resolution - df = self.df.select(rf_z2_index(self.df.poly_geom, rf_crs(lit("EPSG:4326")), 3).alias('index')) - expected = {21, 36} + df = self.df.select(rf_z2_index(self.df.poly_geom, rf_crs(lit("EPSG:4326")), 6).alias('index')) + expected = {1704, 1706} indexes = {x[0] for x in df.collect()} - self.assertSetEqual(indexes, expected) \ No newline at end of file + self.assertSetEqual(indexes, expected) From b2f8d3cf674f4920e2f8d43f8610a1dec4c70fda Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Wed, 13 Nov 2019 13:08:05 -0500 Subject: [PATCH 06/24] Add failing unit test for mask by value on 0 Signed-off-by: Jason T. Brown --- .../rasterframes/RasterFunctionsSpec.scala | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index 816771950..955daab04 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -686,6 +686,25 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { checkDocs("rf_mask_by_value") } + it("should mask by value for value 0.") { + // maskingTile has -4, ND, and -15 values. Expect mask by value with 0 to not change the + val df = Seq((byteArrayTile, maskingTile)).toDF("data", "mask") + + // data tile is all data cells + df.select(rf_no_data_cells($"data")).first() should be (byteArrayTile.size) + + // mask by value against 15 should set 3 cell locations to Nodata + df.withColumn("mbv", rf_mask_by_value($"data", $"mask", 15)) + .select(rf_data_cells($"mbv")) + .first() should be (byteArrayTile.size - 3) + + // breaks with issue https://github.com/locationtech/rasterframes/issues/416 + val result = df.withColumn("mbv", rf_mask_by_value($"data", $"mask", 0)) + .select(rf_data_cells($"mbv")) + .first() + result should be (byteArrayTile.size) + } + it("should inverse mask tile by another identified by specified value") { val df = Seq[Tile](randPRT).toDF("tile") val mask_value = 4 From e113ffdd3425db0d33edfbdc56c2b013f1d97ddb Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Wed, 13 Nov 2019 16:27:34 -0500 Subject: [PATCH 07/24] Added (disabled) integration test for profiling spatial index effects. Reorganized non-evaluated documentationm files to `docs`. --- .../raster/RaterSourceDataSourceIT.scala | 64 +++++++++++++++++++ .../src/main/paradox}/concepts.md | 0 .../docs => docs/src/main/paradox}/index.md | 0 .../src/main/paradox}/machine-learning.md | 0 .../src/main/paradox}/raster-io.md | 0 .../src/main/paradox}/raster-processing.md | 0 .../src/main/python/docs/raster-read.pymd | 8 +++ .../main/python/pyrasterframes/__init__.py | 13 ++++ .../src/main/python/tests/RasterSourceTest.py | 6 ++ 9 files changed, 91 insertions(+) create mode 100644 datasource/src/it/scala/org/locationtech/rasterframes/datasource/raster/RaterSourceDataSourceIT.scala rename {pyrasterframes/src/main/python/docs => docs/src/main/paradox}/concepts.md (100%) rename {pyrasterframes/src/main/python/docs => docs/src/main/paradox}/index.md (100%) rename {pyrasterframes/src/main/python/docs => docs/src/main/paradox}/machine-learning.md (100%) rename {pyrasterframes/src/main/python/docs => docs/src/main/paradox}/raster-io.md (100%) rename {pyrasterframes/src/main/python/docs => docs/src/main/paradox}/raster-processing.md (100%) diff --git a/datasource/src/it/scala/org/locationtech/rasterframes/datasource/raster/RaterSourceDataSourceIT.scala b/datasource/src/it/scala/org/locationtech/rasterframes/datasource/raster/RaterSourceDataSourceIT.scala new file mode 100644 index 000000000..f2ffd407b --- /dev/null +++ b/datasource/src/it/scala/org/locationtech/rasterframes/datasource/raster/RaterSourceDataSourceIT.scala @@ -0,0 +1,64 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.datasource.raster + +import org.locationtech.rasterframes._ + +class RaterSourceDataSourceIT extends TestEnvironment with TestData { + + describe("RasterJoin Performance") { + import spark.implicits._ + ignore("joining classification raster against L8 should run in a reasonable amount of time") { + // A regression test. + val rf = spark.read.raster + .withSpatialIndex() + .load("https://s22s-test-geotiffs.s3.amazonaws.com/water_class/seasonality_90W_50N.tif") + + val target_rf = + rf.select(rf_extent($"proj_raster").alias("extent"), rf_crs($"proj_raster").alias("crs"), rf_tile($"proj_raster").alias("target")) + + val cat = + """ + B3,B5 + https://landsat-pds.s3.us-west-2.amazonaws.com/c1/L8/021/028/LC08_L1TP_021028_20180515_20180604_01_T1/LC08_L1TP_021028_20180515_20180604_01_T1_B3.TIF,https://landsat-pds.s3.us-west-2.amazonaws.com/c1/L8/021/028/LC08_L1TP_021028_20180515_20180604_01_T1/LC08_L1TP_021028_20180515_20180604_01_T1_B5.TIF + """ + + val features_rf = spark.read.raster + .fromCSV(cat, "B3", "B5") + .withSpatialIndex() + .load() + .withColumn("extent", rf_extent($"B3")) + .withColumn("crs", rf_crs($"B3")) + .withColumn("B3", rf_tile($"B3")) + .withColumn("B5", rf_tile($"B5")) + .withColumn("ndwi", rf_normalized_difference($"B3", $"B5")) + .where(!rf_is_no_data_tile($"B3")) + .select("extent", "crs", "ndwi") + + features_rf.explain(true) + + val joined = target_rf.rasterJoin(features_rf).cache() + joined.show(false) + //println(joined.select("ndwi").toMarkdown()) + } + } +} diff --git a/pyrasterframes/src/main/python/docs/concepts.md b/docs/src/main/paradox/concepts.md similarity index 100% rename from pyrasterframes/src/main/python/docs/concepts.md rename to docs/src/main/paradox/concepts.md diff --git a/pyrasterframes/src/main/python/docs/index.md b/docs/src/main/paradox/index.md similarity index 100% rename from pyrasterframes/src/main/python/docs/index.md rename to docs/src/main/paradox/index.md diff --git a/pyrasterframes/src/main/python/docs/machine-learning.md b/docs/src/main/paradox/machine-learning.md similarity index 100% rename from pyrasterframes/src/main/python/docs/machine-learning.md rename to docs/src/main/paradox/machine-learning.md diff --git a/pyrasterframes/src/main/python/docs/raster-io.md b/docs/src/main/paradox/raster-io.md similarity index 100% rename from pyrasterframes/src/main/python/docs/raster-io.md rename to docs/src/main/paradox/raster-io.md diff --git a/pyrasterframes/src/main/python/docs/raster-processing.md b/docs/src/main/paradox/raster-processing.md similarity index 100% rename from pyrasterframes/src/main/python/docs/raster-processing.md rename to docs/src/main/paradox/raster-processing.md diff --git a/pyrasterframes/src/main/python/docs/raster-read.pymd b/pyrasterframes/src/main/python/docs/raster-read.pymd index 443ee0d96..73a1a018b 100644 --- a/pyrasterframes/src/main/python/docs/raster-read.pymd +++ b/pyrasterframes/src/main/python/docs/raster-read.pymd @@ -215,6 +215,14 @@ non_lazy In the initial examples on this page, you may have noticed that the realized (non-lazy) _tiles_ are shown, but we did not change `lazy_tiles`. Instead, we used @ref:[`rf_tile`](reference.md#rf-tile) to explicitly request the realized _tile_ from the lazy representation. +## Spatial Indexing and Partitioning + +It's often desirable to take extra steps in ensuring your data is effectively distributed over your computing resources. One way of doing that is using something called a ["space filling curve"](https://en.wikipedia.org/wiki/Space-filling_curve), which turns an N-dimensional value into a one dimensional value, with properties that favor keeping entities near each other in N-space near each other in index space. To have RasterFrames add a spatial index based partitioning on a raster reads, use the `spatial_index_partitions` parameter: + +```python, spatial_indexing +df = spark.read.raster(uri, spatial_index_partitions=True) +df +``` ## GeoTrellis diff --git a/pyrasterframes/src/main/python/pyrasterframes/__init__.py b/pyrasterframes/src/main/python/pyrasterframes/__init__.py index 7915af34e..78cafdff5 100644 --- a/pyrasterframes/src/main/python/pyrasterframes/__init__.py +++ b/pyrasterframes/src/main/python/pyrasterframes/__init__.py @@ -109,6 +109,7 @@ def _raster_reader( band_indexes=None, tile_dimensions=(256, 256), lazy_tiles=True, + spatial_index_partitions=None, **options): """ Returns a Spark DataFrame from raster data files specified by URIs. @@ -124,6 +125,8 @@ def _raster_reader( :param band_indexes: list of integers indicating which bands, zero-based, to read from the raster files specified; default is to read only the first band. :param tile_dimensions: tuple or list of two indicating the default tile dimension as (columns, rows). :param lazy_tiles: If true (default) only generate minimal references to tile contents; if false, fetch tile cell values. + :param spatial_index_partitions: If true, partitions read tiles by a Z2 spatial index using the default shuffle partitioning. + If a values > 0, the given number of partitions are created instead of the default. :param options: Additional keyword arguments to pass to the Spark DataSource. """ @@ -146,6 +149,16 @@ def temp_name(): if band_indexes is None: band_indexes = [0] + if spatial_index_partitions is False: + spatial_index_partitions = None + + if spatial_index_partitions is not None: + if spatial_index_partitions is True: + spatial_index_partitions = "-1" + else: + spatial_index_partitions = str(spatial_index_partitions) + options.update({"spatial_index_partitions": spatial_index_partitions}) + options.update({ "band_indexes": to_csv(band_indexes), "tile_dimensions": to_csv(tile_dimensions), diff --git a/pyrasterframes/src/main/python/tests/RasterSourceTest.py b/pyrasterframes/src/main/python/tests/RasterSourceTest.py index c4c8e64f7..e840092ca 100644 --- a/pyrasterframes/src/main/python/tests/RasterSourceTest.py +++ b/pyrasterframes/src/main/python/tests/RasterSourceTest.py @@ -211,3 +211,9 @@ def test_catalog_named_arg(self): df = self.spark.read.raster(catalog=self.path_pandas_df(), catalog_col_names=['b1', 'b2', 'b3']) self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo self.assertTrue(df.select('b1_path').distinct().count() == 3) + + def test_spatial_partitioning(self): + df = self.spark.read.raster(self.path(1, 1), spatial_index_partitions=True) + self.assertTrue('spatial_index' in df.columns) + # Other tests? + From 0a7f90cc2c2d06be0b4e2faba85bbe05bdcdc2f7 Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Thu, 14 Nov 2019 12:57:47 -0500 Subject: [PATCH 08/24] Regression --- .../datasource/raster/RasterSourceDataSourceSpec.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala index 345579567..0e92e5ba5 100644 --- a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala +++ b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala @@ -23,11 +23,11 @@ package org.locationtech.rasterframes.datasource.raster import geotrellis.raster.Tile import org.apache.spark.sql.functions.{lit, round, udf} import org.apache.spark.sql.types.LongType -import org.locationtech.rasterframes.{TestEnvironment, _} import org.locationtech.rasterframes.datasource.raster.RasterSourceDataSource.{RasterSourceCatalog, _} import org.locationtech.rasterframes.model.TileDimensions import org.locationtech.rasterframes.ref.RasterRef.RasterRefTile import org.locationtech.rasterframes.util._ +import org.locationtech.rasterframes.{TestEnvironment, _} class RasterSourceDataSourceSpec extends TestEnvironment with TestData { import spark.implicits._ @@ -336,15 +336,16 @@ class RasterSourceDataSourceSpec extends TestEnvironment with TestData { describe("attaching a spatial index") { val l8_df = spark.read.raster - .withSpatialIndex() + .withSpatialIndex(5) .load(remoteL8.toASCIIString) .cache() it("should add index") { l8_df.columns should contain("spatial_index") l8_df.schema("spatial_index").dataType should be(LongType) - l8_df.show(false) - l8_df.select($"spatial_index").distinct().count() should be(l8_df.count()) + val parts = l8_df.rdd.partitions + parts.length should be (5) + parts.map(_.getClass.getSimpleName).forall(_ == "ShuffledRowRDDPartition") should be (true) } } } From f4f4e9a3d3cc7155684d350f62a6fc9a85269bf5 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Thu, 14 Nov 2019 13:23:06 -0500 Subject: [PATCH 09/24] Masking improvements and unit tests. Mask by value of 0 fixed Masking cell type mutations in rf_mask_by_values is resolved Masking by extraction of bits implemented Signed-off-by: Jason T. Brown --- .../rasterframes/RasterFunctions.scala | 16 +- .../expressions/transformers/Mask.scala | 32 +- .../rasterframes/RasterFunctionsSpec.scala | 387 ++++++++++++------ docs/src/main/paradox/reference.md | 2 +- 4 files changed, 290 insertions(+), 147 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala index 7142e8b6c..2a8f11878 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala @@ -319,7 +319,7 @@ trait RasterFunctions { /** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values` list, replace the value with NODATA. */ - def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Seq[Int]): TypedColumn[Any, Tile] = { + def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Int*): TypedColumn[Any, Tile] = { import org.apache.spark.sql.functions.array val valuesCol: Column = array(maskValues.map(lit).toSeq: _*) rf_mask_by_values(sourceTile, maskTile, valuesCol) @@ -338,22 +338,24 @@ trait RasterFunctions { Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue)) /** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */ - def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): Column = - rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(valueToMask)) + def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): TypedColumn[Any, Tile] = + rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(if (valueToMask) 1 else 0)) /** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */ - def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): Column = - rf_mask_by_bits(dataTile, maskTile, bitPosition, lit(1), valueToMask) + def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): TypedColumn[Any, Tile] = { + import org.apache.spark.sql.functions.array + rf_mask_by_bits(dataTile, maskTile, bitPosition, lit(1), array(valueToMask)) + } /** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */ - def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): Column = { + def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): TypedColumn[Any, Tile] = { val bitMask = rf_local_extract_bits(maskTile, startBit, numBits) rf_mask_by_values(dataTile, bitMask, valuesToMask) } /** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */ - def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): Column = { + def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): TypedColumn[Any, Tile] = { import org.apache.spark.sql.functions.array val values = array(valuesToMask.map(lit):_*) rf_mask_by_bits(dataTile, maskTile, lit(startBit), lit(numBits), values) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala index c6b9b75ec..67b2529ba 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala @@ -24,13 +24,13 @@ package org.locationtech.rasterframes.expressions.transformers import com.typesafe.scalalogging.Logger import geotrellis.raster import geotrellis.raster.Tile -import geotrellis.raster.mapalgebra.local.{Defined, InverseMask => gtInverseMask, Mask => gtMask} +import geotrellis.raster.mapalgebra.local.{Defined, InverseMask ⇒ gtInverseMask, Mask ⇒ gtMask} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression} import org.apache.spark.sql.rf.TileUDT -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, NullType} import org.apache.spark.sql.{Column, TypedColumn} import org.locationtech.rasterframes.encoders.CatalystSerializer._ import org.locationtech.rasterframes.expressions.DynamicExtractors._ @@ -38,7 +38,15 @@ import org.locationtech.rasterframes.expressions.localops.IsIn import org.locationtech.rasterframes.expressions.row import org.slf4j.LoggerFactory -abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, inverse: Boolean) +/** Convert cells in the `left` to NoData based on another tile's contents + * + * @param left a tile of data values, with valid nodata cell type + * @param middle a tile indicating locations to set to nodata + * @param right optional, cell values in the `middle` tile indicating locations to set NoData + * @param defined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells + * @param inverse if true, and defined is true, set `left` to NoData where `middle` is NOT nodata + */ +abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, defined: Boolean, inverse: Boolean) extends TernaryExpression with CodegenFallback with Serializable { // aliases. def targetExp = left @@ -77,13 +85,13 @@ abstract class Mask(val left: Expression, val middle: Expression, val right: Exp val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput) - val masking = if (maskValue.value == 0) Defined(maskTile) - else maskTile + val masking = if (defined) Defined(maskTile) + else maskTile.localEqual(maskValue.value) val result = if (inverse) - gtInverseMask(targetTile, masking, maskValue.value, raster.NODATA) + gtInverseMask(targetTile, masking, 1, raster.NODATA) else - gtMask(targetTile, masking, maskValue.value, raster.NODATA) + gtMask(targetTile, masking, 1, raster.NODATA) targetCtx match { case Some(ctx) => ctx.toProjectRasterTile(result).toInternalRow @@ -106,7 +114,7 @@ object Mask { ...""" ) case class MaskByDefined(target: Expression, mask: Expression) - extends Mask(target, mask, Literal(0), false) { + extends Mask(target, mask, Literal(0), true, false) { override def nodeName: String = "rf_mask" } object MaskByDefined { @@ -126,7 +134,7 @@ object Mask { ...""" ) case class InverseMaskByDefined(leftTile: Expression, rightTile: Expression) - extends Mask(leftTile, rightTile, Literal(0), true) { + extends Mask(leftTile, rightTile, Literal(0), true, true) { override def nodeName: String = "rf_inverse_mask" } object InverseMaskByDefined { @@ -146,7 +154,7 @@ object Mask { ...""" ) case class MaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression) - extends Mask(leftTile, rightTile, maskValue, false) { + extends Mask(leftTile, rightTile, maskValue, false, false) { override def nodeName: String = "rf_mask_by_value" } object MaskByValue { @@ -168,7 +176,7 @@ object Mask { ...""" ) case class InverseMaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression) - extends Mask(leftTile, rightTile, maskValue, true) { + extends Mask(leftTile, rightTile, maskValue, false, true) { override def nodeName: String = "rf_inverse_mask_by_value" } object InverseMaskByValue { @@ -190,7 +198,7 @@ object Mask { ...""" ) case class MaskByValues(dataTile: Expression, maskTile: Expression) - extends Mask(dataTile, maskTile, Literal(1), inverse = false) { + extends Mask(dataTile, maskTile, Literal(1), false, false) { def this(dataTile: Expression, maskTile: Expression, maskValues: Expression) = this(dataTile, IsIn(maskTile, maskValues)) override def nodeName: String = "rf_mask_by_values" diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index 955daab04..b4455c69d 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -642,6 +642,16 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { checkDocs("rf_mask") } + it("should mask without mutating cell type") { + val result = Seq((byteArrayTile, maskingTile)) + .toDF("tile", "mask") + .select(rf_mask($"tile", $"mask").as("masked_tile")) + .select(rf_cell_type($"masked_tile")) + .first() + + result should be (byteArrayTile.cellType) + } + it("should inverse mask one tile against another") { val df = Seq[Tile](randPRT).toDF("tile") @@ -691,7 +701,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { val df = Seq((byteArrayTile, maskingTile)).toDF("data", "mask") // data tile is all data cells - df.select(rf_no_data_cells($"data")).first() should be (byteArrayTile.size) + df.select(rf_data_cells($"data")).first() should be (byteArrayTile.size) // mask by value against 15 should set 3 cell locations to Nodata df.withColumn("mbv", rf_mask_by_value($"data", $"mask", 15)) @@ -734,23 +744,23 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { it("should mask tile by another identified by sequence of specified values") { val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(six.rows), six.extent, six.crs) val df = Seq((six, squareIncrementingPRT)) - .toDF("tile", "mask") + .toDF("tile", "mask") val mask_values = Seq(4, 5, 6, 12) val withMasked = df.withColumn("masked", - rf_mask_by_values($"tile", $"mask", mask_values)) + rf_mask_by_values($"tile", $"mask", mask_values:_*)) val expected = squareIncrementingPRT.toArray().count(v ⇒ mask_values.contains(v)) val result = withMasked.agg(rf_agg_no_data_cells($"masked") as "masked_nd") - .first() + .first() - result.getAs[BigInt](0) should be (expected) + result.getAs[BigInt](0) should be(expected) val withMaskedSql = df.selectExpr("rf_mask_by_values(tile, mask, array(4, 5, 6, 12)) AS masked") val resultSql = withMaskedSql.agg(rf_agg_no_data_cells($"masked")).as[Long] - resultSql.first() should be (expected) + resultSql.first() should be(expected) checkDocs("rf_mask_by_values") } @@ -902,6 +912,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { } } + it("should resample") { def lowRes = { def base = ArrayTile(Array(1,2,3,4), 2, 2) @@ -936,74 +947,77 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { checkDocs("rf_resample") } - it("should create RGB composite") { - val red = TestData.l8Sample(4).toProjectedRasterTile - val green = TestData.l8Sample(3).toProjectedRasterTile - val blue = TestData.l8Sample(2).toProjectedRasterTile + describe("create encoded representation of 3 band images") { - val expected = ArrayMultibandTile( - red.rescale(0, 255), - green.rescale(0, 255), - blue.rescale(0, 255) - ).color() + it("should create RGB composite") { + val red = TestData.l8Sample(4).toProjectedRasterTile + val green = TestData.l8Sample(3).toProjectedRasterTile + val blue = TestData.l8Sample(2).toProjectedRasterTile - val df = Seq((red, green, blue)).toDF("red", "green", "blue") + val expected = ArrayMultibandTile( + red.rescale(0, 255), + green.rescale(0, 255), + blue.rescale(0, 255) + ).color() - val expr = df.select(rf_rgb_composite($"red", $"green", $"blue")).as[ProjectedRasterTile] + val df = Seq((red, green, blue)).toDF("red", "green", "blue") - val nat_color = expr.first() + val expr = df.select(rf_rgb_composite($"red", $"green", $"blue")).as[ProjectedRasterTile] - checkDocs("rf_rgb_composite") - assertEqual(nat_color.toArrayTile(), expected) - } + val nat_color = expr.first() - it("should create an RGB PNG image") { - val red = TestData.l8Sample(4).toProjectedRasterTile - val green = TestData.l8Sample(3).toProjectedRasterTile - val blue = TestData.l8Sample(2).toProjectedRasterTile + checkDocs("rf_rgb_composite") + assertEqual(nat_color.toArrayTile(), expected) + } - val df = Seq((red, green, blue)).toDF("red", "green", "blue") + it("should create an RGB PNG image") { + val red = TestData.l8Sample(4).toProjectedRasterTile + val green = TestData.l8Sample(3).toProjectedRasterTile + val blue = TestData.l8Sample(2).toProjectedRasterTile - val expr = df.select(rf_render_png($"red", $"green", $"blue")) + val df = Seq((red, green, blue)).toDF("red", "green", "blue") - val pngData = expr.first() + val expr = df.select(rf_render_png($"red", $"green", $"blue")) - val image = ImageIO.read(new ByteArrayInputStream(pngData)) - image.getWidth should be(red.cols) - image.getHeight should be(red.rows) - } + val pngData = expr.first() - it("should create a color-ramp PNG image") { - val red = TestData.l8Sample(4).toProjectedRasterTile + val image = ImageIO.read(new ByteArrayInputStream(pngData)) + image.getWidth should be(red.cols) + image.getHeight should be(red.rows) + } - val df = Seq(red).toDF("red") + it("should create a color-ramp PNG image") { + val red = TestData.l8Sample(4).toProjectedRasterTile - val expr = df.select(rf_render_png($"red", ColorRamps.HeatmapBlueToYellowToRedSpectrum)) + val df = Seq(red).toDF("red") - val pngData = expr.first() + val expr = df.select(rf_render_png($"red", ColorRamps.HeatmapBlueToYellowToRedSpectrum)) - val image = ImageIO.read(new ByteArrayInputStream(pngData)) - image.getWidth should be(red.cols) - image.getHeight should be(red.rows) - } - it("should interpret cell values with a specified cell type") { - checkDocs("rf_interpret_cell_type_as") - val df = Seq(randNDPRT).toDF("t") - .withColumn("tile", rf_interpret_cell_type_as($"t", "int8raw")) - val resultTile = df.select("tile").as[Tile].first() + val pngData = expr.first() + + val image = ImageIO.read(new ByteArrayInputStream(pngData)) + image.getWidth should be(red.cols) + image.getHeight should be(red.rows) + } + it("should interpret cell values with a specified cell type") { + checkDocs("rf_interpret_cell_type_as") + val df = Seq(randNDPRT).toDF("t") + .withColumn("tile", rf_interpret_cell_type_as($"t", "int8raw")) + val resultTile = df.select("tile").as[Tile].first() - resultTile.cellType should be (CellType.fromName("int8raw")) - // should have same number of values that are -2 the old ND - val countOldNd = df.select( - rf_tile_sum(rf_local_equal($"tile", ct.noDataValue)), - rf_no_data_cells($"t") - ).first() - countOldNd._1 should be (countOldNd._2) + resultTile.cellType should be (CellType.fromName("int8raw")) + // should have same number of values that are -2 the old ND + val countOldNd = df.select( + rf_tile_sum(rf_local_equal($"tile", ct.noDataValue)), + rf_no_data_cells($"t") + ).first() + countOldNd._1 should be (countOldNd._2) - // should not have no data any more (raw type) - val countNewNd = df.select(rf_no_data_cells($"tile")).first() - countNewNd should be (0L) + // should not have no data any more (raw type) + val countNewNd = df.select(rf_no_data_cells($"tile")).first() + countNewNd should be (0L) + } } it("should return local data and nodata"){ @@ -1021,92 +1035,211 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { dResult should be (randNDPRT.localDefined()) } - it("should check values isin"){ - checkDocs("rf_local_is_in") - // tile is 3 by 3 with values, 1 to 9 - val df = Seq(byteArrayTile).toDF("t") - .withColumn("one", lit(1)) - .withColumn("five", lit(5)) - .withColumn("ten", lit(10)) - .withColumn("in_expect_2", rf_local_is_in($"t", array($"one", $"five"))) - .withColumn("in_expect_1", rf_local_is_in($"t", array($"ten", $"five"))) - .withColumn("in_expect_1a", rf_local_is_in($"t", Array(10, 5))) - .withColumn("in_expect_0", rf_local_is_in($"t", array($"ten"))) + describe("masking by specific bit values") { + // Define a dataframe set up similar to the Landsat8 masking scheme + // Sample of https://www.usgs.gov/media/images/landsat-8-quality-assessment-band-pixel-value-interpretations + val fill = 1 + val clear = 2720 + val cirrus = 6816 + val med_cloud = 2756 // with 1-2 bands saturated + val hi_cirrus = 6900 // yes cloud, hi conf cloud and hi conf cirrus and 1-2band sat + val dataColumnCellType = UShortConstantNoDataCellType + val tiles = Seq(fill, clear, cirrus, med_cloud, hi_cirrus).map{v ⇒ + ( + TestData.projectedRasterTile(3, 3, 6, TestData.extent, TestData.crs, dataColumnCellType), + TestData.projectedRasterTile(3, 3, v, TestData.extent, TestData.crs, UShortCellType) // because masking returns the union of cell types + ) + } - val e2Result = df.select(rf_tile_sum($"in_expect_2")).as[Double].first() - e2Result should be (2.0) + val df = tiles.toDF("data", "mask") + .withColumn("val", rf_tile_min($"mask")) - val e1Result = df.select(rf_tile_sum($"in_expect_1")).as[Double].first() - e1Result should be (1.0) + it("should give LHS cell type"){ + val resultMask = df.select( + rf_cell_type( + rf_mask($"data", $"mask") + ) + ).distinct().collect() + all (resultMask) should be (dataColumnCellType) - val e1aResult = df.select(rf_tile_sum($"in_expect_1a")).as[Double].first() - e1aResult should be (1.0) + val resultMaskVal = df.select( + rf_cell_type( + rf_mask_by_value($"data", $"mask", 5) + ) + ).distinct().collect() - val e0Result = df.select($"in_expect_0").as[Tile].first() - e0Result.toArray() should contain only (0) + all(resultMaskVal) should be (dataColumnCellType) - } + val resultMaskValues = df.select( + rf_cell_type( + rf_mask_by_values($"data", $"mask", 5, 6, 7 ) + ) + ).distinct().collect() + all(resultMaskValues) should be (dataColumnCellType) - it("should unpack QA bits"){ - checkDocs("rf_local_extract_bits") + val resultMaskBit = df.select( + rf_cell_type( + rf_mask_by_bit($"data", $"mask", 5, true) + ) + ).distinct().collect() + all(resultMaskBit) should be (dataColumnCellType) + + val resultMaskValInv = df.select( + rf_cell_type( + rf_inverse_mask_by_value($"data", $"mask", 5) + ) + ).distinct().collect() + all(resultMaskValInv) should be (dataColumnCellType) - // Sample of https://www.usgs.gov/media/images/landsat-8-quality-assessment-band-pixel-value-interpretations - val fill = 1 - val clear = 2720 - val cirrus = 6816 - val med_cloud = 2756 // with 1-2 bands saturated - val tiles = Seq(fill, clear, cirrus, med_cloud).map(v ⇒ - TestData.projectedRasterTile(3, 3, v, TestData.extent, TestData.crs, UShortCellType)) - - val df = tiles.toDF("tile") - .withColumn("val", rf_tile_min($"tile")) - - val result = df - .withColumn("qa_fill", rf_local_extract_bits($"tile", lit(0))) - .withColumn("qa_sat", rf_local_extract_bits($"tile", lit(2), lit(2))) - .withColumn("qa_cloud", rf_local_extract_bits($"tile", lit(4))) - .withColumn("qa_cconf", rf_local_extract_bits($"tile", 5, 2)) - .withColumn("qa_snow", rf_local_extract_bits($"tile", lit(9), lit(2))) - .withColumn("qa_circonf", rf_local_extract_bits($"tile", 11, 2)) - - def checker(colName: String, valFilter: Int, assertValue: Int): Unit = { - // print this so we can see what's happening if something wrong - println(s"${colName} should be ${assertValue} for qa val ${valFilter}") - result.filter($"val" === lit(valFilter)) - .select(col(colName)) - .as[ProjectedRasterTile] - .first() - .get(0, 0) should be (assertValue) } - checker("qa_fill", fill, 1) - checker("qa_cloud", fill, 0) - checker("qa_cconf", fill, 0) - checker("qa_sat", fill, 0) - checker("qa_snow", fill, 0) - checker("qa_circonf", fill, 0) + it("should check values isin"){ + checkDocs("rf_local_is_in") - // trivial bits selection (numBits=0) and SQL - df.filter($"val" === lit(fill)) - .selectExpr("rf_local_extract_bits(tile, 0, 0) AS t") - .select(rf_exists($"t")).as[Boolean].first() should be (false) + // tile is 3 by 3 with values, 1 to 9 + val rf = Seq(byteArrayTile).toDF("t") + .withColumn("one", lit(1)) + .withColumn("five", lit(5)) + .withColumn("ten", lit(10)) + .withColumn("in_expect_2", rf_local_is_in($"t", array($"one", $"five"))) + .withColumn("in_expect_1", rf_local_is_in($"t", array($"ten", $"five"))) + .withColumn("in_expect_1a", rf_local_is_in($"t", Array(10, 5))) + .withColumn("in_expect_0", rf_local_is_in($"t", array($"ten"))) - checker("qa_fill", clear, 0) - checker("qa_cloud", clear, 0) - checker("qa_cconf", clear, 1) + val e2Result = rf.select(rf_tile_sum($"in_expect_2")).as[Double].first() + e2Result should be (2.0) - checker("qa_fill", med_cloud, 0) - checker("qa_cloud", med_cloud, 0) - checker("qa_cconf", med_cloud, 2) // L8 only tags hi conf in the cloud assessment - checker("qa_sat", med_cloud, 1) + val e1Result = rf.select(rf_tile_sum($"in_expect_1")).as[Double].first() + e1Result should be (1.0) - checker("qa_fill", cirrus, 0) - checker("qa_sat", cirrus, 0) - checker("qa_cloud", cirrus, 0) //low cloud conf - checker("qa_cconf", cirrus, 1) //low cloud conf - checker("qa_circonf", cirrus, 3) //high cirrus conf + val e1aResult = rf.select(rf_tile_sum($"in_expect_1a")).as[Double].first() + e1aResult should be (1.0) + val e0Result = rf.select($"in_expect_0").as[Tile].first() + e0Result.toArray() should contain only (0) + } + it("should unpack QA bits"){ + checkDocs("rf_local_extract_bits") + + val result = df + .withColumn("qa_fill", rf_local_extract_bits($"mask", lit(0))) + .withColumn("qa_sat", rf_local_extract_bits($"mask", lit(2), lit(2))) + .withColumn("qa_cloud", rf_local_extract_bits($"mask", lit(4))) + .withColumn("qa_cconf", rf_local_extract_bits($"mask", 5, 2)) + .withColumn("qa_snow", rf_local_extract_bits($"mask", lit(9), lit(2))) + .withColumn("qa_circonf", rf_local_extract_bits($"mask", 11, 2)) + + def checker(colName: String, valFilter: Int, assertValue: Int): Unit = { + // print this so we can see what's happening if something wrong + println(s"${colName} should be ${assertValue} for qa val ${valFilter}") + result.filter($"val" === lit(valFilter)) + .select(col(colName)) + .as[ProjectedRasterTile] + .first() + .get(0, 0) should be (assertValue) + } + + checker("qa_fill", fill, 1) + checker("qa_cloud", fill, 0) + checker("qa_cconf", fill, 0) + checker("qa_sat", fill, 0) + checker("qa_snow", fill, 0) + checker("qa_circonf", fill, 0) + + // trivial bits selection (numBits=0) and SQL + df.filter($"val" === lit(fill)) + .selectExpr("rf_local_extract_bits(mask, 0, 0) AS t") + .select(rf_exists($"t")).as[Boolean].first() should be (false) + + checker("qa_fill", clear, 0) + checker("qa_cloud", clear, 0) + checker("qa_cconf", clear, 1) + + checker("qa_fill", med_cloud, 0) + checker("qa_cloud", med_cloud, 0) + checker("qa_cconf", med_cloud, 2) // L8 only tags hi conf in the cloud assessment + checker("qa_sat", med_cloud, 1) + + checker("qa_fill", cirrus, 0) + checker("qa_sat", cirrus, 0) + checker("qa_cloud", cirrus, 0) //low cloud conf + checker("qa_cconf", cirrus, 1) //low cloud conf + checker("qa_circonf", cirrus, 3) //high cirrus conf + } + it("should mask by QA bits"){ + val result = df + .withColumn("fill_no", rf_mask_by_bit($"data", $"mask", 0, true)) + .withColumn("sat_0", rf_mask_by_bits($"data", $"mask", 2, 2, 1, 2, 3)) // strict no bands + .withColumn("sat_2", rf_mask_by_bits($"data", $"mask", 2, 2, 2, 3)) // up to 2 bands contain sat + .withColumn("sat_4", + rf_mask_by_bits($"data", $"mask", lit(2), lit(2), array(lit(3)))) // up to 4 bands contain sat + .withColumn("cloud_no", rf_mask_by_bit($"data", $"mask", lit(4), lit(true))) + .withColumn("cloud_only", rf_mask_by_bit($"data", $"mask", 4, false)) // mask if *not* cloud + .withColumn("cloud_conf_low", rf_mask_by_bits($"data", $"mask", lit(5), lit(2), array(lit(0), lit(1)))) + .withColumn("cloud_conf_med", rf_mask_by_bits($"data", $"mask", 5, 2, 0, 1, 2)) + .withColumn("cirrus_med", rf_mask_by_bits($"data", $"mask", 11, 2, 3, 2)) // n.b. this is masking out more likely cirrus. + + result.select(rf_cell_type($"fill_no")).first() should be (dataColumnCellType) + + def checker(columnName: String, maskValueFilter: Int, resultIsNoData: Boolean = true): Unit = { + /** in this unit test setup, the `val` column is an integer that the entire row's mask is full of + * filter for the maskValueFilter + * then check the columnName and look at the masked data tile given by `columnName` + * assert that the `columnName` tile is / is not all nodata based on `resultIsNoData` + * */ + + val printOutcome = if (resultIsNoData) "all NoData cells" + else "all data cells" + + println(s"${columnName} should contain ${printOutcome} for qa val ${maskValueFilter}") + val resultDf = result + .filter($"val" === lit(maskValueFilter)) + + val resultToCheck: Boolean = resultDf + .select(rf_is_no_data_tile(col(columnName))) + .first() + + val dataTile = resultDf.select(col(columnName)).as[ProjectedRasterTile].first() + println(s"\tData tile values for col ${columnName}: ${dataTile.toArray().mkString(",")}") +// val celltype = resultDf.select(rf_cell_type(col(columnName))).as[CellType].first() +// println(s"Cell type for col ${columnName}: ${celltype}") + + resultToCheck should be (resultIsNoData) + } + + checker("fill_no", fill, true) + checker("cloud_only", clear, true) + checker("cloud_only", hi_cirrus, false) + checker("cloud_no", hi_cirrus, false) + checker("sat_0", clear, false) + checker("cloud_no", clear, false) + checker("cloud_no", med_cloud, false) + checker("cloud_conf_low", med_cloud, false) + checker("cloud_conf_med", med_cloud, true) + checker("cirrus_med", cirrus, true) + checker("cloud_no", cirrus, false) + } + + it("should have SQL equivalent"){ + + df.createOrReplaceTempView("df_maskbits") + + val maskedCol = "cloud_conf_med" + val result = spark.sql( + s""" + |SELECT rf_mask_by_values( + | data, + | rf_local_extract_bits(mask, 5, 2), + | array(0, 1, 2) + | ) as ${maskedCol} + | FROM df_maskbits + | WHERE val = 2756 + |""".stripMargin) + result.select(rf_is_no_data_tile(col(maskedCol))).first() should be (true) + + } } + } diff --git a/docs/src/main/paradox/reference.md b/docs/src/main/paradox/reference.md index aef211035..b6af613ac 100644 --- a/docs/src/main/paradox/reference.md +++ b/docs/src/main/paradox/reference.md @@ -257,7 +257,7 @@ This function is not available in the SQL API. The below is equivalent: ```sql SELECT rf_mask_by_values( tile, - rf_extract_bits(mask_tile, start_bit, num_bits), + rf_local_extract_bits(mask_tile, start_bit, num_bits), mask_values ), ``` From 0ec9b6e1d23125bdb99f710e73def7037966b69c Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Thu, 14 Nov 2019 13:41:56 -0500 Subject: [PATCH 10/24] Python regression. --- pyrasterframes/src/main/python/tests/VectorTypesTests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyrasterframes/src/main/python/tests/VectorTypesTests.py b/pyrasterframes/src/main/python/tests/VectorTypesTests.py index 7c85d3119..0eeaf6eda 100644 --- a/pyrasterframes/src/main/python/tests/VectorTypesTests.py +++ b/pyrasterframes/src/main/python/tests/VectorTypesTests.py @@ -165,10 +165,11 @@ def test_xz2_index(self): self.assertSetEqual(indexes, expected) # Test against proj_raster (has CRS and Extent embedded). - result_one_arg = self.df.select(rf_xz2_index('tile').alias('ix')) \ + df = self.spark.read.raster(self.img_uri) + result_one_arg = df.select(rf_xz2_index('proj_raster').alias('ix')) \ .agg(F_min('ix')).first()[0] - result_two_arg = self.df.select(rf_xz2_index(rf_extent('tile'), rf_crs('tile')).alias('ix')) \ + result_two_arg = df.select(rf_xz2_index(rf_extent('proj_raster'), rf_crs('proj_raster')).alias('ix')) \ .agg(F_min('ix')).first()[0] self.assertEqual(result_two_arg, result_one_arg) From abb7add31e1242678decfc0c2a8df07afa5f5936 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Thu, 14 Nov 2019 16:13:41 -0500 Subject: [PATCH 11/24] Fix for both masking by def and value; expand code comments; update tests Signed-off-by: Jason T. Brown --- .../expressions/transformers/Mask.scala | 15 +++++++++------ .../rasterframes/RasterFunctionsSpec.scala | 12 ++++++++++++ .../main/python/tests/RasterFunctionsTests.py | 16 ++++++++++------ 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala index 67b2529ba..03dd2ffbf 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala @@ -24,13 +24,13 @@ package org.locationtech.rasterframes.expressions.transformers import com.typesafe.scalalogging.Logger import geotrellis.raster import geotrellis.raster.Tile -import geotrellis.raster.mapalgebra.local.{Defined, InverseMask ⇒ gtInverseMask, Mask ⇒ gtMask} +import geotrellis.raster.mapalgebra.local.{Undefined, InverseMask ⇒ gtInverseMask, Mask ⇒ gtMask} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression} import org.apache.spark.sql.rf.TileUDT -import org.apache.spark.sql.types.{DataType, NullType} +import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Column, TypedColumn} import org.locationtech.rasterframes.encoders.CatalystSerializer._ import org.locationtech.rasterframes.expressions.DynamicExtractors._ @@ -43,10 +43,10 @@ import org.slf4j.LoggerFactory * @param left a tile of data values, with valid nodata cell type * @param middle a tile indicating locations to set to nodata * @param right optional, cell values in the `middle` tile indicating locations to set NoData - * @param defined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells + * @param undefined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells * @param inverse if true, and defined is true, set `left` to NoData where `middle` is NOT nodata */ -abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, defined: Boolean, inverse: Boolean) +abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, undefined: Boolean, inverse: Boolean) extends TernaryExpression with CodegenFallback with Serializable { // aliases. def targetExp = left @@ -85,9 +85,12 @@ abstract class Mask(val left: Expression, val middle: Expression, val right: Exp val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput) - val masking = if (defined) Defined(maskTile) - else maskTile.localEqual(maskValue.value) + // Get a tile where values of 1 indicate locations to set to ND in the target tile + // When `undefined` is true, setting targetTile locations to ND for ND locations of the `maskTile` + val masking = if (undefined) Undefined(maskTile) + else maskTile.localEqual(maskValue.value) // Otherwise if `maskTile` locations equal `maskValue`, set location to ND + // apply the `masking` where values are 1 set to ND (possibly inverted!) val result = if (inverse) gtInverseMask(targetTile, masking, 1, raster.NODATA) else diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index b4455c69d..49968a85b 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -642,6 +642,17 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { checkDocs("rf_mask") } + it("should mask with expected results") { + val df = Seq((byteArrayTile, maskingTile)).toDF("tile", "mask") + + val withMasked = df.withColumn("masked", + rf_mask($"tile", $"mask")) + + val result: Tile = withMasked.select($"masked").as[Tile].first() + + result.localUndefined().toArray() should be (maskingTile.localUndefined().toArray()) + } + it("should mask without mutating cell type") { val result = Seq((byteArrayTile, maskingTile)) .toDF("tile", "mask") @@ -1227,6 +1238,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { df.createOrReplaceTempView("df_maskbits") val maskedCol = "cloud_conf_med" + // this is the example in the docs val result = spark.sql( s""" |SELECT rf_mask_by_values( diff --git a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py index 6c82f867c..a5c98088b 100644 --- a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py +++ b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py @@ -261,19 +261,23 @@ def test_mask(self): from pyrasterframes.rf_types import Tile, CellType np.random.seed(999) - ma = np.ma.array(np.random.randint(0, 10, (5, 5), dtype='int8'), mask=np.random.rand(5, 5) > 0.7) + # importantly exclude 0 from teh range because that's the nodata value for the `data_tile`'s cell type + ma = np.ma.array(np.random.randint(1, 10, (5, 5), dtype='int8'), mask=np.random.rand(5, 5) > 0.7) expected_data_values = ma.compressed().size expected_no_data_values = ma.size - expected_data_values self.assertTrue(expected_data_values > 0, "Make sure random seed is cooperative ") self.assertTrue(expected_no_data_values > 0, "Make sure random seed is cooperative ") - df = self.spark.createDataFrame([ - Row(t=Tile(np.ones(ma.shape, ma.dtype)), m=Tile(ma)) - ]) + data_tile = Tile(np.ones(ma.shape, ma.dtype), CellType.uint8()) + + df = self.spark.createDataFrame([Row(t=data_tile, m=Tile(ma))]) \ + .withColumn('masked_t', rf_mask('t', 'm')) - df = df.withColumn('masked_t', rf_mask('t', 'm')) result = df.select(rf_data_cells('masked_t')).first()[0] - self.assertEqual(result, expected_data_values) + self.assertEqual(result, expected_data_values, + f"Masked tile should have {expected_data_values} data values but found: {df.select('masked_t').first()[0].cells}." + f"Original data: {data_tile.cells}" + f"Masked by {ma}") nd_result = df.select(rf_no_data_cells('masked_t')).first()[0] self.assertEqual(nd_result, expected_no_data_values) From b67049a49d55d0dcab08bfdf5b38c50e9c23eb13 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Fri, 15 Nov 2019 10:25:52 -0500 Subject: [PATCH 12/24] Extract bits should throw on non-integral cell types Signed-off-by: Jason T. Brown --- .../transformers/ExtractBits.scala | 1 + .../rasterframes/RasterFunctionsSpec.scala | 31 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala index 3219fc9b3..352f78ac9 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala @@ -86,6 +86,7 @@ object ExtractBits{ new Column(ExtractBits(tile.expr, startBit.expr, numBits.expr)) def apply(tile: Tile, startBit: Int, numBits: Int): Tile = { + assert(!tile.cellType.isFloatingPoint, "ExtractBits operation requires integral CellType") // this is the last `numBits` positions of "111111111111111" val widthMask = Int.MaxValue >> (63 - numBits) // map preserving the nodata structure diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index 49968a85b..73cce1f27 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -1179,6 +1179,37 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { checker("qa_cconf", cirrus, 1) //low cloud conf checker("qa_circonf", cirrus, 3) //high cirrus conf } + it("should extract bits from different cell types") { + import org.locationtech.rasterframes.expressions.transformers.ExtractBits + + case class TestCase[N: Numeric](cellType: CellType, cellValue: N, bitPosition: Int, numBits: Int, expectedValue: Int) { + def testIt(): Unit = { + val tile = projectedRasterTile(3, 3, cellValue, TestData.extent, TestData.crs, cellType) + val extracted = ExtractBits(tile, bitPosition, numBits) + all(extracted.toArray()) should be (expectedValue) + } + } + + Seq( + TestCase(BitCellType, 1, 0, 1, 1), + TestCase(ByteCellType, 127, 6, 2, 1), // 7th bit is sign + TestCase(ByteCellType, 127, 5, 2, 3), + TestCase(ByteCellType, -128, 6, 2, 2), // 7th bit is sign + TestCase(UByteCellType, 255, 6, 2, 3), + TestCase(UByteCellType, 255, 10, 2, 0), // shifting beyond range of cell type results in 0 + TestCase(ShortCellType, 32767, 15, 1, 0), + TestCase(ShortCellType, 32767, 14, 2, 1), + TestCase(ShortUserDefinedNoDataCellType(0), -32768, 14, 2, 2), + TestCase(UShortCellType, 65535, 14, 2, 3), + TestCase(UShortCellType, 65535, 18, 2, 0), // shifting beyond range of cell type results in 0 + TestCase(IntCellType, 2147483647, 30, 2, 1), + TestCase(IntCellType, 2147483647, 29, 2, 3) + ).foreach(_.testIt) + + // floating point types + an [AssertionError] should be thrownBy TestCase[Float](FloatCellType, Float.MaxValue, 29, 2, 3).testIt() + + } it("should mask by QA bits"){ val result = df .withColumn("fill_no", rf_mask_by_bit($"data", $"mask", 0, true)) From 5529d482808a87ed92b1fa78e4a9f005dd6cdb14 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Fri, 15 Nov 2019 13:30:13 -0500 Subject: [PATCH 13/24] Add mask bits python api and unit test Signed-off-by: Jason T. Brown --- .../rasterframes/RasterFunctionsSpec.scala | 2 +- .../python/pyrasterframes/rasterfunctions.py | 33 ++++++++++++++ .../main/python/tests/PyRasterFramesTests.py | 19 +++++++- .../main/python/tests/RasterFunctionsTests.py | 45 +++++++++++++++++++ 4 files changed, 97 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index 73cce1f27..c293f5341 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -1264,7 +1264,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { checker("cloud_no", cirrus, false) } - it("should have SQL equivalent"){ + it("mask bits should have SQL equivalent"){ df.createOrReplaceTempView("df_maskbits") diff --git a/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py b/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py index ae5977f51..7177614da 100644 --- a/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py +++ b/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py @@ -495,6 +495,39 @@ def rf_inverse_mask_by_value(data_tile, mask_tile, mask_value): return _apply_column_function('rf_inverse_mask_by_value', data_tile, mask_tile, mask_value) +def rf_mask_by_bit(data_tile, mask_tile, bit_position, value_to_mask): + """Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value.""" + if isinstance(bit_position, int): + bit_position = lit(bit_position) + if isinstance(value_to_mask, (int, float, bool)): + value_to_mask = lit(bool(value_to_mask)) + return _apply_column_function('rf_mask_by_bit', data_tile, mask_tile, bit_position, value_to_mask) + + +def rf_mask_by_bits(data_tile, mask_tile, start_bit, num_bits, values_to_mask): + """Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned.""" + if isinstance(start_bit, int): + start_bit = lit(start_bit) + if isinstance(num_bits, int): + num_bits = lit(num_bits) + if isinstance(values_to_mask, (tuple, list)): + from pyspark.sql.functions import array + values_to_mask = array([lit(v) for v in values_to_mask]) + + return _apply_column_function('rf_mask_by_bits', data_tile, mask_tile, start_bit, num_bits, values_to_mask) + + +def rf_local_extract_bits(tile, start_bit, num_bits=1): + """Extract value from specified bits of the cells' underlying binary data. + * `startBit` is the first bit to consider, working from the right. It is zero indexed. + * `numBits` is the number of bits to take moving further to the left. """ + if isinstance(start_bit, int): + start_bit = lit(bit_position) + if isinstance(num_bits, int): + num_bits = lit(num_bits) + return _apply_column_function('rf_local_extract_bits', tile, start_bit, num_bits) + + def rf_local_less(left_tile_col, right_tile_col): """Cellwise less than comparison between two tiles""" return _apply_column_function('rf_local_less', left_tile_col, right_tile_col) diff --git a/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py b/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py index 3bb2ce491..0dc36a8e7 100644 --- a/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py +++ b/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py @@ -25,6 +25,8 @@ from pyrasterframes.rf_types import * from pyspark.sql import SQLContext from pyspark.sql.functions import * +from pyspark.sql import Row + from . import TestEnvironment @@ -139,6 +141,22 @@ def test_tile_udt_serialization(self): long_trip = df.first()["tile"] self.assertEqual(long_trip, a_tile) + def test_masked_deser(self): + t = Tile(np.array([[1, 2, 3,], [4, 5, 6], [7, 8, 9]]), + CellType('uint8')) + + df = self.spark.createDataFrame([Row(t=t)]) + roundtrip = df.select(rf_mask_by_value('t', + rf_local_greater('t', lit(6)), + 1)) \ + .first()[0] + self.assertEqual( + roundtrip.cells.mask.sum(), + 3, + f"Expected {3} nodata values but found Tile" + f"{roundtrip}" + ) + def test_udf_on_tile_type_input(self): import numpy.testing df = self.spark.read.raster(self.img_uri) @@ -248,7 +266,6 @@ def less_pi(t): class TileOps(TestEnvironment): def setUp(self): - from pyspark.sql import Row # convenience so we can assert around Tile() == Tile() self.t1 = Tile(np.array([[1, 2], [3, 4]]), CellType.int8().with_no_data_value(3)) diff --git a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py index a5c98088b..a251f682b 100644 --- a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py +++ b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py @@ -256,6 +256,44 @@ def test_mask_by_values(self): # assert_equal(result0[0].cells, expected_diag_nd) self.assertTrue(result0[0] == expected_diag_nd) + def test_mask_bits(self): + t = Tile(42 * np.ones((4, 4), 'uint16'), CellType.uint16()) + # with a varitey of known values + mask = Tile(np.array([ + [1, 1, 2720, 2720], + [1, 6816, 6816, 2756], + [2720, 2720, 6900, 2720], + [2720, 6900, 6816, 1] + ]), CellType('uint16raw')) + + df = self.spark.createDataFrame([Row(t=t, mask=mask)]) + + # removes fill value 1 + mask_fill_df = df.select(rf_mask_by_bit('t', 'mask', 0, True).alias('mbb')) + mask_fill_tile = mask_fill_df.first()['mbb'] + + self.assertTrue(mask_fill_tile.cell_type.has_no_data()) + + self.assertTrue( + mask_fill_df.select(rf_data_cells('mbb')).first()[0], + 16 - 4 + ) + # Unsure why this fails. mask_fill_tile.cells is all 42 unmasked. + # self.assertEqual(mask_fill_tile.cells.mask.sum(), 4, + # f'Expected {16 - 4} data values but got the masked tile:' + # f'{mask_fill_tile}' + # ) + # + # mask out 6816, 6900 + mask_med_hi_cir = df.withColumn('mask_cir_mh', + rf_mask_by_bits('t', 'mask', 11, 2, [2, 3])) \ + .first()['mask_cir_mh'].cells + + self.assertEqual( + mask_med_hi_cir.mask.sum(), + 5 + ) + def test_mask(self): from pyspark.sql import Row from pyrasterframes.rf_types import Tile, CellType @@ -282,6 +320,13 @@ def test_mask(self): nd_result = df.select(rf_no_data_cells('masked_t')).first()[0] self.assertEqual(nd_result, expected_no_data_values) + # deser of tile is correct + self.assertEqual( + df.select('masked_t').first()[0].cells.compressed().size, + expected_data_values + ) + + def test_resample(self): from pyspark.sql.functions import lit result = self.rf.select( From ea8974cdba3c762d02d9f015254033f6f47fed98 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Fri, 15 Nov 2019 14:53:44 -0500 Subject: [PATCH 14/24] Add landsat masking section to masking docs page Signed-off-by: Jason T. Brown --- .../src/main/python/docs/masking.pymd | 75 +++++++++++++++++-- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/pyrasterframes/src/main/python/docs/masking.pymd b/pyrasterframes/src/main/python/docs/masking.pymd index f3ab08826..024faab50 100644 --- a/pyrasterframes/src/main/python/docs/masking.pymd +++ b/pyrasterframes/src/main/python/docs/masking.pymd @@ -14,8 +14,10 @@ spark = create_rf_spark_session() ``` Masking is a common operation in raster processing. It is setting certain cells to the @ref:[NoData value](nodata-handling.md). This is usually done to remove low-quality observations from the raster processing. Another related use case is to @ref:["clip"](masking.md#clipping) a raster to a given polygon. + +In this section we will demonstrate two common schemes for masking. In Sentinel 2, there is a separate classification raster that defines low quality areas. In Landsat 8, several quality factors are measured and the indications are packed into a single integer, which we have to unpack. -## Masking Example +## Masking Sentinel 2 Let's demonstrate masking with a pair of bands of Sentinel-2 data. The measurement bands we will use, blue and green, have no defined NoData. They share quality information from a separate file called the scene classification (SCL), which delineates areas of missing data and probable clouds. For more information on this, see the [Sentinel-2 algorithm overview](https://earth.esa.int/web/sentinel/technical-guides/sentinel-2-msi/level-2a/algorithm). Figure 3 tells us how to interpret the scene classification. For this example, we will exclude NoData, defective pixels, probable clouds, and cirrus clouds: values 0, 1, 8, 9, and 10. @@ -40,7 +42,7 @@ unmasked.printSchema() unmasked.select(rf_cell_type('blue'), rf_cell_type('scl')).distinct() ``` -## Define CellType for Masked Tile +### Define CellType for Masked Tile Because there is not a NoData already defined for the blue band, we must choose one. In this particular example, the minimum value is greater than zero, so we can use 0 as the NoData value. We will construct a new `CellType` object to represent this. @@ -59,7 +61,7 @@ We next convert the blue band to this cell type. converted = unmasked.select('scl', 'green', rf_convert_cell_type('blue', masked_blue_ct).alias('blue')) ``` -## Apply Mask from Quality Band +### Apply Mask from Quality Band Now we set cells of our `blue` column to NoData for all locations where the `scl` tile is in our set of undesirable values. This is the actual _masking_ operation. @@ -89,7 +91,7 @@ And the original SCL data. The bright yellow is a cloudy region in the original display(sample[1]) ``` -## Transferring Mask +### Transferring Mask We can now apply the same mask from the blue column to the green column. Note here we have supressed the step of explicitly checking what a "safe" NoData value for the green band should be. @@ -98,9 +100,72 @@ masked.withColumn('green_masked', rf_mask(rf_convert_cell_type('green', masked_b .orderBy(-rf_no_data_cells('blue_masked')) ``` +## Masking Landsat 8 + + +We will work with the Landsat scene [here](https://landsat-pds.s3.us-west-2.amazonaws.com/c1/L8/153/075/LC08_L1TP_153075_20190718_20190731_01_T1/index.html). For simplicity, we will just use two of the seven 30m bands. The quality mask for all bands is all contained in the `BQA` band. + + +```python, build_l8_df +base_url = 'https://landsat-pds.s3.us-west-2.amazonaws.com/c1/L8/153/075/LC08_L1TP_153075_20190718_20190731_01_T1/LC08_L1TP_153075_20190718_20190731_01_T1_' +data4 = base_url + 'B4.TIF' +data2 = base_url + 'B2.TIF' +mask = base_url + 'BQA.TIF' +l8_df = spark.read.raster([[data4, data2, mask]]) \ + .withColumnRenamed('proj_raster_0', 'data') \ + .withColumnRenamed('proj_raster_1', 'data2') \ + .withColumnRenamed('proj_raster_2', 'mask') +``` + +Masking is described [on the Landsat Missions page](https://www.usgs.gov/land-resources/nli/landsat/landsat-collection-1-level-1-quality-assessment-band). It is pretty dense. Focus for this data set is the Collection 1 Level-1 for Landsat 8. + +There are several inter-related factors to consider. In this exercise we will mask away the following. + + * Designated Fill = yes + * Cloud = yes + * Cloud Shadow Confidence = Medium or High + * Cirrus Confidence = Medium or High + +Note that you should consider your application and do your own exploratory analysis to determine the most appropriate mask! + +According to the information on the Landsat site this translates to masking by bit values in the BQA according to the following table. + +| Description | Value | Bits | Bit values | +|-------------------- |---------- |------- |---------------- | +| Designated fill | yes | 0 | 1 | +| Cloud | yes | 4 | 1 | +| Cloud shadow conf. | med / hi | 7-8 | 10, 11 (2, 3) | +| Cirrus conf. | med / hi | 11-12 | 10, 11 (2, 3) | + + +In this case, we will use the value of 0 as the NoData in the band data. Inspecting the associated [MTL txt file](https://landsat-pds.s3.us-west-2.amazonaws.com/c1/L8/153/075/LC08_L1TP_153075_20190718_20190731_01_T1/LC08_L1TP_153075_20190718_20190731_01_T1_MTL.txt) By inspection, we can discover that the minimum value in the band will be 1, thus allowing our use of 0 for NoData. + +The code chunk below works through each of the rows in the table above. The first expression sets the cell type to have the selected NoData. The @ref:[`rf_mask_by_bit`](reference.md#rf-mask-by-bit) and @ref:[`rf_mask_by_bits`](reference.md#rf-mask-by-bits) functions extract the selected bit or bits from the `mask` cells and compare them to the provided values. + +```python, build_l8_mask +l8_df = l8_df.withColumn('data_masked', # set to cell type that has a nodata + rf_convert_cell_type('data', CellType.uint16())) \ + .withColumn('data_masked', # fill yes + rf_mask_by_bit('data_masked', 'mask', 0, 1)) \ + .withColumn('data_masked', # cloud yes + rf_mask_by_bit('data_masked', 'mask', 4, 1)) \ + .withColumn('data_masked', # cloud shadow conf is medium or high + rf_mask_by_bits('data_masked', 'mask', 7, 2, [2, 3])) \ + .withColumn('data_masked', # cloud shadow conf is medium or high + rf_mask_by_bits('data_masked', 'mask', 11, 2, [2, 3])) \ + .withColumn('data2', # mask other data col against the other band + rf_mask(rf_convert_cell_type('data2',CellType.uint16()), 'data_masked')) \ + .filter(rf_data_cells('data_masked') > 0) # remove any entirely ND rows + +# Inspect a sample +l8_df.select('data', 'mask', 'data_masked', 'data2', rf_extent('data_masked')) \ + .filter(rf_data_cells('data_masked') > 32000) +``` + + ## Clipping -Clipping is the use of a polygon to determine the areas to mask in a raster. Typically the areas inside a polygon are retained and the cells outside are set to NoData. Given a geometry column on our DataFrame, we have to carry out three basic steps. First we have to ensure the vector geometry is correctly projected to the same @ref:[CRS](concepts.md#coordinate-reference-system-crs) as the raster. We'll continue with our example creating a simple polygon. Buffering a point will create an approximate circle. +Clipping is the use of a polygon to determine the areas to mask in a raster. Typically the areas inside a polygon are retained and the cells outside are set to NoData. Given a geometry column on our DataFrame, we have to carry out three basic steps. First we have to ensure the vector geometry is correctly projected to the same @ref:[CRS](concepts.md#coordinate-reference-system-crs) as the raster. We'll continue with our Sentinel 2 example, creating a simple polygon. Buffering a point will create an approximate circle. ```python, reproject_geom From 9385136bb4932fb952076804f5e4985cab2fd99e Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Mon, 18 Nov 2019 11:32:56 -0500 Subject: [PATCH 15/24] Added the ability to do a raster_join on proj_raster types. Fixes #419. --- .../rasterframes/extensions/RasterJoin.scala | 23 +++++++++++++++---- .../rasterframes/RasterJoinSpec.scala | 11 +++++++-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala b/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala index e0cec7a8c..d7e71dbe8 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala @@ -23,6 +23,8 @@ package org.locationtech.rasterframes.extensions import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.locationtech.rasterframes._ +import org.locationtech.rasterframes.expressions.SpatialRelation +import org.locationtech.rasterframes.expressions.accessors.ExtractTile import org.locationtech.rasterframes.functions.reproject_and_merge import org.locationtech.rasterframes.util._ @@ -30,15 +32,28 @@ import scala.util.Random object RasterJoin { + /** Perform a raster join on dataframes that each have proj_raster columns, or crs and extent explicitly included. */ def apply(left: DataFrame, right: DataFrame): DataFrame = { - val df = apply(left, right, left("extent"), left("crs"), right("extent"), right("crs")) - df.drop(right("extent")).drop(right("crs")) + def usePRT(d: DataFrame) = + d.projRasterColumns.headOption + .map(p => (rf_crs(p), rf_extent(p))) + .orElse(Some(col("crs"), col("extent"))) + .map { case (crs, extent) => + val d2 = d.withColumn("crs", crs).withColumn("extent", extent) + (d2, d2("crs"), d2("extent")) + } + .get + + val (ldf, lcrs, lextent) = usePRT(left) + val (rdf, rcrs, rextent) = usePRT(right) + + apply(ldf, rdf, lextent, lcrs, rextent, rcrs) } def apply(left: DataFrame, right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column): DataFrame = { val leftGeom = st_geometry(leftExtent) val rightGeomReproj = st_reproject(st_geometry(rightExtent), rightCRS, leftCRS) - val joinExpr = st_intersects(leftGeom, rightGeomReproj) + val joinExpr = new Column(SpatialRelation.Intersects(leftGeom.expr, rightGeomReproj.expr)) apply(left, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS) } @@ -65,7 +80,7 @@ object RasterJoin { val leftAggCols = left.columns.map(s => first(left(s), true) as s) // On the RHS we collect result as a list. val rightAggCtx = Seq(collect_list(rightExtent) as rightExtent2, collect_list(rightCRS) as rightCRS2) - val rightAggTiles = right.tileColumns.map(c => collect_list(c) as c.columnName) + val rightAggTiles = right.tileColumns.map(c => collect_list(ExtractTile(c)) as c.columnName) val rightAggOther = right.notTileColumns .filter(n => n.columnName != rightExtent.columnName && n.columnName != rightCRS.columnName) .map(c => collect_list(c) as (c.columnName + "_agg")) diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala index b2cd5d8ce..ff0cfeebb 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala @@ -154,8 +154,6 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { total18 should be > 0.0 total18 should be < total17 - - } it("should pass through ancillary columns") { @@ -164,5 +162,14 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers { val joined = left.rasterJoin(right) joined.columns should contain allElementsOf Seq("left_id", "right_id_agg") } + + it("should handle proj_raster types") { + val df1 = Seq(one).toDF("one") + val df2 = Seq(two).toDF("two") + noException shouldBe thrownBy { + val joined1 = df1.rasterJoin(df2) + val joined2 = df2.rasterJoin(df1) + } + } } } From a71505dba2713faa0f74d2391cf8f71372578162 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Mon, 18 Nov 2019 12:37:04 -0500 Subject: [PATCH 16/24] register rf_local_extract_bit with SQL functions Signed-off-by: Jason T. Brown --- .../org/locationtech/rasterframes/expressions/package.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index d2896eb1e..52c53adf6 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -140,5 +140,6 @@ package object expressions { registry.registerExpression[transformers.ReprojectGeometry]("st_reproject") registry.registerExpression[ExtractBits]("rf_local_extract_bits") + registry.registerExpression[ExtractBits]("rf_local_extract_bit") } } From 419a920f54f27b8e14ea743b1d864443f0bdb394 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Mon, 18 Nov 2019 13:57:19 -0500 Subject: [PATCH 17/24] break out commented assert into skipped unit test around masking and deserialization Signed-off-by: Jason T. Brown --- .../main/python/tests/RasterFunctionsTests.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py index a251f682b..46bb71886 100644 --- a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py +++ b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py @@ -27,6 +27,7 @@ import numpy as np from numpy.testing import assert_equal +from unittest import skip from . import TestEnvironment @@ -278,12 +279,7 @@ def test_mask_bits(self): mask_fill_df.select(rf_data_cells('mbb')).first()[0], 16 - 4 ) - # Unsure why this fails. mask_fill_tile.cells is all 42 unmasked. - # self.assertEqual(mask_fill_tile.cells.mask.sum(), 4, - # f'Expected {16 - 4} data values but got the masked tile:' - # f'{mask_fill_tile}' - # ) - # + # mask out 6816, 6900 mask_med_hi_cir = df.withColumn('mask_cir_mh', rf_mask_by_bits('t', 'mask', 11, 2, [2, 3])) \ @@ -294,6 +290,32 @@ def test_mask_bits(self): 5 ) + @skip('Issue #422 https://github.com/locationtech/rasterframes/issues/422') + def test_mask_and_deser(self): + # duplicates much of test_mask_bits but + t = Tile(42 * np.ones((4, 4), 'uint16'), CellType.uint16()) + # with a varitey of known values + mask = Tile(np.array([ + [1, 1, 2720, 2720], + [1, 6816, 6816, 2756], + [2720, 2720, 6900, 2720], + [2720, 6900, 6816, 1] + ]), CellType('uint16raw')) + + df = self.spark.createDataFrame([Row(t=t, mask=mask)]) + + # removes fill value 1 + mask_fill_df = df.select(rf_mask_by_bit('t', 'mask', 0, True).alias('mbb')) + mask_fill_tile = mask_fill_df.first()['mbb'] + + self.assertTrue(mask_fill_tile.cell_type.has_no_data()) + + # Unsure why this fails. mask_fill_tile.cells is all 42 unmasked. + self.assertEqual(mask_fill_tile.cells.mask.sum(), 4, + f'Expected {16 - 4} data values but got the masked tile:' + f'{mask_fill_tile}' + ) + def test_mask(self): from pyspark.sql import Row from pyrasterframes.rf_types import Tile, CellType From f4d307452b1c1c0ef1f4356b194341c09d0908f6 Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Mon, 18 Nov 2019 14:01:54 -0500 Subject: [PATCH 18/24] Fixed `ReprojectToLayer` to work with `proj_raster`. Closes #357 Depends on #420. --- .../aggregates/TileRasterizerAggregate.scala | 4 +- .../rasterframes/extensions/Implicits.scala | 10 ++ ....scala => LayerSpatialColumnMethods.scala} | 6 +- .../extensions/RasterFrameLayerMethods.scala | 2 +- .../rasterframes/extensions/RasterJoin.scala | 9 +- .../extensions/ReprojectToLayer.scala | 5 +- .../scala/examples/CreatingRasterFrames.scala | 92 ------------------- core/src/test/scala/examples/MeanValue.scala | 50 ---------- ...rFrameSpec.scala => RasterLayerSpec.scala} | 77 ++++++++-------- .../rasterframes/TestEnvironment.scala | 26 +++++- 10 files changed, 89 insertions(+), 192 deletions(-) rename core/src/main/scala/org/locationtech/rasterframes/extensions/{RFSpatialColumnMethods.scala => LayerSpatialColumnMethods.scala} (96%) delete mode 100644 core/src/test/scala/examples/CreatingRasterFrames.scala delete mode 100644 core/src/test/scala/examples/MeanValue.scala rename core/src/test/scala/org/locationtech/rasterframes/{RasterFrameSpec.scala => RasterLayerSpec.scala} (87%) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala index 6647f4258..098841362 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala @@ -100,9 +100,7 @@ object TileRasterizerAggregate { def apply(tlm: TileLayerMetadata[_], sampler: ResampleMethod): ProjectedRasterDefinition = { // Try to determine the actual dimensions of our data coverage - val actualSize = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent) // <--- Do we have the math right here? - val cols = actualSize.width - val rows = actualSize.height + val TileDimensions(cols, rows) = tlm.totalDimensions new ProjectedRasterDefinition(cols, rows, tlm.cellType, tlm.crs, tlm.extent, sampler) } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/extensions/Implicits.scala b/core/src/main/scala/org/locationtech/rasterframes/extensions/Implicits.scala index 563e03e87..46be52a4e 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/extensions/Implicits.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/extensions/Implicits.scala @@ -31,6 +31,7 @@ import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types.{MetadataBuilder, Metadata => SMetadata} +import org.locationtech.rasterframes.model.TileDimensions import spray.json.JsonFormat import scala.reflect.runtime.universe._ @@ -79,6 +80,15 @@ trait Implicits { private[rasterframes] implicit class WithMetadataBuilderMethods(val self: MetadataBuilder) extends MetadataBuilderMethods + + private[rasterframes] + implicit class TLMHasTotalCells(tlm: TileLayerMetadata[_]) { + // TODO: With upgrade to GT 3.1, replace this with the more general `Dimensions[Long]` + def totalDimensions: TileDimensions = { + val gb = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent) + TileDimensions(gb.width, gb.height) + } + } } object Implicits extends Implicits diff --git a/core/src/main/scala/org/locationtech/rasterframes/extensions/RFSpatialColumnMethods.scala b/core/src/main/scala/org/locationtech/rasterframes/extensions/LayerSpatialColumnMethods.scala similarity index 96% rename from core/src/main/scala/org/locationtech/rasterframes/extensions/RFSpatialColumnMethods.scala rename to core/src/main/scala/org/locationtech/rasterframes/extensions/LayerSpatialColumnMethods.scala index af79c1c05..e78d07017 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/extensions/RFSpatialColumnMethods.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/extensions/LayerSpatialColumnMethods.scala @@ -41,7 +41,7 @@ import org.locationtech.rasterframes.encoders.serialized_literal * * @since 12/15/17 */ -trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with StandardColumns { +trait LayerSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with StandardColumns { import Implicits.{WithDataFrameMethods, WithRasterFrameLayerMethods} import org.locationtech.geomesa.spark.jts._ @@ -112,7 +112,7 @@ trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with Sta */ def withCenterLatLng(colName: String = "center"): RasterFrameLayer = { val key2Center = sparkUdf(keyCol2LatLng) - self.withColumn(colName, key2Center(self.spatialKeyColumn).cast(RFSpatialColumnMethods.LngLatStructType)).certify + self.withColumn(colName, key2Center(self.spatialKeyColumn).cast(LayerSpatialColumnMethods.LngLatStructType)).certify } /** @@ -130,6 +130,6 @@ trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with Sta } } -object RFSpatialColumnMethods { +object LayerSpatialColumnMethods { private[rasterframes] val LngLatStructType = StructType(Seq(StructField("longitude", DoubleType), StructField("latitude", DoubleType))) } diff --git a/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterFrameLayerMethods.scala b/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterFrameLayerMethods.scala index e9d375f12..49a1cfdee 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterFrameLayerMethods.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterFrameLayerMethods.scala @@ -52,7 +52,7 @@ import scala.reflect.runtime.universe._ * @since 7/18/17 */ trait RasterFrameLayerMethods extends MethodExtensions[RasterFrameLayer] - with RFSpatialColumnMethods with MetadataKeys { + with LayerSpatialColumnMethods with MetadataKeys { import Implicits.{WithDataFrameMethods, WithRasterFrameLayerMethods} @transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName)) diff --git a/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala b/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala index d7e71dbe8..fe6b65c7d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala @@ -22,7 +22,9 @@ package org.locationtech.rasterframes.extensions import org.apache.spark.sql._ import org.apache.spark.sql.functions._ +import org.locationtech.rasterframes import org.locationtech.rasterframes._ +import org.locationtech.rasterframes.encoders.serialized_literal import org.locationtech.rasterframes.expressions.SpatialRelation import org.locationtech.rasterframes.expressions.accessors.ExtractTile import org.locationtech.rasterframes.functions.reproject_and_merge @@ -89,9 +91,12 @@ object RasterJoin { // After the aggregation we take all the tiles we've collected and resample + merge // into LHS extent/CRS. // Use a representative tile from the left for the tile dimensions - val leftTile = left.tileColumns.headOption.getOrElse(throw new IllegalArgumentException("Need at least one target tile on LHS")) + val destDims = left.tileColumns.headOption + .map(t => rf_dimensions(unresolved(t))) + .getOrElse(serialized_literal(NOMINAL_TILE_DIMS)) + val reprojCols = rightAggTiles.map(t => reproject_and_merge( - col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), rf_dimensions(unresolved(leftTile)) + col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), destDims ) as t.columnName) val finalCols = leftAggCols.map(unresolved) ++ reprojCols ++ rightAggOther.map(unresolved) diff --git a/core/src/main/scala/org/locationtech/rasterframes/extensions/ReprojectToLayer.scala b/core/src/main/scala/org/locationtech/rasterframes/extensions/ReprojectToLayer.scala index d5e6f5e31..6c80e9d0d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/extensions/ReprojectToLayer.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/extensions/ReprojectToLayer.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.functions.broadcast import org.locationtech.rasterframes._ import org.locationtech.rasterframes.util._ + +/** Algorithm for projecting an arbitrary RasterFrame into a layer with consistent CRS and gridding. */ object ReprojectToLayer { def apply(df: DataFrame, tlm: TileLayerMetadata[SpatialKey]): RasterFrameLayer = { // create a destination dataframe with crs and extend columns @@ -42,8 +44,9 @@ object ReprojectToLayer { e = tlm.mapTransform(sk) } yield (sk, e, crs) + // Create effectively a target RasterFrame, but with no tiles. val dest = gridItems.toSeq.toDF(SPATIAL_KEY_COLUMN.columnName, EXTENT_COLUMN.columnName, CRS_COLUMN.columnName) - dest.show(false) + val joined = RasterJoin(broadcast(dest), df) joined.asLayer(SPATIAL_KEY_COLUMN, tlm) diff --git a/core/src/test/scala/examples/CreatingRasterFrames.scala b/core/src/test/scala/examples/CreatingRasterFrames.scala deleted file mode 100644 index 8b5c00c72..000000000 --- a/core/src/test/scala/examples/CreatingRasterFrames.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * This software is licensed under the Apache 2 license, quoted below. - * - * Copyright 2017 Astraea, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * [http://www.apache.org/licenses/LICENSE-2.0] - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - * - */ - -package examples - -/** - * - * @author sfitch - * @since 11/6/17 - */ -object CreatingRasterFrames extends App { -// # Creating RasterFrames -// -// There are a number of ways to create a `RasterFrameLayer`, as enumerated in the sections below. -// -// ## Initialization -// -// First, some standard `import`s: - - import org.locationtech.rasterframes._ - import geotrellis.raster._ - import geotrellis.raster.io.geotiff.SinglebandGeoTiff - import geotrellis.spark.io._ - import org.apache.spark.sql._ - -// Next, initialize the `SparkSession`, and call the `withRasterFrames` method on it: - - implicit val spark = SparkSession.builder(). - master("local[*]").appName("RasterFrames"). - getOrCreate(). - withRasterFrames - spark.sparkContext.setLogLevel("ERROR") - -// ## From `ProjectedExtent` -// -// The simplest mechanism for getting a RasterFrameLayer is to use the `toLayer(tileCols, tileRows)` extension method on `ProjectedRaster`. - - val scene = SinglebandGeoTiff("src/test/resources/L8-B8-Robinson-IL.tiff") - val rf = scene.projectedRaster.toLayer(128, 128) - rf.show(5, false) - - -// ## From `TileLayerRDD` -// -// Another option is to use a GeoTrellis [`LayerReader`](https://docs.geotrellis.io/en/latest/guide/tile-backends.html), to get a `TileLayerRDD` for which there's also a `toLayer` extension method. - - -// ## Inspecting Structure -// -// `RasterFrameLayer` has a number of methods providing access to metadata about the contents of the RasterFrameLayer. -// -// ### Tile Column Names - - rf.tileColumns.map(_.toString) - -// ### Spatial Key Column Name - - rf.spatialKeyColumn.toString - -// ### Temporal Key Column -// -// Returns an `Option[Column]` since not all RasterFrames have an explicit temporal dimension. - - rf.temporalKeyColumn.map(_.toString) - -// ### Tile Layer Metadata -// -// The Tile Layer Metadata defines how the spatial/spatiotemporal domain is discretized into tiles, -// and what the key bounds are. - - import spray.json._ - // The `fold` is required because an `Either` is retured, depending on the key type. - rf.tileLayerMetadata.fold(_.toJson, _.toJson).prettyPrint - - spark.stop() -} diff --git a/core/src/test/scala/examples/MeanValue.scala b/core/src/test/scala/examples/MeanValue.scala deleted file mode 100644 index 2ee264469..000000000 --- a/core/src/test/scala/examples/MeanValue.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * This software is licensed under the Apache 2 license, quoted below. - * - * Copyright 2017 Astraea, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * [http://www.apache.org/licenses/LICENSE-2.0] - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - * - */ - -package examples - -import org.locationtech.rasterframes._ -import geotrellis.raster.io.geotiff.SinglebandGeoTiff -import org.apache.spark.sql.SparkSession - -/** - * Compute the cell mean value of an image. - * - * @since 10/23/17 - */ -object MeanValue extends App { - - implicit val spark = SparkSession.builder() - .master("local[*]") - .appName(getClass.getName) - .getOrCreate() - .withRasterFrames - - - val scene = SinglebandGeoTiff("src/test/resources/L8-B8-Robinson-IL.tiff") - - val rf = scene.projectedRaster.toLayer(128, 128) // <-- tile size - - rf.printSchema - - val tileCol = rf("tile") - rf.agg(rf_agg_no_data_cells(tileCol), rf_agg_data_cells(tileCol), rf_agg_mean(tileCol)).show(false) - - spark.stop() -} diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFrameSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterLayerSpec.scala similarity index 87% rename from core/src/test/scala/org/locationtech/rasterframes/RasterFrameSpec.scala rename to core/src/test/scala/org/locationtech/rasterframes/RasterLayerSpec.scala index f37c5150a..931ca409d 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFrameSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterLayerSpec.scala @@ -23,19 +23,21 @@ package org.locationtech.rasterframes +import java.net.URI import java.sql.Timestamp import java.time.ZonedDateTime -import org.locationtech.rasterframes.util._ -import geotrellis.proj4.LatLng -import geotrellis.raster.render.{ColorMap, ColorRamp} -import geotrellis.raster.{ProjectedRaster, Tile, TileFeature, TileLayout, UByteCellType} +import geotrellis.proj4.{CRS, LatLng} +import geotrellis.raster.{MultibandTile, ProjectedRaster, Raster, Tile, TileFeature, TileLayout, UByteCellType, UByteConstantNoDataCellType} import geotrellis.spark._ import geotrellis.spark.tiling._ import geotrellis.vector.{Extent, ProjectedExtent} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{SQLContext, SparkSession} +import org.apache.spark.sql.{Encoders, SQLContext, SparkSession} import org.locationtech.rasterframes.model.TileDimensions +import org.locationtech.rasterframes.ref.RasterSource +import org.locationtech.rasterframes.tiles.ProjectedRasterTile +import org.locationtech.rasterframes.util._ import scala.util.control.NonFatal @@ -44,7 +46,7 @@ import scala.util.control.NonFatal * * @since 7/10/17 */ -class RasterFrameSpec extends TestEnvironment with MetadataKeys +class RasterLayerSpec extends TestEnvironment with MetadataKeys with TestData { import TestData.randomTile import spark.implicits._ @@ -232,17 +234,40 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys assert(bounds._2 === SpaceTimeKey(3, 1, now)) } - def basicallySame(expected: Extent, computed: Extent): Unit = { - val components = Seq( - (expected.xmin, computed.xmin), - (expected.ymin, computed.ymin), - (expected.xmax, computed.xmax), - (expected.ymax, computed.ymax) - ) - forEvery(components)(c ⇒ - assert(c._1 === c._2 +- 0.000001) + it("should create layer from arbitrary RasterFrame") { + val src = RasterSource(URI.create("https://raw.githubusercontent.com/locationtech/rasterframes/develop/core/src/test/resources/LC08_RGB_Norfolk_COG.tiff")) + val srcCrs = src.crs + + def project(r: Raster[MultibandTile]): Seq[ProjectedRasterTile] = + r.tile.bands.map(b => ProjectedRasterTile(b, r.extent, srcCrs)) + + val prtEnc = ProjectedRasterTile.prtEncoder + implicit val enc = Encoders.tuple(prtEnc, prtEnc, prtEnc) + + val rasters = src.readAll(bands = Seq(0, 1, 2)).map(project).map(p => (p(0), p(1), p(2))) + + val df = rasters.toDF("red", "green", "blue") + + val crs = CRS.fromString("+proj=utm +zone=18 +datum=WGS84 +units=m +no_defs") + + val extent = Extent(364455.0, 4080315.0, 395295.0, 4109985.0) + val layout = LayoutDefinition(extent, TileLayout(2, 2, 32, 32)) + + val tlm = new TileLayerMetadata[SpatialKey]( + UByteConstantNoDataCellType, + layout, + extent, + crs, + KeyBounds(SpatialKey(0, 0), SpatialKey(1, 1)) ) - } + val layer = df.toLayer(tlm) + + val TileDimensions(cols, rows) = tlm.totalDimensions + val prt = layer.toMultibandRaster(Seq($"red", $"green", $"blue"), cols, rows) + prt.tile.dimensions should be((cols, rows)) + prt.crs should be(crs) + prt.extent should be(extent) + } it("shouldn't clip already clipped extents") { val rf = TestData.randomSpatialTileLayerRDD(1024, 1024, 8, 8).toLayer @@ -258,27 +283,8 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys basicallySame(expected2, computed2) } - def Greyscale(stops: Int): ColorRamp = { - val colors = (0 to stops) - .map(i ⇒ { - val c = java.awt.Color.HSBtoRGB(0f, 0f, i / stops.toFloat) - (c << 8) | 0xFF // Add alpha channel. - }) - ColorRamp(colors) - } - - def render(tile: Tile, tag: String): Unit = { - if(false && !isCI) { - val colors = ColorMap.fromQuantileBreaks(tile.histogram, Greyscale(128)) - val path = s"target/${getClass.getSimpleName}_$tag.png" - logger.info(s"Writing '$path'") - tile.color(colors).renderPng().write(path) - } - } - it("should rasterize with a spatiotemporal key") { val rf = TestData.randomSpatioTemporalTileLayerRDD(20, 20, 2, 2).toLayer - noException shouldBe thrownBy { rf.toRaster($"tile", 128, 128) } @@ -291,7 +297,6 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys val joinTypes = Seq("inner", "outer", "fullouter", "left_outer", "right_outer", "leftsemi") forEvery(joinTypes) { jt ⇒ val joined = rf1.spatialJoin(rf2, jt) - //println(joined.schema.json) assert(joined.tileLayerMetadata.isRight) } } diff --git a/core/src/test/scala/org/locationtech/rasterframes/TestEnvironment.scala b/core/src/test/scala/org/locationtech/rasterframes/TestEnvironment.scala index 01fbffcd0..e9af4f382 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/TestEnvironment.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/TestEnvironment.scala @@ -23,7 +23,10 @@ package org.locationtech.rasterframes import java.nio.file.{Files, Path} import com.typesafe.scalalogging.Logger +import geotrellis.raster.Tile +import geotrellis.raster.render.{ColorMap, ColorRamps} import geotrellis.raster.testkit.RasterMatchers +import geotrellis.vector.Extent import org.apache.spark.sql._ import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.StructType @@ -78,6 +81,13 @@ trait TestEnvironment extends FunSpec rows.length == inRows } + def render(tile: Tile, tag: String): Unit = { + val colors = ColorMap.fromQuantileBreaks(tile.histogram, ColorRamps.greyscale(128)) + val path = s"target/${getClass.getSimpleName}_$tag.png" + logger.info(s"Writing '$path'") + tile.color(colors).renderPng().write(path) + } + /** * Constructor for creating a DataFrame with a single row and no columns. * Useful for testing the invocation of data constructing UDFs. @@ -99,6 +109,18 @@ trait TestEnvironment extends FunSpec def matchGeom(g: Geometry, tolerance: Double) = new GeometryMatcher(g, tolerance) + def basicallySame(expected: Extent, computed: Extent): Unit = { + val components = Seq( + (expected.xmin, computed.xmin), + (expected.ymin, computed.ymin), + (expected.xmax, computed.xmax), + (expected.ymax, computed.ymax) + ) + forEvery(components)(c ⇒ + assert(c._1 === c._2 +- 0.000001) + ) + } + def checkDocs(name: String): Unit = { import spark.implicits._ val docs = sql(s"DESCRIBE FUNCTION EXTENDED $name").as[String].collect().mkString("\n") @@ -107,8 +129,4 @@ trait TestEnvironment extends FunSpec docs shouldNot include("null") docs shouldNot include("N/A") } -} - -object TestEnvironment { - } \ No newline at end of file From 047a63d1d60fe6d66b5c5608c8e2009183709d1b Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Mon, 18 Nov 2019 14:58:59 -0500 Subject: [PATCH 19/24] PR feedback. --- .../rasterframes/RasterFunctions.scala | 14 +++++++------- .../expressions/transformers/XZ2Indexer.scala | 4 ++-- .../expressions/transformers/Z2Indexer.scala | 4 ++-- .../datasource/raster/RasterSourceRelation.scala | 5 ++++- docs/src/main/paradox/reference.md | 2 +- docs/src/main/paradox/release-notes.md | 2 +- .../src/main/python/pyrasterframes/__init__.py | 14 +++++++++----- .../src/main/python/tests/RasterSourceTest.py | 9 +++++++-- 8 files changed, 33 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala index 3f238ea3e..09016d527 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala @@ -62,7 +62,7 @@ trait RasterFunctions { /** Extracts the bounding box from a RasterSource or ProjectedRasterTile */ def rf_extent(col: Column): TypedColumn[Any, Extent] = GetExtent(col) - /** Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS + /** Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_xz2_index(targetExtent: Column, targetCRS: Column, indexResolution: Short) = XZ2Indexer(targetExtent, targetCRS, indexResolution) @@ -70,30 +70,30 @@ trait RasterFunctions { * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_xz2_index(targetExtent: Column, targetCRS: Column) = XZ2Indexer(targetExtent, targetCRS, 18: Short) - /** Constructs a XZ2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource + /** Constructs a XZ2 index with provided resolution level in WGS84 from either a ProjectedRasterTile or RasterSource. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_xz2_index(targetExtent: Column, indexResolution: Short) = XZ2Indexer(targetExtent, indexResolution) - /** Constructs a XZ2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource + /** Constructs a XZ2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_xz2_index(targetExtent: Column) = XZ2Indexer(targetExtent, 18: Short) - /** Constructs a Z2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS + /** Constructs a Z2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS. * First the native extent is extracted or computed, and then center is used as the indexing location. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_z2_index(targetExtent: Column, targetCRS: Column, indexResolution: Short) = Z2Indexer(targetExtent, targetCRS, indexResolution) - /** Constructs a Z2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS + /** Constructs a Z2 index with index resolution of 31 in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS. * First the native extent is extracted or computed, and then center is used as the indexing location. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_z2_index(targetExtent: Column, targetCRS: Column) = Z2Indexer(targetExtent, targetCRS, 31: Short) - /** Constructs a Z2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource + /** Constructs a Z2 index with the given index resolution in WGS84 from either a ProjectedRasterTile or RasterSource * First the native extent is extracted or computed, and then center is used as the indexing location. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_z2_index(targetExtent: Column, indexResolution: Short) = Z2Indexer(targetExtent, indexResolution) - /** Constructs a Z2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource + /** Constructs a Z2 index with index resolution of 31 in WGS84 from either a ProjectedRasterTile or RasterSource * First the native extent is extracted or computed, and then center is used as the indexing location. * For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */ def rf_z2_index(targetExtent: Column) = Z2Indexer(targetExtent, 31: Short) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala index 9d4ea57da..4329ae105 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala @@ -61,9 +61,9 @@ case class XZ2Indexer(left: Expression, right: Expression, indexResolution: Shor override def checkInputDataTypes(): TypeCheckResult = { if (!envelopeExtractor.isDefinedAt(left.dataType)) - TypeCheckFailure(s"Input type '${left.dataType}' does not look like something with an Extent or something with one.") + TypeCheckFailure(s"Input type '${left.dataType}' does not look like a geometry, extent, or something with one.") else if(!crsExtractor.isDefinedAt(right.dataType)) - TypeCheckFailure(s"Input type '${right.dataType}' does not look like something with a CRS.") + TypeCheckFailure(s"Input type '${right.dataType}' does not look like a CRS or something with one.") else TypeCheckSuccess } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala index 42d2a120d..e6c1c428b 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Z2Indexer.scala @@ -62,9 +62,9 @@ case class Z2Indexer(left: Expression, right: Expression, indexResolution: Short override def checkInputDataTypes(): TypeCheckResult = { if (!centroidExtractor.isDefinedAt(left.dataType)) - TypeCheckFailure(s"Input type '${left.dataType}' does not look like something with a point.") + TypeCheckFailure(s"Input type '${left.dataType}' does not look like something with a centroid.") else if(!crsExtractor.isDefinedAt(right.dataType)) - TypeCheckFailure(s"Input type '${right.dataType}' does not look like something with a CRS.") + TypeCheckFailure(s"Input type '${right.dataType}' does not look like a CRS or something with one.") else TypeCheckSuccess } diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceRelation.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceRelation.scala index 706bb9680..17ea45393 100644 --- a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceRelation.scala +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceRelation.scala @@ -42,7 +42,10 @@ import org.locationtech.rasterframes.tiles.ProjectedRasterTile * @param bandIndexes band indexes to fetch * @param subtileDims how big to tile/subdivide rasters info * @param lazyTiles if true, creates a lazy representation of tile instead of fetching contents. - * @param spatialIndexPartitions if not None, adds a spatial index. If Option value < 1, uses the value of `numShufflePartitions` in SparkContext. + * @param spatialIndexPartitions Number of spatial index-based partitions to create. + * If Option value > 0, that number of partitions are created after adding a spatial index. + * If Option value <= 0, uses the value of `numShufflePartitions` in SparkContext. + * If None, no spatial index is added and hash partitioning is used. */ case class RasterSourceRelation( sqlContext: SQLContext, diff --git a/docs/src/main/paradox/reference.md b/docs/src/main/paradox/reference.md index 11a02d3ce..f4d3425c5 100644 --- a/docs/src/main/paradox/reference.md +++ b/docs/src/main/paradox/reference.md @@ -81,7 +81,7 @@ Constructs a XZ2 index in WGS84/EPSG:4326 from either a Geometry, Extent, Projec Long rf_z2_index(Extent extent, CRS crs) Long rf_z2_index(ProjectedRasterTile proj_raster) -Constructs a Z2 index in WGS84/EPSG:4326 from either a Geometry, Extent, ProjectedRasterTile and its CRS. First the native extent is extracted or computed, and then center is used as the indexing location. This function is useful for [range partitioning](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=registerjava#pyspark.sql.DataFrame.repartitionByRange). +Constructs a Z2 index in WGS84/EPSG:4326 from either a Geometry, Extent, ProjectedRasterTile and its CRS. First the native extent is extracted or computed, and then center is used as the indexing location. This function is useful for [range partitioning](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=registerjava#pyspark.sql.DataFrame.repartitionByRange). See @ref:[Reading Raster Data](raster-read.md#spatial-indexing-and-partitioning) section for details on how to have an index automatically added when reading raster data. ## Tile Metadata and Mutation diff --git a/docs/src/main/paradox/release-notes.md b/docs/src/main/paradox/release-notes.md index a907c868a..09d52eaf7 100644 --- a/docs/src/main/paradox/release-notes.md +++ b/docs/src/main/paradox/release-notes.md @@ -23,7 +23,7 @@ * Updated to GeoTrellis 2.3.3 and Proj4j 1.1.0. * Fixed issues with `LazyLogger` and shading assemblies ([#293](https://github.com/locationtech/rasterframes/issues/293)) * Updated `rf_crs` to accept string columns containing CRS specifications. ([#366](https://github.com/locationtech/rasterframes/issues/366)) -* Added `rf_xz2_index` function. ([#368](https://github.com/locationtech/rasterframes/issues/368)) +* Added `rf_spatial_index` function. ([#368](https://github.com/locationtech/rasterframes/issues/368)) * _Breaking_ (potentially): removed `pyrasterframes.create_spark_session` in lieu of `pyrasterframes.utils.create_rf_spark_session` ### 0.8.2 diff --git a/pyrasterframes/src/main/python/pyrasterframes/__init__.py b/pyrasterframes/src/main/python/pyrasterframes/__init__.py index 78cafdff5..5f89508b1 100644 --- a/pyrasterframes/src/main/python/pyrasterframes/__init__.py +++ b/pyrasterframes/src/main/python/pyrasterframes/__init__.py @@ -149,11 +149,15 @@ def temp_name(): if band_indexes is None: band_indexes = [0] - if spatial_index_partitions is False: - spatial_index_partitions = None - - if spatial_index_partitions is not None: - if spatial_index_partitions is True: + if spatial_index_partitions: + num = int(spatial_index_partitions) + if num < 0: + spatial_index_partitions = '-1' + elif num == 0: + spatial_index_partitions = None + + if spatial_index_partitions: + if spatial_index_partitions == True: spatial_index_partitions = "-1" else: spatial_index_partitions = str(spatial_index_partitions) diff --git a/pyrasterframes/src/main/python/tests/RasterSourceTest.py b/pyrasterframes/src/main/python/tests/RasterSourceTest.py index e840092ca..4687864ad 100644 --- a/pyrasterframes/src/main/python/tests/RasterSourceTest.py +++ b/pyrasterframes/src/main/python/tests/RasterSourceTest.py @@ -213,7 +213,12 @@ def test_catalog_named_arg(self): self.assertTrue(df.select('b1_path').distinct().count() == 3) def test_spatial_partitioning(self): - df = self.spark.read.raster(self.path(1, 1), spatial_index_partitions=True) + f = self.path(1, 1) + df = self.spark.read.raster(f, spatial_index_partitions=True) self.assertTrue('spatial_index' in df.columns) - # Other tests? + self.assertEqual(df.rdd.getNumPartitions(), int(self.spark.conf.get("spark.sql.shuffle.partitions"))) + self.assertEqual(self.spark.read.raster(f, spatial_index_partitions=34).rdd.getNumPartitions(), 34) + self.assertEqual(self.spark.read.raster(f, spatial_index_partitions="42").rdd.getNumPartitions(), 42) + self.assertFalse('spatial_index' in self.spark.read.raster(f, spatial_index_partitions=False).columns) + self.assertFalse('spatial_index' in self.spark.read.raster(f, spatial_index_partitions=0).columns) \ No newline at end of file From f55088664e1f235c68aed1d7497d05ff8206b8c8 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Mon, 18 Nov 2019 16:40:58 -0500 Subject: [PATCH 20/24] Add failing unit tests for issue 425 ML custom transformer loading is broken Signed-off-by: Jason T. Brown --- .../src/main/python/tests/ExploderTests.py | 14 +++++- .../main/python/tests/NoDataFilterTests.py | 48 +++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 pyrasterframes/src/main/python/tests/NoDataFilterTests.py diff --git a/pyrasterframes/src/main/python/tests/ExploderTests.py b/pyrasterframes/src/main/python/tests/ExploderTests.py index f7635ad7a..9740ab697 100644 --- a/pyrasterframes/src/main/python/tests/ExploderTests.py +++ b/pyrasterframes/src/main/python/tests/ExploderTests.py @@ -25,7 +25,7 @@ from pyrasterframes import TileExploder from pyspark.ml.feature import VectorAssembler -from pyspark.ml import Pipeline +from pyspark.ml import Pipeline, PipelineModel from pyspark.sql.functions import * import unittest @@ -56,3 +56,15 @@ def test_tile_exploder_pipeline_for_tile(self): pipe_model = pipe.fit(df) tranformed_df = pipe_model.transform(df) self.assertTrue(tranformed_df.count() > df.count()) + + def test_tile_exploder_read_write(self): + path = 'test_tile_exploder_read_write.pipe' + df = self.spark.read.raster(self.img_uri) + + assembler = VectorAssembler().setInputCols(['proj_raster']) + pipe = Pipeline().setStages([TileExploder(), assembler]) + + pipe.fit(df).write().overwrite().save(path) + + read_pipe = PipelineModel.load(path) + self.assertEqual(len(read_pipe.stages), 2) diff --git a/pyrasterframes/src/main/python/tests/NoDataFilterTests.py b/pyrasterframes/src/main/python/tests/NoDataFilterTests.py new file mode 100644 index 000000000..7e2db88ac --- /dev/null +++ b/pyrasterframes/src/main/python/tests/NoDataFilterTests.py @@ -0,0 +1,48 @@ +# +# This software is licensed under the Apache 2 license, quoted below. +# +# Copyright 2019 Astraea, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# [http://www.apache.org/licenses/LICENSE-2.0] +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +from . import TestEnvironment + +from pyrasterframes.rasterfunctions import * +from pyrasterframes.rf_types import * + +from pyspark.ml.feature import VectorAssembler +from pyspark.ml import Pipeline, PipelineModel +from pyspark.sql.functions import * + +import unittest + + +class ExploderTests(TestEnvironment): + + def test_no_data_filter_read_write(self): + path = 'test_no_data_filter_read_write.pipe' + df = self.spark.read.raster(self.img_uri) \ + .select(rf_tile_mean('proj_raster').alias('mean')) + + ndf = NoDataFilter().setInputCols(['mean']) + assembler = VectorAssembler().setInputCols(['mean']) + + pipe = Pipeline().setStages([ndf, assembler]) + + pipe.fit(df).write().overwrite().save(path) + + read_pipe = PipelineModel.load(path) + self.assertEqual(len(read_pipe.stages), 2) From 13d04e630df3be51c1b1e1f565f7cddce6aa5a40 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Mon, 18 Nov 2019 17:09:31 -0500 Subject: [PATCH 21/24] Fix for ML transformer read/write Signed-off-by: Jason T. Brown --- pyrasterframes/src/main/python/pyrasterframes/rf_types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyrasterframes/src/main/python/pyrasterframes/rf_types.py b/pyrasterframes/src/main/python/pyrasterframes/rf_types.py index a54617ca1..3e9563242 100644 --- a/pyrasterframes/src/main/python/pyrasterframes/rf_types.py +++ b/pyrasterframes/src/main/python/pyrasterframes/rf_types.py @@ -31,7 +31,7 @@ class here provides the PyRasterFrames entry point. from pyspark.ml.param.shared import HasInputCols from pyspark.ml.wrapper import JavaTransformer -from pyspark.ml.util import JavaMLReadable, JavaMLWritable +from pyspark.ml.util import JavaMLReadable, JavaMLWritable, DefaultParamsReadable, DefaultParamsWritable from pyrasterframes.rf_context import RFContext @@ -462,7 +462,7 @@ def deserialize(self, datum): Tile.__UDT__ = TileUDT() -class TileExploder(JavaTransformer, JavaMLReadable, JavaMLWritable): +class TileExploder(JavaTransformer, DefaultParamsReadable, DefaultParamsWritable): """ Python wrapper for TileExploder.scala """ @@ -472,7 +472,7 @@ def __init__(self): self._java_obj = self._new_java_obj("org.locationtech.rasterframes.ml.TileExploder", self.uid) -class NoDataFilter(JavaTransformer, HasInputCols, JavaMLReadable, JavaMLWritable): +class NoDataFilter(JavaTransformer, HasInputCols, DefaultParamsReadable, DefaultParamsWritable): """ Python wrapper for NoDataFilter.scala """ From c7bf0fb8de75c48bdf209f50c75f3b0ff9aa7785 Mon Sep 17 00:00:00 2001 From: "Simeon H.K. Fitch" Date: Mon, 18 Nov 2019 18:08:34 -0500 Subject: [PATCH 22/24] Cruft removal. --- .../rasterframes/expressions/transformers/ExtractBits.scala | 2 +- .../org/locationtech/rasterframes/extensions/RasterJoin.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala index 352f78ac9..4ba658baa 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala @@ -22,7 +22,7 @@ package org.locationtech.rasterframes.expressions.transformers import geotrellis.raster.Tile -import org.apache.spark.sql.{Column, TypedColumn} +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback diff --git a/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala b/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala index fe6b65c7d..6bd66bab4 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala @@ -22,7 +22,6 @@ package org.locationtech.rasterframes.extensions import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.locationtech.rasterframes import org.locationtech.rasterframes._ import org.locationtech.rasterframes.encoders.serialized_literal import org.locationtech.rasterframes.expressions.SpatialRelation From e4c2903b1b4a7fdee31933321f978c5f8ced289e Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Tue, 19 Nov 2019 09:15:08 -0500 Subject: [PATCH 23/24] Python unit tests check read pipeline stages Signed-off-by: Jason T. Brown --- pyrasterframes/src/main/python/tests/ExploderTests.py | 1 + pyrasterframes/src/main/python/tests/NoDataFilterTests.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyrasterframes/src/main/python/tests/ExploderTests.py b/pyrasterframes/src/main/python/tests/ExploderTests.py index 9740ab697..4b24f2f6b 100644 --- a/pyrasterframes/src/main/python/tests/ExploderTests.py +++ b/pyrasterframes/src/main/python/tests/ExploderTests.py @@ -68,3 +68,4 @@ def test_tile_exploder_read_write(self): read_pipe = PipelineModel.load(path) self.assertEqual(len(read_pipe.stages), 2) + self.assertTrue(isinstance(read_pipe.stages[0], TileExploder)) diff --git a/pyrasterframes/src/main/python/tests/NoDataFilterTests.py b/pyrasterframes/src/main/python/tests/NoDataFilterTests.py index 7e2db88ac..169783358 100644 --- a/pyrasterframes/src/main/python/tests/NoDataFilterTests.py +++ b/pyrasterframes/src/main/python/tests/NoDataFilterTests.py @@ -37,8 +37,9 @@ def test_no_data_filter_read_write(self): df = self.spark.read.raster(self.img_uri) \ .select(rf_tile_mean('proj_raster').alias('mean')) - ndf = NoDataFilter().setInputCols(['mean']) - assembler = VectorAssembler().setInputCols(['mean']) + input_cols = ['mean'] + ndf = NoDataFilter().setInputCols(input_cols) + assembler = VectorAssembler().setInputCols(input_cols) pipe = Pipeline().setStages([ndf, assembler]) @@ -46,3 +47,5 @@ def test_no_data_filter_read_write(self): read_pipe = PipelineModel.load(path) self.assertEqual(len(read_pipe.stages), 2) + actual_stages_ndf = read_pipe.stages[0].getInputCols() + self.assertEqual(actual_stages_ndf, input_cols) From 66452467ef48c0f34a82917bd0abb6900dacfd71 Mon Sep 17 00:00:00 2001 From: "Jason T. Brown" Date: Tue, 19 Nov 2019 09:20:01 -0500 Subject: [PATCH 24/24] remove unused imports Signed-off-by: Jason T. Brown --- pyrasterframes/src/main/python/pyrasterframes/rf_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrasterframes/src/main/python/pyrasterframes/rf_types.py b/pyrasterframes/src/main/python/pyrasterframes/rf_types.py index 3e9563242..d76e3832c 100644 --- a/pyrasterframes/src/main/python/pyrasterframes/rf_types.py +++ b/pyrasterframes/src/main/python/pyrasterframes/rf_types.py @@ -31,7 +31,7 @@ class here provides the PyRasterFrames entry point. from pyspark.ml.param.shared import HasInputCols from pyspark.ml.wrapper import JavaTransformer -from pyspark.ml.util import JavaMLReadable, JavaMLWritable, DefaultParamsReadable, DefaultParamsWritable +from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyrasterframes.rf_context import RFContext