@@ -64,7 +64,7 @@ object LiteralValueProtoConverter {
64
64
def arrayBuilder (array : Array [_]) = {
65
65
val ab = builder.getArrayBuilder
66
66
array.foreach(x => ab.addElements(toLiteralProto(x)))
67
- if (ab.getElementsCount == 0 ) {
67
+ if (ab.getElementsCount == 0 || getInferredDataType(ab.getElementsList.get( 0 )).isEmpty ) {
68
68
ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
69
69
}
70
70
ab
@@ -129,7 +129,7 @@ object LiteralValueProtoConverter {
129
129
throw new IllegalArgumentException (s " literal $other not supported (yet). " )
130
130
}
131
131
132
- if (ab.getElementsCount == 0 ) {
132
+ if (ab.getElementsCount == 0 || getInferredDataType(ab.getElementsList.get( 0 )).isEmpty ) {
133
133
ab.setElementType(toConnectProtoType(elementType))
134
134
}
135
135
@@ -138,8 +138,6 @@ object LiteralValueProtoConverter {
138
138
139
139
def mapBuilder (scalaValue : Any , keyType : DataType , valueType : DataType ) = {
140
140
val mb = builder.getMapBuilder
141
- .setKeyType(toConnectProtoType(keyType))
142
- .setValueType(toConnectProtoType(valueType))
143
141
144
142
scalaValue match {
145
143
case map : scala.collection.Map [_, _] =>
@@ -151,11 +149,11 @@ object LiteralValueProtoConverter {
151
149
throw new IllegalArgumentException (s " literal $other not supported (yet). " )
152
150
}
153
151
154
- if (mb.getKeysCount == 0 ) {
152
+ if (mb.getKeysCount == 0 || getInferredDataType(mb.getKeysList.get( 0 )).isEmpty ) {
155
153
mb.setKeyType(toConnectProtoType(keyType))
156
154
}
157
155
158
- if (mb.getValuesCount == 0 ) {
156
+ if (mb.getValuesCount == 0 || getInferredDataType(mb.getValuesList.get( 0 )).isEmpty ) {
159
157
mb.setValueType(toConnectProtoType(valueType))
160
158
}
161
159
@@ -370,7 +368,11 @@ object LiteralValueProtoConverter {
370
368
}
371
369
}
372
370
373
- private def getOuterDataType (literal : proto.Expression .Literal ): proto.DataType = {
371
+ private def getInferredDataType (literal : proto.Expression .Literal ): Option [proto.DataType ] = {
372
+ if (literal.hasNull) {
373
+ return Some (literal.getNull)
374
+ }
375
+
374
376
val builder = proto.DataType .newBuilder()
375
377
literal.getLiteralTypeCase match {
376
378
case proto.Expression .Literal .LiteralTypeCase .BINARY =>
@@ -389,22 +391,14 @@ object LiteralValueProtoConverter {
389
391
builder.setFloat(proto.DataType .Float .newBuilder.build())
390
392
case proto.Expression .Literal .LiteralTypeCase .DOUBLE =>
391
393
builder.setDouble(proto.DataType .Double .newBuilder.build())
392
- case proto.Expression .Literal .LiteralTypeCase .STRING =>
393
- builder.setString(proto.DataType .String .newBuilder.build())
394
394
case proto.Expression .Literal .LiteralTypeCase .DATE =>
395
395
builder.setDate(proto.DataType .Date .newBuilder.build())
396
396
case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP =>
397
397
builder.setTimestamp(proto.DataType .Timestamp .newBuilder.build())
398
398
case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP_NTZ =>
399
399
builder.setTimestampNtz(proto.DataType .TimestampNTZ .newBuilder.build())
400
- case proto.Expression .Literal .LiteralTypeCase .YEAR_MONTH_INTERVAL =>
401
- builder.setYearMonthInterval(proto.DataType .YearMonthInterval .newBuilder.build())
402
- case proto.Expression .Literal .LiteralTypeCase .DECIMAL =>
403
- builder.setDecimal(proto.DataType .Decimal .newBuilder.build())
404
400
case proto.Expression .Literal .LiteralTypeCase .CALENDAR_INTERVAL =>
405
401
builder.setCalendarInterval(proto.DataType .CalendarInterval .newBuilder.build())
406
- case proto.Expression .Literal .LiteralTypeCase .DAY_TIME_INTERVAL =>
407
- builder.setDayTimeInterval(proto.DataType .DayTimeInterval .newBuilder.build())
408
402
case proto.Expression .Literal .LiteralTypeCase .ARRAY =>
409
403
// Element type will be inferred from the elements in the array.
410
404
builder.setArray(proto.DataType .Array .newBuilder.build())
@@ -413,9 +407,19 @@ object LiteralValueProtoConverter {
413
407
builder.setMap(proto.DataType .Map .newBuilder.build())
414
408
case proto.Expression .Literal .LiteralTypeCase .STRUCT =>
415
409
builder.setStruct(literal.getStruct.getStructType.getStruct)
416
- case _ => throw InvalidPlanInput (s " Unsupported Literal Type: ${literal.getLiteralTypeCase}" )
410
+ case _ =>
411
+ // Not all data types support inferring the data type from the literal.
412
+ // e.g. the type of DayTimeInterval contains extra information like start_field and
413
+ // end_field and cannot be inferred from the literal.
414
+ return None
415
+ }
416
+ Some (builder.build())
417
+ }
418
+
419
+ private def getInferredDataTypeOrThrow (literal : proto.Expression .Literal ): proto.DataType = {
420
+ getInferredDataType(literal).getOrElse {
421
+ throw InvalidPlanInput (s " Unsupported Literal Type: ${literal.getLiteralTypeCase}" )
417
422
}
418
- builder.build()
419
423
}
420
424
421
425
def toCatalystArray (
@@ -437,7 +441,7 @@ object LiteralValueProtoConverter {
437
441
protoArrayType(array.getElementType)
438
442
} else if (iter.hasNext) {
439
443
val firstElement = iter.next()
440
- val outerElementType = getOuterDataType (firstElement)
444
+ val outerElementType = getInferredDataTypeOrThrow (firstElement)
441
445
val (elem, inferredElementType) =
442
446
getConverter(outerElementType, inferDataType = true )(firstElement) match {
443
447
case LiteralValueWithDataType (elem, dataType) => (elem, dataType)
@@ -479,15 +483,23 @@ object LiteralValueProtoConverter {
479
483
protoMapType(map.getKeyType, map.getValueType)
480
484
} else if (iter.hasNext) {
481
485
val (key, value) = iter.next()
482
- val outerKeyType = getOuterDataType(key)
486
+ val (outerKeyType, inferKeyType) = if (map.hasKeyType) {
487
+ (map.getKeyType, false )
488
+ } else {
489
+ (getInferredDataTypeOrThrow(key), true )
490
+ }
483
491
val (catalystKey, inferredKeyType) =
484
- getConverter(outerKeyType, inferDataType = true )(key) match {
492
+ getConverter(outerKeyType, inferDataType = inferKeyType )(key) match {
485
493
case LiteralValueWithDataType (catalystKey, dataType) => (catalystKey, dataType)
486
494
case catalystKey => (catalystKey, outerKeyType)
487
495
}
488
- val outerValueType = getOuterDataType(value)
496
+ val (outerValueType, inferValueType) = if (map.hasValueType) {
497
+ (map.getValueType, false )
498
+ } else {
499
+ (getInferredDataTypeOrThrow(value), true )
500
+ }
489
501
val (catalystValue, inferredValueType) =
490
- getConverter(outerValueType, inferDataType = true )(value) match {
502
+ getConverter(outerValueType, inferDataType = inferValueType )(value) match {
491
503
case LiteralValueWithDataType (catalystValue, dataType) => (catalystValue, dataType)
492
504
case catalystValue => (catalystValue, outerValueType)
493
505
}
0 commit comments