Skip to content

[SPARK-52449][CONNECT][PYTHON][ML] Make datatypes for Expression.Literal.Map/Array optional #51473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 71 additions & 9 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
Optional,
)

import json
import decimal
import datetime
import decimal
import json
import warnings
from threading import Lock

Expand Down Expand Up @@ -377,6 +377,52 @@ def _infer_type(cls, value: Any) -> DataType:
def _from_value(cls, value: Any) -> "LiteralExpression":
return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value))

@classmethod
def _infer_type_from_literal(cls, literal: "proto.Expression.Literal") -> Optional[DataType]:
if literal.HasField("null"):
return NullType()
elif literal.HasField("binary"):
return BinaryType()
elif literal.HasField("boolean"):
return BooleanType()
elif literal.HasField("byte"):
return ByteType()
elif literal.HasField("short"):
return ShortType()
elif literal.HasField("integer"):
return IntegerType()
elif literal.HasField("long"):
return LongType()
elif literal.HasField("float"):
return FloatType()
elif literal.HasField("double"):
return DoubleType()
elif literal.HasField("date"):
return DateType()
elif literal.HasField("timestamp"):
return TimestampType()
elif literal.HasField("timestamp_ntz"):
return TimestampNTZType()
elif literal.HasField("array"):
if literal.array.HasField("element_type"):
return ArrayType(
proto_schema_to_pyspark_data_type(literal.array.element_type), True
)
element_type = None
if len(literal.array.elements) > 0:
element_type = LiteralExpression._infer_type_from_literal(literal.array.elements[0])

if element_type is None:
raise PySparkTypeError(
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
messageParameters={},
)
return ArrayType(element_type, True)
# Not all data types support inferring the data type from the literal at the moment.
# e.g. the type of DayTimeInterval contains extra information like start_field and
# end_field and cannot be inferred from the literal.
return None

@classmethod
def _to_value(
cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None
Expand Down Expand Up @@ -426,10 +472,20 @@ def _to_value(
assert dataType is None or isinstance(dataType, DayTimeIntervalType)
return DayTimeIntervalType().fromInternal(literal.day_time_interval)
elif literal.HasField("array"):
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
if dataType is not None:
assert isinstance(dataType, ArrayType)
assert elementType == dataType.elementType
elementType = None
if literal.array.HasField("element_type"):
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
if dataType is not None:
assert isinstance(dataType, ArrayType)
assert elementType == dataType.elementType
elif len(literal.array.elements) > 0:
elementType = LiteralExpression._infer_type_from_literal(literal.array.elements[0])

if elementType is None:
raise PySparkTypeError(
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
messageParameters={},
)
return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]

raise PySparkTypeError(
Expand Down Expand Up @@ -475,11 +531,17 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
elif isinstance(self._dataType, DayTimeIntervalType):
expr.literal.day_time_interval = int(self._value)
elif isinstance(self._dataType, ArrayType):
element_type = self._dataType.elementType
expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type))
for v in self._value:
expr.literal.array.elements.append(
LiteralExpression(v, element_type).to_plan(session).literal
LiteralExpression(v, self._dataType.elementType).to_plan(session).literal
)
if (
len(self._value) == 0
or LiteralExpression._infer_type_from_literal(expr.literal.array.elements[0])
is None
):
expr.literal.array.element_type.CopyFrom(
pyspark_types_to_proto_types(self._dataType.elementType)
)
else:
raise PySparkTypeError(
Expand Down
18 changes: 15 additions & 3 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,11 @@ class Expression(google.protobuf.message.Message):
ELEMENT_TYPE_FIELD_NUMBER: builtins.int
ELEMENTS_FIELD_NUMBER: builtins.int
@property
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Optional) The element type of the array. Only need to set this when the elements are
empty or the element type is not inferable, since spark 4.1+ supports
inferring the element type from the elements.
"""
@property
def elements(
self,
Expand Down Expand Up @@ -506,9 +510,17 @@ class Expression(google.protobuf.message.Message):
KEYS_FIELD_NUMBER: builtins.int
VALUES_FIELD_NUMBER: builtins.int
@property
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Optional) The key type of the map. Only need to set this when the keys are
empty or the key type is not inferable, since spark 4.1+ supports
inferring the key type from the keys
"""
@property
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Optional) The value type of the map. Only need to set this when the values are
empty or the value type is not inferable, since spark 4.1+ supports
inferring the value type from the values.
"""
@property
def keys(
self,
Expand Down
10 changes: 4 additions & 6 deletions python/pyspark/sql/tests/connect/test_connect_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,21 +979,20 @@ def test_literal_expression_with_arrays(self):
self.assertEqual(l0.array.elements[2].string, "z")

l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
self.assertTrue(l1.array.element_type.HasField("integer"))
self.assertFalse(l1.array.element_type.HasField("integer"))
self.assertEqual(len(l1.array.elements), 2)
self.assertEqual(l1.array.elements[0].integer, 3)
self.assertEqual(l1.array.elements[1].integer, -3)

l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
self.assertTrue(l2.array.element_type.HasField("double"))
self.assertFalse(l2.array.element_type.HasField("double"))
self.assertEqual(len(l2.array.elements), 3)
self.assertTrue(math.isnan(l2.array.elements[0].double))
self.assertEqual(l2.array.elements[1].double, -3.0)
self.assertEqual(l2.array.elements[2].double, 0.0)

l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
self.assertTrue(l3.array.element_type.HasField("array"))
self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
self.assertFalse(l3.array.element_type.HasField("array"))
self.assertEqual(len(l3.array.elements), 2)
self.assertEqual(len(l3.array.elements[0].array.elements), 2)
self.assertEqual(len(l3.array.elements[1].array.elements), 3)
Expand All @@ -1003,8 +1002,7 @@ def test_literal_expression_with_arrays(self):
.to_plan(None)
.literal
)
self.assertTrue(l4.array.element_type.HasField("array"))
self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
self.assertFalse(l4.array.element_type.HasField("array"))
self.assertEqual(len(l4.array.elements), 3)
self.assertEqual(len(l4.array.elements[0].array.elements), 2)
self.assertEqual(len(l4.array.elements[1].array.elements), 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,21 @@ message Expression {
}

message Array {
// (Optional) The element type of the array. Only need to set this when the elements are
// empty or the element type is not inferable, since spark 4.1+ supports
// inferring the element type from the elements.
DataType element_type = 1;
repeated Literal elements = 2;
Copy link
Contributor

@zhengruifeng zhengruifeng Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether it is worthwhile to just optimize out the element_type.
For large arrays of primitive types, e.g. large dense matrix for ML, we introduced SpecializedArray.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a Array[Array[Int]] case, how to infer the nullability ?

Copy link
Contributor Author

@heyihong heyihong Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the nullable field is missing in the array literal? I was thinking of deprecating element_type and introducing a new DataType.Array field so that each array literal includes the nullable field within DataType.Array, for example:

   message Array {
      DataType element_type = 1; [deprecated=true]
      repeated Literal elements = 2;
      DataType.Array data_type_array = 3;
  }

Copy link
Contributor Author

@heyihong heyihong Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengruifeng This change optimizes out both arrays and maps, and also applies to non-primitive types. Also, the reduction in size of function_lit_array.json seems obvious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a separate ticket to track the Protobuf message change: https://issues.apache.org/jira/browse/SPARK-52930

}

message Map {
// (Optional) The key type of the map. Only need to set this when the keys are
// empty or the key type is not inferable, since spark 4.1+ supports
// inferring the key type from the keys
DataType key_type = 1;
// (Optional) The value type of the map. Only need to set this when the values are
// empty or the value type is not inferable, since spark 4.1+ supports
// inferring the value type from the values.
DataType value_type = 2;
repeated Literal keys = 3;
repeated Literal values = 4;
Expand Down
Loading