Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un

override def newInstance(): UnresolvedAttribute = this
override def withNullability(newNullability: Boolean): UnresolvedAttribute = this
override def withDeterminism(newDeterminism: Boolean): UnresolvedAttribute = this
override def withQualifier(newQualifier: Seq[String]): UnresolvedAttribute = this
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
override def withMetadata(newMetadata: Metadata): Attribute = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ abstract class Attribute extends LeafExpression with NamedExpression {
override lazy val references: AttributeSet = AttributeSet(this)

def withNullability(newNullability: Boolean): Attribute
def withDeterminism(newDeterminism: Boolean): Attribute
def withQualifier(newQualifier: Seq[String]): Attribute
def withName(newName: String): Attribute
def withMetadata(newMetadata: Metadata): Attribute
Expand Down Expand Up @@ -203,7 +204,8 @@ case class Alias(child: Expression, name: String)(

override def toAttribute: Attribute = {
if (resolved) {
AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifier)
AttributeReference(name, child.dataType, child.nullable, metadata, child.deterministic)(
exprId, qualifier)
} else {
UnresolvedAttribute.quoted(name)
}
Expand Down Expand Up @@ -276,6 +278,7 @@ object AttributeReferenceTreeBits {
* @param dataType The [[DataType]] of this attribute.
* @param nullable True if null is a valid value for this attribute.
* @param metadata The metadata of this attribute.
* @param determinism If this reference is deterministic.
* @param exprId A globally unique id used to check if different AttributeReferences refer to the
* same attribute.
* @param qualifier An optional string that can be used to referred to this attribute in a fully
Expand All @@ -286,13 +289,16 @@ case class AttributeReference(
name: String,
dataType: DataType,
nullable: Boolean = true,
override val metadata: Metadata = Metadata.empty)(
override val metadata: Metadata = Metadata.empty,
determinism: Boolean = true)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifier: Seq[String] = Seq.empty[String])
extends Attribute with Unevaluable {

override lazy val treePatternBits: BitSet = AttributeReferenceTreeBits.bits

override lazy val deterministic: Boolean = determinism

/**
* Returns true iff the expression id is the same for both attributes.
*/
Expand Down Expand Up @@ -326,7 +332,7 @@ case class AttributeReference(
}

override def newInstance(): AttributeReference =
AttributeReference(name, dataType, nullable, metadata)(qualifier = qualifier)
AttributeReference(name, dataType, nullable, metadata, determinism)(qualifier = qualifier)

/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
Expand All @@ -335,15 +341,26 @@ case class AttributeReference(
if (nullable == newNullability) {
this
} else {
AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier)
this.copy(nullable = newNullability)(exprId, qualifier)
}
}

/**
* Returns a copy of this [[AttributeReference]] with changed determinism.
*/
override def withDeterminism(newDeterminism: Boolean): AttributeReference = {
if (determinism == newDeterminism) {
this
} else {
this.copy(determinism = newDeterminism)(exprId, qualifier)
}
}

override def withName(newName: String): AttributeReference = {
if (name == newName) {
this
} else {
AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier)
this.copy(name = newName)(exprId, qualifier)
}
}

Expand All @@ -354,24 +371,24 @@ case class AttributeReference(
if (newQualifier == qualifier) {
this
} else {
AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier)
this.copy()(exprId, newQualifier)
}
}

override def withExprId(newExprId: ExprId): AttributeReference = {
if (exprId == newExprId) {
this
} else {
AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier)
this.copy()(newExprId, qualifier)
}
}

override def withMetadata(newMetadata: Metadata): AttributeReference = {
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
this.copy(metadata = newMetadata)(exprId, qualifier)
}

override def withDataType(newType: DataType): AttributeReference = {
AttributeReference(name, newType, nullable, metadata)(exprId, qualifier)
this.copy(dataType = newType)(exprId, qualifier)
}

override protected final def otherCopyArgs: Seq[AnyRef] = {
Expand Down Expand Up @@ -400,6 +417,16 @@ case class AttributeReference(
}
}

object AttributeReference {
/**
* Customize unapply so the adding of the determinism field does not break API
* compatibility.
*/
def unapply(ar: AttributeReference): Some[(String, DataType, Boolean, Metadata)] = {
Some((ar.name, ar.dataType, ar.nullable, ar.metadata))
}
}

/**
* A place holder used when printing expressions without debugging information such as the
* expression id or the unresolved indicator.
Expand Down Expand Up @@ -434,6 +461,8 @@ case class PrettyAttribute(

override def withNullability(newNullability: Boolean): Attribute =
throw SparkUnsupportedOperationException()
override def withDeterminism(newDeterminism: Boolean): Attribute =
throw SparkUnsupportedOperationException()
override def newInstance(): Attribute =
throw SparkUnsupportedOperationException()
override def withQualifier(newQualifier: Seq[String]): Attribute =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,12 @@ object Union {
childOutputs.transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val deterministic = attrs.forall(_.deterministic)
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
firstAttr.withNullability(nullable).withDeterminism(deterministic)
} else {
AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)(
AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata, deterministic)(
firstAttr.exprId, firstAttr.qualifier)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,50 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}


test("SPARK-47031 + SPARK-13473 union - non-deterministic") {
val testRelation2 = LocalRelation($"d".int, $"e".int, $"f".int)

// in subq1 j is deterministic
val subq1 = testRelation.select(
Literal(1.0).as("j")
)

// j is non-deterministic
val subq2 = testRelation2.select(
Rand(10).as("j")
)

// not deterministic in first sub-query
val originalQuery = Union(Seq(subq2, subq1))
.where($"j" > 5L)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's wrong with pushing down j > 5 in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the j column should be non-deterministic (union of non-deterministic and deterministic fields) I don't think we can push? Right?


val optimized = Optimize.execute(originalQuery.analyze)

val correctAnswer = Union(Seq(
subq2,
subq1))
.where($"j" > 5L)
.analyze

comparePlans(optimized, correctAnswer)

// deterministic in first sub query but not second

val originalQueryReversed = Union(Seq(subq1, subq2))
.where($"j" > 5L)

val optimizedReversed = Optimize.execute(originalQueryReversed.analyze)

val correctAnswerReversed = Union(Seq(
subq1,
subq2))
.where($"j" > 5L)
.analyze

comparePlans(optimizedReversed, correctAnswerReversed)
}

test("union filter pushdown w/reference to grand-child field") {
val nonNullableArray = StructField("a", ArrayType(IntegerType, false))
val bField = StructField("b", IntegerType)
Expand Down