Skip to content

Commit aabcc02

Browse files
committed
Fixes and tests
1 parent 1ce2e29 commit aabcc02

File tree

13 files changed

+180
-343
lines changed

13 files changed

+180
-343
lines changed

python/pyspark/sql/connect/expressions.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,54 @@ def _infer_type(cls, value: Any) -> DataType:
377377
def _from_value(cls, value: Any) -> "LiteralExpression":
378378
return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value))
379379

380+
@classmethod
381+
def _infer_type_from_literal(cls, literal: "proto.Expression.Literal") -> Optional[DataType]:
382+
if literal.HasField("null"):
383+
return NullType()
384+
elif literal.HasField("binary"):
385+
return BinaryType()
386+
elif literal.HasField("boolean"):
387+
return BooleanType()
388+
elif literal.HasField("byte"):
389+
return ByteType()
390+
elif literal.HasField("short"):
391+
return ShortType()
392+
elif literal.HasField("integer"):
393+
return IntegerType()
394+
elif literal.HasField("long"):
395+
return LongType()
396+
elif literal.HasField("float"):
397+
return FloatType()
398+
elif literal.HasField("double"):
399+
return DoubleType()
400+
elif literal.HasField("date"):
401+
return DateType()
402+
elif literal.HasField("timestamp"):
403+
return TimestampType()
404+
elif literal.HasField("timestamp_ntz"):
405+
return TimestampNTZType()
406+
elif literal.HasField("array"):
407+
if len(literal.array.elements) == 0:
408+
if literal.array.HasField("element_type"):
409+
return ArrayType(
410+
proto_schema_to_pyspark_data_type(literal.array.element_type), True
411+
)
412+
raise PySparkTypeError(
413+
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
414+
messageParameters={},
415+
)
416+
element_type = LiteralExpression._infer_type_from_literal(literal.array.elements[0])
417+
if element_type is None:
418+
raise PySparkTypeError(
419+
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
420+
messageParameters={},
421+
)
422+
return ArrayType(element_type, True)
423+
# Not all data types support inferring the data type from the literal.
424+
# e.g. the type of DayTimeInterval contains extra information like start_field and
425+
# end_field and cannot be inferred from the literal.
426+
return None
427+
380428
@classmethod
381429
def _to_value(
382430
cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None
@@ -426,26 +474,19 @@ def _to_value(
426474
assert dataType is None or isinstance(dataType, DayTimeIntervalType)
427475
return DayTimeIntervalType().fromInternal(literal.day_time_interval)
428476
elif literal.HasField("array"):
429-
elements = literal.array.elements
430-
result = []
431-
if dataType is not None:
432-
assert isinstance(dataType, ArrayType)
433-
elementType = dataType.elementType
434-
elif literal.array.HasField("element_type"):
477+
if literal.array.HasField("element_type"):
435478
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
436-
elif len(elements) > 0:
437-
result.append(LiteralExpression._to_value(elements[0], None))
438-
elements = elements[1:]
439-
elementType = LiteralExpression._infer_type(result[0])
479+
if dataType is not None:
480+
assert isinstance(dataType, ArrayType)
481+
assert elementType == dataType.elementType
482+
elif len(literal.array.elements) > 0:
483+
elementType = LiteralExpression._infer_type_from_literal(literal.array.elements[0])
440484
else:
441485
raise PySparkTypeError(
442486
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
443487
messageParameters={},
444488
)
445-
446-
for element in elements:
447-
result.append(LiteralExpression._to_value(element, elementType))
448-
return result
489+
return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
449490

450491
raise PySparkTypeError(
451492
errorClass="UNSUPPORTED_LITERAL",
@@ -490,14 +531,18 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
490531
elif isinstance(self._dataType, DayTimeIntervalType):
491532
expr.literal.day_time_interval = int(self._value)
492533
elif isinstance(self._dataType, ArrayType):
493-
if len(self._value) == 0:
494-
expr.literal.array.element_type.CopyFrom(
495-
pyspark_types_to_proto_types(self._dataType.elementType)
496-
)
497534
for v in self._value:
498535
expr.literal.array.elements.append(
499536
LiteralExpression(v, self._dataType.elementType).to_plan(session).literal
500537
)
538+
if (
539+
len(self._value) == 0
540+
or LiteralExpression._infer_type_from_literal(expr.literal.array.elements[0])
541+
is None
542+
):
543+
expr.literal.array.element_type.CopyFrom(
544+
pyspark_types_to_proto_types(self._dataType.elementType)
545+
)
501546
else:
502547
raise PySparkTypeError(
503548
errorClass="UNSUPPORTED_DATA_TYPE",

python/pyspark/sql/tests/connect/test_connect_plan.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -972,28 +972,27 @@ def test_column_expressions(self):
972972

973973
def test_literal_expression_with_arrays(self):
974974
l0 = LiteralExpression._from_value(["x", "y", "z"]).to_plan(None).literal
975-
self.assertTrue(l0.array.element_type.HasField("string"))
975+
self.assertFalse(l0.array.element_type.HasField("string"))
976976
self.assertEqual(len(l0.array.elements), 3)
977977
self.assertEqual(l0.array.elements[0].string, "x")
978978
self.assertEqual(l0.array.elements[1].string, "y")
979979
self.assertEqual(l0.array.elements[2].string, "z")
980980

981981
l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
982-
self.assertTrue(l1.array.element_type.HasField("integer"))
982+
self.assertFalse(l1.array.element_type.HasField("integer"))
983983
self.assertEqual(len(l1.array.elements), 2)
984984
self.assertEqual(l1.array.elements[0].integer, 3)
985985
self.assertEqual(l1.array.elements[1].integer, -3)
986986

987987
l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
988-
self.assertTrue(l2.array.element_type.HasField("double"))
988+
self.assertFalse(l2.array.element_type.HasField("double"))
989989
self.assertEqual(len(l2.array.elements), 3)
990990
self.assertTrue(math.isnan(l2.array.elements[0].double))
991991
self.assertEqual(l2.array.elements[1].double, -3.0)
992992
self.assertEqual(l2.array.elements[2].double, 0.0)
993993

994994
l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
995-
self.assertTrue(l3.array.element_type.HasField("array"))
996-
self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
995+
self.assertFalse(l3.array.element_type.HasField("array"))
997996
self.assertEqual(len(l3.array.elements), 2)
998997
self.assertEqual(len(l3.array.elements[0].array.elements), 2)
999998
self.assertEqual(len(l3.array.elements[1].array.elements), 3)
@@ -1003,8 +1002,7 @@ def test_literal_expression_with_arrays(self):
10031002
.to_plan(None)
10041003
.literal
10051004
)
1006-
self.assertTrue(l4.array.element_type.HasField("array"))
1007-
self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
1005+
self.assertFalse(l4.array.element_type.HasField("array"))
10081006
self.assertEqual(len(l4.array.elements), 3)
10091007
self.assertEqual(len(l4.array.elements[0].array.elements), 2)
10101008
self.assertEqual(len(l4.array.elements[1].array.elements), 2)
@@ -1033,6 +1031,8 @@ def test_literal_to_any_conversion(self):
10331031
]:
10341032
lit = LiteralExpression._from_value(value)
10351033
proto_lit = lit.to_plan(None).literal
1034+
if proto_lit.HasField("array"):
1035+
self.assertFalse(proto_lit.array.HasField("element_type"))
10361036
value2 = LiteralExpression._to_value(proto_lit)
10371037
self.assertEqual(value, value2)
10381038

sql/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,20 @@ message Expression {
216216

217217
message Array {
218218
// (Optional) The element type of the array. Only need to set this when the elements are
219-
// empty, since spark 4.1+ supports inferring the element type from the elements.
219+
// empty or the element type is not inferrable, since spark 4.1+ supports
220+
// inferring the element type from the elements.
220221
optional DataType element_type = 1;
221222
repeated Literal elements = 2;
222223
}
223224

224225
message Map {
225226
// (Optional) The key type of the map. Only need to set this when the keys are
226-
// empty, since spark 4.1+ supports inferring the key type from the keys
227+
// empty or the key type is not inferrable, since spark 4.1+ supports
228+
// inferring the key type from the keys
227229
optional DataType key_type = 1;
228230
// (Optional) The value type of the map. Only need to set this when the values are
229-
// empty, since spark 4.1+ supports inferring the value type from the values.
231+
// empty or the value type is not inferrable, since spark 4.1+ supports
232+
// inferring the value type from the values.
230233
optional DataType value_type = 2;
231234
repeated Literal keys = 3;
232235
repeated Literal values = 4;

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ object LiteralValueProtoConverter {
6464
def arrayBuilder(array: Array[_]) = {
6565
val ab = builder.getArrayBuilder
6666
array.foreach(x => ab.addElements(toLiteralProto(x)))
67-
if (ab.getElementsCount == 0) {
67+
if (ab.getElementsCount == 0 || getInferredDataType(ab.getElementsList.get(0)).isEmpty) {
6868
ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
6969
}
7070
ab
@@ -129,7 +129,7 @@ object LiteralValueProtoConverter {
129129
throw new IllegalArgumentException(s"literal $other not supported (yet).")
130130
}
131131

132-
if (ab.getElementsCount == 0) {
132+
if (ab.getElementsCount == 0 || getInferredDataType(ab.getElementsList.get(0)).isEmpty) {
133133
ab.setElementType(toConnectProtoType(elementType))
134134
}
135135

@@ -138,8 +138,6 @@ object LiteralValueProtoConverter {
138138

139139
def mapBuilder(scalaValue: Any, keyType: DataType, valueType: DataType) = {
140140
val mb = builder.getMapBuilder
141-
.setKeyType(toConnectProtoType(keyType))
142-
.setValueType(toConnectProtoType(valueType))
143141

144142
scalaValue match {
145143
case map: scala.collection.Map[_, _] =>
@@ -151,11 +149,11 @@ object LiteralValueProtoConverter {
151149
throw new IllegalArgumentException(s"literal $other not supported (yet).")
152150
}
153151

154-
if (mb.getKeysCount == 0) {
152+
if (mb.getKeysCount == 0 || getInferredDataType(mb.getKeysList.get(0)).isEmpty) {
155153
mb.setKeyType(toConnectProtoType(keyType))
156154
}
157155

158-
if (mb.getValuesCount == 0) {
156+
if (mb.getValuesCount == 0 || getInferredDataType(mb.getValuesList.get(0)).isEmpty) {
159157
mb.setValueType(toConnectProtoType(valueType))
160158
}
161159

@@ -370,7 +368,11 @@ object LiteralValueProtoConverter {
370368
}
371369
}
372370

373-
private def getOuterDataType(literal: proto.Expression.Literal): proto.DataType = {
371+
private def getInferredDataType(literal: proto.Expression.Literal): Option[proto.DataType] = {
372+
if (literal.hasNull) {
373+
return Some(literal.getNull)
374+
}
375+
374376
val builder = proto.DataType.newBuilder()
375377
literal.getLiteralTypeCase match {
376378
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
@@ -389,22 +391,14 @@ object LiteralValueProtoConverter {
389391
builder.setFloat(proto.DataType.Float.newBuilder.build())
390392
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
391393
builder.setDouble(proto.DataType.Double.newBuilder.build())
392-
case proto.Expression.Literal.LiteralTypeCase.STRING =>
393-
builder.setString(proto.DataType.String.newBuilder.build())
394394
case proto.Expression.Literal.LiteralTypeCase.DATE =>
395395
builder.setDate(proto.DataType.Date.newBuilder.build())
396396
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
397397
builder.setTimestamp(proto.DataType.Timestamp.newBuilder.build())
398398
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
399399
builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder.build())
400-
case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
401-
builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder.build())
402-
case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
403-
builder.setDecimal(proto.DataType.Decimal.newBuilder.build())
404400
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
405401
builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build())
406-
case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
407-
builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder.build())
408402
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
409403
// Element type will be inferred from the elements in the array.
410404
builder.setArray(proto.DataType.Array.newBuilder.build())
@@ -413,9 +407,19 @@ object LiteralValueProtoConverter {
413407
builder.setMap(proto.DataType.Map.newBuilder.build())
414408
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
415409
builder.setStruct(literal.getStruct.getStructType.getStruct)
416-
case _ => throw InvalidPlanInput(s"Unsupported Literal Type: ${literal.getLiteralTypeCase}")
410+
case _ =>
411+
// Not all data types support inferring the data type from the literal.
412+
// e.g. the type of DayTimeInterval contains extra information like start_field and
413+
// end_field and cannot be inferred from the literal.
414+
return None
415+
}
416+
Some(builder.build())
417+
}
418+
419+
private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = {
420+
getInferredDataType(literal).getOrElse {
421+
throw InvalidPlanInput(s"Unsupported Literal Type: ${literal.getLiteralTypeCase}")
417422
}
418-
builder.build()
419423
}
420424

421425
def toCatalystArray(
@@ -437,7 +441,7 @@ object LiteralValueProtoConverter {
437441
protoArrayType(array.getElementType)
438442
} else if (iter.hasNext) {
439443
val firstElement = iter.next()
440-
val outerElementType = getOuterDataType(firstElement)
444+
val outerElementType = getInferredDataTypeOrThrow(firstElement)
441445
val (elem, inferredElementType) =
442446
getConverter(outerElementType, inferDataType = true)(firstElement) match {
443447
case LiteralValueWithDataType(elem, dataType) => (elem, dataType)
@@ -479,15 +483,23 @@ object LiteralValueProtoConverter {
479483
protoMapType(map.getKeyType, map.getValueType)
480484
} else if (iter.hasNext) {
481485
val (key, value) = iter.next()
482-
val outerKeyType = getOuterDataType(key)
486+
val (outerKeyType, inferKeyType) = if (map.hasKeyType) {
487+
(map.getKeyType, false)
488+
} else {
489+
(getInferredDataTypeOrThrow(key), true)
490+
}
483491
val (catalystKey, inferredKeyType) =
484-
getConverter(outerKeyType, inferDataType = true)(key) match {
492+
getConverter(outerKeyType, inferDataType = inferKeyType)(key) match {
485493
case LiteralValueWithDataType(catalystKey, dataType) => (catalystKey, dataType)
486494
case catalystKey => (catalystKey, outerKeyType)
487495
}
488-
val outerValueType = getOuterDataType(value)
496+
val (outerValueType, inferValueType) = if (map.hasValueType) {
497+
(map.getValueType, false)
498+
} else {
499+
(getInferredDataTypeOrThrow(value), true)
500+
}
489501
val (catalystValue, inferredValueType) =
490-
getConverter(outerValueType, inferDataType = true)(value) match {
502+
getConverter(outerValueType, inferDataType = inferValueType)(value) match {
491503
case LiteralValueWithDataType(catalystValue, dataType) => (catalystValue, dataType)
492504
case catalystValue => (catalystValue, outerValueType)
493505
}

sql/connect/common/src/test/resources/query-tests/queries/function_lit.json

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,6 @@
358358
}, {
359359
"literal": {
360360
"array": {
361-
"elementType": {
362-
"integer": {
363-
}
364-
},
365361
"elements": [{
366362
"integer": 8
367363
}, {
Binary file not shown.

0 commit comments

Comments
 (0)