Skip to content

Commit 021e7d0

Browse files
moving all Seq references to the Scala specific SparkSqlUtils implemntaiton so it wouldn't leak
1 parent ba2bbeb commit 021e7d0

File tree

17 files changed

+166
-154
lines changed

17 files changed

+166
-154
lines changed

spark-bigquery-connector-common/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import com.google.cloud.bigquery.storage.v1.ProtoRows;
2323
import com.google.cloud.bigquery.storage.v1.ProtoSchema;
2424
import com.google.cloud.bigquery.storage.v1.ProtoSchemaConverter;
25-
import com.google.cloud.spark.bigquery.util.ScalaUtils;
2625
import com.google.common.annotations.VisibleForTesting;
2726
import com.google.common.base.Preconditions;
2827
import com.google.common.collect.ImmutableMap;
@@ -35,9 +34,28 @@
3534
import org.apache.spark.bigquery.BigNumericUDT;
3635
import org.apache.spark.bigquery.BigQueryDataTypes;
3736
import org.apache.spark.sql.Row;
37+
import org.apache.spark.sql.SparkSqlUtils;
3838
import org.apache.spark.sql.catalyst.InternalRow;
3939
import org.apache.spark.sql.catalyst.util.ArrayData;
40-
import org.apache.spark.sql.types.*;
40+
import org.apache.spark.sql.types.ArrayType;
41+
import org.apache.spark.sql.types.BinaryType;
42+
import org.apache.spark.sql.types.BooleanType;
43+
import org.apache.spark.sql.types.ByteType;
44+
import org.apache.spark.sql.types.DataType;
45+
import org.apache.spark.sql.types.DataTypes;
46+
import org.apache.spark.sql.types.DateType;
47+
import org.apache.spark.sql.types.Decimal;
48+
import org.apache.spark.sql.types.DecimalType;
49+
import org.apache.spark.sql.types.DoubleType;
50+
import org.apache.spark.sql.types.FloatType;
51+
import org.apache.spark.sql.types.IntegerType;
52+
import org.apache.spark.sql.types.LongType;
53+
import org.apache.spark.sql.types.MapType;
54+
import org.apache.spark.sql.types.ShortType;
55+
import org.apache.spark.sql.types.StringType;
56+
import org.apache.spark.sql.types.StructField;
57+
import org.apache.spark.sql.types.StructType;
58+
import org.apache.spark.sql.types.TimestampType;
4159
import org.slf4j.Logger;
4260
import org.slf4j.LoggerFactory;
4361
import scala.collection.mutable.IndexedSeq;
@@ -354,7 +372,7 @@ private static Object convertSparkValueToProtoRowValue(
354372
if (sparkType instanceof StructType) {
355373
InternalRow internalRow = null;
356374
if (sparkValue instanceof Row) {
357-
internalRow = ScalaUtils.getInstance().rowToInternalRow((Row) sparkValue);
375+
internalRow = SparkSqlUtils.getInstance().rowToInternalRow((Row) sparkValue);
358376
} else {
359377
internalRow = (InternalRow) sparkValue;
360378
}

spark-bigquery-connector-common/src/main/java/com/google/cloud/spark/bigquery/util/Scala213Utils.java

Lines changed: 0 additions & 39 deletions
This file was deleted.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.google.cloud.spark.bigquery.util;
2+
3+
public class ScalaVersionSupport {}

spark-bigquery-connector-common/src/main/java/com/google/cloud/spark/bigquery/write/DataSourceWriterContextPartitionHandler.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.google.cloud.spark.bigquery.write;
22

3-
import com.google.cloud.spark.bigquery.util.ScalaUtils;
43
import com.google.cloud.spark.bigquery.write.context.DataWriterContext;
54
import com.google.cloud.spark.bigquery.write.context.DataWriterContextFactory;
65
import com.google.cloud.spark.bigquery.write.context.WriterCommitMessageContext;
@@ -11,6 +10,7 @@
1110
import org.apache.spark.TaskContext;
1211
import org.apache.spark.api.java.function.Function2;
1312
import org.apache.spark.sql.Row;
13+
import org.apache.spark.sql.SparkSqlUtils;
1414
import org.apache.spark.sql.catalyst.InternalRow;
1515
import org.slf4j.Logger;
1616
import org.slf4j.LoggerFactory;
@@ -42,7 +42,7 @@ public Iterator<WriterCommitMessageContext> call(Integer partitionId, Iterator<R
4242
try {
4343
while (rowIterator.hasNext()) {
4444
Row row = rowIterator.next();
45-
InternalRow internalRow = ScalaUtils.getInstance().rowToInternalRow(row);
45+
InternalRow internalRow = SparkSqlUtils.getInstance().rowToInternalRow(row);
4646
dataWriterContext.write(internalRow);
4747
}
4848
return Iterators.forArray(dataWriterContext.commit());
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright 2022 Google Inc. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.apache.spark.sql;
17+
18+
import java.util.List;
19+
import java.util.stream.Collectors;
20+
import java.util.stream.Stream;
21+
import org.apache.spark.sql.catalyst.InternalRow;
22+
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer$;
23+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
24+
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
25+
import org.apache.spark.sql.catalyst.expressions.Attribute;
26+
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
27+
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
28+
import org.apache.spark.sql.types.StructType;
29+
import scala.collection.JavaConverters;
30+
import scala.collection.mutable.ListBuffer;
31+
32+
public class Scala213SparkSqlUtils extends SparkSqlUtils {
33+
34+
@Override
35+
public boolean supportsScalaVersion(String scalaVersion) {
36+
return scalaVersion.compareTo("2.13") > 0;
37+
}
38+
39+
@Override
40+
public InternalRow rowToInternalRow(Row row) {
41+
return InternalRow.fromSeq(row.toSeq());
42+
}
43+
44+
@Override
45+
public ExpressionEncoder<Row> createExpressionEncoder(StructType schema) {
46+
List<Attribute> attributes =
47+
JavaConverters.asJavaCollection(toAttributes(schema)).stream()
48+
.map(Attribute::toAttribute)
49+
.collect(Collectors.toList());
50+
ExpressionEncoder<Row> expressionEncoder =
51+
RowEncoder.apply(schema)
52+
.resolveAndBind(
53+
JavaConverters.asScalaIteratorConverter(attributes.iterator()).asScala().toSeq(),
54+
SimpleAnalyzer$.MODULE$);
55+
return expressionEncoder;
56+
}
57+
58+
// `toAttributes` is protected[sql] starting spark 3.2.0, so we need this call to be in the same
59+
// package. Since Scala 2.13/Spark 3.3 forbids it, the implementation has been ported to Java
60+
public static scala.collection.immutable.Seq<AttributeReference> toAttributes(StructType schema) {
61+
List<AttributeReference> result =
62+
Stream.of(schema.fields())
63+
.map(
64+
field ->
65+
new AttributeReference(
66+
field.name(),
67+
field.dataType(),
68+
field.nullable(),
69+
field.metadata(),
70+
NamedExpression.newExprId(),
71+
new ListBuffer<String>().toSeq()))
72+
.collect(Collectors.toList());
73+
return JavaConverters.asScalaBuffer(result).toSeq();
74+
}
75+
}
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,21 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
package com.google.cloud.spark.bigquery.util;
16+
package org.apache.spark.sql;
1717

1818
import com.google.common.collect.Streams;
1919
import java.util.ServiceLoader;
20-
import org.apache.spark.sql.Row;
2120
import org.apache.spark.sql.catalyst.InternalRow;
22-
import scala.collection.immutable.Seq;
21+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
22+
import org.apache.spark.sql.types.StructType;
2323

24-
public abstract class ScalaUtils {
25-
private static ScalaUtils instance;
24+
public abstract class SparkSqlUtils {
25+
private static SparkSqlUtils instance;
2626

27-
public static ScalaUtils getInstance() {
27+
public static SparkSqlUtils getInstance() {
2828
String scalaVersion = scala.util.Properties.versionNumberString();
2929
if (instance == null) {
30-
ServiceLoader<ScalaUtils> serviceLoader = ServiceLoader.load(ScalaUtils.class);
30+
ServiceLoader<SparkSqlUtils> serviceLoader = ServiceLoader.load(SparkSqlUtils.class);
3131
instance =
3232
Streams.stream(serviceLoader.iterator())
3333
.filter(s -> s.supportsScalaVersion(scalaVersion))
@@ -37,14 +37,14 @@ public static ScalaUtils getInstance() {
3737
new IllegalArgumentException(
3838
String.format(
3939
"Could not load instance of [%], please check the META-INF/services directory in the connector's jar",
40-
ScalaUtils.class.getCanonicalName())));
40+
SparkSqlUtils.class.getCanonicalName())));
4141
}
4242
return instance;
4343
}
4444

4545
public abstract boolean supportsScalaVersion(String scalaVersion);
4646

47-
public abstract <T> Seq<T> emptySeq(Class<T> clazz);
48-
4947
public abstract InternalRow rowToInternalRow(Row row);
48+
49+
public abstract ExpressionEncoder<Row> createExpressionEncoder(StructType schema);
5050
}

spark-bigquery-connector-common/src/main/resources/META-INF/services/com.google.cloud.spark.bigquery.util.ScalaUtils

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
org.apache.spark.sql.Scala213SparkSqlUtils
Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,23 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
package com.google.cloud.spark.bigquery.util;
16+
package org.apache.spark.sql;
1717

1818
import static com.google.common.truth.Truth.assertThat;
1919

20-
import org.apache.spark.sql.Row;
2120
import org.apache.spark.sql.catalyst.InternalRow;
2221
import org.apache.spark.sql.catalyst.expressions.GenericRow;
2322
import org.apache.spark.unsafe.types.UTF8String;
2423
import org.junit.Test;
25-
import scala.collection.immutable.Seq;
2624

27-
public class Scala213UtilsTest {
28-
29-
@Test
30-
public void testRowToSeq() throws Exception {
31-
ScalaUtils su = ScalaUtils.getInstance();
32-
assertThat(su).isInstanceOf(Scala213Utils.class);
33-
Seq<Object> seq = su.emptySeq(Object.class);
34-
assertThat(seq.size()).isEqualTo(0);
35-
}
25+
public class Scala213SparkSqlUtilsTest {
3626

3727
@Test
3828
public void testRowToInternalRow() throws Exception {
39-
ScalaUtils su = ScalaUtils.getInstance();
40-
assertThat(su).isInstanceOf(Scala213Utils.class);
29+
SparkSqlUtils ssu = SparkSqlUtils.getInstance();
30+
assertThat(ssu).isInstanceOf(Scala213SparkSqlUtils.class);
4131
Row row = new GenericRow(new Object[] {UTF8String.fromString("a"), 1});
42-
InternalRow internalRow = su.rowToInternalRow(row);
32+
InternalRow internalRow = ssu.rowToInternalRow(row);
4333
assertThat(internalRow.numFields()).isEqualTo(2);
4434
assertThat(internalRow.getString(0).toString()).isEqualTo("a");
4535
assertThat(internalRow.getInt(1)).isEqualTo(1);

spark-bigquery-connector-common/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/AvroSchemaConverter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*/
1616
package com.google.cloud.spark.bigquery;
1717

18-
import com.google.cloud.spark.bigquery.util.ScalaUtils;
1918
import com.google.common.base.Preconditions;
2019
import org.apache.avro.Conversions;
2120
import org.apache.avro.LogicalTypes;
@@ -24,6 +23,7 @@
2423
import org.apache.avro.generic.GenericData;
2524
import org.apache.avro.util.Utf8;
2625
import org.apache.spark.sql.Row;
26+
import org.apache.spark.sql.SparkSqlUtils;
2727
import org.apache.spark.sql.catalyst.InternalRow;
2828
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters;
2929
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
@@ -306,7 +306,7 @@ static Converter createConverterFor(DataType sparkType, Schema avroType) {
306306
} else {
307307
Object obj = getter.get(ordinal, /* unused */ null);
308308
if (obj instanceof Row) {
309-
internalRow = ScalaUtils.getInstance().rowToInternalRow((Row) obj);
309+
internalRow = SparkSqlUtils.getInstance().rowToInternalRow((Row) obj);
310310
} else {
311311
internalRow = (InternalRow) obj;
312312
}

0 commit comments

Comments
 (0)