diff --git a/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/ReadByFormatIntegrationTestBase.java b/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/ReadByFormatIntegrationTestBase.java index d89049fe3..53e17bed8 100644 --- a/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/ReadByFormatIntegrationTestBase.java +++ b/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/ReadByFormatIntegrationTestBase.java @@ -16,6 +16,7 @@ package com.google.cloud.spark.bigquery.integration; import static com.google.common.truth.Truth.assertThat; +import static org.apache.spark.sql.functions.*; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assume.assumeThat; import static org.junit.Assume.assumeTrue; @@ -42,6 +43,8 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.expressions.WindowSpec; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -316,6 +319,31 @@ public void testTimestampNTZReadFromBigQuery() { assertThat(row.get(0)).isEqualTo(dateTime); } + @Test + public void testWindowFunctionPartitionBy() { + WindowSpec windowSpec = + Window.partitionBy( + concat( + col("fullVisitorId"), col("visitStartTime"), col("trafficSource.referralPath"))) + .orderBy(lit("window_ordering")); + Dataset dataset = + spark + .read() + .format("bigquery") + .option("table", TestConstants.ANALYTICS_TABLE) + .option("readDataFormat", dataFormat) + .load() + .withColumn("row_num", row_number().over(windowSpec)); + + assertThat(dataset.columns().length).isEqualTo(17); + assertThat( + Arrays.stream(dataset.schema().fields()) + .filter(field -> field.name().equals("row_num")) + .count()) + .isEqualTo(1); + assertThat(dataset.head().get(16)).isEqualTo(1); + } + static Map scalaMapToJavaMap(scala.collection.Map map) { ImmutableMap.Builder result = ImmutableMap.builder(); map.foreach(entry -> result.put(entry._1(), entry._2())); diff --git a/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/TestConstants.java b/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/TestConstants.java index eb8fbffdd..d8af6a00b 100644 --- a/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/TestConstants.java +++ b/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/TestConstants.java @@ -133,6 +133,8 @@ public class TestConstants { static DataType BQ_NUMERIC = DataTypes.createDecimalType(38, 9); static DataType BQ_BIGNUMERIC = DataTypes.createDecimalType(38, 38); public static int BIG_NUMERIC_COLUMN_POSITION = 11; + static final String ANALYTICS_TABLE = + "bigquery-public-data.google_analytics_sample.ga_sessions_20170801"; public static StructType ALL_TYPES_TABLE_SCHEMA = new StructType(