Skip to content

Commit

Permalink
[OPTIQ-427] Off-by-one issues in RemoveDistinctAggregateRule, Aggrega…
Browse files Browse the repository at this point in the history
…teFilterTransposeRule
  • Loading branch information
julianhyde committed Sep 26, 2014
1 parent a1539e3 commit 5d95c5f
Show file tree
Hide file tree
Showing 15 changed files with 183 additions and 83 deletions.
1 change: 1 addition & 0 deletions core/src/main/java/org/eigenbase/rel/AggregateRelBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ protected AggregateRelBase(
assert groupSet.isEmpty() == (groupSet.cardinality() == 0)
: "See https://bugs.openjdk.java.net/browse/JDK-6222207, "
+ "BitSet internal invariants may be violated";
assert groupSet.length() <= child.getRowType().getFieldCount();
for (AggregateCall aggCall : aggCalls) {
assert typeMatchesInferred(aggCall, true);
}
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/java/org/eigenbase/rel/FilterRelBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ protected FilterRelBase(
assert condition != null;
assert RexUtil.isFlat(condition) : condition;
this.condition = condition;
// Too expensive for everyday use:
// assert isValid(true);
}

/**
* Creates a FilterRelBase by parsing serialized output.
*/
protected FilterRelBase(RelInput input) {
this(
input.getCluster(), input.getTraitSet(), input.getInput(),
this(input.getCluster(), input.getTraitSet(), input.getInput(),
input.getExpression("condition"));
}

Expand All @@ -84,6 +85,16 @@ public RexNode getCondition() {
return condition;
}

@Override public boolean isValid(boolean fail) {
final RexChecker checker = new RexChecker(getChild().getRowType(), fail);
condition.accept(checker);
if (checker.getFailureCount() > 0) {
assert !fail;
return false;
}
return true;
}

public RelOptCost computeSelfCost(RelOptPlanner planner) {
double dRows = RelMetadataQuery.getRowCount(this);
double dCpu = RelMetadataQuery.getRowCount(getChild());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import net.hydromatic.optiq.util.BitSets;

import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

/**
Expand Down Expand Up @@ -108,14 +109,8 @@ public Integer apply(Integer a0) {
topGroupSet.set(BitSets.toList(newGroupSet).indexOf(c));
}
final List<AggregateCall> topAggCallList = Lists.newArrayList();
final int offset = newGroupSet.cardinality()
- aggregate.getGroupSet().cardinality();
assert offset > 0;
int i = newGroupSet.cardinality();
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
final List<Integer> args = Lists.newArrayList();
for (int arg : aggregateCall.getArgList()) {
args.add(arg + offset);
}
final Aggregation rollup =
SubstitutionVisitor.getRollup(aggregateCall.getAggregation());
if (rollup == null) {
Expand All @@ -127,8 +122,8 @@ public Integer apply(Integer a0) {
return;
}
topAggCallList.add(
new AggregateCall(rollup, aggregateCall.isDistinct(), args,
aggregateCall.type, aggregateCall.name));
new AggregateCall(rollup, aggregateCall.isDistinct(),
ImmutableList.of(i++), aggregateCall.type, aggregateCall.name));
}
final AggregateRelBase topAggregate =
aggregate.copy(aggregate.getTraitSet(), newFilter, topGroupSet,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
import org.eigenbase.relopt.RelOptRuleOperand;
import org.eigenbase.relopt.RelOptTable;
import org.eigenbase.relopt.RelOptUtil;
import org.eigenbase.relopt.SubstitutionVisitor;
import org.eigenbase.reltype.RelDataType;
import org.eigenbase.sql.fun.SqlStdOperatorTable;
import org.eigenbase.util.Pair;
import org.eigenbase.util.mapping.AbstractSourceMapping;

Expand Down Expand Up @@ -194,7 +194,7 @@ private static AggregateCall rollUp(AggregateCall aggregateCall,
final int i = find(measures, seek);
tryRoll:
if (i >= 0) {
final Aggregation roll = getRollup(aggregation);
final Aggregation roll = SubstitutionVisitor.getRollup(aggregation);
if (roll == null) {
break tryRoll;
}
Expand All @@ -221,18 +221,6 @@ private static AggregateCall rollUp(AggregateCall aggregateCall,
return null;
}

private static Aggregation getRollup(Aggregation aggregation) {
if (aggregation == SqlStdOperatorTable.SUM
|| aggregation == SqlStdOperatorTable.MIN
|| aggregation == SqlStdOperatorTable.MAX) {
return aggregation;
} else if (aggregation == SqlStdOperatorTable.COUNT) {
return SqlStdOperatorTable.SUM;
} else {
return null;
}
}

private static int find(ImmutableList<Lattice.Measure> measures,
Pair<Aggregation, List<Integer>> seek) {
for (int i = 0; i < measures.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

/**
* Rule to remove distinct aggregates from a {@link AggregateRel}.
Expand Down Expand Up @@ -94,7 +95,7 @@ public void onMatch(RelOptRuleCall call) {
final List<RexInputRef> refs = new ArrayList<RexInputRef>();
final List<String> fieldNames = aggregate.getRowType().getFieldNames();
final BitSet groupSet = aggregate.getGroupSet();
for (int i : BitSets.toIter(groupSet)) {
for (int i : Util.range(groupSet.cardinality())) {
refs.add(RexInputRef.of(i, aggFields));
}

Expand Down Expand Up @@ -176,7 +177,7 @@ private RelNode convertMonopole(
return aggregate.copy(
aggregate.getTraitSet(),
distinct,
aggregate.getGroupSet(),
BitSets.range(aggregate.getGroupSet().cardinality()),
newAggCalls);
}

Expand Down Expand Up @@ -300,16 +301,12 @@ private RelNode doRewrite(
aggCall.getName());
assert refs.get(i) == null;
if (left == null) {
refs.set(
i,
new RexInputRef(
groupCount + aggCallList.size(),
refs.set(i,
new RexInputRef(groupCount + aggCallList.size(),
newAggCall.getType()));
} else {
refs.set(
i,
new RexInputRef(
leftFields.size() + groupCount + aggCallList.size(),
refs.set(i,
new RexInputRef(leftFields.size() + groupCount + aggCallList.size(),
newAggCall.getType()));
}
aggCallList.add(newAggCall);
Expand All @@ -319,7 +316,7 @@ private RelNode doRewrite(
aggregate.copy(
aggregate.getTraitSet(),
distinct,
aggregate.getGroupSet(),
BitSets.range(aggregate.getGroupSet().cardinality()),
aggCallList);

// If there's no left child yet, no need to create the join
Expand All @@ -332,37 +329,22 @@ private RelNode doRewrite(
// where {f0, f1, ...} are the GROUP BY fields.
final List<RelDataTypeField> distinctFields =
distinctAgg.getRowType().getFieldList();
RexNode condition = rexBuilder.makeLiteral(true);
final List<RexNode> conditions = Lists.newArrayList();
for (i = 0; i < groupCount; ++i) {
final int leftOrdinal = i;
final int rightOrdinal = sourceOf.get(i);

// null values form its own group
// use "is not distinct from" so that the join condition
// allows null values to match.
RexNode equi =
rexBuilder.makeCall(
SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
RexInputRef.of(leftOrdinal, leftFields),
new RexInputRef(
leftFields.size() + rightOrdinal,
distinctFields.get(rightOrdinal).getType()));
if (i == 0) {
condition = equi;
} else {
condition =
rexBuilder.makeCall(
SqlStdOperatorTable.AND,
condition,
equi);
}
conditions.add(
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
RexInputRef.of(i, leftFields),
new RexInputRef(leftFields.size() + i,
distinctFields.get(i).getType())));
}

// Join in the new 'select distinct' relation.
return joinFactory.createJoin(
left,
return joinFactory.createJoin(left,
distinctAgg,
condition,
RexUtil.composeConjunction(rexBuilder, conditions, false),
JoinRelType.INNER,
ImmutableSet.<String>of(),
false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ private void match(ProjectFilterTable left, ProjectFilterTable right,
Mappings.asList(mapping.inverse()));
final List<RexNode> conditions = Lists.newArrayList();
if (left.condition != null) {
conditions.add(RexUtil.apply(mapping, left.condition));
conditions.add(left.condition);
}
if (right.condition != null) {
conditions.add(
Expand All @@ -174,12 +174,12 @@ private void match(ProjectFilterTable left, ProjectFilterTable right,
Mappings.asList(mapping.inverse()));
final List<RexNode> conditions = Lists.newArrayList();
if (left.condition != null) {
conditions.add(RexUtil.apply(mapping, left.condition));
}
if (right.condition != null) {
conditions.add(
RexUtil.apply(mapping,
RexUtil.shift(right.condition, offset)));
RexUtil.shift(left.condition, offset)));
}
if (right.condition != null) {
conditions.add(RexUtil.apply(mapping, right.condition));
}
final RelNode filter =
RelOptUtil.createFilter(project, conditions);
Expand Down
18 changes: 9 additions & 9 deletions core/src/main/java/org/eigenbase/relopt/RelOptUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import net.hydromatic.optiq.util.BitSets;

import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

Expand All @@ -46,6 +47,13 @@ public abstract class RelOptUtil {

public static final double EPSILON = 1.0e-5;

private static final Function<RelDataTypeField, RelDataType> GET_TYPE =
new Function<RelDataTypeField, RelDataType>() {
public RelDataType apply(RelDataTypeField field) {
return field.getType();
}
};

//~ Methods ----------------------------------------------------------------

/**
Expand Down Expand Up @@ -131,15 +139,7 @@ public static void go(
* @see org.eigenbase.reltype.RelDataType#getFieldNames()
*/
public static List<RelDataType> getFieldTypeList(final RelDataType type) {
return new AbstractList<RelDataType>() {
public RelDataType get(int index) {
return type.getFieldList().get(index).getType();
}

public int size() {
return type.getFieldCount();
}
};
return Lists.transform(type.getFieldList(), GET_TYPE);
}

public static boolean areRowTypesEqual(
Expand Down
13 changes: 10 additions & 3 deletions core/src/main/java/org/eigenbase/relopt/SubstitutionVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,7 @@ public static MutableRel unifyAggregates(MutableAggregate query,
aggregateCalls.add(
new AggregateCall(getRollup(aggregateCall.getAggregation()),
aggregateCall.isDistinct(),
ImmutableList.of(groupSet.cardinality() + i),
ImmutableList.of(target.groupSet.cardinality() + i),
aggregateCall.type, aggregateCall.name));
}
result = MutableAggregate.of(target, groupSet, aggregateCalls);
Expand Down Expand Up @@ -1215,8 +1215,15 @@ public UnifyResult apply(UnifyRuleCall call) {
}

public static Aggregation getRollup(Aggregation aggregation) {
// TODO: count rolls up using sum; etc.
return aggregation;
if (aggregation == SqlStdOperatorTable.SUM
|| aggregation == SqlStdOperatorTable.MIN
|| aggregation == SqlStdOperatorTable.MAX) {
return aggregation;
} else if (aggregation == SqlStdOperatorTable.COUNT) {
return SqlStdOperatorTable.SUM;
} else {
return null;
}
}

/** Builds a shuttle that stores a list of expressions, and can map incoming
Expand Down
8 changes: 5 additions & 3 deletions core/src/main/java/org/eigenbase/sql2rel/RelDecorrelator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2204,8 +2204,9 @@ public void onMatch(RelOptRuleCall call) {

rightInputRel =
createProjectWithAdditionalExprs(rightInputRel,
ImmutableList.of(Pair.<RexNode, String>of(rexBuilder.makeLiteral(
true), "nullIndicator")));
ImmutableList.of(
Pair.<RexNode, String>of(rexBuilder.makeLiteral(true),
"nullIndicator")));

JoinRel joinRel =
new JoinRel(
Expand Down Expand Up @@ -2279,7 +2280,8 @@ public void onMatch(RelOptRuleCall call) {
}
}

newAggCalls.add(aggCall.adaptTo(joinOutputProjRel, newAggArgs,
newAggCalls.add(
aggCall.adaptTo(joinOutputProjRel, newAggArgs,
aggRel.getGroupCount(), groupCount));
}

Expand Down
12 changes: 12 additions & 0 deletions core/src/main/java/org/eigenbase/util/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -1990,6 +1990,18 @@ public static <E> List<E> skip(List<E> list, int fromIndex) {
return list.subList(fromIndex, list.size());
}

public static List<Integer> range(final int end) {
return new AbstractList<Integer>() {
public int size() {
return end;
}

public Integer get(int index) {
return index;
}
};
}

public static List<Integer> range(final int start, final int end) {
return new AbstractList<Integer>() {
public int size() {
Expand Down
20 changes: 18 additions & 2 deletions core/src/test/java/net/hydromatic/optiq/test/FoodmartTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import com.google.common.collect.ImmutableList;

import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
Expand Down Expand Up @@ -151,10 +152,8 @@ public FoodmartTest(int id) throws IOException {
public void test() {
try {
OptiqAssert.that()
// .withModel(JdbcTest.FOODMART_MODEL)
.with(OptiqAssert.Config.FOODMART_CLONE)
.pooled()
// .withSchema("foodmart")
.query(query.sql)
.runs();
} catch (Throwable e) {
Expand All @@ -163,6 +162,23 @@ public void test() {
}
}

@Test(timeout = 60000)
@Ignore
public void testWithLattice() {
try {
OptiqAssert.that()
.with(OptiqAssert.Config.JDBC_FOODMART_WITH_LATTICE)
.pooled()
.withSchema("foodmart")
.query(query.sql)
.enableMaterializations(true)
.runs();
} catch (Throwable e) {
throw new RuntimeException("Test failed, id=" + query.id + ", sql="
+ query.sql, e);
}
}

public static class FoodMartQuerySet {
private static SoftReference<FoodMartQuerySet> ref;

Expand Down
Loading

0 comments on commit 5d95c5f

Please sign in to comment.