From 314ff9ba5bd89e3476e7d4079edc418514c9faf3 Mon Sep 17 00:00:00 2001 From: Andrei Ionescu Date: Wed, 24 Feb 2021 15:50:51 +0200 Subject: [PATCH 1/3] Add support for building index on nested fields --- .../hyperspace/actions/CreateAction.scala | 10 +- .../hyperspace/actions/CreateActionBase.scala | 37 +- .../hyperspace/index/IndexConstants.scala | 3 + .../hyperspace/index/IndexLogEntry.scala | 5 + .../index/rules/FilterIndexRule.scala | 113 +++- .../hyperspace/index/rules/RuleUtils.scala | 120 ++++- .../FileBasedSourceProviderManager.scala | 40 +- .../hyperspace/util/SchemaUtils.scala | 60 +++ .../hyperspace/SampleNestedData.scala | 66 +++ .../index/CreateIndexNestedTest.scala | 196 +++++++ .../index/HybridScanForNestedFieldsTest.scala | 447 ++++++++++++++++ .../index/RefreshIndexNestedTest.scala | 498 ++++++++++++++++++ .../hyperspace/util/SchemaUtilsTest.scala | 203 +++++++ 13 files changed, 1766 insertions(+), 32 deletions(-) create mode 100644 src/main/scala/com/microsoft/hyperspace/util/SchemaUtils.scala create mode 100644 src/test/scala/com/microsoft/hyperspace/SampleNestedData.scala create mode 100644 src/test/scala/com/microsoft/hyperspace/index/CreateIndexNestedTest.scala create mode 100644 src/test/scala/com/microsoft/hyperspace/index/HybridScanForNestedFieldsTest.scala create mode 100644 src/test/scala/com/microsoft/hyperspace/index/RefreshIndexNestedTest.scala create mode 100644 src/test/scala/com/microsoft/hyperspace/util/SchemaUtilsTest.scala diff --git a/src/main/scala/com/microsoft/hyperspace/actions/CreateAction.scala b/src/main/scala/com/microsoft/hyperspace/actions/CreateAction.scala index b81deed86..48c682750 100644 --- a/src/main/scala/com/microsoft/hyperspace/actions/CreateAction.scala +++ b/src/main/scala/com/microsoft/hyperspace/actions/CreateAction.scala @@ -25,7 +25,7 @@ import com.microsoft.hyperspace.{Hyperspace, HyperspaceException} import com.microsoft.hyperspace.actions.Constants.States.{ACTIVE, CREATING, DOESNOTEXIST} import com.microsoft.hyperspace.index._ import com.microsoft.hyperspace.telemetry.{AppInfo, CreateActionEvent, HyperspaceEvent} -import com.microsoft.hyperspace.util.ResolverUtils +import com.microsoft.hyperspace.util.{ResolverUtils, SchemaUtils} class CreateAction( spark: SparkSession, @@ -65,9 +65,15 @@ class CreateAction( } private def isValidIndexSchema(config: IndexConfig, schema: StructType): Boolean = { + // Flatten the schema to support nested fields + val fields = SchemaUtils.escapeFieldNames(SchemaUtils.flatten(schema)) // Resolve index config columns from available column names present in the schema. ResolverUtils - .resolve(spark, config.indexedColumns ++ config.includedColumns, schema.fieldNames) + .resolve( + spark, + SchemaUtils.escapeFieldNames(config.indexedColumns) + ++ SchemaUtils.escapeFieldNames(config.includedColumns), + fields) .isDefined } diff --git a/src/main/scala/com/microsoft/hyperspace/actions/CreateActionBase.scala b/src/main/scala/com/microsoft/hyperspace/actions/CreateActionBase.scala index 08d875e37..732f9ce27 100644 --- a/src/main/scala/com/microsoft/hyperspace/actions/CreateActionBase.scala +++ b/src/main/scala/com/microsoft/hyperspace/actions/CreateActionBase.scala @@ -19,13 +19,15 @@ package com.microsoft.hyperspace.actions import org.apache.hadoop.fs.Path import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LeafNode -import org.apache.spark.sql.functions.input_file_name +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.functions.{col, input_file_name} +import org.apache.spark.sql.types.StructType import com.microsoft.hyperspace.{Hyperspace, HyperspaceException} import com.microsoft.hyperspace.index._ import com.microsoft.hyperspace.index.DataFrameWriterExtensions.Bucketizer import com.microsoft.hyperspace.index.sources.FileBasedRelation -import com.microsoft.hyperspace.util.{HyperspaceConf, PathUtils, ResolverUtils} +import com.microsoft.hyperspace.util.{HyperspaceConf, PathUtils, ResolverUtils, SchemaUtils} /** * CreateActionBase provides functionality to write dataframe as covering index. @@ -73,7 +75,8 @@ private[actions] abstract class CreateActionBase(dataManager: IndexDataManager) LogicalPlanFingerprint.Properties(Seq(Signature(signatureProvider.name, s))))) val coveringIndexProperties = - (hasLineageProperty(spark) ++ hasParquetAsSourceFormatProperty(relation)).toMap + (hasLineageProperty(spark) ++ hasParquetAsSourceFormatProperty(relation) ++ + usesNestedFieldsProperty(indexConfig)).toMap IndexLogEntry( indexConfig.indexName, @@ -109,6 +112,14 @@ private[actions] abstract class CreateActionBase(dataManager: IndexDataManager) } } + private def usesNestedFieldsProperty(indexConfig: IndexConfig): Option[(String, String)] = { + if (SchemaUtils.hasNestedFields(indexConfig.indexedColumns ++ indexConfig.includedColumns)) { + Some(IndexConstants.USES_NESTED_FIELDS_PROPERTY -> "true") + } else { + None + } + } + protected def write(spark: SparkSession, df: DataFrame, indexConfig: IndexConfig): Unit = { val numBuckets = numBucketsForIndex(spark) @@ -117,7 +128,7 @@ private[actions] abstract class CreateActionBase(dataManager: IndexDataManager) // run job val repartitionedIndexDataFrame = - indexDataFrame.repartition(numBuckets, resolvedIndexedColumns.map(df(_)): _*) + indexDataFrame.repartition(numBuckets, resolvedIndexedColumns.map(c => col(s"$c")): _*) // Save the index with the number of buckets specified. repartitionedIndexDataFrame.write @@ -144,9 +155,9 @@ private[actions] abstract class CreateActionBase(dataManager: IndexDataManager) df: DataFrame, indexConfig: IndexConfig): (Seq[String], Seq[String]) = { val spark = df.sparkSession - val dfColumnNames = df.schema.fieldNames - val indexedColumns = indexConfig.indexedColumns - val includedColumns = indexConfig.includedColumns + val dfColumnNames = SchemaUtils.flatten(df.schema) + val indexedColumns = SchemaUtils.unescapeFieldNames(indexConfig.indexedColumns) + val includedColumns = SchemaUtils.unescapeFieldNames(indexConfig.includedColumns) val resolvedIndexedColumns = ResolverUtils.resolve(spark, indexedColumns, dfColumnNames) val resolvedIncludedColumns = ResolverUtils.resolve(spark, includedColumns, dfColumnNames) @@ -177,8 +188,8 @@ private[actions] abstract class CreateActionBase(dataManager: IndexDataManager) // 2. If source data is partitioned, all partitioning key(s) are added to index schema // as columns if they are not already part of the schema. val partitionColumns = relation.partitionSchema.map(_.name) - val missingPartitionColumns = partitionColumns.filter( - ResolverUtils.resolve(spark, _, columnsFromIndexConfig).isEmpty) + val missingPartitionColumns = + partitionColumns.filter(ResolverUtils.resolve(spark, _, columnsFromIndexConfig).isEmpty) val allIndexColumns = columnsFromIndexConfig ++ missingPartitionColumns // File id value in DATA_FILE_ID_COLUMN column (lineage column) is stored as a @@ -202,10 +213,16 @@ private[actions] abstract class CreateActionBase(dataManager: IndexDataManager) .select( allIndexColumns.head, allIndexColumns.tail :+ IndexConstants.DATA_FILE_NAME_ID: _*) + .toDF( + SchemaUtils.escapeFieldNames(allIndexColumns) :+ IndexConstants.DATA_FILE_NAME_ID: _*) } else { df.select(columnsFromIndexConfig.head, columnsFromIndexConfig.tail: _*) + .toDF(SchemaUtils.escapeFieldNames(columnsFromIndexConfig): _*) } - (indexDF, resolvedIndexedColumns, resolvedIncludedColumns) + val escapedIndexedColumns = SchemaUtils.escapeFieldNames(resolvedIndexedColumns) + val escapedIncludedColumns = SchemaUtils.escapeFieldNames(resolvedIncludedColumns) + + (indexDF, escapedIndexedColumns, escapedIncludedColumns) } } diff --git a/src/main/scala/com/microsoft/hyperspace/index/IndexConstants.scala b/src/main/scala/com/microsoft/hyperspace/index/IndexConstants.scala index e4e930358..39ec07850 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/IndexConstants.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/IndexConstants.scala @@ -109,4 +109,7 @@ object IndexConstants { // To provide multiple paths in the globbing pattern, separate them with commas, e.g. // "/temp/1/*, /temp/2/*" val GLOBBING_PATTERN_KEY = "spark.hyperspace.source.globbingPattern" + + // Indicate whether the index has been built over a nested field. + private[hyperspace] val USES_NESTED_FIELDS_PROPERTY = "hasNestedFields" } diff --git a/src/main/scala/com/microsoft/hyperspace/index/IndexLogEntry.scala b/src/main/scala/com/microsoft/hyperspace/index/IndexLogEntry.scala index 58817e1e6..b4f391e15 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/IndexLogEntry.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/IndexLogEntry.scala @@ -557,6 +557,11 @@ case class IndexLogEntry( config.hashCode + signature.hashCode + numBuckets.hashCode + content.hashCode } + def usesNestedFields: Boolean = { + derivedDataset.properties.properties.getOrElse( + IndexConstants.USES_NESTED_FIELDS_PROPERTY, "false").toBoolean + } + /** * A mutable map for holding auxiliary information of this index log entry while applying rules. */ diff --git a/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala b/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala index 23b1f5838..f1801e66d 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala @@ -16,11 +16,14 @@ package com.microsoft.hyperspace.index.rules +import scala.util.Try + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.CleanupAliases -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GetStructField} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{DataType, StructType} import com.microsoft.hyperspace.{ActiveSparkSession, Hyperspace} import com.microsoft.hyperspace.actions.Constants @@ -28,7 +31,7 @@ import com.microsoft.hyperspace.index.IndexLogEntry import com.microsoft.hyperspace.index.rankers.FilterIndexRanker import com.microsoft.hyperspace.index.sources.FileBasedRelation import com.microsoft.hyperspace.telemetry.{AppInfo, HyperspaceEventLogging, HyperspaceIndexUsageEvent} -import com.microsoft.hyperspace.util.{HyperspaceConf, ResolverUtils} +import com.microsoft.hyperspace.util.{HyperspaceConf, ResolverUtils, SchemaUtils} /** * FilterIndex rule looks for opportunities in a logical plan to replace @@ -113,8 +116,8 @@ object FilterIndexRule val candidateIndexes = allIndexes.filter { index => indexCoversPlan( - outputColumns, - filterColumns, + SchemaUtils.escapeFieldNames(outputColumns), + SchemaUtils.escapeFieldNames(filterColumns), index.indexedColumns, index.includedColumns) } @@ -168,9 +171,17 @@ object ExtractFilterNode { val projectColumnNames = CleanupAliases(project) .asInstanceOf[Project] .projectList - .map(_.references.map(_.asInstanceOf[AttributeReference].name)) + .map(extractNamesFromExpression) .flatMap(_.toSeq) - val filterColumnNames = condition.references.map(_.name).toSeq + val filterColumnNames = extractNamesFromExpression(condition).toSeq + .sortBy(-_.length) + .foldLeft(Seq.empty[String]) { (acc, e) => + if (!acc.exists(i => i.startsWith(e))) { + acc :+ e + } else { + acc + } + } Some(project, filter, projectColumnNames, filterColumnNames) @@ -183,6 +194,96 @@ object ExtractFilterNode { case _ => None // plan does not match with any of filter index rule patterns } + + def extractNamesFromExpression(exp: Expression): Set[String] = { + exp match { + case AttributeReference(name, _, _, _) => + Set(s"$name") + case otherExp => + otherExp.containsChild.map { + case g: GetStructField => + s"${getChildNameFromStruct(g)}" + case e: Expression => + extractNamesFromExpression(e).filter(_.nonEmpty).mkString(".") + case _ => "" + } + } + } + + def getChildNameFromStruct(field: GetStructField): String = { + field.child match { + case f: GetStructField => + s"${getChildNameFromStruct(f)}.${field.name.get}" + case a: AttributeReference => + s"${a.name}.${field.name.get}" + case _ => + s"${field.name.get}" + } + } + + def extractSearchQuery(exp: Expression, name: String): (Expression, Expression) = { + val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX) + val expFound = exp.find { + case a: AttributeReference if splits.forall(s => a.name.contains(s)) => true + case f: GetStructField if splits.forall(s => f.toString().contains(s)) => true + case _ => false + }.get + val parent = exp.find { + case e: Expression if e.containsChild.contains(expFound) => true + case _ => false + }.get + (parent, expFound) + } + + def replaceInSearchQuery( + parent: Expression, + needle: Expression, + repl: Expression): Expression = { + parent.mapChildren { c => + if (c == needle) { + repl + } else { + c + } + } + } + + def extractAttributeRef(exp: Expression, name: String): AttributeReference = { + val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX) + val elem = exp.find { + case a: AttributeReference if splits.contains(a.name) => true + case _ => false + } + elem.get.asInstanceOf[AttributeReference] + } + + def extractTypeFromExpression(exp: Expression, name: String): DataType = { + val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX) + val elem = exp.flatMap { + case a: AttributeReference => + if (splits.forall(s => a.name == s)) { + Some((name, a.dataType)) + } else { + Try({ + val h :: t = splits.toList + if (a.name == h && a.dataType.isInstanceOf[StructType]) { + val currentDataType = a.dataType.asInstanceOf[StructType] + val foldedFields = t.foldLeft(Seq.empty[(String, DataType)]) { (acc, i) => + val idx = currentDataType.indexWhere(_.name.equalsIgnoreCase(i)) + acc :+ (i, currentDataType(idx).dataType) + } + Some(foldedFields.last) + } else { + None + } + }).getOrElse(None) + } + case f: GetStructField if splits.forall(s => f.toString().contains(s)) => + Some((name, f.dataType)) + case _ => None + } + elem.find(e => e._1 == name || e._1 == splits.last).get._2 + } } object ExtractRelation extends ActiveSparkSession { diff --git a/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala b/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala index aa95805a5..e20becbdd 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, In, Literal, Not} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ExprId, GetStructField, In, Literal, Not} import org.apache.spark.sql.catalyst.optimizer.OptimizeIn import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources._ @@ -34,6 +34,7 @@ import com.microsoft.hyperspace.index.IndexLogEntryTags.{HYBRIDSCAN_RELATED_CONF import com.microsoft.hyperspace.index.plans.logical.{BucketUnion, IndexHadoopFsRelation} import com.microsoft.hyperspace.index.sources.FileBasedRelation import com.microsoft.hyperspace.util.HyperspaceConf +import com.microsoft.hyperspace.util.SchemaUtils object RuleUtils { @@ -286,10 +287,30 @@ object RuleUtils { new ParquetFileFormat, Map(IndexConstants.INDEX_RELATION_IDENTIFIER))(spark, index) - val updatedOutput = relation.plan.output - .filter(attr => indexFsRelation.schema.fieldNames.contains(attr.name)) - .map(_.asInstanceOf[AttributeReference]) + val flatSchema = SchemaUtils.escapeFieldNames(SchemaUtils.flatten(relation.plan.schema)) + val updatedOutput = + if (SchemaUtils.hasNestedFields(SchemaUtils.unescapeFieldNames(flatSchema))) { + indexFsRelation.schema.flatMap { s => + val exprId = getFieldPosition(index, s.name) + relation.plan.output.find(a => s.name.contains(a.name)).map { a => + AttributeReference(s.name, s.dataType, a.nullable, a.metadata)( + ExprId(exprId), + a.qualifier) + } + } + } else { + relation.plan.output + .filter(attr => indexFsRelation.schema.fieldNames.contains(attr.name)) + .map(_.asInstanceOf[AttributeReference]) + } + relation.createLogicalRelation(indexFsRelation, updatedOutput) + + case p: Project if provider.isSupportedProject(p) => + transformProject(p, index) + + case f: Filter if provider.isSupportedFilter(f) => + transformFilter(f, index) } } @@ -353,7 +374,7 @@ object RuleUtils { val filesToRead = { if (useBucketSpec || !index.hasParquetAsSourceFormat || filesDeleted.nonEmpty || - relation.partitionSchema.nonEmpty) { + relation.partitionSchema.nonEmpty || index.usesNestedFields) { // Since the index data is in "parquet" format, we cannot read source files // in formats other than "parquet" using one FileScan node as the operator requires // files in one homogenous format. To address this, we need to read the appended @@ -377,9 +398,10 @@ object RuleUtils { // In order to handle deleted files, read index data with the lineage column so that // we could inject Filter-Not-In conditions on the lineage column to exclude the indexed // rows from the deleted files. + val flatSchema = SchemaUtils.escapeFieldNames(SchemaUtils.flatten(relation.plan.schema)) val newSchema = StructType( index.schema.filter(s => - relation.plan.schema.contains(s) || (filesDeleted.nonEmpty && s.name.equals( + flatSchema.contains(s.name) || (filesDeleted.nonEmpty && s.name.equals( IndexConstants.DATA_FILE_NAME_ID)))) def fileIndex: InMemoryFileIndex = { @@ -400,9 +422,21 @@ object RuleUtils { new ParquetFileFormat, Map(IndexConstants.INDEX_RELATION_IDENTIFIER))(spark, index) - val updatedOutput = relation.plan.output - .filter(attr => indexFsRelation.schema.fieldNames.contains(attr.name)) - .map(_.asInstanceOf[AttributeReference]) + val updatedOutput = + if (SchemaUtils.hasNestedFields(SchemaUtils.unescapeFieldNames(flatSchema))) { + indexFsRelation.schema.flatMap { s => + val exprId = getFieldPosition(index, s.name) + relation.plan.output.find(a => s.name.contains(a.name)).map { a => + AttributeReference(s.name, s.dataType, a.nullable, a.metadata)( + ExprId(exprId), + a.qualifier) + } + } + } else { + relation.plan.output + .filter(attr => indexFsRelation.schema.fieldNames.contains(attr.name)) + .map(_.asInstanceOf[AttributeReference]) + } if (filesDeleted.isEmpty) { relation.createLogicalRelation(indexFsRelation, updatedOutput) @@ -414,6 +448,13 @@ object RuleUtils { val filterForDeleted = Filter(Not(In(lineageAttr, deletedFileIds)), rel) Project(updatedOutput, OptimizeIn(filterForDeleted)) } + + case p: Project if provider.isSupportedProject(p) => + transformProject(p, index) + + case f: Filter if provider.isSupportedFilter(f) => + transformFilter(f, index) + } if (unhandledAppendedFiles.nonEmpty) { @@ -487,11 +528,14 @@ object RuleUtils { // Set the same output schema with the index plan to merge them using BucketUnion. // Include partition columns for data loading. val partitionColumns = relation.partitionSchema.map(_.name) - val updatedSchema = StructType(relation.plan.schema.filter(col => - index.schema.contains(col) || relation.partitionSchema.contains(col))) + val updatedSchema = StructType( + relation.plan.schema.filter(col => + index.schema.fieldNames.exists(n => n.contains(col.name)) || + relation.partitionSchema.contains(col))) val updatedOutput = relation.plan.output .filter(attr => - index.schema.fieldNames.contains(attr.name) || partitionColumns.contains(attr.name)) + index.schema.fieldNames.exists(n => n.contains(attr.name)) || + partitionColumns.contains(attr.name)) .map(_.asInstanceOf[AttributeReference]) val newRelation = relation.createHadoopFsRelation( newLocation, @@ -576,4 +620,56 @@ object RuleUtils { assert(shuffleInjected) shuffled } + + private def transformProject(project: Project, index: IndexLogEntry): Project = { + val projectedFields = project.projectList.map { exp => + val fieldName = ExtractFilterNode.extractNamesFromExpression(exp).head + val escapedFieldName = SchemaUtils.escapeFieldName(fieldName) + val attr = ExtractFilterNode.extractAttributeRef(exp, fieldName) + val fieldType = ExtractFilterNode.extractTypeFromExpression(exp, fieldName) + val exprId = getFieldPosition(index, escapedFieldName) + attr.copy(escapedFieldName, fieldType, attr.nullable, attr.metadata)( + ExprId(exprId), + attr.qualifier) + } + project.copy(projectList = projectedFields) + } + + private def transformFilter(filter: Filter, index: IndexLogEntry): Filter = { + val fieldNames = ExtractFilterNode.extractNamesFromExpression(filter.condition) + var mutableFilter = filter + fieldNames.foreach { fieldName => + val escapedFieldName = SchemaUtils.escapeFieldName(fieldName) + val nestedFields = getNestedFields(index) + if (nestedFields.nonEmpty && + nestedFields.exists(i => i.equalsIgnoreCase(escapedFieldName))) { + val (parentExpresion, exp) = + ExtractFilterNode.extractSearchQuery(filter.condition, fieldName) + val fieldType = ExtractFilterNode.extractTypeFromExpression(exp, fieldName) + val attr = ExtractFilterNode.extractAttributeRef(exp, fieldName) + val exprId = getFieldPosition(index, escapedFieldName) + val newAttr = attr.copy(escapedFieldName, fieldType, attr.nullable, attr.metadata)( + ExprId(exprId), + attr.qualifier) + val newExp = exp match { + case _: GetStructField => newAttr + case other: Expression => other + } + val newParentExpression = + ExtractFilterNode.replaceInSearchQuery(parentExpresion, exp, newExp) + mutableFilter = filter.copy(condition = newParentExpression) + } else { + filter + } + } + mutableFilter + } + + private def getNestedFields(index: IndexLogEntry): Seq[String] = { + index.schema.fieldNames.filter(_.contains(SchemaUtils.NESTED_FIELD_REPLACEMENT)) + } + + private def getFieldPosition(index: IndexLogEntry, fieldName: String): Int = { + index.schema.fieldNames.indexWhere(_.equalsIgnoreCase(fieldName)) + } } diff --git a/src/main/scala/com/microsoft/hyperspace/index/sources/FileBasedSourceProviderManager.scala b/src/main/scala/com/microsoft/hyperspace/index/sources/FileBasedSourceProviderManager.scala index fc64537a3..a9cee8a0e 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/sources/FileBasedSourceProviderManager.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/sources/FileBasedSourceProviderManager.scala @@ -19,12 +19,13 @@ package com.microsoft.hyperspace.index.sources import scala.util.{Success, Try} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.util.hyperspace.Utils import com.microsoft.hyperspace.HyperspaceException import com.microsoft.hyperspace.index.Relation -import com.microsoft.hyperspace.util.{CacheWithTransform, HyperspaceConf} +import com.microsoft.hyperspace.index.rules.ExtractFilterNode +import com.microsoft.hyperspace.util.{CacheWithTransform, HyperspaceConf, SchemaUtils} /** * [[FileBasedSourceProviderManager]] is responsible for loading source providers which implements @@ -90,6 +91,41 @@ class FileBasedSourceProviderManager(spark: SparkSession) { run(p => p.getRelation(plan)) } + /** + * Returns true if the given project is a supported project. If all of the registered + * providers return None, this returns false. + * + * @param project Project to check if it's supported. + * @return True if the given project is a supported relation. + */ + def isSupportedProject(project: Project): Boolean = { + val containsNestedFields = SchemaUtils.hasNestedFields( + project.projectList.flatMap(ExtractFilterNode.extractNamesFromExpression)) + var containsNestedChildren = false + project.child.foreach { + case f: Filter => + containsNestedChildren = containsNestedChildren || { + SchemaUtils.hasNestedFields(SchemaUtils.unescapeFieldNames( + ExtractFilterNode.extractNamesFromExpression(f.condition).toSeq)) + } + case _ => + } + containsNestedFields || containsNestedChildren + } + + /** + * Returns true if the given filter is a supported filter. If all of the registered + * providers return None, this returns false. + * + * @param filter Filter to check if it's supported. + * @return True if the given project is a supported relation. + */ + def isSupportedFilter(filter: Filter): Boolean = { + val containsNestedFields = SchemaUtils.hasNestedFields( + ExtractFilterNode.extractNamesFromExpression(filter.condition).toSeq) + containsNestedFields + } + /** * Runs the given function 'f', which executes a [[FileBasedSourceProvider]]'s API that returns * [[Option]] for each provider built. This function ensures that only one provider returns diff --git a/src/main/scala/com/microsoft/hyperspace/util/SchemaUtils.scala b/src/main/scala/com/microsoft/hyperspace/util/SchemaUtils.scala new file mode 100644 index 000000000..cba0d5a57 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/util/SchemaUtils.scala @@ -0,0 +1,60 @@ +/* + * Copyright (2020) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.util + +import org.apache.spark.sql.types.{ArrayType, StructField, StructType} + +object SchemaUtils { + + val NESTED_FIELD_NEEDLE_REGEX = "\\." + val NESTED_FIELD_REPLACEMENT = "__" + + def flatten(structFields: Seq[StructField], prefix: Option[String] = None): Seq[String] = { + structFields.flatMap { + case StructField(name, StructType(fields), _, _) => + flatten(fields, Some(prefix.map(o => s"$o.$name").getOrElse(name))) + case StructField(name, ArrayType(StructType(fields), _), _, _) => + flatten(fields, Some(prefix.map(o => s"$o.$name").getOrElse(name))) + case other => + Seq(prefix.map(o => s"$o.${other.name}").getOrElse(other.name)) + } + } + + def escapeFieldNames(fields: Seq[String]): Seq[String] = { + fields.map(escapeFieldName) + } + + def escapeFieldName(field: String): String = { + field.replaceAll(NESTED_FIELD_NEEDLE_REGEX, NESTED_FIELD_REPLACEMENT) + } + + def unescapeFieldNames(fields: Seq[String]): Seq[String] = { + fields.map(unescapeFieldName) + } + + def unescapeFieldName(field: String): String = { + field.replaceAll(NESTED_FIELD_REPLACEMENT, NESTED_FIELD_NEEDLE_REGEX) + } + + def hasNestedFields(fields: Seq[String]): Boolean = { + fields.exists(isNestedField) + } + + def isNestedField(field: String): Boolean = { + NESTED_FIELD_NEEDLE_REGEX.r.findFirstIn(field).isDefined + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/SampleNestedData.scala b/src/test/scala/com/microsoft/hyperspace/SampleNestedData.scala new file mode 100644 index 000000000..8c5201139 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/SampleNestedData.scala @@ -0,0 +1,66 @@ +/* + * Copyright (2020) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace + +import org.apache.spark.sql.SparkSession + +/** + * Sample data for testing. + */ +object SampleNestedData { + + val testData = Seq( + ("2017-09-03", "810a20a2baa24ff3ad493bfbf064569a", "donde", 2, 1000, + SampleNestedDataStruct("id1", SampleNestedDataLeaf("leaf_id1", 1))), + ("2017-09-03", "fd093f8a05604515957083e70cb3dceb", "facebook", 1, 3000, + SampleNestedDataStruct("id1", SampleNestedDataLeaf("leaf_id1", 2))), + ("2017-09-03", "af3ed6a197a8447cba8bc8ea21fad208", "facebook", 1, 3000, + SampleNestedDataStruct("id2", SampleNestedDataLeaf("leaf_id7", 1))), + ("2017-09-03", "975134eca06c4711a0406d0464cbe7d6", "facebook", 1, 4000, + SampleNestedDataStruct("id2", SampleNestedDataLeaf("leaf_id7", 2))), + ("2018-09-03", "e90a6028e15b4f4593eef557daf5166d", "ibraco", 2, 3000, + SampleNestedDataStruct("id2", SampleNestedDataLeaf("leaf_id7", 5))), + ("2018-09-03", "576ed96b0d5340aa98a47de15c9f87ce", "facebook", 2, 3000, + SampleNestedDataStruct("id2", SampleNestedDataLeaf("leaf_id9", 1))), + ("2018-09-03", "50d690516ca641438166049a6303650c", "ibraco", 2, 1000, + SampleNestedDataStruct("id3", SampleNestedDataLeaf("leaf_id9", 10))), + ("2019-10-03", "380786e6495d4cd8a5dd4cc8d3d12917", "facebook", 2, 3000, + SampleNestedDataStruct("id4", SampleNestedDataLeaf("leaf_id9", 12))), + ("2019-10-03", "ff60e4838b92421eafc3e6ee59a9e9f1", "miperro", 2, 2000, + SampleNestedDataStruct("id5", SampleNestedDataLeaf("leaf_id9", 21))), + ("2019-10-03", "187696fe0a6a40cc9516bc6e47c70bc1", "facebook", 4, 3000, + SampleNestedDataStruct("id6", SampleNestedDataLeaf("leaf_id9", 22)))) + + def save( + spark: SparkSession, + path: String, + columns: Seq[String], + partitionColumns: Option[Seq[String]] = None): Unit = { + val df = spark.createDataFrame( + spark.sparkContext.parallelize(testData) + ).toDF(columns: _*) + partitionColumns match { + case Some(pcs) => + df.write.partitionBy(pcs: _*).parquet(path) + case None => + df.write.parquet(path) + } + } +} + +case class SampleNestedDataStruct(id: String, leaf: SampleNestedDataLeaf) +case class SampleNestedDataLeaf(id: String, cnt: Int) diff --git a/src/test/scala/com/microsoft/hyperspace/index/CreateIndexNestedTest.scala b/src/test/scala/com/microsoft/hyperspace/index/CreateIndexNestedTest.scala new file mode 100644 index 000000000..f4f215098 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/CreateIndexNestedTest.scala @@ -0,0 +1,196 @@ +/* + * Copyright (2020) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index + +import scala.collection.mutable.WrappedArray + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.functions._ + +import com.microsoft.hyperspace.{Hyperspace, HyperspaceException, SampleNestedData} +import com.microsoft.hyperspace.util.{FileUtils, SchemaUtils} + +class CreateIndexNestedTest extends HyperspaceSuite with SQLHelper { + override val systemPath = new Path("src/test/resources/indexLocation") + private val testDir = "src/test/resources/createIndexTests/" + private val nonPartitionedDataPath = testDir + "samplenestedparquet" + private val partitionedDataPath = testDir + "samplenestedpartitionedparquet" + private val partitionKeys = Seq("Date", "Query") + private val indexConfig1 = + IndexConfig("index1", Seq("nested.leaf.id"), Seq("Date", "nested.leaf.cnt")) + private val indexConfig2 = IndexConfig("index3", Seq("nested.leaf.id"), Seq("nested.leaf.cnt")) + private var nonPartitionedDataDF: DataFrame = _ + private var partitionedDataDF: DataFrame = _ + private var hyperspace: Hyperspace = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + hyperspace = new Hyperspace(spark) + FileUtils.delete(new Path(testDir), isRecursive = true) + + val dataColumns = Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested") + // save test data non-partitioned. + SampleNestedData.save(spark, nonPartitionedDataPath, dataColumns) + nonPartitionedDataDF = spark.read.parquet(nonPartitionedDataPath) + + // save test data partitioned. + SampleNestedData.save(spark, partitionedDataPath, dataColumns, Some(partitionKeys)) + partitionedDataDF = spark.read.parquet(partitionedDataPath) + } + + override def afterAll(): Unit = { + FileUtils.delete(new Path(testDir), isRecursive = true) + super.afterAll() + } + + after { + FileUtils.delete(systemPath) + } + + test("Index creation with nested indexed and included columns") { + hyperspace.createIndex(nonPartitionedDataDF, indexConfig1) + assert(hyperspace.indexes.where(s"name = 'index1' ").count == 1) + assert(hyperspace.indexes.where( + array_contains(col("indexedColumns"), "nested__leaf__id")).count == 1) + assert(hyperspace.indexes.where( + array_contains(col("includedColumns"), "nested__leaf__cnt")).count == 1) + val colTypes = hyperspace.indexes.select("schema") + .collect().map(r => r.getString(0)).head + assert(colTypes.contains("nested__leaf__id")) + assert(colTypes.contains("nested__leaf__cnt")) + } + + test("Index creation passes with columns of different case if case-sensitivity is false.") { + hyperspace.createIndex( + nonPartitionedDataDF, + IndexConfig("index1", Seq("Nested.leaF.id"), Seq("nested.leaf.CNT"))) + val indexes = hyperspace.indexes.where(s"name = 'index1' ") + assert(indexes.count == 1) + assert( + indexes.head.getAs[WrappedArray[String]]("indexedColumns").head == "nested__leaf__id", + "Indexed columns with wrong case are stored in metadata") + assert( + indexes.head.getAs[WrappedArray[String]]("includedColumns").head == "nested__leaf__cnt", + "Included columns with wrong case are stored in metadata") + } + + test("Index creation fails with columns of different case if case-sensitivity is true.") { + withSQLConf("spark.sql.caseSensitive" -> "true") { + val exception = intercept[HyperspaceException] { + hyperspace.createIndex( + nonPartitionedDataDF, + IndexConfig("index1", Seq("Nested.leaF.id"), Seq("nested.leaf.CNT"))) + } + assert(exception.getMessage.contains("Index config is not applicable to dataframe schema.")) + } + } + + test("Index creation fails since the dataframe has a filter node.") { + val dfWithFilter = nonPartitionedDataDF.filter("nested.leaf.id='leaf_id1'") + val exception = intercept[HyperspaceException] { + hyperspace.createIndex(dfWithFilter, indexConfig1) + } + assert( + exception.getMessage.contains( + "Only creating index over HDFS file based scan nodes is supported.")) + } + + test("Index creation fails since the dataframe has a projection node.") { + val dfWithSelect = nonPartitionedDataDF.select("nested.leaf.id") + val exception = intercept[HyperspaceException] { + hyperspace.createIndex(dfWithSelect, indexConfig1) + } + assert( + exception.getMessage.contains( + "Only creating index over HDFS file based scan nodes is supported.")) + } + + test("Index creation fails since the dataframe has a join node.") { + val dfJoin = nonPartitionedDataDF + .join(nonPartitionedDataDF, nonPartitionedDataDF("Query") === nonPartitionedDataDF("Query")) + .select( + nonPartitionedDataDF("RGUID"), + nonPartitionedDataDF("Query"), + nonPartitionedDataDF("nested.leaf.cnt")) + val exception = intercept[HyperspaceException] { + hyperspace.createIndex(dfJoin, indexConfig1) + } + assert( + exception.getMessage.contains( + "Only creating index over HDFS file based scan nodes is supported.")) + } + + test("Check lineage in index records for partitioned data when partition key is not in config.") { + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + hyperspace.createIndex(partitionedDataDF, indexConfig2) + val indexRecordsDF = spark.read.parquet( + s"$systemPath/${indexConfig2.indexName}/${IndexConstants.INDEX_VERSION_DIRECTORY_PREFIX}=0") + + // For partitioned data, beside file name lineage column all partition keys columns + // should be added to index schema if they are not already among index config columns. + assert( + indexRecordsDF.schema.fieldNames.sorted === + (SchemaUtils.escapeFieldNames(indexConfig2.indexedColumns ++ + indexConfig2.includedColumns) ++ + Seq(IndexConstants.DATA_FILE_NAME_ID) ++ partitionKeys).sorted) + } + } + + test("Check lineage in index records for non-partitioned data.") { + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + hyperspace.createIndex(nonPartitionedDataDF, indexConfig1) + val indexRecordsDF = spark.read.parquet( + s"$systemPath/${indexConfig1.indexName}/${IndexConstants.INDEX_VERSION_DIRECTORY_PREFIX}=0") + + // For non-partitioned data, only file name lineage column should be added to index schema. + assert( + indexRecordsDF.schema.fieldNames.sorted === + (SchemaUtils.escapeFieldNames(indexConfig1.indexedColumns ++ + indexConfig1.includedColumns) ++ + Seq(IndexConstants.DATA_FILE_NAME_ID)).sorted) + } + } + + test("Verify content of lineage column.") { + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + val dataPath = new Path(nonPartitionedDataPath, "*parquet") + val dataFilesCount = dataPath + .getFileSystem(new Configuration) + .globStatus(dataPath) + .length + .toLong + + // File ids are assigned incrementally starting from 0. + val lineageRange = 0L to dataFilesCount + + hyperspace.createIndex(nonPartitionedDataDF, indexConfig1) + val indexRecordsDF = spark.read.parquet( + s"$systemPath/${indexConfig1.indexName}/${IndexConstants.INDEX_VERSION_DIRECTORY_PREFIX}=0") + val lineageValues = indexRecordsDF + .select(IndexConstants.DATA_FILE_NAME_ID) + .distinct() + .collect() + .map(r => r.getLong(0)) + + lineageValues.forall(lineageRange.contains(_)) + } + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/HybridScanForNestedFieldsTest.scala b/src/test/scala/com/microsoft/hyperspace/index/HybridScanForNestedFieldsTest.scala new file mode 100644 index 000000000..2d6ba6c16 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/HybridScanForNestedFieldsTest.scala @@ -0,0 +1,447 @@ +/* + * Copyright (2020) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.expressions.{Attribute, EqualTo, In, InSet, Literal, Not} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, RepartitionByExpression, Union} +import org.apache.spark.sql.execution.{FileSourceScanExec, ProjectExec} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.internal.SQLConf + +import com.microsoft.hyperspace._ +import com.microsoft.hyperspace.{Hyperspace, SampleNestedData, TestConfig} +import com.microsoft.hyperspace.TestUtils.{latestIndexLogEntry, logManager} +import com.microsoft.hyperspace.index.execution.BucketUnionExec +import com.microsoft.hyperspace.index.plans.logical.BucketUnion +import com.microsoft.hyperspace.util.FileUtils + +// Hybrid Scan tests for non partitioned source data. Test cases of HybridScanSuite are also +// executed with non partitioned source data. +class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { + override val systemPath = new Path("src/test/resources/hybridScanTestNestedFields") + + val sampleNestedData = SampleNestedData.testData + var hyperspace: Hyperspace = _ + + val fileFormat = "parquet" + + import spark.implicits._ + val nestedDf = sampleNestedData.toDF("Date", "RGUID", "Query", "imprs", "clicks", "nested") + val indexConfig1 = + IndexConfig("index1", Seq("nested.leaf.cnt"), Seq("query", "nested.leaf.id")) + val indexConfig2 = + IndexConfig("index2", Seq("nested.leaf.cnt"), Seq("Date", "nested.leaf.id")) + + def normalizePaths(in: Seq[String]): Seq[String] = { + in.map(_.replace("file:///", "file:/")) + } + def equalNormalizedPaths(a: Seq[String], b: Seq[String]): Boolean = { + normalizePaths(a).toSet === normalizePaths(b).toSet + } + + def setupIndexAndChangeData( + sourceFileFormat: String, + sourcePath: String, + indexConfig: IndexConfig, + appendCnt: Int, + deleteCnt: Int): (Seq[String], Seq[String]) = { + nestedDf.write.format(sourceFileFormat).save(sourcePath) + val df = spark.read.format(sourceFileFormat).load(sourcePath) + hyperspace.createIndex(df, indexConfig) + val inputFiles = df.inputFiles + assert(appendCnt + deleteCnt < inputFiles.length) + + val fs = systemPath.getFileSystem(new Configuration) + for (i <- 0 until appendCnt) { + val sourcePath = new Path(inputFiles(i)) + val destPath = new Path(inputFiles(i) + ".copy") + fs.copyToLocalFile(sourcePath, destPath) + } + + for (i <- 1 to deleteCnt) { + fs.delete(new Path(inputFiles(inputFiles.length - i)), false) + } + + val df2 = spark.read.format(sourceFileFormat).load(sourcePath) + (df2.inputFiles diff inputFiles, inputFiles diff df2.inputFiles) + } + + override def beforeAll(): Unit = { + super.beforeAll() + hyperspace = new Hyperspace(spark) + } + + before { + spark.conf.set(IndexConstants.INDEX_LINEAGE_ENABLED, "true") + spark.enableHyperspace() + } + + after { + FileUtils.delete(systemPath) + spark.disableHyperspace() + } + + private def getLatestStableLog(indexName: String): IndexLogEntry = { + val entry = logManager(systemPath, indexName).getLatestStableLog() + assert(entry.isDefined) + assert(entry.get.isInstanceOf[IndexLogEntry]) + entry.get.asInstanceOf[IndexLogEntry] + } + + def checkDeletedFiles( + plan: LogicalPlan, + indexName: String, + expectedDeletedFileNames: Seq[String]): Unit = { + + val fileNameToId = getLatestStableLog(indexName).fileIdTracker.getFileToIdMap.toSeq.map { + kv => + (kv._1._1, kv._2) + }.toMap + + val expectedDeletedFiles = + expectedDeletedFileNames.map(f => fileNameToId(f.replace("file:///", "file:/")).toString) + + if (expectedDeletedFiles.nonEmpty) { + log + val inputFiles = plan.collect { + case LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) => + fsRelation.inputFiles.toSeq + }.flatten + val deletedFilesList = plan collect { + case Filter( + Not(EqualTo(left: Attribute, right: Literal)), + LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)) => + // Check new filter condition on lineage column. + val colName = left.toString + val deletedFile = right.toString + assert(colName.toString.contains(IndexConstants.DATA_FILE_NAME_ID)) + val deleted = Seq(deletedFile) + assert(expectedDeletedFiles.length == 1) + // Check the location is replaced with index data files properly. + val files = fsRelation.location.inputFiles + assert(files.nonEmpty && files.forall(_.contains(indexName))) + deleted + case Filter( + Not(InSet(attr, deletedFileIds)), + LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)) => + // Check new filter condition on lineage column. + assert(attr.toString.contains(IndexConstants.DATA_FILE_NAME_ID)) + val deleted = deletedFileIds.map(_.toString).toSeq + assert( + expectedDeletedFiles.length > spark.conf + .get("spark.sql.optimizer.inSetConversionThreshold") + .toLong) + // Check the location is replaced with index data files properly. + val files = fsRelation.location.inputFiles + assert(files.nonEmpty && files.forall(_.contains(indexName))) + deleted + case Filter( + Not(In(attr, deletedFileIds)), + LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)) => + // Check new filter condition on lineage column. + assert(attr.toString.contains(IndexConstants.DATA_FILE_NAME_ID)) + val deleted = deletedFileIds.map(_.toString) + assert( + expectedDeletedFiles.length <= spark.conf + .get("spark.sql.optimizer.inSetConversionThreshold") + .toLong) + // Check the location is replaced with index data files properly. + val files = fsRelation.location.inputFiles + assert(files.nonEmpty && files.forall(_.contains(indexName))) + deleted + } + assert(deletedFilesList.length === 1) + val deletedFiles = deletedFilesList.flatten + assert(deletedFiles.length === expectedDeletedFiles.size) + assert(deletedFiles.distinct.length === deletedFiles.length) + assert(deletedFiles.forall(f => !inputFiles.contains(f))) + assert(equalNormalizedPaths(deletedFiles, expectedDeletedFiles)) + + val execPlan = spark.sessionState.executePlan(plan).executedPlan + val execNodes = execPlan collect { + case p @ FileSourceScanExec(_, _, _, _, _, dataFilters, _) => + // Check deleted files. + assert(deletedFiles.forall(dataFilters.toString.contains)) + p + } + assert(execNodes.length === 1) + } + } + + def checkJoinIndexHybridScan( + plan: LogicalPlan, + leftIndexName: String, + leftAppended: Seq[String], + leftDeleted: Seq[String], + rightIndexName: String, + rightAppended: Seq[String], + rightDeleted: Seq[String], + filterConditions: Seq[String] = Nil): Unit = { + // Project - Join - children + val left = plan.children.head.children.head + val right = plan.children.head.children.last + + // Check deleted files with the first child of each left and right child. + checkDeletedFiles(left.children.head, leftIndexName, leftDeleted) + checkDeletedFiles(right.children.head, rightIndexName, rightDeleted) + + val leftNodes = left.collect { + case b @ BucketUnion(children, bucketSpec) => + assert(bucketSpec.numBuckets === 200) + assert( + bucketSpec.bucketColumnNames.size === 1 && + bucketSpec.bucketColumnNames.head === "clicks") + + val childNodes = children.collect { + case r @ RepartitionByExpression( + attrs, + Project(_, Filter(_, LogicalRelation(fsRelation: HadoopFsRelation, _, _, _))), + numBucket) => + assert(attrs.size === 1) + assert(attrs.head.asInstanceOf[Attribute].name.contains("clicks")) + + // Check appended file. + val files = fsRelation.location.inputFiles + assert(equalNormalizedPaths(files, leftAppended)) + assert(files.length === leftAppended.length) + assert(numBucket === 200) + r + case p @ Project(_, Filter(_, _)) => + val files = p collect { + case LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) => + fsRelation.location.inputFiles + } + assert(files.nonEmpty && files.flatten.forall(_.contains(leftIndexName))) + p + } + + // BucketUnion has 2 children. + assert(childNodes.size === 2) + assert(childNodes.count(_.isInstanceOf[Project]) === 1) + assert(childNodes.count(_.isInstanceOf[RepartitionByExpression]) === 1) + b + } + + val rightNodes = right.collect { + case b @ BucketUnion(children, bucketSpec) => + assert(bucketSpec.numBuckets === 200) + assert( + bucketSpec.bucketColumnNames.size === 1 && + bucketSpec.bucketColumnNames.head === "clicks") + + val childNodes = children.collect { + case r @ RepartitionByExpression( + attrs, + Project(_, Filter(_, LogicalRelation(fsRelation: HadoopFsRelation, _, _, _))), + numBucket) => + assert(attrs.size === 1) + assert(attrs.head.asInstanceOf[Attribute].name.contains("clicks")) + + // Check appended files. + val files = fsRelation.location.inputFiles + assert(equalNormalizedPaths(files, rightAppended)) + assert(files.length === rightAppended.length) + assert(numBucket === 200) + r + case p @ Project( + _, + Filter(_, LogicalRelation(fsRelation: HadoopFsRelation, _, _, _))) => + // Check index data files. + val files = fsRelation.location.inputFiles + assert(files.nonEmpty && files.forall(_.contains(rightIndexName))) + p + } + + // BucketUnion has 2 children. + assert(childNodes.size === 2) + assert(childNodes.count(_.isInstanceOf[Project]) === 1) + assert(childNodes.count(_.isInstanceOf[RepartitionByExpression]) === 1) + b + } + + // Check BucketUnion node if needed. + assert(leftAppended.isEmpty || leftNodes.count(_.isInstanceOf[BucketUnion]) === 1) + assert(rightAppended.isEmpty || rightNodes.count(_.isInstanceOf[BucketUnion]) === 1) + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val execPlan = spark.sessionState.executePlan(plan).executedPlan + val execNodes = execPlan.collect { + case p @ BucketUnionExec(children, bucketSpec) => + assert(children.size === 2) + // children.head is always the index plan. + assert(children.head.isInstanceOf[ProjectExec]) + assert(children.last.isInstanceOf[ShuffleExchangeExec]) + assert(bucketSpec.numBuckets === 200) + p + case p @ FileSourceScanExec(_, _, _, partitionFilters, _, dataFilters, _) => + // Check filter pushed down properly. + if (partitionFilters.nonEmpty) { + assert(filterConditions.forall(partitionFilters.toString.contains)) + } else { + assert(filterConditions.forall(dataFilters.toString.contains)) + } + p + } + var requiredBucketUnion = 0 + if (leftAppended.nonEmpty) requiredBucketUnion += 1 + if (rightAppended.nonEmpty) requiredBucketUnion += 1 + assert(execNodes.count(_.isInstanceOf[BucketUnionExec]) === requiredBucketUnion) + // 2 of index data and number of indexes with appended files. + assert(execNodes.count(_.isInstanceOf[FileSourceScanExec]) === 2 + requiredBucketUnion) + } + } + + test( + "Append-only: union over index and new files " + + "due to field names being different: `nested__leaf__cnt` + `nested.leaf.cnt`.") { + // This flag is for testing plan transformation if appended files could be load with index + // data scan node. Currently, it's applied for a very specific case: FilterIndexRule, + // Parquet source format, no partitioning, no deleted files. + withTempPathAsString { testPath => + val (appendedFiles, _) = setupIndexAndChangeData( + "parquet", + testPath, + indexConfig1.copy(indexName = "index_Append"), + appendCnt = 1, + deleteCnt = 0) + + val df = spark.read.format("parquet").load(testPath) + def filterQuery: DataFrame = + df.filter(df("nested.leaf.cnt") <= 20).select(df("query")) + + val baseQuery = filterQuery + val basePlan = baseQuery.queryExecution.optimizedPlan + + withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_ENABLED -> "false") { + val filter = filterQuery + assert(basePlan.equals(filter.queryExecution.optimizedPlan)) + } + + withSQLConf(TestConfig.HybridScanEnabledAppendOnly: _*) { + val filter = filterQuery + val planWithHybridScan = filter.queryExecution.optimizedPlan + assert(!basePlan.equals(planWithHybridScan)) + + // Check appended file is added to relation node or not. + val nodes = planWithHybridScan.collect { + case u @ Union(children) => + val indexChild = children.head + indexChild collect { + case LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) => + assert(fsRelation.location.inputFiles.forall(_.contains("index_Append"))) + } + + assert(children.tail.size === 1) + val appendChild = children.last + appendChild collect { + case LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) => + val files = fsRelation.location.inputFiles + assert(files.toSeq == appendedFiles) + assert(files.length === appendedFiles.size) + } + u + } + + // Filter Index and Parquet format source file can be handled with 1 LogicalRelation + assert(nodes.length === 1) + val left = baseQuery.collect().map(_.getString(0)) + val right = filter.collect().map(_.getString(0)) + assert(left.diff(right).isEmpty) + assert(right.diff(left).isEmpty) + } + } + } + + test("Delete-only: Hybrid Scan for delete support doesn't work without lineage column.") { + val indexConfig = IndexConfig("index_ParquetDelete2", Seq("nested.leaf.cnt"), Seq("query")) + Seq(("indexWithoutLineage", "false", false), ("indexWithLineage", "true", true)) foreach { + case (indexName, lineageColumnConfig, transformationExpected) => + withTempPathAsString { testPath => + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> lineageColumnConfig) { + setupIndexAndChangeData( + fileFormat, + testPath, + indexConfig.copy(indexName = indexName), + appendCnt = 0, + deleteCnt = 1) + + val df = spark.read.format(fileFormat).load(testPath) + + def filterQuery: DataFrame = + df.filter(df("nested.leaf.cnt") <= 20).select(df("query")) + + val baseQuery = filterQuery + val basePlan = baseQuery.queryExecution.optimizedPlan + withSQLConf(TestConfig.HybridScanEnabledAppendOnly: _*) { + val filter = filterQuery + assert(basePlan.equals(filter.queryExecution.optimizedPlan)) + } + withSQLConf(TestConfig.HybridScanEnabled: _*) { + val filter = filterQuery + assert( + basePlan + .equals(filter.queryExecution.optimizedPlan) + .equals(!transformationExpected)) + } + } + } + } + } + + test("Delete-only: filter rule, number of delete files threshold.") { + withTempPathAsString { testPath => + val indexName = "IndexDeleteCntTest" + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + setupIndexAndChangeData( + fileFormat, + testPath, + indexConfig1.copy(indexName = indexName), + appendCnt = 0, + deleteCnt = 2) + } + + val df = spark.read.format(fileFormat).load(testPath) + def filterQuery: DataFrame = + df.filter(df("nested.leaf.cnt") <= 20).select(df("query")) + val baseQuery = filterQuery + val basePlan = baseQuery.queryExecution.optimizedPlan + val sourceSize = latestIndexLogEntry(systemPath, indexName).sourceFilesSizeInBytes + + val afterDeleteSize = FileUtils.getDirectorySize(new Path(testPath)) + val deletedRatio = 1 - (afterDeleteSize / sourceSize.toFloat) + + withSQLConf(TestConfig.HybridScanEnabled: _*) { + withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_DELETED_RATIO_THRESHOLD -> + (deletedRatio + 0.1).toString) { + val filter = filterQuery + // As deletedRatio is less than the threshold, the index can be applied. + assert(!basePlan.equals(filter.queryExecution.optimizedPlan)) + } + withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_DELETED_RATIO_THRESHOLD -> + (deletedRatio - 0.1).toString) { + val filter = filterQuery + // As deletedRatio is greater than the threshold, the index shouldn't be applied. + assert(basePlan.equals(filter.queryExecution.optimizedPlan)) + } + } + } + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/RefreshIndexNestedTest.scala b/src/test/scala/com/microsoft/hyperspace/index/RefreshIndexNestedTest.scala new file mode 100644 index 000000000..64ecd1447 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/RefreshIndexNestedTest.scala @@ -0,0 +1,498 @@ +/* + * Copyright (2020) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.sql.{AnalysisException, QueryTest} + +import com.microsoft.hyperspace.{Hyperspace, HyperspaceException, MockEventLogger, SampleNestedData, TestUtils} +import com.microsoft.hyperspace.TestUtils.{getFileIdTracker, logManager} +import com.microsoft.hyperspace.actions.{RefreshIncrementalAction, RefreshQuickAction} +import com.microsoft.hyperspace.index.IndexConstants.REFRESH_MODE_INCREMENTAL +import com.microsoft.hyperspace.telemetry.RefreshIncrementalActionEvent +import com.microsoft.hyperspace.util.{FileUtils, PathUtils} +import com.microsoft.hyperspace.util.PathUtils.DataPathFilter + +/** + * Unit E2E test cases for RefreshIndex. + */ +class RefreshIndexNestedTest extends QueryTest with HyperspaceSuite { + override val systemPath = new Path("src/test/resources/indexLocation") + private val testDir = "src/test/resources/RefreshIndexDeleteTests/" + private val nonPartitionedDataPath = testDir + "nonpartitioned" + private val partitionedDataPath = testDir + "partitioned" + private val indexConfig = IndexConfig("index1", Seq("nested.leaf.id"), Seq("nested.leaf.cnt")) + private var hyperspace: Hyperspace = _ + + override def beforeAll(): Unit = { + super.beforeAll() + hyperspace = new Hyperspace(spark) + FileUtils.delete(new Path(testDir)) + } + + override def afterAll(): Unit = { + FileUtils.delete(new Path(testDir)) + super.afterAll() + } + + after { + FileUtils.delete(new Path(testDir)) + FileUtils.delete(systemPath) + } + + test("Validate incremental refresh index when some file gets deleted from the source data.") { + // Save test data non-partitioned. + SampleNestedData.save( + spark, + nonPartitionedDataPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested")) + val nonPartitionedDataDF = spark.read.parquet(nonPartitionedDataPath) + + // Save test data partitioned. + SampleNestedData.save( + spark, + partitionedDataPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested"), + Some(Seq("Date", "Query"))) + val partitionedDataDF = spark.read.parquet(partitionedDataPath) + + Seq(nonPartitionedDataPath, partitionedDataPath).foreach { loc => + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + withIndex(indexConfig.indexName) { + val dfToIndex = + if (loc.equals(nonPartitionedDataPath)) nonPartitionedDataDF else partitionedDataDF + hyperspace.createIndex(dfToIndex, indexConfig) + + // Delete one source data file. + val deletedFile = if (loc.equals(nonPartitionedDataPath)) { + deleteOneDataFile(nonPartitionedDataPath) + } else { + deleteOneDataFile(partitionedDataPath, true) + } + + // Get deleted file's file id, used as lineage for its records. + val fileId = getFileIdTracker(systemPath, indexConfig).getFileId( + deletedFile.getPath.toString, + deletedFile.getLen, + deletedFile.getModificationTime) + assert(fileId.nonEmpty) + + // Validate only index records whose lineage is the deleted file are removed. + val originalIndexDF = spark.read.parquet(s"$systemPath/${indexConfig.indexName}/" + + s"${IndexConstants.INDEX_VERSION_DIRECTORY_PREFIX}=0") + val originalIndexWithoutDeletedFile = originalIndexDF + .filter(s"""${IndexConstants.DATA_FILE_NAME_ID} != ${fileId.get}""") + + hyperspace.refreshIndex(indexConfig.indexName, REFRESH_MODE_INCREMENTAL) + + val refreshedIndexDF = spark.read.parquet(s"$systemPath/${indexConfig.indexName}/" + + s"${IndexConstants.INDEX_VERSION_DIRECTORY_PREFIX}=1") + + checkAnswer(originalIndexWithoutDeletedFile, refreshedIndexDF) + } + } + } + } + + test( + "Validate incremental refresh index (to handle deletes from the source data) " + + "fails as expected on an index without lineage.") { + SampleNestedData.save( + spark, + nonPartitionedDataPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested")) + val nonPartitionedDataDF = spark.read.parquet(nonPartitionedDataPath) + + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "false") { + hyperspace.createIndex(nonPartitionedDataDF, indexConfig) + + deleteOneDataFile(nonPartitionedDataPath) + + val ex = intercept[HyperspaceException]( + hyperspace.refreshIndex(indexConfig.indexName, REFRESH_MODE_INCREMENTAL)) + assert( + ex.getMessage.contains(s"Index refresh (to handle deleted source data) is " + + "only supported on an index with lineage.")) + } + } + + test( + "Validate incremental refresh index is a no-op if no source data file is deleted or " + + "appended.") { + SampleNestedData.save( + spark, + nonPartitionedDataPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested")) + val nonPartitionedDataDF = spark.read.parquet(nonPartitionedDataPath) + + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + hyperspace.createIndex(nonPartitionedDataDF, indexConfig) + + val latestId = logManager(systemPath, indexConfig.indexName).getLatestId().get + + MockEventLogger.reset() + hyperspace.refreshIndex(indexConfig.indexName, REFRESH_MODE_INCREMENTAL) + // Check that no new log files were created in this operation. + assert(latestId == logManager(systemPath, indexConfig.indexName).getLatestId().get) + + // Check emitted events. + MockEventLogger.emittedEvents match { + case Seq( + RefreshIncrementalActionEvent(_, _, "Operation started."), + RefreshIncrementalActionEvent(_, _, msg)) => + assert(msg.contains("Refresh incremental aborted as no source data change found.")) + case _ => fail() + } + } + } + + test( + "Validate incremental refresh index (to handle deletes from the source data) " + + "fails as expected when all source data files are deleted.") { + Seq(true, false).foreach { deleteDataFolder => + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + SampleNestedData.save( + spark, + nonPartitionedDataPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested")) + val nonPartitionedDataDF = spark.read.parquet(nonPartitionedDataPath) + + hyperspace.createIndex(nonPartitionedDataDF, indexConfig) + + if (deleteDataFolder) { + FileUtils.delete(new Path(nonPartitionedDataPath)) + + val ex = intercept[AnalysisException]( + hyperspace.refreshIndex(indexConfig.indexName, REFRESH_MODE_INCREMENTAL)) + assert(ex.getMessage.contains("Path does not exist")) + + } else { + val dataPath = new Path(nonPartitionedDataPath, "*parquet") + dataPath + .getFileSystem(new Configuration) + .globStatus(dataPath) + .foreach(p => FileUtils.delete(p.getPath)) + + val ex = + intercept[HyperspaceException]( + hyperspace.refreshIndex(indexConfig.indexName, REFRESH_MODE_INCREMENTAL)) + assert(ex.getMessage.contains("Invalid plan for creating an index.")) + } + FileUtils.delete(new Path(nonPartitionedDataPath)) + FileUtils.delete(systemPath) + } + } + } + + test( + "Validate incremental refresh index (to handle deletes from the source data) " + + "works well when file info for an existing source data file changes.") { + SampleNestedData.save( + spark, + nonPartitionedDataPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested")) + val nonPartitionedDataDF = spark.read.parquet(nonPartitionedDataPath) + + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + hyperspace.createIndex(nonPartitionedDataDF, indexConfig) + } + + // Replace a source data file with a new file with same name but different properties. + val deletedFile = deleteOneDataFile(nonPartitionedDataPath) + val sourcePath = new Path(spark.read.parquet(nonPartitionedDataPath).inputFiles.head) + val fs = deletedFile.getPath.getFileSystem(new Configuration) + fs.copyToLocalFile(sourcePath, deletedFile.getPath) + + { + // Check the index log entry before refresh. + val indexLogEntry = getLatestStableLog(indexConfig.indexName) + assert(logManager(systemPath, indexConfig.indexName).getLatestId().get == 1) + assert(getIndexFilesCount(indexLogEntry, version = 0) == indexLogEntry.content.files.size) + } + + val indexPath = PathUtils.makeAbsolute(s"$systemPath/${indexConfig.indexName}") + new RefreshIncrementalAction( + spark, + IndexLogManagerFactoryImpl.create(indexPath), + IndexDataManagerFactoryImpl.create(indexPath)) + .run() + + { + // Check the index log entry after RefreshIncrementalAction. + val indexLogEntry = getLatestStableLog(indexConfig.indexName) + assert(logManager(systemPath, indexConfig.indexName).getLatestId().get == 3) + assert(indexLogEntry.deletedFiles.isEmpty) + assert(indexLogEntry.appendedFiles.isEmpty) + + val files = indexLogEntry.relations.head.data.properties.content.files + assert(files.exists(_.equals(deletedFile.getPath))) + assert( + getIndexFilesCount(indexLogEntry, version = 1) == indexLogEntry.content.fileInfos.size) + } + + // Modify the file again. + val sourcePath2 = new Path(spark.read.parquet(nonPartitionedDataPath).inputFiles.last) + fs.copyToLocalFile(sourcePath2, deletedFile.getPath) + + new RefreshIncrementalAction( + spark, + IndexLogManagerFactoryImpl.create(indexPath), + IndexDataManagerFactoryImpl.create(indexPath)) + .run() + + { + // Check non-empty deletedFiles after RefreshIncrementalAction. + val indexLogEntry = getLatestStableLog(indexConfig.indexName) + assert(indexLogEntry.deletedFiles.isEmpty) + assert(indexLogEntry.appendedFiles.isEmpty) + assert(logManager(systemPath, indexConfig.indexName).getLatestId().get == 5) + val files = indexLogEntry.relations.head.data.properties.content.files + assert(files.exists(_.equals(deletedFile.getPath))) + assert( + getIndexFilesCount(indexLogEntry, version = 2) == indexLogEntry.content.fileInfos.size) + } + } + + test( + "Validate RefreshIncrementalAction updates appended and deleted files in metadata " + + "as expected, when some file gets deleted and some appended to source data.") { + withTempPathAsString { testPath => + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + withIndex(indexConfig.indexName) { + SampleNestedData.save(spark, testPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested")) + val df = spark.read.parquet(testPath) + hyperspace.createIndex(df, indexConfig) + + val oldFiles = listFiles(testPath, getFileIdTracker(systemPath, indexConfig)).toSet + + // Delete one source data file. + deleteOneDataFile(testPath) + + // Add some new data to source. + import spark.implicits._ + SampleNestedData.testData + .take(3) + .toDF("Date", "RGUID", "Query", "imprs", "clicks", "nested") + .write + .mode("append") + .parquet(testPath) + + val indexPath = PathUtils.makeAbsolute(s"$systemPath/${indexConfig.indexName}") + new RefreshIncrementalAction( + spark, + IndexLogManagerFactoryImpl.create(indexPath), + IndexDataManagerFactoryImpl.create(indexPath)) + .run() + + // Verify "appendedFiles" is cleared and "deletedFiles" is updated after refresh. + val indexLogEntry = getLatestStableLog(indexConfig.indexName) + assert(indexLogEntry.appendedFiles.isEmpty) + + val latestFiles = listFiles(testPath, getFileIdTracker(systemPath, indexConfig)).toSet + val indexSourceFiles = indexLogEntry.relations.head.data.properties.content.fileInfos + val expectedDeletedFiles = oldFiles -- latestFiles + val expectedAppendedFiles = latestFiles -- oldFiles + assert(expectedDeletedFiles.forall(f => !indexSourceFiles.contains(f))) + assert(expectedAppendedFiles.forall(indexSourceFiles.contains)) + assert(indexSourceFiles.forall(f => + expectedAppendedFiles.contains(f) || oldFiles.contains(f))) + } + } + } + } + + test( + "Validate incremental refresh index when some file gets deleted and some appended to " + + "source data.") { + withTempPathAsString { testPath => + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + withIndex(indexConfig.indexName) { + // Save test data non-partitioned. + SampleNestedData.save(spark, testPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested")) + val df = spark.read.parquet(testPath) + hyperspace.createIndex(df, indexConfig) + val countOriginal = df.count() + + // Delete one source data file. + deleteOneDataFile(testPath) + val countAfterDelete = spark.read.parquet(testPath).count() + assert(countAfterDelete < countOriginal) + + // Add some new data to source. + import spark.implicits._ + SampleNestedData.testData + .take(3) + .toDF("Date", "RGUID", "Query", "imprs", "clicks", "nested") + .write + .mode("append") + .parquet(testPath) + + val countAfterAppend = spark.read.parquet(testPath).count() + assert(countAfterDelete + 3 == countAfterAppend) + + hyperspace.refreshIndex(indexConfig.indexName, REFRESH_MODE_INCREMENTAL) + + // Check if refreshed index is updated appropriately. + val indexDf = spark.read + .parquet(s"$systemPath/${indexConfig.indexName}/" + + s"${IndexConstants.INDEX_VERSION_DIRECTORY_PREFIX}=1") + + assert(indexDf.count() == countAfterAppend) + } + } + } + } + + test( + "Validate the configs for incremental index data is consistent with" + + "the previous version.") { + withTempPathAsString { testPath => + SampleNestedData.save(spark, testPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested")) + val df = spark.read.parquet(testPath) + + withSQLConf( + IndexConstants.INDEX_LINEAGE_ENABLED -> "false", + IndexConstants.INDEX_NUM_BUCKETS -> "20") { + hyperspace.createIndex(df, indexConfig) + } + + // Add some new data to source. + import spark.implicits._ + SampleNestedData.testData + .take(3) + .toDF("Date", "RGUID", "Query", "imprs", "clicks", "nested") + .write + .mode("append") + .parquet(testPath) + + withSQLConf( + IndexConstants.INDEX_LINEAGE_ENABLED -> "true", + IndexConstants.INDEX_NUM_BUCKETS -> "10") { + hyperspace.refreshIndex(indexConfig.indexName, REFRESH_MODE_INCREMENTAL) + } + + val indexLogEntry = getLatestStableLog(indexConfig.indexName) + assert(!indexLogEntry.hasLineageColumn) + assert(indexLogEntry.numBuckets === 20) + } + } + + test( + "Validate RefreshQuickAction updates appended and deleted files in metadata " + + "as expected, when some file gets deleted and some appended to source data.") { + withTempPathAsString { testPath => + withSQLConf(IndexConstants.INDEX_LINEAGE_ENABLED -> "true") { + withIndex(indexConfig.indexName) { + SampleNestedData.save(spark, testPath, + Seq("Date", "RGUID", "Query", "imprs", "clicks", "nested")) + val df = spark.read.parquet(testPath) + hyperspace.createIndex(df, indexConfig) + + val oldFiles = listFiles(testPath, getFileIdTracker(systemPath, indexConfig)).toSet + + // Delete one source data file. + deleteOneDataFile(testPath) + + // Add some new data to source. + import spark.implicits._ + SampleNestedData.testData + .take(3) + .toDF("Date", "RGUID", "Query", "imprs", "clicks", "nested") + .write + .mode("append") + .parquet(testPath) + + val prevIndexLogEntry = getLatestStableLog(indexConfig.indexName) + + val indexPath = PathUtils.makeAbsolute(s"$systemPath/${indexConfig.indexName}") + new RefreshQuickAction( + spark, + IndexLogManagerFactoryImpl.create(indexPath), + IndexDataManagerFactoryImpl.create(indexPath)) + .run() + + val indexLogEntry = getLatestStableLog(indexConfig.indexName) + val latestFiles = listFiles(testPath, getFileIdTracker(systemPath, indexConfig)).toSet + val expectedDeletedFiles = oldFiles -- latestFiles + val expectedAppendedFiles = latestFiles -- oldFiles + + val signatureProvider = LogicalPlanSignatureProvider.create() + val latestDf = spark.read.parquet(testPath) + val expectedLatestSignature = + signatureProvider.signature(latestDf.queryExecution.optimizedPlan).get + + // Check `Update` is collected properly. + assert(indexLogEntry.sourceUpdate.isDefined) + assert( + indexLogEntry.source.plan.properties.fingerprint.properties.signatures.head.value + == expectedLatestSignature) + assert(indexLogEntry.appendedFiles === expectedAppendedFiles) + assert(indexLogEntry.deletedFiles === expectedDeletedFiles) + + // Check index data files and source data files are not updated. + assert( + indexLogEntry.relations.head.data.properties.content.fileInfos + === prevIndexLogEntry.relations.head.data.properties.content.fileInfos) + assert(indexLogEntry.content.fileInfos === prevIndexLogEntry.content.fileInfos) + } + } + } + } + + /** + * Delete one file from a given path. + * + * @param path Path to the parent folder containing data files. + * @param isPartitioned Is data folder partitioned or not. + * @return Deleted file's FileStatus. + */ + private def deleteOneDataFile(path: String, isPartitioned: Boolean = false): FileStatus = { + val dataPath = if (isPartitioned) s"$path/*/*" else path + TestUtils.deleteFiles(dataPath, "*parquet", 1).head + } + + private def listFiles(path: String, fileIdTracker: FileIdTracker): Seq[FileInfo] = { + val absolutePath = PathUtils.makeAbsolute(path) + val fs = absolutePath.getFileSystem(new Configuration) + fs.listStatus(absolutePath) + .toSeq + .filter(f => DataPathFilter.accept(f.getPath)) + .map(f => FileInfo(f, fileIdTracker.addFile(f), asFullPath = true)) + } + + private def getLatestStableLog(indexName: String): IndexLogEntry = { + val entry = logManager(systemPath, indexName).getLatestStableLog() + assert(entry.isDefined) + assert(entry.get.isInstanceOf[IndexLogEntry]) + entry.get.asInstanceOf[IndexLogEntry] + } + + private def getIndexFilesCount( + entry: IndexLogEntry, + version: Int, + allowEmpty: Boolean = false) = { + val cnt = entry.content.fileInfos + .count(_.name.contains(s"${IndexConstants.INDEX_VERSION_DIRECTORY_PREFIX}=$version")) + assert(allowEmpty || cnt > 0) + cnt + } + +} diff --git a/src/test/scala/com/microsoft/hyperspace/util/SchemaUtilsTest.scala b/src/test/scala/com/microsoft/hyperspace/util/SchemaUtilsTest.scala new file mode 100644 index 000000000..ad56c9d52 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/util/SchemaUtilsTest.scala @@ -0,0 +1,203 @@ +/* + * Copyright (2020) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.util + +import org.apache.spark.SparkFunSuite + +import com.microsoft.hyperspace.SparkInvolvedSuite + +class SchemaUtilsTest extends SparkFunSuite with SparkInvolvedSuite { + + test("flatten - no nesting") { + import spark.implicits._ + + val dfNoNesting = Seq( + (1, "name1", "b1"), + (2, "name2", "b2"), + (3, "name3", "b3"), + (4, "name4", "b4") + ).toDF("id", "name", "other") + + val flattenedNoNesting = SchemaUtils.flatten(dfNoNesting.schema) + + assert(flattenedNoNesting.length == 3) + assert(flattenedNoNesting(0) == "id") + assert(flattenedNoNesting(1) == "name") + assert(flattenedNoNesting(2) == "other") + } + + test("flatten - struct") { + import spark.implicits._ + + val df1 = Seq( + (1, "name1", NestedType4("nf1", NestedType("n1", 1L))), + (2, "name2", NestedType4("nf2", NestedType("n2", 2L))), + (3, "name3", NestedType4("nf3", NestedType("n3", 3L))), + (4, "name4", NestedType4("nf4", NestedType("n4", 4L))) + ).toDF("id", "name", "nested") + + val flattened = SchemaUtils.flatten(df1.schema) + + assert(flattened.length == 5) + assert(flattened(0) == "id") + assert(flattened(1) == "name") + assert(flattened(2) == "nested.nf1_b") + assert(flattened(3) == "nested.n.f1") + assert(flattened(4) == "nested.n.f2") + + /** + * Given the dataset with schema below + * + * root + * |-- id: integer (nullable = false) + * |-- name: string (nullable = true) + * |-- nested: struct (nullable = true) + * | |-- nf1: string (nullable = true) + * | |-- n: struct (nullable = true) + * | | |-- nf_a: string (nullable = true) + * | | |-- n: struct (nullable = true) + * | | | |-- nf1_b: string (nullable = true) + * | | | |-- n: struct (nullable = true) + * | | | | |-- f1: string (nullable = true) + * | | | | |-- f2: long (nullable = false) + * + * The output should be the list of leaves maintaining the order + * + * id + * name + * nested.nf1 + * nested.n.nfa + * nested.n.n.nf1_b + * nested.n.n.n.f1 + * nested.n.n.n.f2 + */ + val df2 = Seq( + (1, "name1", NestedType2("nf1", NestedType3("n1", NestedType4("h1", + NestedType("end1", 1L))))), + (2, "name2", NestedType2("nf2", NestedType3("n2", NestedType4("h2", + NestedType("end2", 1L))))), + (3, "name3", NestedType2("nf3", NestedType3("n3", NestedType4("h3", + NestedType("end3", 1L))))), + (4, "name4", NestedType2("nf4", NestedType3("n4", NestedType4("h4", + NestedType("end4", 1L))))) + ).toDF("id", "name", "nested") + + val flattened2 = SchemaUtils.flatten(df2.schema) + + assert(flattened2.length == 7) + assert(flattened2(0) == "id") + assert(flattened2(1) == "name") + assert(flattened2(2) == "nested.nf1") + assert(flattened2(3) == "nested.n.nf_a") + assert(flattened2(4) == "nested.n.n.nf1_b") + assert(flattened2(5) == "nested.n.n.n.f1") + assert(flattened2(6) == "nested.n.n.n.f2") + } + + test("flatten - array") { + import spark.implicits._ + + val df1 = Seq( + (1, "name1", Array[NestedType](NestedType("n1", 1L), NestedType("o1", 10L))), + (2, "name2", Array[NestedType](NestedType("n2", 2L), NestedType("o2", 20L))), + (3, "name3", Array[NestedType](NestedType("n3", 3L), NestedType("o3", 30L))), + (4, "name4", Array[NestedType](NestedType("n4", 4L), NestedType("o4", 40L))) + ).toDF("id", "name", "arrayOfNested") + + val flattened = SchemaUtils.flatten(df1.schema) + + assert(flattened.length == 4) + assert(flattened(0) == "id") + assert(flattened(1) == "name") + assert(flattened(2) == "arrayOfNested.f1") + assert(flattened(3) == "arrayOfNested.f2") + + /** + * Given the dataset with schema below + * + * root + * |-- id: integer (nullable = false) + * |-- name: string (nullable = true) + * |-- arrayOfNested: array (nullable = true) + * | |-- element: struct (containsNull = true) + * | | |-- nf1_b: string (nullable = true) + * | | |-- n: struct (nullable = true) + * | | | |-- f1: string (nullable = true) + * | | | |-- f2: long (nullable = false) + * + * The output should be the list of leaves maintaining the order + * + * id + * name + * arrayOfNested.nf1_b + * arrayOfNested.n.f1 + * arrayOfNested.n.f2 + */ + val df2 = Seq( + (1, "name1", Array[NestedType4]( + NestedType4("n1", NestedType("o1", 11L)), + NestedType4("a1", NestedType("b1", 1L)))), + (2, "name2", Array[NestedType4]( + NestedType4("n2", NestedType("o2", 12L)), + NestedType4("a2", NestedType("b2", 2L)))), + (3, "name3", Array[NestedType4]( + NestedType4("n3", NestedType("o3", 13L)), + NestedType4("a3", NestedType("b3", 3L)))), + (4, "name4", Array[NestedType4]( + NestedType4("n4", NestedType("o4", 14L)), + NestedType4("a4", NestedType("b4", 4L)))) + ).toDF("id", "name", "arrayOfNested") + + val flattened2 = SchemaUtils.flatten(df2.schema) + + assert(flattened2.length == 5) + assert(flattened2(0) == "id") + assert(flattened2(1) == "name") + assert(flattened2(2) == "arrayOfNested.nf1_b") + assert(flattened2(3) == "arrayOfNested.n.f1") + assert(flattened2(4) == "arrayOfNested.n.f2") + } + + test("escapeFieldName") { + assert(SchemaUtils.escapeFieldName("a.b") == "a__b") + assert(SchemaUtils.escapeFieldName("a.b.c.d") == "a__b__c__d") + assert(SchemaUtils.escapeFieldName("a_b_c_d") == "a_b_c_d") + } + + test("escapeFieldNames") { + assert(SchemaUtils.escapeFieldNames( + Seq("a.b.c.d", "a.b", "A_B")) == Seq("a__b__c__d", "a__b", "A_B")) + assert(SchemaUtils.escapeFieldNames(Seq.empty[String]).isEmpty) + } + + test("unescapeFieldName") { + assert(SchemaUtils.unescapeFieldName("a__b") == "a.b") + assert(SchemaUtils.unescapeFieldName("a__b__c__d") == "a.b.c.d") + assert(SchemaUtils.unescapeFieldName("a_b_c_d") == "a_b_c_d") + } + + test("unescapeFieldNames") { + assert(SchemaUtils.unescapeFieldNames( + Seq("a__b__c__d", "a__b", "A_B")) == Seq("a.b.c.d", "a.b", "A_B")) + assert(SchemaUtils.escapeFieldNames(Seq.empty[String]).isEmpty) + } +} + +case class NestedType4(nf1_b: String, n: NestedType) +case class NestedType3(nf_a: String, n: NestedType4) +case class NestedType2(nf1: String, n: NestedType3) +case class NestedType(f1: String, f2: Long) From ef2b45e66347ecc50e4c426d436a3912e03fdeca Mon Sep 17 00:00:00 2001 From: Andrei Ionescu Date: Tue, 2 Mar 2021 17:07:07 +0200 Subject: [PATCH 2/3] Add support for nested fields in joins --- .../index/execution/BucketUnionExec.scala | 25 +- .../index/rules/FilterIndexRule.scala | 59 +- .../index/rules/JoinIndexRule.scala | 186 ++++++- .../hyperspace/index/rules/RuleUtils.scala | 100 +++- .../hyperspace/util/SchemaUtils.scala | 12 + .../index/HybridScanForNestedFieldsTest.scala | 508 ++++++++++++++++-- 6 files changed, 762 insertions(+), 128 deletions(-) diff --git a/src/main/scala/com/microsoft/hyperspace/index/execution/BucketUnionExec.scala b/src/main/scala/com/microsoft/hyperspace/index/execution/BucketUnionExec.scala index c13a1e6be..7f11d34c0 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/execution/BucketUnionExec.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/execution/BucketUnionExec.scala @@ -110,12 +110,23 @@ private[hyperspace] case class BucketUnionExec(children: Seq[SparkPlan], bucketS override def output: Seq[Attribute] = children.head.output override def outputPartitioning: Partitioning = { - assert(children.map(_.outputPartitioning).toSet.size == 1) - assert(children.head.outputPartitioning.isInstanceOf[HashPartitioning]) - assert( - children.head.outputPartitioning - .asInstanceOf[HashPartitioning] - .numPartitions == bucketSpec.numBuckets) - children.head.outputPartitioning + val parts = children.map(_.outputPartitioning).distinct + assert(parts.forall(_.isInstanceOf[HashPartitioning])) + assert(parts.forall(_.numPartitions == bucketSpec.numBuckets)) + + val reduced = parts.reduceLeft { (a, b) => + val h1 = a.asInstanceOf[HashPartitioning] + val h2 = b.asInstanceOf[HashPartitioning] + val h1Name = h1.references.head.name + val h2Name = h2.references.head.name + val same = h1Name.contains(h2Name) || h2Name.contains(h1Name) + assert(same) + if (h1Name.length > h2Name.length) { + h1 + } else { + h2 + } + } + reduced } } diff --git a/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala b/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala index f1801e66d..3f28bd51c 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala @@ -20,7 +20,7 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.CleanupAliases -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GetStructField} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, GetStructField} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{DataType, StructType} @@ -200,12 +200,12 @@ object ExtractFilterNode { case AttributeReference(name, _, _, _) => Set(s"$name") case otherExp => - otherExp.containsChild.map { + otherExp.containsChild.flatMap { case g: GetStructField => - s"${getChildNameFromStruct(g)}" + Set(s"${getChildNameFromStruct(g)}") case e: Expression => - extractNamesFromExpression(e).filter(_.nonEmpty).mkString(".") - case _ => "" + extractNamesFromExpression(e).filter(_.nonEmpty) + case _ => Set.empty[String] } } } @@ -221,20 +221,6 @@ object ExtractFilterNode { } } - def extractSearchQuery(exp: Expression, name: String): (Expression, Expression) = { - val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX) - val expFound = exp.find { - case a: AttributeReference if splits.forall(s => a.name.contains(s)) => true - case f: GetStructField if splits.forall(s => f.toString().contains(s)) => true - case _ => false - }.get - val parent = exp.find { - case e: Expression if e.containsChild.contains(expFound) => true - case _ => false - }.get - (parent, expFound) - } - def replaceInSearchQuery( parent: Expression, needle: Expression, @@ -260,17 +246,26 @@ object ExtractFilterNode { def extractTypeFromExpression(exp: Expression, name: String): DataType = { val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX) val elem = exp.flatMap { - case a: AttributeReference => - if (splits.forall(s => a.name == s)) { - Some((name, a.dataType)) + case attrRef: AttributeReference => + if (splits.forall(s => attrRef.name == s)) { + Some((name, attrRef.dataType)) } else { Try({ val h :: t = splits.toList - if (a.name == h && a.dataType.isInstanceOf[StructType]) { - val currentDataType = a.dataType.asInstanceOf[StructType] + if (attrRef.name == h && attrRef.dataType.isInstanceOf[StructType]) { + val currentDataType = attrRef.dataType.asInstanceOf[StructType] + var localDT = currentDataType val foldedFields = t.foldLeft(Seq.empty[(String, DataType)]) { (acc, i) => - val idx = currentDataType.indexWhere(_.name.equalsIgnoreCase(i)) - acc :+ (i, currentDataType(idx).dataType) + val collected = localDT.collect { + case dt if dt.name == i => + dt.dataType match { + case st: StructType => + localDT = st + case _ => + } + (i, dt.dataType) + } + acc ++ collected } Some(foldedFields.last) } else { @@ -284,6 +279,18 @@ object ExtractFilterNode { } elem.find(e => e._1 == name || e._1 == splits.last).get._2 } + + def collectAliases(plan: LogicalPlan): Seq[(String, Attribute, Expression)] = { + plan + .collect { + case Project(projectList, _) => + projectList.collect { + case a @ Alias(child, name) => + (name, a.toAttribute, child) + } + } + .flatten + } } object ExtractRelation extends ActiveSparkSession { diff --git a/src/main/scala/com/microsoft/hyperspace/index/rules/JoinIndexRule.scala b/src/main/scala/com/microsoft/hyperspace/index/rules/JoinIndexRule.scala index ce7dc3e7f..aaf7d1766 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/rules/JoinIndexRule.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/rules/JoinIndexRule.scala @@ -21,8 +21,8 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.CleanupAliases -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression} -import org.apache.spark.sql.catalyst.plans.logical.{Join, LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, AttributeReference, EqualTo, Expression, GetStructField, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import com.microsoft.hyperspace.{ActiveSparkSession, Hyperspace} @@ -32,6 +32,7 @@ import com.microsoft.hyperspace.index.rankers.JoinIndexRanker import com.microsoft.hyperspace.index.sources.FileBasedRelation import com.microsoft.hyperspace.telemetry.{AppInfo, HyperspaceEventLogging, HyperspaceIndexUsageEvent} import com.microsoft.hyperspace.util.ResolverUtils._ +import com.microsoft.hyperspace.util.SchemaUtils /** * Rule to optimize a join between two indexed dataframes. @@ -108,16 +109,19 @@ object JoinIndexRule private def isApplicable(l: LogicalPlan, r: LogicalPlan, condition: Expression): Boolean = { // The given plan is eligible if it is supported and index has not been applied. def isEligible(optRel: Option[FileBasedRelation]): Boolean = { - optRel.map(!RuleUtils.isIndexApplied(_)).getOrElse(false) + optRel.exists(!RuleUtils.isIndexApplied(_)) } lazy val optLeftRel = RuleUtils.getRelation(spark, l) lazy val optRightRel = RuleUtils.getRelation(spark, r) + val lProj = collectProjections(l) + val rProj = collectProjections(r) + isJoinConditionSupported(condition) && isPlanLinear(l) && isPlanLinear(r) && isEligible(optLeftRel) && isEligible(optRightRel) && - ensureAttributeRequirements(optLeftRel.get, optRightRel.get, condition) + ensureAttributeRequirements(optLeftRel.get, optRightRel.get, lProj, rProj, condition) } /** @@ -226,6 +230,8 @@ object JoinIndexRule * * @param l left relation * @param r right relation + * @param lp left projections + * @param rp right projections * @param condition join condition * @return true if all attributes in join condition are from base relation nodes. False * otherwise @@ -233,15 +239,36 @@ object JoinIndexRule private def ensureAttributeRequirements( l: FileBasedRelation, r: FileBasedRelation, + lp: Seq[NamedExpression], + rp: Seq[NamedExpression], condition: Expression): Boolean = { + // Output attributes from base relations. Join condition attributes must belong to these // attributes. We work on canonicalized forms to make sure we support case-sensitivity. val lBaseAttrs = l.plan.output.map(_.canonicalized) val rBaseAttrs = r.plan.output.map(_.canonicalized) - def fromDifferentBaseRelations(c1: Expression, c2: Expression): Boolean = { - (lBaseAttrs.contains(c1) && rBaseAttrs.contains(c2)) || - (lBaseAttrs.contains(c2) && rBaseAttrs.contains(c1)) + def fromDifferentBaseRelations( + c1: Expression, + c2: Expression, + p1: Seq[NamedExpression], + p2: Seq[NamedExpression]): Boolean = { + val cr1 = if (p1.nonEmpty) { + Try { + extractFieldFromProjection(c1, p1).get.references.head.canonicalized + }.getOrElse(c1) + } else { + c1 + } + val cr2 = if (p2.nonEmpty) { + Try { + extractFieldFromProjection(c2, p2).get.references.head.canonicalized + }.getOrElse(c2) + } else { + c2 + } + (lBaseAttrs.contains(cr1) && rBaseAttrs.contains(cr2)) || + (lBaseAttrs.contains(cr2) && rBaseAttrs.contains(cr1)) } // Map to maintain and check one-to-one relation between join condition attributes. For join @@ -254,7 +281,7 @@ object JoinIndexRule case EqualTo(e1, e2) => val (c1, c2) = (e1.canonicalized, e2.canonicalized) // Check 1: c1 and c2 should belong to l and r respectively, or r and l respectively. - if (!fromDifferentBaseRelations(c1, c2)) { + if (!fromDifferentBaseRelations(c1, c2, lp, rp)) { return false } // Check 2: c1 is compared only against c2 and vice versa. @@ -271,6 +298,47 @@ object JoinIndexRule } } + /** + * The method extracts all the projection fields. + * + * @param plan The plan from which to extract projections + * @return A sequence of [[NamedExpression]] + */ + private def collectProjections(plan: LogicalPlan): Seq[NamedExpression] = { + plan.collect { + case p: Project => p.projectList + }.flatten + } + + /** + * The method tries to map any top level condition field to the fields present in relation. + * It does this by going through projections. + * + * @param projections The available projection expressions + * @return Some of the found expression when the condition field is found otherwise None + */ + private def conditionFieldsToRelationFields( + projections: Seq[NamedExpression]): Map[Expression, Expression] = { + projections.collect { + case a: Alias => + (a.toAttribute.canonicalized, a.child) + }.toMap + } + + /** + * The method tries to return a field out of the fields present in relation. + * + * @param conditionField The field to map to + * @param projections The available projection expressions + * @return Some of the found expression when the condition field is found otherwise None + */ + private def extractFieldFromProjection( + conditionField: Expression, + projections: Seq[NamedExpression]): Option[Expression] = { + val fields = conditionFieldsToRelationFields(projections) + Try(fields(conditionField.canonicalized)).toOption + } + /** * Get best ranked index pair from available indexes of both sides. * @@ -295,12 +363,14 @@ object JoinIndexRule // been already checked in `isApplicable`. val leftRelation = RuleUtils.getRelation(spark, left).get val rightRelation = RuleUtils.getRelation(spark, right).get - val lBaseAttrs = leftRelation.plan.output.map(_.name) - val rBaseAttrs = rightRelation.plan.output.map(_.name) + val lBaseAttrs = SchemaUtils.flatten(leftRelation.plan.output) + val rBaseAttrs = SchemaUtils.flatten(rightRelation.plan.output) // Map of left resolved columns with their corresponding right resolved // columns from condition. - val lRMap = getLRColumnMapping(lBaseAttrs, rBaseAttrs, joinCondition) + val lProj = collectProjections(left) + val rProj = collectProjections(right) + val lRMap = getLRColumnMapping(lBaseAttrs, rBaseAttrs, lProj, rProj, joinCondition) val lRequiredIndexedCols = lRMap.keys.toSeq val rRequiredIndexedCols = lRMap.values.toSeq @@ -370,16 +440,56 @@ object JoinIndexRule */ private def allRequiredCols(plan: LogicalPlan): Seq[String] = { val provider = Hyperspace.getContext(spark).sourceProviderManager - val cleaned = CleanupAliases(plan) - val allReferences = cleaned.collect { - case l: LeafNode if provider.isSupportedRelation(l) => Seq() - case other => other.references + val projectionFields = collectProjections(plan) + + val allReferences = plan.collect { + case l: LeafNode if provider.isSupportedRelation(l) => + Seq.empty[String] + case other => + other match { + case project: Project => + val fields = conditionFieldsToRelationFields(project.projectList).values + fields.flatMap { + case g: GetStructField => + Seq(ExtractFilterNode.getChildNameFromStruct(g)) + case otherFieldType => + ExtractFilterNode.extractNamesFromExpression(otherFieldType).toSeq + } + case filter: Filter => + var acc = Seq.empty[String] + val fls = ExtractFilterNode + .extractNamesFromExpression(filter.condition) + .toSeq + .distinct + .sortBy(-_.length) + .toList + var h :: t = fls + while (t.nonEmpty) { + if (!t.exists(_.contains(h))) { + acc = acc :+ h + } + h = t.head + t = t.tail + } + acc + case o => + o.references.map(_.name) + } }.flatten - val topLevelOutputs = cleaned.outputSet.toSeq - (allReferences ++ topLevelOutputs).distinct.collect { - case attr: AttributeReference => attr.name + val topLevelOutputs = if (projectionFields.nonEmpty) { + plan.outputSet.map { i => + val attr = extractFieldFromProjection(i, projectionFields) + val opt = attr.map { e => + ExtractFilterNode.getChildNameFromStruct(e.asInstanceOf[GetStructField]) + } + opt.getOrElse(i.name) + } + } else { + plan.outputSet.toSeq.map(_.name) } + + (allReferences ++ topLevelOutputs).distinct } /** @@ -399,18 +509,37 @@ object JoinIndexRule private def getLRColumnMapping( leftBaseAttrs: Seq[String], rightBaseAttrs: Seq[String], + lp: Seq[NamedExpression], + rp: Seq[NamedExpression], condition: Expression): Map[String, String] = { extractConditions(condition).map { case EqualTo(attr1: AttributeReference, attr2: AttributeReference) => + val attrLeftName = if (lp.nonEmpty) { + Try { + val attrLeft = extractFieldFromProjection(attr1, lp).get + ExtractFilterNode.getChildNameFromStruct(attrLeft.asInstanceOf[GetStructField]) + }.getOrElse(attr1.name) + } else { + attr1.name + } + val attrRightName = if (rp.nonEmpty) { + Try { + val attrRight = extractFieldFromProjection(attr2, rp).get + ExtractFilterNode.getChildNameFromStruct(attrRight.asInstanceOf[GetStructField]) + }.getOrElse(attr2.name) + } else { + attr2.name + } + Try { ( - resolve(spark, attr1.name, leftBaseAttrs).get, - resolve(spark, attr2.name, rightBaseAttrs).get) + resolve(spark, attrLeftName, leftBaseAttrs).get, + resolve(spark, attrRightName, rightBaseAttrs).get) }.getOrElse { Try { ( - resolve(spark, attr2.name, leftBaseAttrs).get, - resolve(spark, attr1.name, rightBaseAttrs).get) + resolve(spark, attrRightName, leftBaseAttrs).get, + resolve(spark, attrLeftName, rightBaseAttrs).get) }.getOrElse { throw new IllegalStateException("Unexpected exception while using join rule") } @@ -454,8 +583,8 @@ object JoinIndexRule // All required index columns should match one-to-one with all indexed columns and // vice-versa. All required columns must be present in the available index columns. - requiredIndexCols.toSet.equals(idx.indexedColumns.toSet) && - allRequiredCols.forall(allCols.contains) + SchemaUtils.escapeFieldNames(requiredIndexCols).toSet.equals(idx.indexedColumns.toSet) && + SchemaUtils.escapeFieldNames(allRequiredCols).forall(allCols.contains) } } @@ -522,10 +651,13 @@ object JoinIndexRule lIndex: IndexLogEntry, rIndex: IndexLogEntry, columnMapping: Map[String, String]): Boolean = { - require(columnMapping.keys.toSet.equals(lIndex.indexedColumns.toSet)) - require(columnMapping.values.toSet.equals(rIndex.indexedColumns.toSet)) + val escapedMap = columnMapping.map { + case (k, v) => SchemaUtils.escapeFieldName(k) -> SchemaUtils.escapeFieldName(v) + } + require(escapedMap.keys.toSet.equals(lIndex.indexedColumns.toSet)) + require(escapedMap.values.toSet.equals(rIndex.indexedColumns.toSet)) - val requiredRightIndexedCols = lIndex.indexedColumns.map(columnMapping) + val requiredRightIndexedCols = lIndex.indexedColumns.map(escapedMap) rIndex.indexedColumns.equals(requiredRightIndexedCols) } } diff --git a/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala b/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala index e20becbdd..29bbb61b7 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ExprId, GetStructField, In, Literal, Not} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, ExprId, GetStructField, In, IsNotNull, Literal, Not} import org.apache.spark.sql.catalyst.optimizer.OptimizeIn import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources._ @@ -475,7 +475,21 @@ object RuleUtils { // Although only numBuckets of BucketSpec is used in BucketUnion*, bucketColumnNames // and sortColumnNames are shown in plan string. So remove sortColumnNames to avoid // misunderstanding. - val bucketSpec = index.bucketSpec.copy(sortColumnNames = Nil) + + val aliases = ExtractFilterNode + .collectAliases(plan) + .collect { + case (shortName, _, ref: GetStructField) => + val escapedFieldName = + SchemaUtils.escapeFieldName(ExtractFilterNode.getChildNameFromStruct(ref)) + escapedFieldName -> shortName + } + .toMap + val aliasBucketNames = index.bucketSpec.bucketColumnNames.map { col => + aliases.getOrElse(col, col) + } + val bucketSpec = + index.bucketSpec.copy(bucketColumnNames = aliasBucketNames, sortColumnNames = Nil) // Merge index plan & newly shuffled plan by using bucket-aware union. BucketUnion( @@ -581,7 +595,14 @@ object RuleUtils { plan: LogicalPlan, indexedColumns: Seq[String]): Seq[Option[Attribute]] = { val attrMap = plan.output.attrs.map(attr => (attr.name, attr)).toMap - indexedColumns.map(colName => attrMap.get(colName)) + indexedColumns.map { colName => + attrMap + .find { + case (k, _) => + colName.contains(k) + } + .map(_._2) + } } } @@ -624,45 +645,66 @@ object RuleUtils { private def transformProject(project: Project, index: IndexLogEntry): Project = { val projectedFields = project.projectList.map { exp => val fieldName = ExtractFilterNode.extractNamesFromExpression(exp).head + val shortFieldName = fieldName.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX).last val escapedFieldName = SchemaUtils.escapeFieldName(fieldName) val attr = ExtractFilterNode.extractAttributeRef(exp, fieldName) val fieldType = ExtractFilterNode.extractTypeFromExpression(exp, fieldName) val exprId = getFieldPosition(index, escapedFieldName) - attr.copy(escapedFieldName, fieldType, attr.nullable, attr.metadata)( + val attrCopy = attr.copy(escapedFieldName, fieldType, attr.nullable, attr.metadata)( ExprId(exprId), attr.qualifier) + if (fieldName != shortFieldName) { + Alias(attrCopy, shortFieldName)(exprId = exp.toAttribute.exprId) + } else { + Alias(attrCopy, attrCopy.name)(exprId = exp.toAttribute.exprId) + } } project.copy(projectList = projectedFields) } private def transformFilter(filter: Filter, index: IndexLogEntry): Filter = { - val fieldNames = ExtractFilterNode.extractNamesFromExpression(filter.condition) - var mutableFilter = filter - fieldNames.foreach { fieldName => - val escapedFieldName = SchemaUtils.escapeFieldName(fieldName) - val nestedFields = getNestedFields(index) - if (nestedFields.nonEmpty && - nestedFields.exists(i => i.equalsIgnoreCase(escapedFieldName))) { - val (parentExpresion, exp) = - ExtractFilterNode.extractSearchQuery(filter.condition, fieldName) - val fieldType = ExtractFilterNode.extractTypeFromExpression(exp, fieldName) - val attr = ExtractFilterNode.extractAttributeRef(exp, fieldName) - val exprId = getFieldPosition(index, escapedFieldName) - val newAttr = attr.copy(escapedFieldName, fieldType, attr.nullable, attr.metadata)( - ExprId(exprId), - attr.qualifier) - val newExp = exp match { - case _: GetStructField => newAttr - case other: Expression => other - } - val newParentExpression = - ExtractFilterNode.replaceInSearchQuery(parentExpresion, exp, newExp) - mutableFilter = filter.copy(condition = newParentExpression) - } else { - filter + val nestedFields = getNestedFields(index) + if (nestedFields.nonEmpty) { + val newCondition = filter.condition.transformDown { + case gsf: GetStructField => + val fieldName = ExtractFilterNode.getChildNameFromStruct(gsf) + val escapedFieldName = SchemaUtils.escapeFieldName(fieldName) + if (nestedFields.contains(escapedFieldName)) { + val fieldType = ExtractFilterNode.extractTypeFromExpression(gsf, fieldName) + val attr = ExtractFilterNode.extractAttributeRef(gsf, fieldName) + val exprId = getFieldPosition(index, escapedFieldName) + val newAttr = attr.copy(name = escapedFieldName, dataType = fieldType)( + ExprId(exprId), + attr.qualifier) + newAttr + } else { + gsf + } + case cond @ IsNotNull(child) => + val fieldName = + SchemaUtils.escapeFieldName(ExtractFilterNode.extractNamesFromExpression(child).head) + val elemFound = nestedFields.find(i => i.contains(fieldName)) + elemFound match { + case Some(name) => + val newChild = child match { + case attr: AttributeReference => + val fieldType = ExtractFilterNode.extractTypeFromExpression( + cond, + SchemaUtils.unescapeFieldName(name)) + val exprId = getFieldPosition(index, name) + attr.copy(name = name, dataType = fieldType)(ExprId(exprId), attr.qualifier) + case other => + other + } + cond.copy(child = newChild) + case _ => + cond + } } + filter.copy(condition = newCondition) + } else { + filter } - mutableFilter } private def getNestedFields(index: IndexLogEntry): Seq[String] = { diff --git a/src/main/scala/com/microsoft/hyperspace/util/SchemaUtils.scala b/src/main/scala/com/microsoft/hyperspace/util/SchemaUtils.scala index cba0d5a57..b4d7f9ef7 100644 --- a/src/main/scala/com/microsoft/hyperspace/util/SchemaUtils.scala +++ b/src/main/scala/com/microsoft/hyperspace/util/SchemaUtils.scala @@ -16,6 +16,7 @@ package com.microsoft.hyperspace.util +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.types.{ArrayType, StructField, StructType} object SchemaUtils { @@ -34,6 +35,17 @@ object SchemaUtils { } } + def flatten(attributes: Seq[Attribute]): Seq[String] = { + attributes.flatMap { a => + a.dataType match { + case struct: StructType => + flatten(struct, Some(a.name)) + case _ => + Seq(a.name) + } + } + } + def escapeFieldNames(fields: Seq[String]): Seq[String] = { fields.map(escapeFieldName) } diff --git a/src/test/scala/com/microsoft/hyperspace/index/HybridScanForNestedFieldsTest.scala b/src/test/scala/com/microsoft/hyperspace/index/HybridScanForNestedFieldsTest.scala index 2d6ba6c16..bb1c2bad9 100644 --- a/src/test/scala/com/microsoft/hyperspace/index/HybridScanForNestedFieldsTest.scala +++ b/src/test/scala/com/microsoft/hyperspace/index/HybridScanForNestedFieldsTest.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.catalyst.expressions.{Attribute, EqualTo, In, InSet, Literal, Not} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, RepartitionByExpression, Union} -import org.apache.spark.sql.execution.{FileSourceScanExec, ProjectExec} +import org.apache.spark.sql.execution.{FileSourceScanExec, ProjectExec, UnionExec} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.internal.SQLConf @@ -31,24 +31,24 @@ import com.microsoft.hyperspace.{Hyperspace, SampleNestedData, TestConfig} import com.microsoft.hyperspace.TestUtils.{latestIndexLogEntry, logManager} import com.microsoft.hyperspace.index.execution.BucketUnionExec import com.microsoft.hyperspace.index.plans.logical.BucketUnion -import com.microsoft.hyperspace.util.FileUtils +import com.microsoft.hyperspace.util.{FileUtils, SchemaUtils} // Hybrid Scan tests for non partitioned source data. Test cases of HybridScanSuite are also // executed with non partitioned source data. class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { override val systemPath = new Path("src/test/resources/hybridScanTestNestedFields") + import spark.implicits._ val sampleNestedData = SampleNestedData.testData - var hyperspace: Hyperspace = _ - val fileFormat = "parquet" - - import spark.implicits._ + val fileFormat2 = "json" val nestedDf = sampleNestedData.toDF("Date", "RGUID", "Query", "imprs", "clicks", "nested") + val indexConfig1 = IndexConfig("index1", Seq("nested.leaf.cnt"), Seq("query", "nested.leaf.id")) val indexConfig2 = IndexConfig("index2", Seq("nested.leaf.cnt"), Seq("Date", "nested.leaf.id")) + var hyperspace: Hyperspace = _ def normalizePaths(in: Seq[String]): Seq[String] = { in.map(_.replace("file:///", "file:/")) @@ -99,13 +99,6 @@ class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { spark.disableHyperspace() } - private def getLatestStableLog(indexName: String): IndexLogEntry = { - val entry = logManager(systemPath, indexName).getLatestStableLog() - assert(entry.isDefined) - assert(entry.get.isInstanceOf[IndexLogEntry]) - entry.get.asInstanceOf[IndexLogEntry] - } - def checkDeletedFiles( plan: LogicalPlan, indexName: String, @@ -127,8 +120,8 @@ class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { }.flatten val deletedFilesList = plan collect { case Filter( - Not(EqualTo(left: Attribute, right: Literal)), - LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)) => + Not(EqualTo(left: Attribute, right: Literal)), + LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)) => // Check new filter condition on lineage column. val colName = left.toString val deletedFile = right.toString @@ -140,29 +133,29 @@ class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { assert(files.nonEmpty && files.forall(_.contains(indexName))) deleted case Filter( - Not(InSet(attr, deletedFileIds)), - LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)) => + Not(InSet(attr, deletedFileIds)), + LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)) => // Check new filter condition on lineage column. assert(attr.toString.contains(IndexConstants.DATA_FILE_NAME_ID)) val deleted = deletedFileIds.map(_.toString).toSeq assert( expectedDeletedFiles.length > spark.conf - .get("spark.sql.optimizer.inSetConversionThreshold") - .toLong) + .get("spark.sql.optimizer.inSetConversionThreshold") + .toLong) // Check the location is replaced with index data files properly. val files = fsRelation.location.inputFiles assert(files.nonEmpty && files.forall(_.contains(indexName))) deleted case Filter( - Not(In(attr, deletedFileIds)), - LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)) => + Not(In(attr, deletedFileIds)), + LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)) => // Check new filter condition on lineage column. assert(attr.toString.contains(IndexConstants.DATA_FILE_NAME_ID)) val deleted = deletedFileIds.map(_.toString) assert( expectedDeletedFiles.length <= spark.conf - .get("spark.sql.optimizer.inSetConversionThreshold") - .toLong) + .get("spark.sql.optimizer.inSetConversionThreshold") + .toLong) // Check the location is replaced with index data files properly. val files = fsRelation.location.inputFiles assert(files.nonEmpty && files.forall(_.contains(indexName))) @@ -208,15 +201,15 @@ class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { assert(bucketSpec.numBuckets === 200) assert( bucketSpec.bucketColumnNames.size === 1 && - bucketSpec.bucketColumnNames.head === "clicks") + bucketSpec.bucketColumnNames.head === "cnt") val childNodes = children.collect { case r @ RepartitionByExpression( - attrs, - Project(_, Filter(_, LogicalRelation(fsRelation: HadoopFsRelation, _, _, _))), - numBucket) => + attrs, + Project(_, Filter(_, LogicalRelation(fsRelation: HadoopFsRelation, _, _, _))), + numBucket) => assert(attrs.size === 1) - assert(attrs.head.asInstanceOf[Attribute].name.contains("clicks")) + assert(attrs.head.asInstanceOf[Attribute].name.contains("cnt")) // Check appended file. val files = fsRelation.location.inputFiles @@ -245,15 +238,15 @@ class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { assert(bucketSpec.numBuckets === 200) assert( bucketSpec.bucketColumnNames.size === 1 && - bucketSpec.bucketColumnNames.head === "clicks") + bucketSpec.bucketColumnNames.head === "cnt") val childNodes = children.collect { case r @ RepartitionByExpression( - attrs, - Project(_, Filter(_, LogicalRelation(fsRelation: HadoopFsRelation, _, _, _))), - numBucket) => + attrs, + Project(_, Filter(_, LogicalRelation(fsRelation: HadoopFsRelation, _, _, _))), + numBucket) => assert(attrs.size === 1) - assert(attrs.head.asInstanceOf[Attribute].name.contains("clicks")) + assert(attrs.head.asInstanceOf[Attribute].name.contains("cnt")) // Check appended files. val files = fsRelation.location.inputFiles @@ -262,8 +255,8 @@ class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { assert(numBucket === 200) r case p @ Project( - _, - Filter(_, LogicalRelation(fsRelation: HadoopFsRelation, _, _, _))) => + _, + Filter(_, LogicalRelation(fsRelation: HadoopFsRelation, _, _, _))) => // Check index data files. val files = fsRelation.location.inputFiles assert(files.nonEmpty && files.forall(_.contains(rightIndexName))) @@ -309,6 +302,73 @@ class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { } } + def checkFilterIndexHybridScanUnion( + plan: LogicalPlan, + indexName: String, + expectedAppendedFiles: Seq[String] = Nil, + expectedDeletedFiles: Seq[String] = Nil, + filterConditions: Seq[String] = Nil): Unit = { + // The first child should be the index data scan; thus check if the deleted files are handled + // properly with the first child plan. + checkDeletedFiles(plan.children.head, indexName, expectedDeletedFiles) + + if (expectedAppendedFiles.nonEmpty) { + val nodes = plan.collect { + case u @ Union(children) => + val indexChild = children.head + indexChild collect { + case LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) => + assert(fsRelation.location.inputFiles.forall(_.contains(indexName))) + } + + assert(children.tail.size === 1) + val appendChild = children.last + appendChild collect { + case LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) => + val files = fsRelation.location.inputFiles + assert(equalNormalizedPaths(files, expectedAppendedFiles)) + assert(files.length === expectedAppendedFiles.size) + } + u + } + assert(nodes.count(_.isInstanceOf[Union]) === 1) + + // Make sure there is no shuffle. + plan.foreach(p => assert(!p.isInstanceOf[RepartitionByExpression])) + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val execPlan = spark.sessionState.executePlan(plan).executedPlan + val execNodes = execPlan.collect { + case p @ UnionExec(children) => + assert(children.size === 2) + assert(children.head.isInstanceOf[ProjectExec]) // index data + assert(children.last.isInstanceOf[ProjectExec]) // appended data + p + case p @ FileSourceScanExec(_, _, _, partitionFilters, _, dataFilters, _) => + // Check filter pushed down properly. + if (partitionFilters.nonEmpty) { + assert(filterConditions.forall(partitionFilters.toString.contains)) + } else { + assert(filterConditions.forall(dataFilters.toString.contains)) + } + p + } + assert(execNodes.count(_.isInstanceOf[UnionExec]) === 1) + // 1 of index, 1 of appended file + assert(execNodes.count(_.isInstanceOf[FileSourceScanExec]) === 2) + // Make sure there is no shuffle. + execPlan.foreach(p => assert(!p.isInstanceOf[ShuffleExchangeExec])) + } + } + } + + private def getLatestStableLog(indexName: String): IndexLogEntry = { + val entry = logManager(systemPath, indexName).getLatestStableLog() + assert(entry.isDefined) + assert(entry.get.isInstanceOf[IndexLogEntry]) + entry.get.asInstanceOf[IndexLogEntry] + } + test( "Append-only: union over index and new files " + "due to field names being different: `nested__leaf__cnt` + `nested.leaf.cnt`.") { @@ -429,14 +489,16 @@ class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { val deletedRatio = 1 - (afterDeleteSize / sourceSize.toFloat) withSQLConf(TestConfig.HybridScanEnabled: _*) { - withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_DELETED_RATIO_THRESHOLD -> - (deletedRatio + 0.1).toString) { + withSQLConf( + IndexConstants.INDEX_HYBRID_SCAN_DELETED_RATIO_THRESHOLD -> + (deletedRatio + 0.1).toString) { val filter = filterQuery // As deletedRatio is less than the threshold, the index can be applied. assert(!basePlan.equals(filter.queryExecution.optimizedPlan)) } - withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_DELETED_RATIO_THRESHOLD -> - (deletedRatio - 0.1).toString) { + withSQLConf( + IndexConstants.INDEX_HYBRID_SCAN_DELETED_RATIO_THRESHOLD -> + (deletedRatio - 0.1).toString) { val filter = filterQuery // As deletedRatio is greater than the threshold, the index shouldn't be applied. assert(basePlan.equals(filter.queryExecution.optimizedPlan)) @@ -444,4 +506,372 @@ class HybridScanForNestedFieldsTest extends QueryTest with HyperspaceSuite { } } } + + test( + "Append-only: join rule, appended data should be shuffled with indexed columns " + + "and merged by BucketUnion.") { + withTempPathAsString { testPath => + val appendPath1 = testPath + "/append1" + val appendPath2 = testPath + "/append2" + val leftIndexName = "index_Append" + val rightIndexName = "indexType2_Append" + val (leftAppended, leftDeleted) = setupIndexAndChangeData( + fileFormat, + appendPath1, + indexConfig1.copy(indexName = leftIndexName), + appendCnt = 1, + deleteCnt = 0) + val (rightAppended, rightDeleted) = setupIndexAndChangeData( + fileFormat, + appendPath2, + indexConfig2.copy(indexName = rightIndexName), + appendCnt = 1, + deleteCnt = 0) + + val df1 = spark.read.format(fileFormat).load(appendPath1) + val df2 = spark.read.format(fileFormat).load(appendPath2) + def joinQuery(): DataFrame = { + val query2 = df1 + .filter(df1("nested.leaf.cnt") >= 20) + .select(df1("nested.leaf.cnt"), df1("query"), df1("nested.leaf.id")) + val query = df2 + .filter(df2("nested.leaf.cnt") <= 40) + .select(df2("nested.leaf.cnt"), df2("Date"), df2("nested.leaf.id")) + query2.join(query, "cnt") + } + val baseQuery = joinQuery() + val basePlan = baseQuery.queryExecution.optimizedPlan + + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + IndexConstants.INDEX_HYBRID_SCAN_ENABLED -> "false") { + val join = joinQuery() + checkAnswer(join, baseQuery) + } + + withSQLConf(TestConfig.HybridScanEnabled: _*) { + val join = joinQuery() + val planWithHybridScan = join.queryExecution.optimizedPlan + assert(!basePlan.equals(planWithHybridScan)) + checkJoinIndexHybridScan( + planWithHybridScan, + leftIndexName, + leftAppended, + leftDeleted, + rightIndexName, + rightAppended, + rightDeleted, + Seq(">= 20", "<= 40")) + checkAnswer(join, baseQuery) + } + } + } + } + + test( + "Append-only: filter rule and non-parquet format," + + "appended data should be shuffled and merged by Union.") { + // Note: for delta lake, this test is also eligible as the dataset is partitioned. + withTempPathAsString { testPath => + val (appendedFiles, deletedFiles) = setupIndexAndChangeData( + fileFormat2, + testPath, + indexConfig1.copy(indexName = "index_Format2Append"), + appendCnt = 1, + deleteCnt = 0) + + val df = spark.read.format(fileFormat2).load(testPath) + def filterQuery: DataFrame = df.filter(df("nested.leaf.cnt") <= 20).select(df("query")) + + val baseQuery = filterQuery + val basePlan = baseQuery.queryExecution.optimizedPlan + + withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_ENABLED -> "false") { + val filter = filterQuery + assert(basePlan.equals(filter.queryExecution.optimizedPlan)) + } + + withSQLConf(TestConfig.HybridScanEnabledAppendOnly: _*) { + val filter = filterQuery + val planWithHybridScan = filter.queryExecution.optimizedPlan + assert(!basePlan.equals(planWithHybridScan)) + + checkFilterIndexHybridScanUnion( + planWithHybridScan, + "index_Format2Append", + appendedFiles, + deletedFiles, + Seq(" <= 20")) + + // Check bucketSpec is not used. + val bucketSpec = planWithHybridScan collect { + case LogicalRelation(HadoopFsRelation(_, _, _, bucketSpec, _, _), _, _, _) => + bucketSpec + } + assert(bucketSpec.length == 2) + + // bucketSpec.head is for the index plan, bucketSpec.last is for the plan + // for appended files. + assert(bucketSpec.head.isEmpty && bucketSpec.last.isEmpty) + + checkAnswer(baseQuery, filter) + } + } + } + + test( + "Append-only: filter rule and non-parquet format," + + "appended data should be shuffled and merged by Union even with bucketSpec.") { + withTempPathAsString { testPath => + val (appendedFiles, deletedFiles) = setupIndexAndChangeData( + fileFormat2, + testPath, + indexConfig1.copy(indexName = "index_Format2Append"), + appendCnt = 1, + deleteCnt = 0) + + val df = spark.read.format(fileFormat2).load(testPath) + def filterQuery: DataFrame = df.filter(df("nested.leaf.cnt") <= 20).select(df("query")) + + val baseQuery = filterQuery + val basePlan = baseQuery.queryExecution.optimizedPlan + + withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_ENABLED -> "false") { + val filter = filterQuery + assert(basePlan.equals(filter.queryExecution.optimizedPlan)) + } + + withSQLConf( + TestConfig.HybridScanEnabledAppendOnly :+ + IndexConstants.INDEX_FILTER_RULE_USE_BUCKET_SPEC -> "true": _*) { + val filter = filterQuery + val planWithHybridScan = filter.queryExecution.optimizedPlan + assert(!basePlan.equals(planWithHybridScan)) + + checkFilterIndexHybridScanUnion( + planWithHybridScan, + "index_Format2Append", + appendedFiles, + deletedFiles, + Seq(" <= 20")) + + // Check bucketSpec is used. + val bucketSpec = planWithHybridScan collect { + case LogicalRelation(HadoopFsRelation(_, _, _, bucketSpec, _, _), _, _, _) => + bucketSpec + } + assert(bucketSpec.length == 2) + // bucketSpec.head is for the index plan, bucketSpec.last is for the plan + // for appended files. + assert(bucketSpec.head.isDefined && bucketSpec.last.isEmpty) + assert( + bucketSpec.head.get.bucketColumnNames.toSet === indexConfig1.indexedColumns.toSet + .map(SchemaUtils.escapeFieldName)) + + checkAnswer(baseQuery, filter) + } + } + } + + test("Delete-only: index relation should have additional filter for deleted files.") { + val testSet = Seq(("index_ParquetDelete", fileFormat), ("index_JsonDelete", fileFormat2)) + testSet foreach { + case (indexName, dataFormat) => + withTempPathAsString { dataPath => + val (appendedFiles, deletedFiles) = setupIndexAndChangeData( + dataFormat, + dataPath, + indexConfig1.copy(indexName = indexName), + appendCnt = 0, + deleteCnt = 1) + + val df = spark.read.format(dataFormat).load(dataPath) + def filterQuery: DataFrame = + df.filter(df("nested.leaf.cnt") <= 20).select(df("query")) + + val baseQuery = filterQuery + val basePlan = baseQuery.queryExecution.optimizedPlan + + withSQLConf(TestConfig.HybridScanEnabledAppendOnly: _*) { + val filter = filterQuery + assert(basePlan.equals(filter.queryExecution.optimizedPlan)) + } + + withSQLConf(TestConfig.HybridScanEnabled: _*) { + val filter = filterQuery + val planWithHybridScan = filter.queryExecution.optimizedPlan + assert(!basePlan.equals(planWithHybridScan)) + checkFilterIndexHybridScanUnion( + planWithHybridScan, + indexName, + appendedFiles, + deletedFiles, + Seq(" <= 20")) + checkAnswer(baseQuery, filter) + } + } + } + } + + test("Delete-only: join rule, deleted files should be excluded from each index data relation.") { + withTempPathAsString { testPath => + val deletePath1 = testPath + "/delete1" + val deletePath2 = testPath + "/delete2" + val leftIndexName = "index_Delete" + val rightIndexName = "indexType2_Delete2" + val (leftAppended, leftDeleted) = setupIndexAndChangeData( + fileFormat, + deletePath1, + indexConfig1.copy(indexName = leftIndexName), + appendCnt = 0, + deleteCnt = 2) + val (rightAppended, rightDeleted) = setupIndexAndChangeData( + fileFormat, + deletePath2, + indexConfig2.copy(indexName = rightIndexName), + appendCnt = 0, + deleteCnt = 2) + + val df1 = spark.read.format(fileFormat).load(deletePath1) + val df2 = spark.read.format(fileFormat).load(deletePath2) + + def joinQuery(): DataFrame = { + val query = + df1.filter(df1("nested.leaf.cnt") >= 20).select(df1("nested.leaf.cnt"), df1("query")) + val query2 = + df2.filter(df2("nested.leaf.cnt") <= 40).select(df2("nested.leaf.cnt"), df2("Date")) + query.join(query2, "cnt") + } + + val baseQuery = joinQuery() + val basePlan = baseQuery.queryExecution.optimizedPlan + + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { + withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_ENABLED -> "false") { + val join = joinQuery() + checkAnswer(baseQuery, join) + } + + withSQLConf(TestConfig.HybridScanEnabled: _*) { + val join = joinQuery() + val planWithHybridScan = join.queryExecution.optimizedPlan + assert(!basePlan.equals(planWithHybridScan)) + checkJoinIndexHybridScan( + planWithHybridScan, + leftIndexName, + leftAppended, + leftDeleted, + rightIndexName, + rightAppended, + rightDeleted, + Seq(" >= 20", " <= 40")) + checkAnswer(join, baseQuery) + } + } + } + } + + test( + "Append+Delete: filter rule, appended files should be handled " + + "with additional plan and merged by Union.") { + + withTempPathAsString { testPath => + val (appendedFiles, deletedFiles) = setupIndexAndChangeData( + fileFormat, + testPath, + indexConfig1.copy(indexName = "index_appendAndDelete"), + appendCnt = 1, + deleteCnt = 1) + + val df = spark.read.format(fileFormat).load(testPath) + + def filterQuery: DataFrame = + df.filter(df("nested.leaf.cnt") <= 20).select(df("query")) + + val baseQuery = filterQuery + val basePlan = baseQuery.queryExecution.optimizedPlan + + withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_ENABLED -> "false") { + val filter = filterQuery + assert(basePlan.equals(filter.queryExecution.optimizedPlan)) + } + + withSQLConf(TestConfig.HybridScanEnabled: _*) { + val filter = filterQuery + val planWithHybridScan = filter.queryExecution.optimizedPlan + assert(!basePlan.equals(planWithHybridScan)) + + checkFilterIndexHybridScanUnion( + planWithHybridScan, + "index_appendAndDelete", + appendedFiles, + deletedFiles, + Seq(" <= 20")) + checkAnswer(baseQuery, filter) + } + } + } + + test( + "Append+Delete: join rule, appended data should be shuffled with indexed columns " + + "and merged by BucketUnion and deleted files are handled with index data.") { + // One relation has both deleted & appended files and the other one has only deleted files. + withTempPathAsString { testPath => + val leftDataPath = testPath + "/leftPath" + val rightDataPath = testPath + "/rightPath" + val leftIndexName = "index_Both" + val rightIndexName = "indexType2_Delete" + val (leftAppended, leftDeleted) = setupIndexAndChangeData( + fileFormat, + leftDataPath, + indexConfig1.copy(indexName = leftIndexName), + appendCnt = 1, + deleteCnt = 1) + val (rightAppended, rightDeleted) = setupIndexAndChangeData( + fileFormat, + rightDataPath, + indexConfig2.copy(indexName = rightIndexName), + appendCnt = 0, + deleteCnt = 2) + + val df1 = spark.read.format(fileFormat).load(leftDataPath) + val df2 = spark.read.format(fileFormat).load(rightDataPath) + def joinQuery(): DataFrame = { + val query = + df1.filter(df1("nested.leaf.cnt") >= 20).select(df1("nested.leaf.cnt"), df1("query")) + val query2 = + df2.filter(df2("nested.leaf.cnt") <= 40).select(df2("nested.leaf.cnt"), df2("Date")) + query.join(query2, "cnt") + } + val baseQuery = joinQuery() + val basePlan = baseQuery.queryExecution.optimizedPlan + + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { + withSQLConf(IndexConstants.INDEX_HYBRID_SCAN_ENABLED -> "false") { + val join = joinQuery() + checkAnswer(baseQuery, join) + } + + withSQLConf( + TestConfig.HybridScanEnabled :+ + "spark.sql.optimizer.inSetConversionThreshold" -> "1": _*) { + // Changed inSetConversionThreshold to check InSet optimization. + val join = joinQuery() + val planWithHybridScan = join.queryExecution.optimizedPlan + assert(!basePlan.equals(planWithHybridScan)) + checkJoinIndexHybridScan( + planWithHybridScan, + leftIndexName, + leftAppended, + leftDeleted, + rightIndexName, + rightAppended, + rightDeleted, + Seq(" >= 20)", " <= 40)")) + checkAnswer(baseQuery, join) + } + } + } + } } From 19d26f8dd92fc96e1e5300e15fb6a23cf7e2b721 Mon Sep 17 00:00:00 2001 From: Andrei Ionescu Date: Wed, 3 Mar 2021 12:07:51 +0200 Subject: [PATCH 3/3] Integrate review feedback (1) --- .../index/rules/FilterIndexRule.scala | 103 +--------- .../index/rules/JoinIndexRule.scala | 12 +- .../hyperspace/index/rules/PlanUtils.scala | 190 ++++++++++++++++++ .../hyperspace/index/rules/RuleUtils.scala | 93 +++++++-- .../FileBasedSourceProviderManager.scala | 40 +--- 5 files changed, 275 insertions(+), 163 deletions(-) create mode 100644 src/main/scala/com/microsoft/hyperspace/index/rules/PlanUtils.scala diff --git a/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala b/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala index 3f28bd51c..9a7ac64be 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala @@ -171,9 +171,11 @@ object ExtractFilterNode { val projectColumnNames = CleanupAliases(project) .asInstanceOf[Project] .projectList - .map(extractNamesFromExpression) + .map(PlanUtils.extractNamesFromExpression) .flatMap(_.toSeq) - val filterColumnNames = extractNamesFromExpression(condition).toSeq + val filterColumnNames = PlanUtils + .extractNamesFromExpression(condition) + .toSeq .sortBy(-_.length) .foldLeft(Seq.empty[String]) { (acc, e) => if (!acc.exists(i => i.startsWith(e))) { @@ -194,103 +196,6 @@ object ExtractFilterNode { case _ => None // plan does not match with any of filter index rule patterns } - - def extractNamesFromExpression(exp: Expression): Set[String] = { - exp match { - case AttributeReference(name, _, _, _) => - Set(s"$name") - case otherExp => - otherExp.containsChild.flatMap { - case g: GetStructField => - Set(s"${getChildNameFromStruct(g)}") - case e: Expression => - extractNamesFromExpression(e).filter(_.nonEmpty) - case _ => Set.empty[String] - } - } - } - - def getChildNameFromStruct(field: GetStructField): String = { - field.child match { - case f: GetStructField => - s"${getChildNameFromStruct(f)}.${field.name.get}" - case a: AttributeReference => - s"${a.name}.${field.name.get}" - case _ => - s"${field.name.get}" - } - } - - def replaceInSearchQuery( - parent: Expression, - needle: Expression, - repl: Expression): Expression = { - parent.mapChildren { c => - if (c == needle) { - repl - } else { - c - } - } - } - - def extractAttributeRef(exp: Expression, name: String): AttributeReference = { - val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX) - val elem = exp.find { - case a: AttributeReference if splits.contains(a.name) => true - case _ => false - } - elem.get.asInstanceOf[AttributeReference] - } - - def extractTypeFromExpression(exp: Expression, name: String): DataType = { - val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX) - val elem = exp.flatMap { - case attrRef: AttributeReference => - if (splits.forall(s => attrRef.name == s)) { - Some((name, attrRef.dataType)) - } else { - Try({ - val h :: t = splits.toList - if (attrRef.name == h && attrRef.dataType.isInstanceOf[StructType]) { - val currentDataType = attrRef.dataType.asInstanceOf[StructType] - var localDT = currentDataType - val foldedFields = t.foldLeft(Seq.empty[(String, DataType)]) { (acc, i) => - val collected = localDT.collect { - case dt if dt.name == i => - dt.dataType match { - case st: StructType => - localDT = st - case _ => - } - (i, dt.dataType) - } - acc ++ collected - } - Some(foldedFields.last) - } else { - None - } - }).getOrElse(None) - } - case f: GetStructField if splits.forall(s => f.toString().contains(s)) => - Some((name, f.dataType)) - case _ => None - } - elem.find(e => e._1 == name || e._1 == splits.last).get._2 - } - - def collectAliases(plan: LogicalPlan): Seq[(String, Attribute, Expression)] = { - plan - .collect { - case Project(projectList, _) => - projectList.collect { - case a @ Alias(child, name) => - (name, a.toAttribute, child) - } - } - .flatten - } } object ExtractRelation extends ActiveSparkSession { diff --git a/src/main/scala/com/microsoft/hyperspace/index/rules/JoinIndexRule.scala b/src/main/scala/com/microsoft/hyperspace/index/rules/JoinIndexRule.scala index aaf7d1766..bab92a94d 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/rules/JoinIndexRule.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/rules/JoinIndexRule.scala @@ -451,13 +451,13 @@ object JoinIndexRule val fields = conditionFieldsToRelationFields(project.projectList).values fields.flatMap { case g: GetStructField => - Seq(ExtractFilterNode.getChildNameFromStruct(g)) + Seq(PlanUtils.getChildNameFromStruct(g)) case otherFieldType => - ExtractFilterNode.extractNamesFromExpression(otherFieldType).toSeq + PlanUtils.extractNamesFromExpression(otherFieldType).toSeq } case filter: Filter => var acc = Seq.empty[String] - val fls = ExtractFilterNode + val fls = PlanUtils .extractNamesFromExpression(filter.condition) .toSeq .distinct @@ -481,7 +481,7 @@ object JoinIndexRule plan.outputSet.map { i => val attr = extractFieldFromProjection(i, projectionFields) val opt = attr.map { e => - ExtractFilterNode.getChildNameFromStruct(e.asInstanceOf[GetStructField]) + PlanUtils.getChildNameFromStruct(e.asInstanceOf[GetStructField]) } opt.getOrElse(i.name) } @@ -517,7 +517,7 @@ object JoinIndexRule val attrLeftName = if (lp.nonEmpty) { Try { val attrLeft = extractFieldFromProjection(attr1, lp).get - ExtractFilterNode.getChildNameFromStruct(attrLeft.asInstanceOf[GetStructField]) + PlanUtils.getChildNameFromStruct(attrLeft.asInstanceOf[GetStructField]) }.getOrElse(attr1.name) } else { attr1.name @@ -525,7 +525,7 @@ object JoinIndexRule val attrRightName = if (rp.nonEmpty) { Try { val attrRight = extractFieldFromProjection(attr2, rp).get - ExtractFilterNode.getChildNameFromStruct(attrRight.asInstanceOf[GetStructField]) + PlanUtils.getChildNameFromStruct(attrRight.asInstanceOf[GetStructField]) }.getOrElse(attr2.name) } else { attr2.name diff --git a/src/main/scala/com/microsoft/hyperspace/index/rules/PlanUtils.scala b/src/main/scala/com/microsoft/hyperspace/index/rules/PlanUtils.scala new file mode 100644 index 000000000..5a2f5ba0b --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/rules/PlanUtils.scala @@ -0,0 +1,190 @@ +/* + * Copyright (2020) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.rules + +import scala.util.Try + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, GetStructField} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.types.{DataType, StructType} + +import com.microsoft.hyperspace.util.SchemaUtils + +object PlanUtils { + + /** + * Returns true if the given project is a supported project. If all of the registered + * providers return None, this returns false. + * + * @param project Project to check if it's supported. + * @return True if the given project is a supported relation. + */ + def isSupportedProject(project: Project): Boolean = { + val containsNestedFields = + SchemaUtils.hasNestedFields(project.projectList.flatMap(extractNamesFromExpression)) + var containsNestedChildren = false + project.child.foreach { + case f: Filter => + containsNestedChildren = containsNestedChildren || { + SchemaUtils.hasNestedFields( + SchemaUtils.unescapeFieldNames(extractNamesFromExpression(f.condition).toSeq)) + } + case _ => + } + containsNestedFields || containsNestedChildren + } + + /** + * Returns true if the given filter is a supported filter. If all of the registered + * providers return None, this returns false. + * + * @param filter Filter to check if it's supported. + * @return True if the given project is a supported relation. + */ + def isSupportedFilter(filter: Filter): Boolean = { + val containsNestedFields = + SchemaUtils.hasNestedFields(extractNamesFromExpression(filter.condition).toSeq) + containsNestedFields + } + + /** + * Given an expression it extracts all the field names from it. + * + * @param exp Expression to extract field names from + * @return A set of distinct strings representing the field names + * (ie: `Set(nested.field.id, nested.field.other)`) + */ + def extractNamesFromExpression(exp: Expression): Set[String] = { + exp match { + case AttributeReference(name, _, _, _) => + Set(s"$name") + case otherExp => + otherExp.containsChild.flatMap { + case g: GetStructField => + Set(s"${getChildNameFromStruct(g)}") + case e: Expression => + extractNamesFromExpression(e).filter(_.nonEmpty) + case _ => Set.empty[String] + } + } + } + + /** + * Given a nested field this method extracts the full name out of it. + * + * @param field The field from which to get the name from + * @return The name of the field (ie: `nested.field.id`) + */ + def getChildNameFromStruct(field: GetStructField): String = { + field.child match { + case f: GetStructField => + s"${getChildNameFromStruct(f)}.${field.name.get}" + case a: AttributeReference => + s"${a.name}.${field.name.get}" + case _ => + s"${field.name.get}" + } + } + + /** + * Given an expression it extracts the attribute reference by field name. + * + * @param exp The expression where to look for the attribute reference + * @param name The name of the field to look for + * @return The attribute reference for that field name + */ + def extractAttributeRef(exp: Expression, name: String): AttributeReference = { + val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX) + val elem = exp.find { + case a: AttributeReference if splits.contains(a.name) => true + case _ => false + } + elem.get.asInstanceOf[AttributeReference] + } + + /** + * Given and expression it extracts the type of the field by field name. + * + * @param exp The expression from where to extract the type from + * @param name The name of the field to look for + * @return The type of the field as [[DataType]] + */ + def extractTypeFromExpression(exp: Expression, name: String): DataType = { + val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX) + val elem = exp.flatMap { + case attrRef: AttributeReference => + if (splits.forall(s => attrRef.name == s)) { + Some((name, attrRef.dataType)) + } else { + Try({ + val h :: t = splits.toList + if (attrRef.name == h && attrRef.dataType.isInstanceOf[StructType]) { + val currentDataType = attrRef.dataType.asInstanceOf[StructType] + var localDT = currentDataType + val foldedFields = t.foldLeft(Seq.empty[(String, DataType)]) { (acc, i) => + val collected = localDT.collect { + case dt if dt.name == i => + dt.dataType match { + case st: StructType => + localDT = st + case _ => + } + (i, dt.dataType) + } + acc ++ collected + } + Some(foldedFields.last) + } else { + None + } + }).getOrElse(None) + } + case f: GetStructField if splits.forall(s => f.toString().contains(s)) => + Some((name, f.dataType)) + case _ => None + } + elem.find(e => e._1 == name || e._1 == splits.last).get._2 + } + + /** + * Given a logical plan the method collects all aliases in the plan. + * For example, given this projection + * `Project [nested#548.leaf.cnt AS cnt#659, Date#543, nested#548.leaf.id AS id#660]` + * the result will be: + * {{{ + * Seq( + * ("cnt", cnt#659, nested#548.leaf.cnt), + * ("id", id#660, nested#548.leaf.id) + * ) + * }}} + * + * @param plan The plan from which to collect the aliases + * @return A collection of: + * - a string representing the alias name + * - the attribute the alias transforms to + * - the expressions from which this alias comes from + */ + def collectAliases(plan: LogicalPlan): Seq[(String, Attribute, Expression)] = { + plan.collect { + case Project(projectList, _) => + projectList.collect { + case a @ Alias(child, name) => + (name, a.toAttribute, child) + } + }.flatten + } +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala b/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala index 29bbb61b7..c6f07800a 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala @@ -33,8 +33,7 @@ import com.microsoft.hyperspace.index._ import com.microsoft.hyperspace.index.IndexLogEntryTags.{HYBRIDSCAN_RELATED_CONFIGS, IS_HYBRIDSCAN_CANDIDATE} import com.microsoft.hyperspace.index.plans.logical.{BucketUnion, IndexHadoopFsRelation} import com.microsoft.hyperspace.index.sources.FileBasedRelation -import com.microsoft.hyperspace.util.HyperspaceConf -import com.microsoft.hyperspace.util.SchemaUtils +import com.microsoft.hyperspace.util.{HyperspaceConf, SchemaUtils} object RuleUtils { @@ -289,7 +288,8 @@ object RuleUtils { val flatSchema = SchemaUtils.escapeFieldNames(SchemaUtils.flatten(relation.plan.schema)) val updatedOutput = - if (SchemaUtils.hasNestedFields(SchemaUtils.unescapeFieldNames(flatSchema))) { + if (index.usesNestedFields && + SchemaUtils.hasNestedFields(SchemaUtils.unescapeFieldNames(flatSchema))) { indexFsRelation.schema.flatMap { s => val exprId = getFieldPosition(index, s.name) relation.plan.output.find(a => s.name.contains(a.name)).map { a => @@ -306,10 +306,10 @@ object RuleUtils { relation.createLogicalRelation(indexFsRelation, updatedOutput) - case p: Project if provider.isSupportedProject(p) => + case p: Project if PlanUtils.isSupportedProject(p) => transformProject(p, index) - case f: Filter if provider.isSupportedFilter(f) => + case f: Filter if PlanUtils.isSupportedFilter(f) => transformFilter(f, index) } } @@ -423,7 +423,8 @@ object RuleUtils { Map(IndexConstants.INDEX_RELATION_IDENTIFIER))(spark, index) val updatedOutput = - if (SchemaUtils.hasNestedFields(SchemaUtils.unescapeFieldNames(flatSchema))) { + if (index.usesNestedFields && + SchemaUtils.hasNestedFields(SchemaUtils.unescapeFieldNames(flatSchema))) { indexFsRelation.schema.flatMap { s => val exprId = getFieldPosition(index, s.name) relation.plan.output.find(a => s.name.contains(a.name)).map { a => @@ -449,10 +450,10 @@ object RuleUtils { Project(updatedOutput, OptimizeIn(filterForDeleted)) } - case p: Project if provider.isSupportedProject(p) => + case p: Project if PlanUtils.isSupportedProject(p) => transformProject(p, index) - case f: Filter if provider.isSupportedFilter(f) => + case f: Filter if PlanUtils.isSupportedFilter(f) => transformFilter(f, index) } @@ -476,12 +477,12 @@ object RuleUtils { // and sortColumnNames are shown in plan string. So remove sortColumnNames to avoid // misunderstanding. - val aliases = ExtractFilterNode + val aliases = PlanUtils .collectAliases(plan) .collect { case (shortName, _, ref: GetStructField) => val escapedFieldName = - SchemaUtils.escapeFieldName(ExtractFilterNode.getChildNameFromStruct(ref)) + SchemaUtils.escapeFieldName(PlanUtils.getChildNameFromStruct(ref)) escapedFieldName -> shortName } .toMap @@ -642,13 +643,31 @@ object RuleUtils { shuffled } + /** + * Transforms the projection to use the field in the index. The nested field in the + * projection is different from the top level field stored in the index. For example + * {{{Project [ + * nested#536.leaf.cnt AS cnt#556, + * query#533, + * nested#536.leaf.id AS id#557]}}} + * + * must be transformed into something similar to this: + * {{{Project [ + * nested__leaf__cnt#0 AS cnt#653, + * query#1 AS query#533, + * nested__leaf__id#2 AS id#654]}}} + * + * @param project The projection we want to transform + * @param index The suitable index + * @return + */ private def transformProject(project: Project, index: IndexLogEntry): Project = { val projectedFields = project.projectList.map { exp => - val fieldName = ExtractFilterNode.extractNamesFromExpression(exp).head + val fieldName = PlanUtils.extractNamesFromExpression(exp).head val shortFieldName = fieldName.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX).last val escapedFieldName = SchemaUtils.escapeFieldName(fieldName) - val attr = ExtractFilterNode.extractAttributeRef(exp, fieldName) - val fieldType = ExtractFilterNode.extractTypeFromExpression(exp, fieldName) + val attr = PlanUtils.extractAttributeRef(exp, fieldName) + val fieldType = PlanUtils.extractTypeFromExpression(exp, fieldName) val exprId = getFieldPosition(index, escapedFieldName) val attrCopy = attr.copy(escapedFieldName, fieldType, attr.nullable, attr.metadata)( ExprId(exprId), @@ -662,16 +681,37 @@ object RuleUtils { project.copy(projectList = projectedFields) } + /** + * Transforms the filter to use the field in the index. The nested field in the + * filter is different from the top level field stored in the index. For example + * {{{Filter ( + * (isnotnull(nested#536) && + * (nested#536.leaf.cnt >= 20)) && + * (nested#536.leaf.cnt <= 40))}}} + * + * must be transformed into something similar to this: + * {{{Filter ( + * (isnotnull(nested__leaf__cnt#0) && + * (nested__leaf__cnt#0 >= 20)) && + * (nested__leaf__cnt#0 <= 40))}}} + * + * Pre-requisite + * - The index must have at least one nested field + * + * @param filter The filter we want to transform + * @param index The suitable index + * @return + */ private def transformFilter(filter: Filter, index: IndexLogEntry): Filter = { val nestedFields = getNestedFields(index) if (nestedFields.nonEmpty) { val newCondition = filter.condition.transformDown { case gsf: GetStructField => - val fieldName = ExtractFilterNode.getChildNameFromStruct(gsf) + val fieldName = PlanUtils.getChildNameFromStruct(gsf) val escapedFieldName = SchemaUtils.escapeFieldName(fieldName) if (nestedFields.contains(escapedFieldName)) { - val fieldType = ExtractFilterNode.extractTypeFromExpression(gsf, fieldName) - val attr = ExtractFilterNode.extractAttributeRef(gsf, fieldName) + val fieldType = PlanUtils.extractTypeFromExpression(gsf, fieldName) + val attr = PlanUtils.extractAttributeRef(gsf, fieldName) val exprId = getFieldPosition(index, escapedFieldName) val newAttr = attr.copy(name = escapedFieldName, dataType = fieldType)( ExprId(exprId), @@ -682,15 +722,14 @@ object RuleUtils { } case cond @ IsNotNull(child) => val fieldName = - SchemaUtils.escapeFieldName(ExtractFilterNode.extractNamesFromExpression(child).head) + SchemaUtils.escapeFieldName(PlanUtils.extractNamesFromExpression(child).head) val elemFound = nestedFields.find(i => i.contains(fieldName)) elemFound match { case Some(name) => val newChild = child match { case attr: AttributeReference => - val fieldType = ExtractFilterNode.extractTypeFromExpression( - cond, - SchemaUtils.unescapeFieldName(name)) + val fieldType = + PlanUtils.extractTypeFromExpression(cond, SchemaUtils.unescapeFieldName(name)) val exprId = getFieldPosition(index, name) attr.copy(name = name, dataType = fieldType)(ExprId(exprId), attr.qualifier) case other => @@ -707,10 +746,24 @@ object RuleUtils { } } + /** + * The method collects a list of nested field names from the index schema. + * + * @param index The chosen index + * @return A collection of nested field names + */ private def getNestedFields(index: IndexLogEntry): Seq[String] = { index.schema.fieldNames.filter(_.contains(SchemaUtils.NESTED_FIELD_REPLACEMENT)) } + /** + * Given and index and a field name it returns the position in the index schema. + * This method is used to properly create attributes over the index dataset. + * + * @param index The chosen index + * @param fieldName The field name for which we need to find its position in the index schema + * @return + */ private def getFieldPosition(index: IndexLogEntry, fieldName: String): Int = { index.schema.fieldNames.indexWhere(_.equalsIgnoreCase(fieldName)) } diff --git a/src/main/scala/com/microsoft/hyperspace/index/sources/FileBasedSourceProviderManager.scala b/src/main/scala/com/microsoft/hyperspace/index/sources/FileBasedSourceProviderManager.scala index a9cee8a0e..fc64537a3 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/sources/FileBasedSourceProviderManager.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/sources/FileBasedSourceProviderManager.scala @@ -19,13 +19,12 @@ package com.microsoft.hyperspace.index.sources import scala.util.{Success, Try} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.util.hyperspace.Utils import com.microsoft.hyperspace.HyperspaceException import com.microsoft.hyperspace.index.Relation -import com.microsoft.hyperspace.index.rules.ExtractFilterNode -import com.microsoft.hyperspace.util.{CacheWithTransform, HyperspaceConf, SchemaUtils} +import com.microsoft.hyperspace.util.{CacheWithTransform, HyperspaceConf} /** * [[FileBasedSourceProviderManager]] is responsible for loading source providers which implements @@ -91,41 +90,6 @@ class FileBasedSourceProviderManager(spark: SparkSession) { run(p => p.getRelation(plan)) } - /** - * Returns true if the given project is a supported project. If all of the registered - * providers return None, this returns false. - * - * @param project Project to check if it's supported. - * @return True if the given project is a supported relation. - */ - def isSupportedProject(project: Project): Boolean = { - val containsNestedFields = SchemaUtils.hasNestedFields( - project.projectList.flatMap(ExtractFilterNode.extractNamesFromExpression)) - var containsNestedChildren = false - project.child.foreach { - case f: Filter => - containsNestedChildren = containsNestedChildren || { - SchemaUtils.hasNestedFields(SchemaUtils.unescapeFieldNames( - ExtractFilterNode.extractNamesFromExpression(f.condition).toSeq)) - } - case _ => - } - containsNestedFields || containsNestedChildren - } - - /** - * Returns true if the given filter is a supported filter. If all of the registered - * providers return None, this returns false. - * - * @param filter Filter to check if it's supported. - * @return True if the given project is a supported relation. - */ - def isSupportedFilter(filter: Filter): Boolean = { - val containsNestedFields = SchemaUtils.hasNestedFields( - ExtractFilterNode.extractNamesFromExpression(filter.condition).toSeq) - containsNestedFields - } - /** * Runs the given function 'f', which executes a [[FileBasedSourceProvider]]'s API that returns * [[Option]] for each provider built. This function ensures that only one provider returns