Skip to content

Commit

Permalink
Merge branch 'develop' into feature/gt-3.0
Browse files Browse the repository at this point in the history
* develop: (24 commits)
  remove unused imports
  Python unit tests check read pipeline stages
  Cruft removal.
  Fix for ML transformer read/write
  Add failing unit tests for issue 425 ML custom transformer loading is broken
  PR feedback.
  Fixed `ReprojectToLayer` to work with `proj_raster`. Closes #357 Depends on #420.
  break out commented assert into skipped unit test around masking and deserialization
  register rf_local_extract_bit with SQL functions
  Added the ability to do a raster_join on proj_raster types. Fixes #419.
  Add landsat masking section to masking docs page
  Add mask bits  python api and unit test
  Extract bits should throw on non-integral cell types
  Fix for both masking by def and value; expand code comments; update tests
  Python regression.
  Masking improvements and unit tests.
  Regression
  Added (disabled) integration test for profiling spatial index effects. Reorganized non-evaluated documentationm files to `docs`.
  Add failing unit test for mask by value on 0
  Updated python tests against spatial index functions.
  ...

# Conflicts:
#	core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala
#	core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala
#	core/src/main/scala/org/locationtech/rasterframes/extensions/Implicits.scala
#	core/src/test/scala/examples/CreatingRasterFrames.scala
#	core/src/test/scala/org/locationtech/rasterframes/RasterLayerSpec.scala
#	core/src/test/scala/org/locationtech/rasterframes/expressions/XZ2IndexerSpec.scala
#	datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceRelation.scala
#	datasource/src/test/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSourceSpec.scala
  • Loading branch information
metasim committed Nov 22, 2019
2 parents 7345251 + d8ea781 commit d223b82
Show file tree
Hide file tree
Showing 46 changed files with 1,548 additions and 523 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ lazy val root = project
.enablePlugins(RFReleasePlugin)
.settings(
publish / skip := true,
clean := clean.dependsOn(`rf-notebook`/clean).value
clean := clean.dependsOn(`rf-notebook`/clean, docs/clean).value
)

lazy val `rf-notebook` = project
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,50 @@ trait RasterFunctions {
/** Query the number of (cols, rows) in a Tile. */
def rf_dimensions(col: Column): TypedColumn[Any, Dimensions[Int]] = GetDimensions(col)

/** Extracts the CRS from a RasterSource or ProjectedRasterTile */
def rf_crs(col: Column): TypedColumn[Any, CRS] = GetCRS(col)

/** Extracts the bounding box of a geometry as an Extent */
def st_extent(col: Column): TypedColumn[Any, Extent] = GeometryToExtent(col)

/** Extracts the bounding box from a RasterSource or ProjectedRasterTile */
def rf_extent(col: Column): TypedColumn[Any, Extent] = GetExtent(col)

/** Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS
/** Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS.
* For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */
def rf_spatial_index(targetExtent: Column, targetCRS: Column, indexResolution: Short) = XZ2Indexer(targetExtent, targetCRS, indexResolution)
def rf_xz2_index(targetExtent: Column, targetCRS: Column, indexResolution: Short) = XZ2Indexer(targetExtent, targetCRS, indexResolution)

/** Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS
* For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */
def rf_spatial_index(targetExtent: Column, targetCRS: Column) = XZ2Indexer(targetExtent, targetCRS, 18: Short)
def rf_xz2_index(targetExtent: Column, targetCRS: Column) = XZ2Indexer(targetExtent, targetCRS, 18: Short)

/** Constructs a XZ2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource
/** Constructs a XZ2 index with provided resolution level in WGS84 from either a ProjectedRasterTile or RasterSource.
* For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */
def rf_spatial_index(targetExtent: Column, indexResolution: Short) = XZ2Indexer(targetExtent, indexResolution)
def rf_xz2_index(targetExtent: Column, indexResolution: Short) = XZ2Indexer(targetExtent, indexResolution)

/** Constructs a XZ2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource
/** Constructs a XZ2 index with level 18 resolution in WGS84 from either a ProjectedRasterTile or RasterSource.
* For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */
def rf_spatial_index(targetExtent: Column) = XZ2Indexer(targetExtent, 18: Short)
def rf_xz2_index(targetExtent: Column) = XZ2Indexer(targetExtent, 18: Short)

/** Extracts the CRS from a RasterSource or ProjectedRasterTile */
def rf_crs(col: Column): TypedColumn[Any, CRS] = GetCRS(col)
/** Constructs a Z2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS.
* First the native extent is extracted or computed, and then center is used as the indexing location.
* For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */
def rf_z2_index(targetExtent: Column, targetCRS: Column, indexResolution: Short) = Z2Indexer(targetExtent, targetCRS, indexResolution)

/** Constructs a Z2 index with index resolution of 31 in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS.
* First the native extent is extracted or computed, and then center is used as the indexing location.
* For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */
def rf_z2_index(targetExtent: Column, targetCRS: Column) = Z2Indexer(targetExtent, targetCRS, 31: Short)

/** Constructs a Z2 index with the given index resolution in WGS84 from either a ProjectedRasterTile or RasterSource
* First the native extent is extracted or computed, and then center is used as the indexing location.
* For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */
def rf_z2_index(targetExtent: Column, indexResolution: Short) = Z2Indexer(targetExtent, indexResolution)

/** Constructs a Z2 index with index resolution of 31 in WGS84 from either a ProjectedRasterTile or RasterSource
* First the native extent is extracted or computed, and then center is used as the indexing location.
* For details: https://www.geomesa.org/documentation/user/datastores/index_overview.html */
def rf_z2_index(targetExtent: Column) = Z2Indexer(targetExtent, 31: Short)

/** Extracts the tile from a ProjectedRasterTile, or passes through a Tile. */
def rf_tile(col: Column): TypedColumn[Any, Tile] = RealizeTile(col)
Expand Down Expand Up @@ -318,7 +338,7 @@ trait RasterFunctions {

/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
list, replace the value with NODATA. */
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Seq[Int]): TypedColumn[Any, Tile] = {
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Int*): TypedColumn[Any, Tile] = {
import org.apache.spark.sql.functions.array
val valuesCol: Column = array(maskValues.map(lit).toSeq: _*)
rf_mask_by_values(sourceTile, maskTile, valuesCol)
Expand All @@ -336,6 +356,52 @@ trait RasterFunctions {
def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] =
Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue))

/** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): TypedColumn[Any, Tile] =
rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(if (valueToMask) 1 else 0))

/** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): TypedColumn[Any, Tile] = {
import org.apache.spark.sql.functions.array
rf_mask_by_bits(dataTile, maskTile, bitPosition, lit(1), array(valueToMask))
}

/** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): TypedColumn[Any, Tile] = {
val bitMask = rf_local_extract_bits(maskTile, startBit, numBits)
rf_mask_by_values(dataTile, bitMask, valuesToMask)
}


/** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): TypedColumn[Any, Tile] = {
import org.apache.spark.sql.functions.array
val values = array(valuesToMask.map(lit):_*)
rf_mask_by_bits(dataTile, maskTile, lit(startBit), lit(numBits), values)
}

/** Extract value from specified bits of the cells' underlying binary data.
* `startBit` is the first bit to consider, working from the right. It is zero indexed.
* `numBits` is the number of bits to take moving further to the left. */
def rf_local_extract_bits(tile: Column, startBit: Column, numBits: Column): Column =
ExtractBits(tile, startBit, numBits)

/** Extract value from specified bits of the cells' underlying binary data.
* `bitPosition` is bit to consider, working from the right. It is zero indexed. */
def rf_local_extract_bits(tile: Column, bitPosition: Column): Column =
rf_local_extract_bits(tile, bitPosition, lit(1))

/** Extract value from specified bits of the cells' underlying binary data.
* `startBit` is the first bit to consider, working from the right. It is zero indexed.
* `numBits` is the number of bits to take, moving further to the left. */
def rf_local_extract_bits(tile: Column, startBit: Int, numBits: Int): Column =
rf_local_extract_bits(tile, lit(startBit), lit(numBits))

/** Extract value from specified bits of the cells' underlying binary data.
* `bitPosition` is bit to consider, working from the right. It is zero indexed. */
def rf_local_extract_bits(tile: Column, bitPosition: Int): Column =
rf_local_extract_bits(tile, lit(bitPosition))

/** Create a tile where cells in the grid defined by cols, rows, and bounds are filled with the given value. */
def rf_rasterize(geometry: Column, bounds: Column, value: Column, cols: Int, rows: Int): TypedColumn[Any, Tile] =
withTypedAlias("rf_rasterize", geometry)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.jts.JTSTypes
import org.apache.spark.sql.rf.{RasterSourceUDT, TileUDT}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Envelope
import org.locationtech.jts.geom.{Envelope, Point}
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.model.{LazyCRS, TileContext}
import org.locationtech.rasterframes.ref.{ProjectedRasterLike, RasterRef, RFRasterSource}
Expand Down Expand Up @@ -69,13 +69,13 @@ object DynamicExtractors {
}

/** Partial function for pulling a ProjectedRasterLike an input row. */
lazy val projectedRasterLikeExtractor: PartialFunction[DataType, InternalRow ProjectedRasterLike] = {
lazy val projectedRasterLikeExtractor: PartialFunction[DataType, Any ProjectedRasterLike] = {
case _: RasterSourceUDT
(row: InternalRow) => row.to[RFRasterSource](RasterSourceUDT.rasterSourceSerializer)
(input: Any) => input.asInstanceOf[InternalRow].to[RFRasterSource](RasterSourceUDT.rasterSourceSerializer)
case t if t.conformsTo[ProjectedRasterTile] =>
(row: InternalRow) => row.to[ProjectedRasterTile]
(input: Any) => input.asInstanceOf[InternalRow].to[ProjectedRasterTile]
case t if t.conformsTo[RasterRef] =>
(row: InternalRow) => row.to[RasterRef]
(input: Any) => input.asInstanceOf[InternalRow].to[RasterRef]
}

/** Partial function for pulling a CellGrid from an input row. */
Expand All @@ -97,13 +97,36 @@ object DynamicExtractors {
(v: Any) => v.asInstanceOf[InternalRow].to[CRS]
}

lazy val extentLikeExtractor: PartialFunction[DataType, Any Extent] = {
case t if org.apache.spark.sql.rf.WithTypeConformity(t).conformsTo(JTSTypes.GeometryTypeInstance) =>
(input: Any) => JTSTypes.GeometryTypeInstance.deserialize(input).getEnvelopeInternal
case t if t.conformsTo[Extent] =>
(input: Any) => input.asInstanceOf[InternalRow].to[Extent]
case t if t.conformsTo[Envelope] =>
(input: Any) => Extent(input.asInstanceOf[InternalRow].to[Envelope])
lazy val extentExtractor: PartialFunction[DataType, Any Extent] = {
val base: PartialFunction[DataType, Any Extent]= {
case t if org.apache.spark.sql.rf.WithTypeConformity(t).conformsTo(JTSTypes.GeometryTypeInstance) =>
(input: Any) => Extent(JTSTypes.GeometryTypeInstance.deserialize(input).getEnvelopeInternal)
case t if t.conformsTo[Extent] =>
(input: Any) => input.asInstanceOf[InternalRow].to[Extent]
case t if t.conformsTo[Envelope] =>
(input: Any) => Extent(input.asInstanceOf[InternalRow].to[Envelope])
}

val fromPRL = projectedRasterLikeExtractor.andThen(_.andThen(_.extent))
fromPRL orElse base
}

lazy val envelopeExtractor: PartialFunction[DataType, Any => Envelope] = {
val base = PartialFunction[DataType, Any => Envelope] {
case t if org.apache.spark.sql.rf.WithTypeConformity(t).conformsTo(JTSTypes.GeometryTypeInstance) =>
(input: Any) => JTSTypes.GeometryTypeInstance.deserialize(input).getEnvelopeInternal
case t if t.conformsTo[Extent] =>
(input: Any) => input.asInstanceOf[InternalRow].to[Extent].jtsEnvelope
case t if t.conformsTo[Envelope] =>
(input: Any) => input.asInstanceOf[InternalRow].to[Envelope]
}

val fromPRL = projectedRasterLikeExtractor.andThen(_.andThen(_.extent.jtsEnvelope))
fromPRL orElse base
}

lazy val centroidExtractor: PartialFunction[DataType, Any Point] = {
extentExtractor.andThen(_.andThen(_.center))
}

sealed trait TileOrNumberArg
Expand All @@ -130,8 +153,7 @@ object DynamicExtractors {
case _: DoubleType | _: FloatType | _: DecimalType => {
case d: Double => DoubleArg(d)
case f: Float => DoubleArg(f.toDouble)
case d: Decimal => DoubleArg(d.toDouble)
}
case d: Decimal => DoubleArg(d.toDouble) }
}

lazy val intArgExtractor: PartialFunction[DataType, Any => IntegerArg] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package org.locationtech.rasterframes.expressions.aggregates
import geotrellis.proj4.CRS
import geotrellis.raster.reproject.Reproject
import geotrellis.raster.resample.ResampleMethod
import geotrellis.raster.{ArrayTile, CellType, Dimensions, MultibandTile, ProjectedRaster, Raster, Tile}
import geotrellis.raster.{ArrayTile, CellType, Dimensions, GeoAttrsError, MultibandTile, ProjectedRaster, Raster, Tile}
import geotrellis.layer._
import geotrellis.vector.Extent
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
Expand Down Expand Up @@ -99,10 +99,9 @@ object TileRasterizerAggregate {

def apply(tlm: TileLayerMetadata[_], sampler: ResampleMethod): ProjectedRasterDefinition = {
// Try to determine the actual dimensions of our data coverage
val actualSize = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent) // <--- Do we have the math right here?
val cols = actualSize.width
val rows = actualSize.height
new ProjectedRasterDefinition(cols, rows, tlm.cellType, tlm.crs, tlm.extent, sampler)
val Dimensions(cols, rows) = tlm.totalDimensions
require(cols <= Int.MaxValue && rows <= Int.MaxValue, s"Can't construct a Raster of size $cols x $rows. (Too big!)")
new ProjectedRasterDefinition(cols.toInt, rows.toInt, tlm.cellType, tlm.crs, tlm.extent, sampler)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,12 @@ package object expressions {
registry.registerExpression[RenderPNG.RenderCompositePNG]("rf_render_png")
registry.registerExpression[RGBComposite]("rf_rgb_composite")

registry.registerExpression[XZ2Indexer]("rf_spatial_index")
registry.registerExpression[XZ2Indexer]("rf_xz2_index")
registry.registerExpression[Z2Indexer]("rf_z2_index")

registry.registerExpression[transformers.ReprojectGeometry]("st_reproject")

registry.registerExpression[ExtractBits]("rf_local_extract_bits")
registry.registerExpression[ExtractBits]("rf_local_extract_bit")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* This software is licensed under the Apache 2 license, quoted below.
*
* Copyright 2019 Astraea, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* [http://www.apache.org/licenses/LICENSE-2.0]
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*
* SPDX-License-Identifier: Apache-2.0
*
*/

package org.locationtech.rasterframes.expressions.transformers

import geotrellis.raster.Tile
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
import org.apache.spark.sql.rf.TileUDT
import org.apache.spark.sql.types.DataType
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.expressions.DynamicExtractors._
import org.locationtech.rasterframes.expressions._

@ExpressionDescription(
usage = "_FUNC_(tile, start_bit, num_bits) - In each cell of `tile`, extract `num_bits` from the cell value, starting at `start_bit` from the left.",
arguments = """
Arguments:
* tile - tile column to extract values
* start_bit -
* num_bits -
""",
examples = """
Examples:
> SELECT _FUNC_(tile, lit(4), lit(2))
..."""
)
case class ExtractBits(child1: Expression, child2: Expression, child3: Expression) extends TernaryExpression with CodegenFallback with Serializable {
override val nodeName: String = "rf_local_extract_bits"

override def children: Seq[Expression] = Seq(child1, child2, child3)

override def dataType: DataType = child1.dataType

override def checkInputDataTypes(): TypeCheckResult =
if(!tileExtractor.isDefinedAt(child1.dataType)) {
TypeCheckFailure(s"Input type '${child1.dataType}' does not conform to a raster type.")
} else if (!intArgExtractor.isDefinedAt(child2.dataType)) {
TypeCheckFailure(s"Input type '${child2.dataType}' isn't an integral type.")
} else if (!intArgExtractor.isDefinedAt(child3.dataType)) {
TypeCheckFailure(s"Input type '${child3.dataType}' isn't an integral type.")
} else TypeCheckSuccess


override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
implicit val tileSer = TileUDT.tileSerializer
val (childTile, childCtx) = tileExtractor(child1.dataType)(row(input1))

val startBits = intArgExtractor(child2.dataType)(input2).value

val numBits = intArgExtractor(child2.dataType)(input3).value

childCtx match {
case Some(ctx) => ctx.toProjectRasterTile(op(childTile, startBits, numBits)).toInternalRow
case None => op(childTile, startBits, numBits).toInternalRow
}
}

protected def op(tile: Tile, startBit: Int, numBits: Int): Tile = ExtractBits(tile, startBit, numBits)

}

object ExtractBits{
def apply(tile: Column, startBit: Column, numBits: Column): Column =
new Column(ExtractBits(tile.expr, startBit.expr, numBits.expr))

def apply(tile: Tile, startBit: Int, numBits: Int): Tile = {
assert(!tile.cellType.isFloatingPoint, "ExtractBits operation requires integral CellType")
// this is the last `numBits` positions of "111111111111111"
val widthMask = Int.MaxValue >> (63 - numBits)
// map preserving the nodata structure
tile.mapIfSet(x x >> startBit & widthMask)
}

}
Loading

0 comments on commit d223b82

Please sign in to comment.