Skip to content

Add lambda function and array related functions #3584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.sql.ast.expression.HighlightFunction;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Map;
Expand Down Expand Up @@ -232,6 +233,10 @@ public T visitSort(Sort node, C context) {
return visitChildren(node, context);
}

public T visitLambdaFunction(LambdaFunction node, C context) {
return visitChildren(node, context);
}

public T visitDedupe(Dedupe node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of lambda function. Params include function name (@funcName) and function
* arguments (@funcArgs)
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class LambdaFunction extends UnresolvedExpression {
private final UnresolvedExpression function;
private final List<QualifiedName> funcArgs;

@Override
public List<UnresolvedExpression> getChild() {
List<UnresolvedExpression> children = new ArrayList<>();
children.add(function);
children.addAll(funcArgs);
return children;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitLambdaFunction(this, context);
}

@Override
public String toString() {
return String.format(
"(%s) -> %s",
funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")),
function.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;

import java.sql.Connection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Stack;
import java.util.function.BiFunction;
import lombok.Getter;
import lombok.Setter;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexLambdaRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.RelBuilder;
Expand Down Expand Up @@ -47,6 +50,8 @@ public class CalcitePlanContext {
private final Stack<RexCorrelVariable> correlVar = new Stack<>();
private final Stack<List<RexNode>> windowPartitions = new Stack<>();

@Getter public Map<String, RexLambdaRef> rexLambdaRefMap;

private CalcitePlanContext(FrameworkConfig config, Integer querySizeLimit, QueryType queryType) {
this.config = config;
this.querySizeLimit = querySizeLimit;
Expand All @@ -55,6 +60,7 @@ private CalcitePlanContext(FrameworkConfig config, Integer querySizeLimit, Query
this.relBuilder = CalciteToolsHelper.create(config, TYPE_FACTORY, connection);
this.rexBuilder = new ExtendedRexBuilder(relBuilder.getRexBuilder());
this.functionProperties = new FunctionProperties(QueryType.PPL);
this.rexLambdaRefMap = new HashMap<>();
}

public RexNode resolveJoinCondition(
Expand Down Expand Up @@ -86,8 +92,16 @@ public Optional<RexCorrelVariable> peekCorrelVar() {
}
}

public CalcitePlanContext clone() {
return new CalcitePlanContext(config, querySizeLimit, queryType);
}

public static CalcitePlanContext create(
FrameworkConfig config, Integer querySizeLimit, QueryType queryType) {
return new CalcitePlanContext(config, querySizeLimit, queryType);
}

public void putRexLambdaRefMap(Map<String, RexLambdaRef> candidateMap) {
this.rexLambdaRefMap.putAll(candidateMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,29 @@
import static org.apache.calcite.sql.SqlKind.AS;
import static org.opensearch.sql.ast.expression.SpanUnit.NONE;
import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN;
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import lombok.RequiredArgsConstructor;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLambdaRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.TimeString;
Expand All @@ -39,6 +48,7 @@
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
Expand Down Expand Up @@ -278,6 +288,9 @@ public RexNode visitQualifiedName(QualifiedName node, CalcitePlanContext context
// TODO: Need to support nested fields https://github.com/opensearch-project/sql/issues/3459
// 2. resolve QualifiedName in non-join condition
String qualifiedName = node.toString();
if (context.getRexLambdaRefMap().containsKey(qualifiedName)) {
return context.getRexLambdaRefMap().get(qualifiedName);
}
List<String> currentFields = context.relBuilder.peek().getRowType().getFieldNames();
if (currentFields.contains(qualifiedName)) {
// 2.1 resolve QualifiedName from stack top
Expand Down Expand Up @@ -331,16 +344,132 @@ private boolean isTimeBased(SpanUnit unit) {
return !(unit == NONE || unit == UNKNOWN);
}

@Override
public RexNode visitLambdaFunction(LambdaFunction node, CalcitePlanContext context) {
try {
List<QualifiedName> names = node.getFuncArgs();
List<RexLambdaRef> args =
IntStream.range(0, names.size())
.mapToObj(
i ->
context.rexLambdaRefMap.getOrDefault(
names.get(i).toString(),
new RexLambdaRef(
i,
names.get(i).toString(),
TYPE_FACTORY.createSqlType(SqlTypeName.ANY))))
.collect(Collectors.toList());
RexNode body = node.getFunction().accept(this, context);
RexNode lambdaNode = context.rexBuilder.makeLambdaCall(body, args);
return lambdaNode;
} catch (Exception e) {
throw new RuntimeException("Cannot create lambda function", e);
}
}

@Override
public RexNode visitLet(Let node, CalcitePlanContext context) {
RexNode expr = analyze(node.getExpression(), context);
return context.relBuilder.alias(expr, node.getVar().getField().toString());
}

/**
* The function will clone a context for lambda function. For lambda like (x, y, z) -> ..., we
* will map type for each lambda argument by the order of previous argument. Also, the function
* will add these variables to the context so they can pass visitQualifiedName
*/
public CalcitePlanContext prepareLambdaContext(
CalcitePlanContext context,
LambdaFunction node,
List<RexNode> previousArgument,
String functionName,
@Nullable RelDataType defaultTypeForReduceAcc) {
try {
CalcitePlanContext lambdaContext = context.clone();
List<RelDataType> candidateType = new ArrayList<>();
candidateType.add(
((ArraySqlType) previousArgument.get(0).getType())
.getComponentType()); // The first argument should be array type
candidateType.addAll(previousArgument.stream().skip(1).map(RexNode::getType).toList());
candidateType =
modifyLambdaTypeByFunction(functionName, candidateType, defaultTypeForReduceAcc);
List<QualifiedName> argNames = node.getFuncArgs();
Map<String, RexLambdaRef> lambdaTypes = new HashMap<>();
int candidateIndex;
candidateIndex = 0;
for (int i = 0; i < argNames.size(); i++) {
RelDataType type;
if (candidateIndex < candidateType.size()) {
type = candidateType.get(candidateIndex);
candidateIndex++;
} else {
type =
TYPE_FACTORY.createSqlType(
SqlTypeName.INTEGER); // For transform function, the i is missing in input.
}
lambdaTypes.put(
argNames.get(i).toString(), new RexLambdaRef(i, argNames.get(i).toString(), type));
}
lambdaContext.putRexLambdaRefMap(lambdaTypes);
return lambdaContext;
} catch (Exception e) {
throw new RuntimeException("Fail to prepare lambda context", e);
}
}

/**
* @param functionName function name
* @param originalType the argument type by order
* @return a modified types. Different functions need to implement its own order. Currently, only
* reduce has special logic.
*/
private List<RelDataType> modifyLambdaTypeByFunction(
String functionName,
List<RelDataType> originalType,
@Nullable RelDataType defaultTypeForReduceAcc) {
switch (functionName.toUpperCase(Locale.ROOT)) {
case "REDUCE": // For reduce case, the first type is acc should be any since it is the output
// of accumulator lambda function
if (originalType.size() == 2) {
if (defaultTypeForReduceAcc == null
|| defaultTypeForReduceAcc.equals(originalType.get(1))) {
return List.of(originalType.get(1), originalType.get(0));
}
return List.of(TYPE_FACTORY.createSqlType(SqlTypeName.ANY, true), originalType.get(0));

} else {
return List.of(originalType.get(2));
}
default:
return originalType;
}
}

@Override
public RexNode visitFunction(Function node, CalcitePlanContext context) {
List<RexNode> arguments =
node.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList();
List<UnresolvedExpression> args = node.getFuncArgs();
List<RexNode> arguments = new ArrayList<>();
for (UnresolvedExpression arg : args) {
if (arg instanceof LambdaFunction) {
CalcitePlanContext lambdaContext =
prepareLambdaContext(
context, (LambdaFunction) arg, arguments, node.getFuncName(), null);
RexNode lambdaNode = analyze(arg, lambdaContext);
if (node.getFuncName().equalsIgnoreCase("reduce")) { // analyze again with calculate type
lambdaContext =
prepareLambdaContext(
context,
(LambdaFunction) arg,
arguments,
node.getFuncName(),
lambdaNode.getType());
lambdaNode = analyze(arg, lambdaContext);
}
arguments.add(lambdaNode);
} else {
arguments.add(analyze(arg, context));
}
}
RexNode resolvedNode =
PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.opensearch.sql.expression.function.ImplementorUDF;

public class UserDefinedFunctionUtils {

public static final RelDataType NULLABLE_DATE_UDT = TYPE_FACTORY.createUDT(EXPR_DATE, true);
public static final RelDataType NULLABLE_TIME_UDT = TYPE_FACTORY.createUDT(EXPR_TIME, true);
public static final RelDataType NULLABLE_TIMESTAMP_UDT =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ public enum BuiltinFunctionName {
TAN(FunctionName.of("tan")),
SPAN(FunctionName.of("span")),

/** Collection functions */
ARRAY(FunctionName.of("array")),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already add it.

ARRAY_LENGTH(FunctionName.of("array_length")),
FORALL(FunctionName.of("forall")),
EXISTS(FunctionName.of("exists")),
FILTER(FunctionName.of("filter")),
TRANSFORM(FunctionName.of("transform")),
REDUCE(FunctionName.of("reduce")),

/** Date and Time Functions. */
ADDDATE(FunctionName.of("adddate")),
ADDTIME(FunctionName.of("addtime")),
Expand Down
Loading
Loading