Skip to content

Commit c579c1c

Browse files
committed
Fixes and tests
1 parent 34cd6f9 commit c579c1c

File tree

15 files changed

+250
-442
lines changed

15 files changed

+250
-442
lines changed

python/pyspark/sql/connect/expressions.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,52 @@ def _infer_type(cls, value: Any) -> DataType:
377377
def _from_value(cls, value: Any) -> "LiteralExpression":
378378
return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value))
379379

380+
@classmethod
381+
def _infer_type_from_literal(cls, literal: "proto.Expression.Literal") -> Optional[DataType]:
382+
if literal.HasField("null"):
383+
return NullType()
384+
elif literal.HasField("binary"):
385+
return BinaryType()
386+
elif literal.HasField("boolean"):
387+
return BooleanType()
388+
elif literal.HasField("byte"):
389+
return ByteType()
390+
elif literal.HasField("short"):
391+
return ShortType()
392+
elif literal.HasField("integer"):
393+
return IntegerType()
394+
elif literal.HasField("long"):
395+
return LongType()
396+
elif literal.HasField("float"):
397+
return FloatType()
398+
elif literal.HasField("double"):
399+
return DoubleType()
400+
elif literal.HasField("date"):
401+
return DateType()
402+
elif literal.HasField("timestamp"):
403+
return TimestampType()
404+
elif literal.HasField("timestamp_ntz"):
405+
return TimestampNTZType()
406+
elif literal.HasField("array"):
407+
if literal.array.HasField("element_type"):
408+
return ArrayType(
409+
proto_schema_to_pyspark_data_type(literal.array.element_type), True
410+
)
411+
element_type = None
412+
if len(literal.array.elements) > 0:
413+
element_type = LiteralExpression._infer_type_from_literal(literal.array.elements[0])
414+
415+
if element_type is None:
416+
raise PySparkTypeError(
417+
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
418+
messageParameters={},
419+
)
420+
return ArrayType(element_type, True)
421+
# Not all data types support inferring the data type from the literal at the moment.
422+
# e.g. the type of DayTimeInterval contains extra information like start_field and
423+
# end_field and cannot be inferred from the literal.
424+
return None
425+
380426
@classmethod
381427
def _to_value(
382428
cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None
@@ -426,26 +472,21 @@ def _to_value(
426472
assert dataType is None or isinstance(dataType, DayTimeIntervalType)
427473
return DayTimeIntervalType().fromInternal(literal.day_time_interval)
428474
elif literal.HasField("array"):
429-
elements = literal.array.elements
430-
result = []
431-
if dataType is not None:
432-
assert isinstance(dataType, ArrayType)
433-
elementType = dataType.elementType
434-
elif literal.array.HasField("element_type"):
475+
elementType = None
476+
if literal.array.HasField("element_type"):
435477
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
436-
elif len(elements) > 0:
437-
result.append(LiteralExpression._to_value(elements[0], None))
438-
elements = elements[1:]
439-
elementType = LiteralExpression._infer_type(result[0])
440-
else:
478+
if dataType is not None:
479+
assert isinstance(dataType, ArrayType)
480+
assert elementType == dataType.elementType
481+
elif len(literal.array.elements) > 0:
482+
elementType = LiteralExpression._infer_type_from_literal(literal.array.elements[0])
483+
484+
if elementType is None:
441485
raise PySparkTypeError(
442486
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
443487
messageParameters={},
444488
)
445-
446-
for element in elements:
447-
result.append(LiteralExpression._to_value(element, elementType))
448-
return result
489+
return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
449490

450491
raise PySparkTypeError(
451492
errorClass="UNSUPPORTED_LITERAL",
@@ -490,14 +531,18 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
490531
elif isinstance(self._dataType, DayTimeIntervalType):
491532
expr.literal.day_time_interval = int(self._value)
492533
elif isinstance(self._dataType, ArrayType):
493-
if len(self._value) == 0:
494-
expr.literal.array.element_type.CopyFrom(
495-
pyspark_types_to_proto_types(self._dataType.elementType)
496-
)
497534
for v in self._value:
498535
expr.literal.array.elements.append(
499536
LiteralExpression(v, self._dataType.elementType).to_plan(session).literal
500537
)
538+
if (
539+
len(self._value) == 0
540+
or LiteralExpression._infer_type_from_literal(expr.literal.array.elements[0])
541+
is None
542+
):
543+
expr.literal.array.element_type.CopyFrom(
544+
pyspark_types_to_proto_types(self._dataType.elementType)
545+
)
501546
else:
502547
raise PySparkTypeError(
503548
errorClass="UNSUPPORTED_DATA_TYPE",

python/pyspark/sql/connect/proto/expressions_pb2.py

Lines changed: 60 additions & 60 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/expressions_pb2.pyi

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ class Expression(google.protobuf.message.Message):
477477
@property
478478
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
479479
"""(Optional) The element type of the array. Only need to set this when the elements are
480-
empty, since spark 4.1+ supports inferring the element type from the elements.
480+
empty or the element type is not inferable, since spark 4.1+ supports
481+
inferring the element type from the elements.
481482
"""
482483
@property
483484
def elements(
@@ -492,25 +493,14 @@ class Expression(google.protobuf.message.Message):
492493
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
493494
) -> None: ...
494495
def HasField(
495-
self,
496-
field_name: typing_extensions.Literal[
497-
"_element_type", b"_element_type", "element_type", b"element_type"
498-
],
496+
self, field_name: typing_extensions.Literal["element_type", b"element_type"]
499497
) -> builtins.bool: ...
500498
def ClearField(
501499
self,
502500
field_name: typing_extensions.Literal[
503-
"_element_type",
504-
b"_element_type",
505-
"element_type",
506-
b"element_type",
507-
"elements",
508-
b"elements",
501+
"element_type", b"element_type", "elements", b"elements"
509502
],
510503
) -> None: ...
511-
def WhichOneof(
512-
self, oneof_group: typing_extensions.Literal["_element_type", b"_element_type"]
513-
) -> typing_extensions.Literal["element_type"] | None: ...
514504

515505
class Map(google.protobuf.message.Message):
516506
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -522,12 +512,14 @@ class Expression(google.protobuf.message.Message):
522512
@property
523513
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
524514
"""(Optional) The key type of the map. Only need to set this when the keys are
525-
empty, since spark 4.1+ supports inferring the key type from the keys
515+
empty or the key type is not inferable, since spark 4.1+ supports
516+
inferring the key type from the keys
526517
"""
527518
@property
528519
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
529520
"""(Optional) The value type of the map. Only need to set this when the values are
530-
empty, since spark 4.1+ supports inferring the value type from the values.
521+
empty or the value type is not inferable, since spark 4.1+ supports
522+
inferring the value type from the values.
531523
"""
532524
@property
533525
def keys(
@@ -552,23 +544,12 @@ class Expression(google.protobuf.message.Message):
552544
def HasField(
553545
self,
554546
field_name: typing_extensions.Literal[
555-
"_key_type",
556-
b"_key_type",
557-
"_value_type",
558-
b"_value_type",
559-
"key_type",
560-
b"key_type",
561-
"value_type",
562-
b"value_type",
547+
"key_type", b"key_type", "value_type", b"value_type"
563548
],
564549
) -> builtins.bool: ...
565550
def ClearField(
566551
self,
567552
field_name: typing_extensions.Literal[
568-
"_key_type",
569-
b"_key_type",
570-
"_value_type",
571-
b"_value_type",
572553
"key_type",
573554
b"key_type",
574555
"keys",
@@ -579,14 +560,6 @@ class Expression(google.protobuf.message.Message):
579560
b"values",
580561
],
581562
) -> None: ...
582-
@typing.overload
583-
def WhichOneof(
584-
self, oneof_group: typing_extensions.Literal["_key_type", b"_key_type"]
585-
) -> typing_extensions.Literal["key_type"] | None: ...
586-
@typing.overload
587-
def WhichOneof(
588-
self, oneof_group: typing_extensions.Literal["_value_type", b"_value_type"]
589-
) -> typing_extensions.Literal["value_type"] | None: ...
590563

591564
class Struct(google.protobuf.message.Message):
592565
DESCRIPTOR: google.protobuf.descriptor.Descriptor

python/pyspark/sql/tests/connect/test_connect_plan.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -979,21 +979,20 @@ def test_literal_expression_with_arrays(self):
979979
self.assertEqual(l0.array.elements[2].string, "z")
980980

981981
l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
982-
self.assertTrue(l1.array.element_type.HasField("integer"))
982+
self.assertFalse(l1.array.element_type.HasField("integer"))
983983
self.assertEqual(len(l1.array.elements), 2)
984984
self.assertEqual(l1.array.elements[0].integer, 3)
985985
self.assertEqual(l1.array.elements[1].integer, -3)
986986

987987
l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
988-
self.assertTrue(l2.array.element_type.HasField("double"))
988+
self.assertFalse(l2.array.element_type.HasField("double"))
989989
self.assertEqual(len(l2.array.elements), 3)
990990
self.assertTrue(math.isnan(l2.array.elements[0].double))
991991
self.assertEqual(l2.array.elements[1].double, -3.0)
992992
self.assertEqual(l2.array.elements[2].double, 0.0)
993993

994994
l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
995-
self.assertTrue(l3.array.element_type.HasField("array"))
996-
self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
995+
self.assertFalse(l3.array.element_type.HasField("array"))
997996
self.assertEqual(len(l3.array.elements), 2)
998997
self.assertEqual(len(l3.array.elements[0].array.elements), 2)
999998
self.assertEqual(len(l3.array.elements[1].array.elements), 3)
@@ -1003,8 +1002,7 @@ def test_literal_expression_with_arrays(self):
10031002
.to_plan(None)
10041003
.literal
10051004
)
1006-
self.assertTrue(l4.array.element_type.HasField("array"))
1007-
self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
1005+
self.assertFalse(l4.array.element_type.HasField("array"))
10081006
self.assertEqual(len(l4.array.elements), 3)
10091007
self.assertEqual(len(l4.array.elements[0].array.elements), 2)
10101008
self.assertEqual(len(l4.array.elements[1].array.elements), 2)

sql/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,18 +216,21 @@ message Expression {
216216

217217
message Array {
218218
// (Optional) The element type of the array. Only need to set this when the elements are
219-
// empty, since spark 4.1+ supports inferring the element type from the elements.
220-
optional DataType element_type = 1;
219+
// empty or the element type is not inferable, since spark 4.1+ supports
220+
// inferring the element type from the elements.
221+
DataType element_type = 1;
221222
repeated Literal elements = 2;
222223
}
223224

224225
message Map {
225226
// (Optional) The key type of the map. Only need to set this when the keys are
226-
// empty, since spark 4.1+ supports inferring the key type from the keys
227-
optional DataType key_type = 1;
227+
// empty or the key type is not inferable, since spark 4.1+ supports
228+
// inferring the key type from the keys
229+
DataType key_type = 1;
228230
// (Optional) The value type of the map. Only need to set this when the values are
229-
// empty, since spark 4.1+ supports inferring the value type from the values.
230-
optional DataType value_type = 2;
231+
// empty or the value type is not inferable, since spark 4.1+ supports
232+
// inferring the value type from the values.
233+
DataType value_type = 2;
231234
repeated Literal keys = 3;
232235
repeated Literal values = 4;
233236
}

0 commit comments

Comments
 (0)