Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;

Expand Down Expand Up @@ -71,7 +74,7 @@ public Rule build() {
}

@VisibleForTesting
protected static boolean isBinaryArithmeticSlot(TreeNode<Expression> expr) {
protected static boolean isBinaryArithmeticSlot(Expression expr) {
if (expr instanceof Slot) {
return true;
}
Expand All @@ -81,7 +84,70 @@ protected static boolean isBinaryArithmeticSlot(TreeNode<Expression> expr) {
if (!supportedFunctions.contains(expr.getClass())) {
return false;
}
return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal
|| ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal;

// Float/double arithmetic: precision loss for all operations
if (expr.child(0).getDataType().isFloatLikeType()
|| expr.child(1).getDataType().isFloatLikeType()) {
return false;
}

Expression slotExpr;
Literal literal;
if (expr.child(0) instanceof Literal) {
literal = (Literal) expr.child(0);
slotExpr = expr.child(1);
} else if (expr.child(1) instanceof Literal) {
literal = (Literal) expr.child(1);
slotExpr = expr.child(0);
} else {
return false;
}

if (!canExtractSlot(slotExpr)) {
return false;
}

return checkLiteral(expr, literal);
}

@VisibleForTesting
protected static boolean checkLiteral(Expression expr, Literal literal) {
if (literal.isNullLiteral()) {
return false;
}
if (expr instanceof Multiply || expr instanceof Divide) {
if (literal.isZero()) {
return false;
}
}
return true;
}

@VisibleForTesting
protected static boolean canExtractSlot(Expression expr) {
while (expr instanceof Cast) {
Cast cast = (Cast) expr;
Expression inner = cast.child();
if (!isLosslessWidening(inner.getDataType(), cast.getDataType())) {
return false;
}
expr = inner;
}
return expr instanceof Slot;
}

@VisibleForTesting
protected static boolean isLosslessWidening(DataType src, DataType tgt) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in #64080, DataType add a new interface isInjectiveCastTo, i think we could reuse it here

if (src instanceof IntegralType && tgt instanceof IntegralType) {
return src.width() <= tgt.width();
}
if (src.isDecimalLikeType() && tgt.isDecimalLikeType()) {
return DecimalV3Type.forType(src).getRange() <= DecimalV3Type.forType(tgt).getRange()
&& DecimalV3Type.forType(src).getScale() <= DecimalV3Type.forType(tgt).getScale();
}
if (src instanceof IntegralType && tgt.isDecimalLikeType()) {
return ((IntegralType) src).range() <= DecimalV3Type.forType(tgt).getRange();
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
Expand Down Expand Up @@ -422,40 +421,6 @@ public static <S extends NamedExpression> S selectMinimumColumn(Collection<S> sl
}

/**
* Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
* <p>
* for example:
* - SlotReference to a column:
* col
* - Cast on SlotReference:
* cast(int_col as string)
* cast(cast(int_col as long) as string)
*
* @param expr input expression
* @return Return Optional[ExprId] of underlying slot reference if input expression is a slot or cast on slot.
* Otherwise, return empty optional result.
*/
public static Optional<ExprId> isSlotOrCastOnSlot(Expression expr) {
return extractSlotOrCastOnSlot(expr).map(Slot::getExprId);
}

/**
* Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
*/
public static Optional<Slot> extractSlotOrCastOnSlot(Expression expr) {
while (expr instanceof Cast) {
expr = expr.child(0);
}

if (expr instanceof SlotReference) {
return Optional.of((Slot) expr);
} else {
return Optional.empty();
}
}

/**
* Generate replaceMap Slot -> Expression from NamedExpression[Expression as name]
*/
Expand Down
Loading
Loading