Skip to content

Commit 9dab649

Browse files
committed
Add parametric max and min function
1 parent 99633bb commit 9dab649

26 files changed

+695
-278
lines changed

presto-benchmark/src/main/java/com/facebook/presto/benchmark/BenchmarkSuite.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public static List<AbstractBenchmark> createBenchmarks(LocalQueryRunner localQue
6161
new Top100SqlBenchmark(localQueryRunner),
6262
new SqlHashJoinBenchmark(localQueryRunner),
6363
new SqlJoinWithPredicateBenchmark(localQueryRunner),
64+
new LongMaxAggregationSqlBenchmark(localQueryRunner),
6465
new VarBinaryMaxAggregationSqlBenchmark(localQueryRunner),
6566
new SqlDistinctMultipleFields(localQueryRunner),
6667
new SqlDistinctSingleField(localQueryRunner),
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.benchmark;
15+
16+
import com.facebook.presto.testing.LocalQueryRunner;
17+
18+
import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner;
19+
20+
public class LongMaxAggregationSqlBenchmark
21+
extends AbstractSqlBenchmark
22+
{
23+
public LongMaxAggregationSqlBenchmark(LocalQueryRunner localQueryRunner)
24+
{
25+
super(localQueryRunner, "sql_long_max", 40, 200, "select max(partkey) from lineitem");
26+
}
27+
28+
public static void main(String[] args)
29+
{
30+
new LongMaxAggregationSqlBenchmark(createLocalQueryRunner()).runBenchmark(new SimpleLineBenchmarkResultWriter(System.out));
31+
}
32+
}

presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,10 @@
2626
import com.facebook.presto.operator.aggregation.BooleanOrAggregation;
2727
import com.facebook.presto.operator.aggregation.CountAggregation;
2828
import com.facebook.presto.operator.aggregation.CountIfAggregation;
29-
import com.facebook.presto.operator.aggregation.DoubleMaxAggregation;
30-
import com.facebook.presto.operator.aggregation.DoubleMinAggregation;
3129
import com.facebook.presto.operator.aggregation.DoubleSumAggregation;
32-
import com.facebook.presto.operator.aggregation.LongMaxAggregation;
33-
import com.facebook.presto.operator.aggregation.LongMinAggregation;
3430
import com.facebook.presto.operator.aggregation.LongSumAggregation;
3531
import com.facebook.presto.operator.aggregation.MergeHyperLogLogAggregation;
3632
import com.facebook.presto.operator.aggregation.NumericHistogramAggregation;
37-
import com.facebook.presto.operator.aggregation.VarBinaryMaxAggregation;
38-
import com.facebook.presto.operator.aggregation.VarBinaryMinAggregation;
3933
import com.facebook.presto.operator.aggregation.VarianceAggregation;
4034
import com.facebook.presto.operator.scalar.ArrayFunctions;
4135
import com.facebook.presto.operator.scalar.ColorFunctions;
@@ -131,7 +125,9 @@
131125
import static com.facebook.presto.operator.aggregation.ArbitraryAggregation.ARBITRARY_AGGREGATION;
132126
import static com.facebook.presto.operator.aggregation.CountColumn.COUNT_COLUMN;
133127
import static com.facebook.presto.operator.aggregation.MapAggregation.MAP_AGG;
128+
import static com.facebook.presto.operator.aggregation.MaxAggregation.MAX_AGGREGATION;
134129
import static com.facebook.presto.operator.aggregation.MaxBy.MAX_BY;
130+
import static com.facebook.presto.operator.aggregation.MinAggregation.MIN_AGGREGATION;
135131
import static com.facebook.presto.operator.aggregation.MinBy.MIN_BY;
136132
import static com.facebook.presto.operator.scalar.ArrayCardinalityFunction.ARRAY_CARDINALITY;
137133
import static com.facebook.presto.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION;
@@ -271,12 +267,6 @@ public FunctionInfo load(SpecializedFunctionKey key)
271267
.aggregate(CountIfAggregation.class)
272268
.aggregate(BooleanAndAggregation.class)
273269
.aggregate(BooleanOrAggregation.class)
274-
.aggregate(DoubleMinAggregation.class)
275-
.aggregate(DoubleMaxAggregation.class)
276-
.aggregate(LongMinAggregation.class)
277-
.aggregate(LongMaxAggregation.class)
278-
.aggregate(VarBinaryMinAggregation.class)
279-
.aggregate(VarBinaryMaxAggregation.class)
280270
.aggregate(DoubleSumAggregation.class)
281271
.aggregate(LongSumAggregation.class)
282272
.aggregate(AverageAggregations.class)
@@ -322,6 +312,7 @@ public FunctionInfo load(SpecializedFunctionKey key)
322312
.function(GREATEST)
323313
.function(MAX_BY)
324314
.function(MIN_BY)
315+
.functions(MAX_AGGREGATION, MIN_AGGREGATION)
325316
.function(COUNT_COLUMN)
326317
.functions(ROW_HASH_CODE, ROW_TO_JSON, ROW_EQUAL, ROW_NOT_EQUAL)
327318
.function(TRY_CAST);
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.operator.aggregation;
15+
16+
import com.facebook.presto.byteCode.DynamicClassLoader;
17+
import com.facebook.presto.metadata.FunctionInfo;
18+
import com.facebook.presto.metadata.FunctionRegistry;
19+
import com.facebook.presto.metadata.OperatorType;
20+
import com.facebook.presto.metadata.ParametricAggregation;
21+
import com.facebook.presto.metadata.Signature;
22+
import com.facebook.presto.operator.aggregation.state.AccumulatorState;
23+
import com.facebook.presto.operator.aggregation.state.AccumulatorStateFactory;
24+
import com.facebook.presto.operator.aggregation.state.AccumulatorStateSerializer;
25+
import com.facebook.presto.operator.aggregation.state.NullableBooleanState;
26+
import com.facebook.presto.operator.aggregation.state.NullableBooleanStateSerializer;
27+
import com.facebook.presto.operator.aggregation.state.NullableDoubleState;
28+
import com.facebook.presto.operator.aggregation.state.NullableDoubleStateSerializer;
29+
import com.facebook.presto.operator.aggregation.state.NullableLongState;
30+
import com.facebook.presto.operator.aggregation.state.NullableLongStateSerializer;
31+
import com.facebook.presto.operator.aggregation.state.SliceState;
32+
import com.facebook.presto.operator.aggregation.state.SliceStateSerializer;
33+
import com.facebook.presto.operator.aggregation.state.StateCompiler;
34+
import com.facebook.presto.spi.PrestoException;
35+
import com.facebook.presto.spi.StandardErrorCode;
36+
import com.facebook.presto.spi.type.Type;
37+
import com.facebook.presto.spi.type.TypeManager;
38+
import com.google.common.base.Throwables;
39+
import com.google.common.collect.ImmutableList;
40+
import io.airlift.slice.Slice;
41+
42+
import java.lang.invoke.MethodHandle;
43+
import java.util.List;
44+
import java.util.Map;
45+
46+
import static com.facebook.presto.metadata.Signature.orderableTypeParameter;
47+
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata;
48+
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL;
49+
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
50+
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
51+
import static com.facebook.presto.spi.StandardErrorCode.INTERNAL_ERROR;
52+
import static com.facebook.presto.util.Reflection.methodHandle;
53+
import static com.google.common.base.Preconditions.checkNotNull;
54+
55+
public abstract class AbstractMinMaxAggregation
56+
extends ParametricAggregation
57+
{
58+
private static final MethodHandle LONG_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregation.class, "input", MethodHandle.class, NullableLongState.class, long.class);
59+
private static final MethodHandle DOUBLE_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregation.class, "input", MethodHandle.class, NullableDoubleState.class, double.class);
60+
private static final MethodHandle SLICE_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregation.class, "input", MethodHandle.class, SliceState.class, Slice.class);
61+
private static final MethodHandle BOOLEAN_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregation.class, "input", MethodHandle.class, NullableBooleanState.class, boolean.class);
62+
63+
private final String name;
64+
private final OperatorType operatorType;
65+
private final Signature signature;
66+
67+
private final StateCompiler compiler = new StateCompiler();
68+
69+
protected AbstractMinMaxAggregation(String name, OperatorType operatorType)
70+
{
71+
checkNotNull(name);
72+
checkNotNull(operatorType);
73+
this.name = name;
74+
this.operatorType = operatorType;
75+
this.signature = new Signature(name, ImmutableList.of(orderableTypeParameter("E")), "E", ImmutableList.of("E"), false, false);
76+
}
77+
78+
@Override
79+
public Signature getSignature()
80+
{
81+
return signature;
82+
}
83+
84+
@Override
85+
public FunctionInfo specialize(Map<String, Type> types, int arity, TypeManager typeManager, FunctionRegistry functionRegistry)
86+
{
87+
Type type = types.get("E");
88+
MethodHandle compareMethodHandle = functionRegistry.resolveOperator(operatorType, ImmutableList.of(type, type)).getMethodHandle();
89+
Signature signature = new Signature(name, type.getTypeSignature(), type.getTypeSignature());
90+
InternalAggregationFunction aggregation = generateAggregation(type, compareMethodHandle);
91+
return new FunctionInfo(signature, getDescription(), aggregation.getIntermediateType().getTypeSignature(), aggregation, false);
92+
}
93+
94+
protected InternalAggregationFunction generateAggregation(Type type, MethodHandle compareMethodHandle)
95+
{
96+
DynamicClassLoader classLoader = new DynamicClassLoader(AbstractMinMaxAggregation.class.getClassLoader());
97+
98+
List<Type> inputTypes = ImmutableList.of(type);
99+
100+
AccumulatorStateSerializer stateSerializer;
101+
AccumulatorStateFactory stateFactory;
102+
MethodHandle inputFunction;
103+
Class<? extends AccumulatorState> stateInterface;
104+
105+
if (type.getJavaType() == long.class) {
106+
stateFactory = compiler.generateStateFactory(NullableLongState.class, classLoader);
107+
stateSerializer = new NullableLongStateSerializer(type);
108+
stateInterface = NullableLongState.class;
109+
inputFunction = LONG_INPUT_FUNCTION;
110+
}
111+
else if (type.getJavaType() == double.class) {
112+
stateFactory = compiler.generateStateFactory(NullableDoubleState.class, classLoader);
113+
stateSerializer = new NullableDoubleStateSerializer(type);
114+
stateInterface = NullableDoubleState.class;
115+
inputFunction = DOUBLE_INPUT_FUNCTION;
116+
}
117+
else if (type.getJavaType() == Slice.class) {
118+
stateFactory = compiler.generateStateFactory(SliceState.class, classLoader);
119+
stateSerializer = new SliceStateSerializer(type);
120+
stateInterface = SliceState.class;
121+
inputFunction = SLICE_INPUT_FUNCTION;
122+
}
123+
else if (type.getJavaType() == boolean.class) {
124+
stateFactory = compiler.generateStateFactory(NullableBooleanState.class, classLoader);
125+
stateSerializer = new NullableBooleanStateSerializer(type);
126+
stateInterface = NullableBooleanState.class;
127+
inputFunction = BOOLEAN_INPUT_FUNCTION;
128+
}
129+
else {
130+
throw new PrestoException(StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Argument type to max/min unsupported");
131+
}
132+
133+
inputFunction = inputFunction.bindTo(compareMethodHandle);
134+
135+
Type intermediateType = stateSerializer.getSerializedType();
136+
List<ParameterMetadata> inputParameterMetadata = createInputParameterMetadata(type);
137+
AggregationMetadata metadata = new AggregationMetadata(
138+
generateAggregationName(name, type, inputTypes),
139+
inputParameterMetadata,
140+
inputFunction,
141+
inputParameterMetadata,
142+
inputFunction,
143+
null,
144+
null,
145+
stateInterface,
146+
stateSerializer,
147+
stateFactory,
148+
type,
149+
false);
150+
151+
GenericAccumulatorFactoryBinder factory = new AccumulatorCompiler().generateAccumulatorFactoryBinder(metadata, classLoader);
152+
return new InternalAggregationFunction(name, inputTypes, intermediateType, type, true, false, factory);
153+
}
154+
155+
private static List<ParameterMetadata> createInputParameterMetadata(Type type)
156+
{
157+
return ImmutableList.of(
158+
new ParameterMetadata(STATE),
159+
new ParameterMetadata(INPUT_CHANNEL, type));
160+
}
161+
162+
public static void input(MethodHandle methodHandle, NullableDoubleState state, double value)
163+
{
164+
if (state.isNull()) {
165+
state.setNull(false);
166+
state.setDouble(value);
167+
return;
168+
}
169+
try {
170+
if ((boolean) methodHandle.invokeExact(value, state.getDouble())) {
171+
state.setDouble(value);
172+
}
173+
}
174+
catch (Throwable t) {
175+
Throwables.propagateIfInstanceOf(t, Error.class);
176+
Throwables.propagateIfInstanceOf(t, PrestoException.class);
177+
throw new PrestoException(INTERNAL_ERROR, t);
178+
}
179+
}
180+
181+
public static void input(MethodHandle methodHandle, NullableLongState state, long value)
182+
{
183+
if (state.isNull()) {
184+
state.setNull(false);
185+
state.setLong(value);
186+
return;
187+
}
188+
try {
189+
if ((boolean) methodHandle.invokeExact(value, state.getLong())) {
190+
state.setLong(value);
191+
}
192+
}
193+
catch (Throwable t) {
194+
Throwables.propagateIfInstanceOf(t, Error.class);
195+
Throwables.propagateIfInstanceOf(t, PrestoException.class);
196+
throw new PrestoException(INTERNAL_ERROR, t);
197+
}
198+
}
199+
200+
public static void input(MethodHandle methodHandle, SliceState state, Slice value)
201+
{
202+
if (state.getSlice() == null) {
203+
state.setSlice(value);
204+
return;
205+
}
206+
try {
207+
if ((boolean) methodHandle.invokeExact(value, state.getSlice())) {
208+
state.setSlice(value);
209+
}
210+
}
211+
catch (Throwable t) {
212+
Throwables.propagateIfInstanceOf(t, Error.class);
213+
Throwables.propagateIfInstanceOf(t, PrestoException.class);
214+
throw new PrestoException(INTERNAL_ERROR, t);
215+
}
216+
}
217+
218+
public static void input(MethodHandle methodHandle, NullableBooleanState state, boolean value)
219+
{
220+
if (state.isNull()) {
221+
state.setNull(false);
222+
state.setBoolean(value);
223+
return;
224+
}
225+
try {
226+
if ((boolean) methodHandle.invokeExact(value, state.getBoolean())) {
227+
state.setBoolean(value);
228+
}
229+
}
230+
catch (Throwable t) {
231+
Throwables.propagateIfInstanceOf(t, Error.class);
232+
Throwables.propagateIfInstanceOf(t, PrestoException.class);
233+
throw new PrestoException(INTERNAL_ERROR, t);
234+
}
235+
}
236+
}

presto-main/src/main/java/com/facebook/presto/operator/aggregation/BooleanAndAggregation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import static com.facebook.presto.operator.aggregation.state.TriStateBooleanState.NULL_VALUE;
2222
import static com.facebook.presto.operator.aggregation.state.TriStateBooleanState.TRUE_VALUE;
2323

24-
@AggregationFunction(value = "bool_and", alias = {"every", "min"})
24+
@AggregationFunction(value = "bool_and", alias = "every")
2525
public final class BooleanAndAggregation
2626
{
2727
private BooleanAndAggregation() {}

presto-main/src/main/java/com/facebook/presto/operator/aggregation/BooleanOrAggregation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import static com.facebook.presto.operator.aggregation.state.TriStateBooleanState.NULL_VALUE;
2222
import static com.facebook.presto.operator.aggregation.state.TriStateBooleanState.TRUE_VALUE;
2323

24-
@AggregationFunction(value = "bool_or", alias = "max")
24+
@AggregationFunction(value = "bool_or")
2525
public final class BooleanOrAggregation
2626
{
2727
private BooleanOrAggregation() {}

presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleMaxAggregation.java

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)