Skip to content

Commit

Permalink
Merge pull request #423 from s22s/fix/357
Browse files Browse the repository at this point in the history
Fixed `toLayer` on arbitrary RasterFrame.
  • Loading branch information
metasim authored Nov 19, 2019
2 parents 47b42a9 + c7bf0fb commit d8ea781
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 193 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ object TileRasterizerAggregate {

def apply(tlm: TileLayerMetadata[_], sampler: ResampleMethod): ProjectedRasterDefinition = {
// Try to determine the actual dimensions of our data coverage
val actualSize = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent) // <--- Do we have the math right here?
val cols = actualSize.width
val rows = actualSize.height
val TileDimensions(cols, rows) = tlm.totalDimensions
new ProjectedRasterDefinition(cols, rows, tlm.cellType, tlm.crs, tlm.extent, sampler)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
package org.locationtech.rasterframes.expressions.transformers

import geotrellis.raster.Tile
import org.apache.spark.sql.{Column, TypedColumn}
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.types.{MetadataBuilder, Metadata => SMetadata}
import org.locationtech.rasterframes.model.TileDimensions
import spray.json.JsonFormat

import scala.reflect.runtime.universe._
Expand Down Expand Up @@ -79,6 +80,15 @@ trait Implicits {
private[rasterframes]
implicit class WithMetadataBuilderMethods(val self: MetadataBuilder)
extends MetadataBuilderMethods

private[rasterframes]
implicit class TLMHasTotalCells(tlm: TileLayerMetadata[_]) {
// TODO: With upgrade to GT 3.1, replace this with the more general `Dimensions[Long]`
def totalDimensions: TileDimensions = {
val gb = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent)
TileDimensions(gb.width, gb.height)
}
}
}

object Implicits extends Implicits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.locationtech.rasterframes.encoders.serialized_literal
*
* @since 12/15/17
*/
trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with StandardColumns {
trait LayerSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with StandardColumns {
import Implicits.{WithDataFrameMethods, WithRasterFrameLayerMethods}
import org.locationtech.geomesa.spark.jts._

Expand Down Expand Up @@ -112,7 +112,7 @@ trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with Sta
*/
def withCenterLatLng(colName: String = "center"): RasterFrameLayer = {
val key2Center = sparkUdf(keyCol2LatLng)
self.withColumn(colName, key2Center(self.spatialKeyColumn).cast(RFSpatialColumnMethods.LngLatStructType)).certify
self.withColumn(colName, key2Center(self.spatialKeyColumn).cast(LayerSpatialColumnMethods.LngLatStructType)).certify
}

/**
Expand All @@ -130,6 +130,6 @@ trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with Sta
}
}

object RFSpatialColumnMethods {
object LayerSpatialColumnMethods {
private[rasterframes] val LngLatStructType = StructType(Seq(StructField("longitude", DoubleType), StructField("latitude", DoubleType)))
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import scala.reflect.runtime.universe._
* @since 7/18/17
*/
trait RasterFrameLayerMethods extends MethodExtensions[RasterFrameLayer]
with RFSpatialColumnMethods with MetadataKeys {
with LayerSpatialColumnMethods with MetadataKeys {
import Implicits.{WithDataFrameMethods, WithRasterFrameLayerMethods}

@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package org.locationtech.rasterframes.extensions
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.encoders.serialized_literal
import org.locationtech.rasterframes.expressions.SpatialRelation
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
import org.locationtech.rasterframes.functions.reproject_and_merge
Expand Down Expand Up @@ -89,9 +90,12 @@ object RasterJoin {
// After the aggregation we take all the tiles we've collected and resample + merge
// into LHS extent/CRS.
// Use a representative tile from the left for the tile dimensions
val leftTile = left.tileColumns.headOption.getOrElse(throw new IllegalArgumentException("Need at least one target tile on LHS"))
val destDims = left.tileColumns.headOption
.map(t => rf_dimensions(unresolved(t)))
.getOrElse(serialized_literal(NOMINAL_TILE_DIMS))

val reprojCols = rightAggTiles.map(t => reproject_and_merge(
col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), rf_dimensions(unresolved(leftTile))
col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), destDims
) as t.columnName)

val finalCols = leftAggCols.map(unresolved) ++ reprojCols ++ rightAggOther.map(unresolved)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.sql._
import org.apache.spark.sql.functions.broadcast
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.util._

/** Algorithm for projecting an arbitrary RasterFrame into a layer with consistent CRS and gridding. */
object ReprojectToLayer {
def apply(df: DataFrame, tlm: TileLayerMetadata[SpatialKey]): RasterFrameLayer = {
// create a destination dataframe with crs and extend columns
Expand All @@ -42,8 +44,9 @@ object ReprojectToLayer {
e = tlm.mapTransform(sk)
} yield (sk, e, crs)

// Create effectively a target RasterFrame, but with no tiles.
val dest = gridItems.toSeq.toDF(SPATIAL_KEY_COLUMN.columnName, EXTENT_COLUMN.columnName, CRS_COLUMN.columnName)
dest.show(false)

val joined = RasterJoin(broadcast(dest), df)

joined.asLayer(SPATIAL_KEY_COLUMN, tlm)
Expand Down
92 changes: 0 additions & 92 deletions core/src/test/scala/examples/CreatingRasterFrames.scala

This file was deleted.

50 changes: 0 additions & 50 deletions core/src/test/scala/examples/MeanValue.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,21 @@

package org.locationtech.rasterframes

import java.net.URI
import java.sql.Timestamp
import java.time.ZonedDateTime

import org.locationtech.rasterframes.util._
import geotrellis.proj4.LatLng
import geotrellis.raster.render.{ColorMap, ColorRamp}
import geotrellis.raster.{ProjectedRaster, Tile, TileFeature, TileLayout, UByteCellType}
import geotrellis.proj4.{CRS, LatLng}
import geotrellis.raster.{MultibandTile, ProjectedRaster, Raster, Tile, TileFeature, TileLayout, UByteCellType, UByteConstantNoDataCellType}
import geotrellis.spark._
import geotrellis.spark.tiling._
import geotrellis.vector.{Extent, ProjectedExtent}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.sql.{Encoders, SQLContext, SparkSession}
import org.locationtech.rasterframes.model.TileDimensions
import org.locationtech.rasterframes.ref.RasterSource
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
import org.locationtech.rasterframes.util._

import scala.util.control.NonFatal

Expand All @@ -44,7 +46,7 @@ import scala.util.control.NonFatal
*
* @since 7/10/17
*/
class RasterFrameSpec extends TestEnvironment with MetadataKeys
class RasterLayerSpec extends TestEnvironment with MetadataKeys
with TestData {
import TestData.randomTile
import spark.implicits._
Expand Down Expand Up @@ -232,17 +234,40 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys
assert(bounds._2 === SpaceTimeKey(3, 1, now))
}

def basicallySame(expected: Extent, computed: Extent): Unit = {
val components = Seq(
(expected.xmin, computed.xmin),
(expected.ymin, computed.ymin),
(expected.xmax, computed.xmax),
(expected.ymax, computed.ymax)
)
forEvery(components)(c
assert(c._1 === c._2 +- 0.000001)
it("should create layer from arbitrary RasterFrame") {
val src = RasterSource(URI.create("https://raw.githubusercontent.com/locationtech/rasterframes/develop/core/src/test/resources/LC08_RGB_Norfolk_COG.tiff"))
val srcCrs = src.crs

def project(r: Raster[MultibandTile]): Seq[ProjectedRasterTile] =
r.tile.bands.map(b => ProjectedRasterTile(b, r.extent, srcCrs))

val prtEnc = ProjectedRasterTile.prtEncoder
implicit val enc = Encoders.tuple(prtEnc, prtEnc, prtEnc)

val rasters = src.readAll(bands = Seq(0, 1, 2)).map(project).map(p => (p(0), p(1), p(2)))

val df = rasters.toDF("red", "green", "blue")

val crs = CRS.fromString("+proj=utm +zone=18 +datum=WGS84 +units=m +no_defs")

val extent = Extent(364455.0, 4080315.0, 395295.0, 4109985.0)
val layout = LayoutDefinition(extent, TileLayout(2, 2, 32, 32))

val tlm = new TileLayerMetadata[SpatialKey](
UByteConstantNoDataCellType,
layout,
extent,
crs,
KeyBounds(SpatialKey(0, 0), SpatialKey(1, 1))
)
}
val layer = df.toLayer(tlm)

val TileDimensions(cols, rows) = tlm.totalDimensions
val prt = layer.toMultibandRaster(Seq($"red", $"green", $"blue"), cols, rows)
prt.tile.dimensions should be((cols, rows))
prt.crs should be(crs)
prt.extent should be(extent)
}

it("shouldn't clip already clipped extents") {
val rf = TestData.randomSpatialTileLayerRDD(1024, 1024, 8, 8).toLayer
Expand All @@ -258,27 +283,8 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys
basicallySame(expected2, computed2)
}

def Greyscale(stops: Int): ColorRamp = {
val colors = (0 to stops)
.map(i {
val c = java.awt.Color.HSBtoRGB(0f, 0f, i / stops.toFloat)
(c << 8) | 0xFF // Add alpha channel.
})
ColorRamp(colors)
}

def render(tile: Tile, tag: String): Unit = {
if(false && !isCI) {
val colors = ColorMap.fromQuantileBreaks(tile.histogram, Greyscale(128))
val path = s"target/${getClass.getSimpleName}_$tag.png"
logger.info(s"Writing '$path'")
tile.color(colors).renderPng().write(path)
}
}

it("should rasterize with a spatiotemporal key") {
val rf = TestData.randomSpatioTemporalTileLayerRDD(20, 20, 2, 2).toLayer

noException shouldBe thrownBy {
rf.toRaster($"tile", 128, 128)
}
Expand All @@ -291,7 +297,6 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys
val joinTypes = Seq("inner", "outer", "fullouter", "left_outer", "right_outer", "leftsemi")
forEvery(joinTypes) { jt
val joined = rf1.spatialJoin(rf2, jt)
//println(joined.schema.json)
assert(joined.tileLayerMetadata.isRight)
}
}
Expand Down
Loading

0 comments on commit d8ea781

Please sign in to comment.