Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
Add support for nested fields in joins
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-ionescu committed Mar 2, 2021
1 parent 45fe1ce commit 2223313
Show file tree
Hide file tree
Showing 6 changed files with 762 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,23 @@ private[hyperspace] case class BucketUnionExec(children: Seq[SparkPlan], bucketS
override def output: Seq[Attribute] = children.head.output

override def outputPartitioning: Partitioning = {
assert(children.map(_.outputPartitioning).toSet.size == 1)
assert(children.head.outputPartitioning.isInstanceOf[HashPartitioning])
assert(
children.head.outputPartitioning
.asInstanceOf[HashPartitioning]
.numPartitions == bucketSpec.numBuckets)
children.head.outputPartitioning
val parts = children.map(_.outputPartitioning).distinct
assert(parts.forall(_.isInstanceOf[HashPartitioning]))
assert(parts.forall(_.numPartitions == bucketSpec.numBuckets))

val reduced = parts.reduceLeft { (a, b) =>
val h1 = a.asInstanceOf[HashPartitioning]
val h2 = b.asInstanceOf[HashPartitioning]
val h1Name = h1.references.head.name
val h2Name = h2.references.head.name
val same = h1Name.contains(h2Name) || h2Name.contains(h1Name)
assert(same)
if (h1Name.length > h2Name.length) {
h1
} else {
h2
}
}
reduced
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.util.Try

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GetStructField}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, GetStructField}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -200,12 +200,12 @@ object ExtractFilterNode {
case AttributeReference(name, _, _, _) =>
Set(s"$name")
case otherExp =>
otherExp.containsChild.map {
otherExp.containsChild.flatMap {
case g: GetStructField =>
s"${getChildNameFromStruct(g)}"
Set(s"${getChildNameFromStruct(g)}")
case e: Expression =>
extractNamesFromExpression(e).filter(_.nonEmpty).mkString(".")
case _ => ""
extractNamesFromExpression(e).filter(_.nonEmpty)
case _ => Set.empty[String]
}
}
}
Expand All @@ -221,20 +221,6 @@ object ExtractFilterNode {
}
}

def extractSearchQuery(exp: Expression, name: String): (Expression, Expression) = {
val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX)
val expFound = exp.find {
case a: AttributeReference if splits.forall(s => a.name.contains(s)) => true
case f: GetStructField if splits.forall(s => f.toString().contains(s)) => true
case _ => false
}.get
val parent = exp.find {
case e: Expression if e.containsChild.contains(expFound) => true
case _ => false
}.get
(parent, expFound)
}

def replaceInSearchQuery(
parent: Expression,
needle: Expression,
Expand All @@ -260,17 +246,26 @@ object ExtractFilterNode {
def extractTypeFromExpression(exp: Expression, name: String): DataType = {
val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX)
val elem = exp.flatMap {
case a: AttributeReference =>
if (splits.forall(s => a.name == s)) {
Some((name, a.dataType))
case attrRef: AttributeReference =>
if (splits.forall(s => attrRef.name == s)) {
Some((name, attrRef.dataType))
} else {
Try({
val h :: t = splits.toList
if (a.name == h && a.dataType.isInstanceOf[StructType]) {
val currentDataType = a.dataType.asInstanceOf[StructType]
if (attrRef.name == h && attrRef.dataType.isInstanceOf[StructType]) {
val currentDataType = attrRef.dataType.asInstanceOf[StructType]
var localDT = currentDataType
val foldedFields = t.foldLeft(Seq.empty[(String, DataType)]) { (acc, i) =>
val idx = currentDataType.indexWhere(_.name.equalsIgnoreCase(i))
acc :+ (i, currentDataType(idx).dataType)
val collected = localDT.collect {
case dt if dt.name == i =>
dt.dataType match {
case st: StructType =>
localDT = st
case _ =>
}
(i, dt.dataType)
}
acc ++ collected
}
Some(foldedFields.last)
} else {
Expand All @@ -284,6 +279,18 @@ object ExtractFilterNode {
}
elem.find(e => e._1 == name || e._1 == splits.last).get._2
}

def collectAliases(plan: LogicalPlan): Seq[(String, Attribute, Expression)] = {
plan
.collect {
case Project(projectList, _) =>
projectList.collect {
case a @ Alias(child, name) =>
(name, a.toAttribute, child)
}
}
.flatten
}
}

object ExtractRelation extends ActiveSparkSession {
Expand Down
Loading

0 comments on commit 2223313

Please sign in to comment.