diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index fe6809d4aeeca..0e2c4816a3c1a 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -316,9 +316,15 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb, case STRING: sb.append(escapeJson(VariantUtil.getString(value, pos))); break; - case DOUBLE: - sb.append(VariantUtil.getDouble(value, pos)); + case DOUBLE: { + double d = VariantUtil.getDouble(value, pos); + if (Double.isFinite(d)) { + sb.append(d); + } else { + appendQuoted(sb, Double.toString(d)); + } break; + } case DECIMAL: sb.append(VariantUtil.getDecimal(value, pos).toPlainString()); break; @@ -333,9 +339,15 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb, appendQuoted(sb, TIMESTAMP_NTZ_FORMATTER.format( microsToInstant(VariantUtil.getLong(value, pos)).atZone(ZoneOffset.UTC))); break; - case FLOAT: - sb.append(VariantUtil.getFloat(value, pos)); + case FLOAT: { + float f = VariantUtil.getFloat(value, pos); + if (Float.isFinite(f)) { + sb.append(f); + } else { + appendQuoted(sb, Float.toString(f)); + } break; + } case BINARY: appendQuoted(sb, Base64.getEncoder().encodeToString(VariantUtil.getBinary(value, pos))); break; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala index 641f22ba3f786..b942006e87e9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -26,7 +26,6 @@ import com.fasterxml.jackson.core.json.JsonReadFeature import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} -import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonGenerator, JacksonParser, JsonInferSchema, JSONOptions} import org.apache.spark.sql.catalyst.util.{ArrayData, FailFastMode, FailureSafeParser, MapData, PermissiveMode} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -123,6 +122,8 @@ case class JsonToStructsEvaluator( (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null case _: MapType => (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null + case _: VariantType => + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getVariant(0) else null } @transient @@ -152,13 +153,7 @@ case class JsonToStructsEvaluator( final def evaluate(json: UTF8String): Any = { if (json == null) return null - nullableSchema match { - case _: VariantType => - VariantExpressionEvalUtils.parseJson(json, - allowDuplicateKeys = variantAllowDuplicateKeys) - case _ => - converter(parser.parse(json)) - } + converter(parser.parse(json)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 1cd4b4cd29bcf..849946f6a2a55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -108,6 +108,9 @@ class JacksonParser( */ private def makeRootConverter(dt: DataType): JsonParser => Iterable[InternalRow] = { dt match { + case _: VariantType => (parser: JsonParser) => { + Some(InternalRow(parseVariant(parser))) + } case _: StructType if options.singleVariantColumn.isDefined => (parser: JsonParser) => { Some(InternalRow(parseVariant(parser))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index cf83ae30a6798..a40e34d94d085 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.types.variant.VariantBuilder import org.apache.spark.types.variant.VariantUtil._ -import org.apache.spark.unsafe.types.VariantVal +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} class VariantEndToEndSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -101,6 +101,26 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { ) // scalastyle:on nonascii check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]") + + // Validate options work. + checkAnswer(Seq("""{"a": NaN}""").toDF("v") + .selectExpr("from_json(v, 'variant', map('allowNonNumericNumbers', 'false'))"), Row(null)) + checkAnswer(Seq("""{"a": NaN}""").toDF("v") + .selectExpr("from_json(v, 'variant', map('allowNonNumericNumbers', 'true'))"), + Row( + VariantExpressionEvalUtils.castToVariant(InternalRow(Double.NaN), + StructType.fromDDL("a double")))) + // String input "NaN" will remain a string instead of double. + checkAnswer(Seq("""{"a": "NaN"}""").toDF("v") + .selectExpr("from_json(v, 'variant', map('allowNonNumericNumbers', 'true'))"), + Row( + VariantExpressionEvalUtils.castToVariant(InternalRow(UTF8String.fromString("NaN")), + StructType.fromDDL("a string")))) + // to_json should put special floating point values in quotes. + checkAnswer(Seq("""{"a": NaN}""").toDF("v") + .selectExpr("to_json(from_json(v, 'variant', map('allowNonNumericNumbers', 'true')))"), + Row("""{"a":"NaN"}""")) + } test("try_parse_json/to_json round-trip") { @@ -346,6 +366,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { test("from_json(_, 'variant') with duplicate keys") { val json: String = """{"a": 1, "b": 2, "c": "3", "a": 4}""" + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "true") { val df = Seq(json).toDF("j") .selectExpr("from_json(j,'variant')") @@ -359,24 +380,25 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { val expectedMetadata: Array[Byte] = Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c') assert(actual === new VariantVal(expectedValue, expectedMetadata)) } - // Check whether the parse_json and from_json expressions throw the correct exception. - Seq("from_json(j, 'variant')", "parse_json(j)").foreach { expr => - withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") { - val df = Seq(json).toDF("j").selectExpr(expr) - val exception = intercept[SparkException] { - df.collect() - } - checkError( - exception = exception, - condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", - parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST") - ) - checkError( - exception = exception.getCause.asInstanceOf[SparkRuntimeException], - condition = "VARIANT_DUPLICATE_KEY", - parameters = Map("key" -> "a") - ) + + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") { + // In default mode (PERMISSIVE), JSON with duplicate keys is still invalid, but no error will + // be thrown. + checkAnswer(Seq(json).toDF("j").selectExpr("from_json(j, 'variant')"), Row(null)) + + val exception = intercept[SparkException] { + Seq(json).toDF("j").selectExpr("from_json(j, 'variant', map('mode', 'FAILFAST'))").collect() } + checkError( + exception = exception, + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map("badRecord" -> "[null]", "failFastMode" -> "FAILFAST") + ) + checkError( + exception = exception.getCause.asInstanceOf[SparkRuntimeException], + condition = "VARIANT_DUPLICATE_KEY", + parameters = Map("key" -> "a") + ) } }