Skip to content

[SPARK-52153][SQL] Fix from_json and to_json with variant #50901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,15 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb,
case STRING:
sb.append(escapeJson(VariantUtil.getString(value, pos)));
break;
case DOUBLE:
sb.append(VariantUtil.getDouble(value, pos));
case DOUBLE: {
double d = VariantUtil.getDouble(value, pos);
if (Double.isFinite(d)) {
sb.append(d);
} else {
appendQuoted(sb, Double.toString(d));
}
break;
}
case DECIMAL:
sb.append(VariantUtil.getDecimal(value, pos).toPlainString());
break;
Expand All @@ -333,9 +339,15 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb,
appendQuoted(sb, TIMESTAMP_NTZ_FORMATTER.format(
microsToInstant(VariantUtil.getLong(value, pos)).atZone(ZoneOffset.UTC)));
break;
case FLOAT:
sb.append(VariantUtil.getFloat(value, pos));
case FLOAT: {
float f = VariantUtil.getFloat(value, pos);
if (Float.isFinite(f)) {
sb.append(f);
} else {
appendQuoted(sb, Float.toString(f));
}
break;
}
case BINARY:
appendQuoted(sb, Base64.getEncoder().encodeToString(VariantUtil.getBinary(value, pos)));
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import com.fasterxml.jackson.core.json.JsonReadFeature
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow}
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonGenerator, JacksonParser, JsonInferSchema, JSONOptions}
import org.apache.spark.sql.catalyst.util.{ArrayData, FailFastMode, FailureSafeParser, MapData, PermissiveMode}
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand Down Expand Up @@ -123,6 +122,8 @@ case class JsonToStructsEvaluator(
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null
case _: MapType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null
case _: VariantType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getVariant(0) else null
}

@transient
Expand Down Expand Up @@ -152,13 +153,7 @@ case class JsonToStructsEvaluator(

final def evaluate(json: UTF8String): Any = {
if (json == null) return null
nullableSchema match {
case _: VariantType =>
VariantExpressionEvalUtils.parseJson(json,
allowDuplicateKeys = variantAllowDuplicateKeys)
case _ =>
converter(parser.parse(json))
}
converter(parser.parse(json))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class JacksonParser(
*/
private def makeRootConverter(dt: DataType): JsonParser => Iterable[InternalRow] = {
dt match {
case _: VariantType => (parser: JsonParser) => {
Some(InternalRow(parseVariant(parser)))
}
case _: StructType if options.singleVariantColumn.isDefined => (parser: JsonParser) => {
Some(InternalRow(parseVariant(parser)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarArray
import org.apache.spark.types.variant.VariantBuilder
import org.apache.spark.types.variant.VariantUtil._
import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}

class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
import testImplicits._
Expand Down Expand Up @@ -101,6 +101,26 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
)
// scalastyle:on nonascii
check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")

// Validate options work.
checkAnswer(Seq("""{"a": NaN}""").toDF("v")
.selectExpr("from_json(v, 'variant', map('allowNonNumericNumbers', 'false'))"), Row(null))
checkAnswer(Seq("""{"a": NaN}""").toDF("v")
.selectExpr("from_json(v, 'variant', map('allowNonNumericNumbers', 'true'))"),
Row(
VariantExpressionEvalUtils.castToVariant(InternalRow(Double.NaN),
StructType.fromDDL("a double"))))
// String input "NaN" will remain a string instead of double.
checkAnswer(Seq("""{"a": "NaN"}""").toDF("v")
.selectExpr("from_json(v, 'variant', map('allowNonNumericNumbers', 'true'))"),
Row(
VariantExpressionEvalUtils.castToVariant(InternalRow(UTF8String.fromString("NaN")),
StructType.fromDDL("a string"))))
// to_json should put special floating point values in quotes.
checkAnswer(Seq("""{"a": NaN}""").toDF("v")
.selectExpr("to_json(from_json(v, 'variant', map('allowNonNumericNumbers', 'true')))"),
Row("""{"a":"NaN"}"""))

}

test("try_parse_json/to_json round-trip") {
Expand Down Expand Up @@ -346,6 +366,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {

test("from_json(_, 'variant') with duplicate keys") {
val json: String = """{"a": 1, "b": 2, "c": "3", "a": 4}"""

withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "true") {
val df = Seq(json).toDF("j")
.selectExpr("from_json(j,'variant')")
Expand All @@ -359,24 +380,25 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
val expectedMetadata: Array[Byte] = Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c')
assert(actual === new VariantVal(expectedValue, expectedMetadata))
}
// Check whether the parse_json and from_json expressions throw the correct exception.
Seq("from_json(j, 'variant')", "parse_json(j)").foreach { expr =>
withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") {
val df = Seq(json).toDF("j").selectExpr(expr)
val exception = intercept[SparkException] {
df.collect()
}
checkError(
exception = exception,
condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION",
parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST")
)
checkError(
exception = exception.getCause.asInstanceOf[SparkRuntimeException],
condition = "VARIANT_DUPLICATE_KEY",
parameters = Map("key" -> "a")
)

withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") {
// In default mode (PERMISSIVE), JSON with duplicate keys is still invalid, but no error will
// be thrown.
checkAnswer(Seq(json).toDF("j").selectExpr("from_json(j, 'variant')"), Row(null))

val exception = intercept[SparkException] {
Seq(json).toDF("j").selectExpr("from_json(j, 'variant', map('mode', 'FAILFAST'))").collect()
}
checkError(
exception = exception,
condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION",
parameters = Map("badRecord" -> "[null]", "failFastMode" -> "FAILFAST")
)
checkError(
exception = exception.getCause.asInstanceOf[SparkRuntimeException],
condition = "VARIANT_DUPLICATE_KEY",
parameters = Map("key" -> "a")
)
}
}

Expand Down