Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ object GeographyType extends SpatialType {
GeographyType(MIXED_CRS, GEOGRAPHY_DEFAULT_ALGORITHM)

/** Returns whether the given SRID is supported. */
private[types] def isSridSupported(srid: Int): Boolean = {
def isSridSupported(srid: Int): Boolean = {
GeographicSpatialReferenceSystemMapper.getStringId(srid) != null
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ object GeometryType extends SpatialType {
GeometryType(MIXED_CRS)

/** Returns whether the given SRID is supported. */
private[types] def isSridSupported(srid: Int): Boolean = {
def isSridSupported(srid: Int): Boolean = {
CartesianSpatialReferenceSystemMapper.getStringId(srid) != null
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,7 @@ interface Geo {
// Returns the Spatial Reference Identifier (SRID) value of the geo object.
int srid();

// Sets the Spatial Reference Identifier (SRID) value of the geo object.
void setSrid(int srid);

}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,20 @@ public byte[] toEwkt() {
@Override
public int srid() {
// This method gets the SRID value from the in-memory Geography representation header.
return ByteBuffer.wrap(getBytes()).order(DEFAULT_ENDIANNESS).getInt(SRID_OFFSET);
return getWrapper().getInt(SRID_OFFSET);
}

@Override
public void setSrid(int srid) {
// This method sets the SRID value in the in-memory Geography representation header.
getWrapper().putInt(SRID_OFFSET, srid);
}

/** Other private helper/utility methods used for implementation. */

// Returns a byte buffer wrapper over the byte buffer of this geography value.
private ByteBuffer getWrapper() {
return ByteBuffer.wrap(getBytes()).order(DEFAULT_ENDIANNESS);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,20 @@ public byte[] toEwkt() {
@Override
public int srid() {
// This method gets the SRID value from the in-memory Geometry representation header.
return ByteBuffer.wrap(getBytes()).order(DEFAULT_ENDIANNESS).getInt(SRID_OFFSET);
return getWrapper().getInt(SRID_OFFSET);
}

@Override
public void setSrid(int srid) {
// This method sets the SRID value in the in-memory Geometry representation header.
getWrapper().putInt(SRID_OFFSET, srid);
}

/** Other private helper/utility methods used for implementation. */

// Returns a byte buffer wrapper over the byte buffer of this geometry value.
private ByteBuffer getWrapper() {
return ByteBuffer.wrap(getBytes()).order(DEFAULT_ENDIANNESS);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.catalyst.util;

import org.apache.spark.sql.errors.QueryExecutionErrors;
import org.apache.spark.sql.types.GeographyType;
import org.apache.spark.sql.types.GeometryType;
import org.apache.spark.unsafe.types.GeographyVal;
Expand Down Expand Up @@ -101,6 +102,31 @@ public static GeometryVal stGeomFromWKB(byte[] wkb) {
return toPhysVal(Geometry.fromWkb(wkb));
}

// ST_SetSrid
public static GeographyVal stSetSrid(GeographyVal geo, int srid) {
// We only allow setting the SRID to geographic values.
if(!GeographyType.isSridSupported(srid)) {
throw QueryExecutionErrors.stInvalidSridValueError(srid);
}
// Create a copy of the input geography.
Geography copy = fromPhysVal(geo).copy();
// Set the SRID of the copy to the specified value.
copy.setSrid(srid);
return toPhysVal(copy);
}

public static GeometryVal stSetSrid(GeometryVal geo, int srid) {
// We only allow setting the SRID to valid values.
if(!GeometryType.isSridSupported(srid)) {
throw QueryExecutionErrors.stInvalidSridValueError(srid);
}
// Create a copy of the input geometry.
Geometry copy = fromPhysVal(geo).copy();
// Set the SRID of the copy to the specified value.
copy.setSrid(srid);
return toPhysVal(copy);
}

// ST_Srid
public static int stSrid(GeographyVal geog) {
return fromPhysVal(geog).srid();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,7 @@ object FunctionRegistry {
expression[ST_GeogFromWKB]("st_geogfromwkb"),
expression[ST_GeomFromWKB]("st_geomfromwkb"),
expression[ST_Srid]("st_srid"),
expression[ST_SetSrid]("st_setsrid"),

// cast
expression[Cast]("cast"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.st

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

private[sql] object STExpressionUtils {
Expand All @@ -29,4 +30,49 @@ private[sql] object STExpressionUtils {
case _ => false
}

/**
* Returns the input GEOMETRY or GEOGRAPHY value with the specified SRID. Only geospatial types
* are allowed as the source type, and calls are delegated to the corresponding helper methods.
*/
def geospatialTypeWithSrid(sourceType: DataType, srid: Expression): DataType = {
sourceType match {
case _: GeometryType =>
geometryTypeWithSrid(srid)
case _: GeographyType =>
geographyTypeWithSrid(srid)
case _ =>
throw new IllegalArgumentException(s"Unexpected data type: $sourceType.")
}
}

/**
* Returns the input GEOMETRY value with the specified SRID. If the SRID expression is a literal,
* the SRID value can be directly extracted. Otherwise, only the mixed SRID value can be used.
*/
private def geometryTypeWithSrid(srid: Expression): GeometryType = {
srid match {
case Literal(sridValue: Int, IntegerType) =>
// If the SRID expression is a literal, the SRID value can be directly extracted.
GeometryType(sridValue)
case _ =>
// Otherwise, only the mixed SRID value can be used for the output GEOMETRY value.
GeometryType("ANY")
}
}

/**
* Returns the input GEOGRAPHY value with the specified SRID. If the SRID expression is a literal,
* the SRID value can be directly extracted. Otherwise, only the mixed SRID value can be used.
*/
private def geographyTypeWithSrid(srid: Expression): GeographyType = {
srid match {
case Literal(sridValue: Int, IntegerType) =>
// If the SRID expression is a literal, the SRID value can be directly extracted.
GeographyType(sridValue)
case _ =>
// Otherwise, only the mixed SRID value can be used for the output GEOMETRY value.
GeographyType("ANY")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,56 @@ case class ST_Srid(geo: Expression)
override protected def withNewChildInternal(newChild: Expression): ST_Srid =
copy(geo = newChild)
}

/** ST modifier expressions. */

/**
* Returns a new GEOGRAPHY or GEOMETRY value whose SRID is the specified SRID value.
*/
@ExpressionDescription(
usage = "_FUNC_(geo, srid) - Returns a new GEOGRAPHY or GEOMETRY value whose SRID is " +
"the specified SRID value.",
arguments = """
Arguments:
* geo - A GEOGRAPHY or GEOMETRY value.
* srid - The new SRID value of the geography or geometry.
""",
examples = """
Examples:
> SELECT st_srid(_FUNC_(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 4326));
4326
> SELECT st_srid(_FUNC_(ST_GeomFromWKB(X'0101000000000000000000F03F0000000000000040'), 3857));
3857
""",
since = "4.1.0",
group = "st_funcs"
)
case class ST_SetSrid(geo: Expression, srid: Expression)
extends RuntimeReplaceable
with ImplicitCastInputTypes
with BinaryLike[Expression] {

override def inputTypes: Seq[AbstractDataType] =
Seq(
TypeCollection(GeographyType, GeometryType),
IntegerType
)

override lazy val replacement: Expression = StaticInvoke(
classOf[STUtils],
STExpressionUtils.geospatialTypeWithSrid(geo.dataType, srid),
"stSetSrid",
Seq(geo, srid),
returnNullable = false
)

override def prettyName: String = "st_setsrid"

override def left: Expression = geo

override def right: Expression = srid

override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression): ST_SetSrid = copy(geo = newLeft, srid = newRight)
}
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,17 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
summary = "")
}

def stInvalidSridValueError(srid: String): SparkIllegalArgumentException = {
new SparkIllegalArgumentException(
errorClass = "ST_INVALID_SRID_VALUE",
messageParameters = Map("srid" -> srid)
)
}

def stInvalidSridValueError(srid: Int): SparkIllegalArgumentException = {
stInvalidSridValueError(srid.toString)
}

def withSuggestionIntervalArithmeticOverflowError(
suggestedFunc: String,
context: QueryContext): ArithmeticException = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.util;

import org.apache.spark.SparkIllegalArgumentException;
import org.apache.spark.unsafe.types.GeographyVal;
import org.apache.spark.unsafe.types.GeometryVal;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -110,4 +111,49 @@ void testStSridGeometry() {
assertEquals(testGeometrySrid, STUtils.stSrid(geometryVal));
}

// ST_SetSrid
@Test
void testStSetSridGeography() {
for (int validGeographySrid : new int[]{4326}) {
GeographyVal geographyVal = GeographyVal.fromBytes(testGeographyBytes);
GeographyVal updatedGeographyVal = STUtils.stSetSrid(geographyVal, validGeographySrid);
assertNotNull(updatedGeographyVal);
Geography updatedGeography = Geography.fromBytes(updatedGeographyVal.getBytes());
assertEquals(validGeographySrid, updatedGeography.srid());
}
}

@Test
void testStSetSridGeographyInvalidSrid() {
for (int invalidGeographySrid : new int[]{-9999, -2, -1, 0, 1, 2, 3857, 9999}) {
GeographyVal geographyVal = GeographyVal.fromBytes(testGeographyBytes);
SparkIllegalArgumentException exception = assertThrows(SparkIllegalArgumentException.class,
() -> STUtils.stSetSrid(geographyVal, invalidGeographySrid));
assertEquals("ST_INVALID_SRID_VALUE", exception.getCondition());
assertTrue(exception.getMessage().contains("value: " + invalidGeographySrid + "."));
}
}

@Test
void testStSetSridGeometry() {
for (int validGeographySrid : new int[]{0, 3857, 4326}) {
GeometryVal geometryVal = GeometryVal.fromBytes(testGeometryBytes);
GeometryVal updatedGeometryVal = STUtils.stSetSrid(geometryVal, validGeographySrid);
assertNotNull(updatedGeometryVal);
Geometry updatedGeometry = Geometry.fromBytes(updatedGeometryVal.getBytes());
assertEquals(validGeographySrid, updatedGeometry.srid());
}
}

@Test
void testStSetSridGeometryInvalidSrid() {
for (int invalidGeometrySrid : new int[]{-9999, -2, -1, 1, 2, 9999}) {
GeometryVal geometryVal = GeometryVal.fromBytes(testGeometryBytes);
SparkIllegalArgumentException exception = assertThrows(SparkIllegalArgumentException.class,
() -> STUtils.stSetSrid(geometryVal, invalidGeometrySrid));
assertEquals("ST_INVALID_SRID_VALUE", exception.getCondition());
assertTrue(exception.getMessage().contains("value: " + invalidGeometrySrid + "."));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@
| org.apache.spark.sql.catalyst.expressions.st.ST_AsBinary | st_asbinary | SELECT hex(st_asbinary(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'))) | struct<hex(st_asbinary(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'))):string> |
| org.apache.spark.sql.catalyst.expressions.st.ST_GeogFromWKB | st_geogfromwkb | SELECT hex(st_asbinary(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'))) | struct<hex(st_asbinary(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'))):string> |
| org.apache.spark.sql.catalyst.expressions.st.ST_GeomFromWKB | st_geomfromwkb | SELECT hex(st_asbinary(st_geomfromwkb(X'0101000000000000000000F03F0000000000000040'))) | struct<hex(st_asbinary(st_geomfromwkb(X'0101000000000000000000F03F0000000000000040'))):string> |
| org.apache.spark.sql.catalyst.expressions.st.ST_SetSrid | st_setsrid | SELECT st_srid(st_setsrid(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 4326)) | struct<st_srid(st_setsrid(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'), 4326)):int> |
| org.apache.spark.sql.catalyst.expressions.st.ST_Srid | st_srid | SELECT st_srid(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040')) | struct<st_srid(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040')):int> |
| org.apache.spark.sql.catalyst.expressions.variant.IsVariantNull | is_variant_null | SELECT is_variant_null(parse_json('null')) | struct<is_variant_null(parse_json(null)):boolean> |
| org.apache.spark.sql.catalyst.expressions.variant.ParseJsonExpressionBuilder | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct<parse_json({"a":1,"b":0.8}):variant> |
Expand Down
Loading