diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 28cbd17b903..f276ee774e3 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -292,15 +292,38 @@ public Expression visitIn(In node, AnalysisContext context) { private Expression visitIn( UnresolvedExpression field, List valueList, AnalysisContext context) { - if (valueList.size() == 1) { - return visitCompare(new Compare("=", field, valueList.get(0)), context); - } else if (valueList.size() > 1) { - return DSL.or( - visitCompare(new Compare("=", field, valueList.get(0)), context), - visitIn(field, valueList.subList(1, valueList.size()), context)); - } else { + if (valueList.isEmpty()) { throw new SemanticCheckException("Values in In clause should not be empty"); } + + Expression[] expressions = new Expression[valueList.size()]; + + for (int i = 0; i < expressions.length; i++) { + expressions[i] = visitCompare(new Compare("=", field, valueList.get(i)), context); + } + + return buildOrTree(expressions, 0, expressions.length); + } + + /** + * `DSL.or` can only take two arguments. To represent large lists without massive recursion, we + * want to represent the expression as a balanced tree. This builds that tree from a node list. + * + * @param children The list of expressions to merge. + * @param start The starting position (inclusive) for the current combination step. + * @param end The ending position (exclusive) for the current combination step. If <= start, + * children[start] is returned. + * @return The final `DSL.or` expression. + */ + private Expression buildOrTree(Expression[] children, int start, int end) { + if (end - start <= 1) { + return children[start]; + } + if (end - start == 2) { + return DSL.or(children[start], children[end - 1]); + } + int split = start + (end - start) / 2; + return DSL.or(buildOrTree(children, start, split), buildOrTree(children, split, end)); } @Override diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index b27b8348e2f..4b096a6ed9a 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -27,6 +27,7 @@ import static org.opensearch.sql.expression.DSL.ref; import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; @@ -401,6 +402,17 @@ void visit_in() { () -> analyze(AstDSL.in(field("integer_value"), Collections.emptyList()))); } + @Test + void visit_in_large_list() { + List ints = new ArrayList<>(); + for (int i = 0; i < 10000; i++) { + ints.add(intLiteral(i)); + } + + // Shouldn't crash + analyze(AstDSL.in(field("integer_value"), ints)); + } + @Test void multi_match_expression() { assertAnalyzeEqual(