diff --git a/.github/actions/compile-commit/action.yml b/.github/actions/compile-commit/action.yml index ddda80c7b1c670..6eab2ec5a4f440 100644 --- a/.github/actions/compile-commit/action.yml +++ b/.github/actions/compile-commit/action.yml @@ -41,7 +41,7 @@ runs: # For building with Maven we need MAVEN_OPTS to equal MAVEN_INSTALL_OPTS export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" - $MAVEN package \ + $MAVEN install \ ${MAVEN_COMPILE_COMMITS} `# defaults, kept in sync with ci.yml` \ -Dair.check.skip-all=false -Dair.check.skip-basic=true -Dair.check.skip-extended=true -Dair.check.skip-checkstyle=false \ ${MAVEN_GIB} diff --git a/.github/config/labeler-config.yml b/.github/config/labeler-config.yml index e7ff2840486e01..2dc9f51511d89e 100644 --- a/.github/config/labeler-config.yml +++ b/.github/config/labeler-config.yml @@ -1,40 +1,44 @@ # Pull Request Labeler Github Action Configuration: https://github.com/marketplace/actions/labeler "tests:hive": - - lib/trino-orc/** - - lib/trino-parquet/** - - lib/trino-hive-formats/** - - plugin/trino-hive/** - - testing/trino-product-tests/** - - lib/trino-filesystem/** - - lib/trino-filesystem-*/** + - changed-files: + - any-glob-to-any-file: ['lib/trino-orc/**', 'lib/trino-parquet/**', 'lib/trino-hive-formats/**', 'plugin/trino-hive/**', 'testing/trino-product-tests/**', 'lib/trino-filesystem/**', 'lib/trino-filesystem-*/**'] jdbc: - - client/trino-jdbc/** + - changed-files: + - any-glob-to-any-file: 'client/trino-jdbc/**' bigquery: - - plugin/trino-bigquery/** + - changed-files: + - any-glob-to-any-file: 'plugin/trino-bigquery/**' delta-lake: - - plugin/trino-delta-lake/** + - changed-files: + - any-glob-to-any-file: 'plugin/trino-delta-lake/**' hive: - - plugin/trino-hive/** + - changed-files: + - any-glob-to-any-file: 'plugin/trino-hive/**' hudi: - - plugin/trino-hudi/** + - changed-files: + - any-glob-to-any-file: 'plugin/trino-hudi/**' iceberg: - - plugin/trino-iceberg/** + - changed-files: + - any-glob-to-any-file: 'plugin/trino-iceberg/**' mongodb: - - plugin/trino-mongodb/** + - changed-files: + - any-glob-to-any-file: 'plugin/trino-mongodb/**' docs: - - docs/** + - changed-files: + - any-glob-to-any-file: 'docs/**' release-notes: - - docs/src/main/sphinx/release/** - - docs/src/main/sphinx/release.rst + - changed-files: + - any-glob-to-any-file: ['docs/src/main/sphinx/release/**', 'docs/src/main/sphinx/release.rst'] ui: - - core/trino-main/src/main/resources/webapp/** + - changed-files: + - any-glob-to-any-file: 'core/trino-main/src/main/resources/webapp/**' diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 7b15a2a5c3509b..1e954cee1b712f 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -12,11 +12,10 @@ jobs: triage: runs-on: ubuntu-latest steps: - - uses: actions/labeler@v4 + - uses: actions/labeler@v5 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" # Do not sync labels as this reverts manually added "tests:hive" label. # Syncing labels requires that we define "components" labels. - # See https://github.com/actions/labeler/issues/112#issuecomment-999953377 for why an empty string instead of false. - sync-labels: '' + sync-labels: false configuration-path: .github/config/labeler-config.yml diff --git a/core/docker/Dockerfile b/core/docker/Dockerfile index f3c79a07f95662..d983034cad8acb 100644 --- a/core/docker/Dockerfile +++ b/core/docker/Dockerfile @@ -31,6 +31,7 @@ FROM registry.access.redhat.com/ubi9/ubi-minimal:latest ARG JDK_VERSION ENV JAVA_HOME="/usr/lib/jvm/jdk-${JDK_VERSION}" ENV PATH=$PATH:$JAVA_HOME/bin +ENV CATALOG_MANAGEMENT=static COPY --from=jdk-download $JAVA_HOME $JAVA_HOME RUN \ diff --git a/core/docker/default/etc/config.properties b/core/docker/default/etc/config.properties index a11cba39db4662..559b9a37ea3eeb 100644 --- a/core/docker/default/etc/config.properties +++ b/core/docker/default/etc/config.properties @@ -3,3 +3,4 @@ coordinator=true node-scheduler.include-coordinator=true http-server.http.port=8080 discovery.uri=http://localhost:8080 +catalog.management=${ENV:CATALOG_MANAGEMENT} diff --git a/core/docker/default/etc/jvm.config b/core/docker/default/etc/jvm.config index a5d08bcf8ce825..0a6bca1bc2fa92 100644 --- a/core/docker/default/etc/jvm.config +++ b/core/docker/default/etc/jvm.config @@ -15,3 +15,5 @@ # Reduce starvation of threads by GClocker, recommend to set about the number of cpu cores (JDK-8192647) -XX:+UnlockDiagnosticVMOptions -XX:GCLockerRetryAllocationCount=32 +# Allow loading dynamic agent used by JOL +-XX:+EnableDynamicAgentLoading diff --git a/core/trino-main/src/main/java/io/trino/execution/CallTask.java b/core/trino-main/src/main/java/io/trino/execution/CallTask.java index a1909df8613e31..6017e6ce4c4f45 100644 --- a/core/trino-main/src/main/java/io/trino/execution/CallTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/CallTask.java @@ -62,7 +62,7 @@ import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.PROCEDURE_CALL_FAILED; import static io.trino.spi.type.TypeUtils.writeNativeValue; -import static io.trino.sql.analyzer.ExpressionInterpreter.evaluateConstantExpression; +import static io.trino.sql.analyzer.ConstantEvaluator.evaluateConstant; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; @@ -164,7 +164,7 @@ else if (i < procedure.getArguments().size()) { Expression expression = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(parameterLookup), callArgument.getValue()); Type type = argument.getType(); - Object value = evaluateConstantExpression(expression, type, plannerContext, session, accessControl, parameterLookup); + Object value = evaluateConstant(expression, type, plannerContext, session, accessControl); values[index] = toTypeObjectValue(session, type, value); } diff --git a/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java b/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java index 331fdb655f729a..9bf3bfb5c747ec 100644 --- a/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java @@ -56,7 +56,7 @@ import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; import static io.trino.spi.connector.ConnectorCapabilities.MATERIALIZED_VIEW_GRACE_PERIOD; import static io.trino.sql.SqlFormatterUtil.getFormattedSql; -import static io.trino.sql.analyzer.ExpressionInterpreter.evaluateConstantExpression; +import static io.trino.sql.analyzer.ConstantEvaluator.evaluateConstant; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; import static io.trino.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static java.util.Locale.ENGLISH; @@ -145,13 +145,12 @@ Analysis executeInternal( if (type != INTERVAL_DAY_TIME) { throw new TrinoException(TYPE_MISMATCH, "Unsupported grace period type %s, expected %s".formatted(type.getDisplayName(), INTERVAL_DAY_TIME.getDisplayName())); } - Long milliseconds = (Long) evaluateConstantExpression( + Long milliseconds = (Long) evaluateConstant( expression, type, plannerContext, session, - accessControl, - parameterLookup); + accessControl); // Sanity check. Impossible per grammar. verify(milliseconds != null, "Grace period cannot be null"); return Duration.ofMillis(milliseconds); diff --git a/core/trino-main/src/main/java/io/trino/execution/SetTimeZoneTask.java b/core/trino-main/src/main/java/io/trino/execution/SetTimeZoneTask.java index 73ca08b034591e..e64389f36a16bc 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SetTimeZoneTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SetTimeZoneTask.java @@ -42,8 +42,8 @@ import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey; import static io.trino.spi.type.TimeZoneKey.getTimeZoneKeyForOffset; +import static io.trino.sql.analyzer.ConstantEvaluator.evaluateConstant; import static io.trino.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer; -import static io.trino.sql.analyzer.ExpressionInterpreter.evaluateConstantExpression; import static io.trino.util.Failures.checkCondition; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -105,13 +105,12 @@ private String getTimeZoneId( throw new TrinoException(TYPE_MISMATCH, format("Expected expression of varchar or interval day-time type, but '%s' has %s type", expression, type.getDisplayName())); } - Object timeZoneValue = evaluateConstantExpression( + Object timeZoneValue = evaluateConstant( expression, type, plannerContext, stateMachine.getSession(), - accessControl, - parameterLookup); + accessControl); TimeZoneKey timeZoneKey; if (timeZoneValue instanceof Slice) { diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryArrowConnectorSmokeTest.java b/core/trino-main/src/main/java/io/trino/execution/resourcegroups/ResourceGroupInfoProvider.java similarity index 51% rename from plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryArrowConnectorSmokeTest.java rename to core/trino-main/src/main/java/io/trino/execution/resourcegroups/ResourceGroupInfoProvider.java index 1c9bd37a9c83b1..43a43b82120b9a 100644 --- a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryArrowConnectorSmokeTest.java +++ b/core/trino-main/src/main/java/io/trino/execution/resourcegroups/ResourceGroupInfoProvider.java @@ -11,21 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.plugin.bigquery; +package io.trino.execution.resourcegroups; -import com.google.common.collect.ImmutableMap; -import io.trino.testing.QueryRunner; +import io.trino.server.ResourceGroupInfo; +import io.trino.spi.resourcegroups.ResourceGroupId; -public class TestBigQueryArrowConnectorSmokeTest - extends BaseBigQueryConnectorSmokeTest +import java.util.List; +import java.util.Optional; + +public interface ResourceGroupInfoProvider { - @Override - protected QueryRunner createQueryRunner() - throws Exception - { - return BigQueryQueryRunner.createQueryRunner( - ImmutableMap.of(), - ImmutableMap.of("bigquery.experimental.arrow-serialization.enabled", "true"), - REQUIRED_TPCH_TABLES); - } + Optional tryGetResourceGroupInfo(ResourceGroupId id); + + Optional> tryGetPathToRoot(ResourceGroupId id); } diff --git a/core/trino-main/src/main/java/io/trino/execution/resourcegroups/ResourceGroupManager.java b/core/trino-main/src/main/java/io/trino/execution/resourcegroups/ResourceGroupManager.java index 43d23bae727f25..307855ee24363f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/resourcegroups/ResourceGroupManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/resourcegroups/ResourceGroupManager.java @@ -15,14 +15,10 @@ import com.google.errorprone.annotations.ThreadSafe; import io.trino.execution.ManagedQueryExecution; -import io.trino.server.ResourceGroupInfo; import io.trino.spi.resourcegroups.ResourceGroupConfigurationManagerFactory; -import io.trino.spi.resourcegroups.ResourceGroupId; import io.trino.spi.resourcegroups.SelectionContext; import io.trino.spi.resourcegroups.SelectionCriteria; -import java.util.List; -import java.util.Optional; import java.util.concurrent.Executor; /** @@ -31,15 +27,12 @@ */ @ThreadSafe public interface ResourceGroupManager + extends ResourceGroupInfoProvider { void submit(ManagedQueryExecution queryExecution, SelectionContext selectionContext, Executor executor); SelectionContext selectGroup(SelectionCriteria criteria); - Optional tryGetResourceGroupInfo(ResourceGroupId id); - - Optional> tryGetPathToRoot(ResourceGroupId id); - void addConfigurationManagerFactory(ResourceGroupConfigurationManagerFactory factory); void loadConfigurationManager() diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByEagerParentOutputStatsEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByEagerParentOutputStatsEstimator.java index 05b39597ac1785..90a1fbebf2434f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByEagerParentOutputStatsEstimator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByEagerParentOutputStatsEstimator.java @@ -49,6 +49,6 @@ public Optional getEstimatedOutputStats(StageExecutio for (int i = 0; i < outputPartitionsCount; ++i) { estimateBuilder.add(0); } - return Optional.of(new OutputStatsEstimateResult(estimateBuilder.build(), 0, OutputStatsEstimateStatus.ESTIMATED_FOR_EAGER_PARENT)); + return Optional.of(new OutputStatsEstimateResult(estimateBuilder.build(), 0, "FOR_EAGER_PARENT")); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/BySmallStageOutputStatsEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/BySmallStageOutputStatsEstimator.java index aa22dc010b3162..ca9f210c90c339 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/BySmallStageOutputStatsEstimator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/BySmallStageOutputStatsEstimator.java @@ -143,6 +143,6 @@ public Optional getEstimatedOutputStats(StageExecutio estimateBuilder.add(inputSizeEstimate / outputPartitionsCount); } // TODO: For now we can skip calculating outputRowCountEstimate since we won't run adaptive planner in the case of small inputs - return Optional.of(new OutputStatsEstimateResult(estimateBuilder.build(), 0, OutputStatsEstimateStatus.ESTIMATED_BY_SMALL_INPUT)); + return Optional.of(new OutputStatsEstimateResult(estimateBuilder.build(), 0, "BY_SMALL_INPUT")); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByTaskProgressOutputStatsEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByTaskProgressOutputStatsEstimator.java index 4c37ee8ae8c070..25ea6f610f8e96 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByTaskProgressOutputStatsEstimator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByTaskProgressOutputStatsEstimator.java @@ -72,6 +72,6 @@ public Optional getEstimatedOutputStats(StageExecutio estimateBuilder.add((long) (partitionSize / progress)); } long outputRowCountEstimate = (long) (stageExecution.getOutputRowCount() / progress); - return Optional.of(new OutputStatsEstimateResult(new OutputDataSizeEstimate(estimateBuilder.build()), outputRowCountEstimate, OutputStatsEstimateStatus.ESTIMATED_BY_PROGRESS)); + return Optional.of(new OutputStatsEstimateResult(new OutputDataSizeEstimate(estimateBuilder.build()), outputRowCountEstimate, "BY_PROGRESS")); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java index 4749eca08b3ef3..221e0ded36dbc3 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java @@ -70,7 +70,6 @@ import io.trino.execution.scheduler.TaskExecutionStats; import io.trino.execution.scheduler.faulttolerant.NodeAllocator.NodeLease; import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult; -import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateStatus; import io.trino.execution.scheduler.faulttolerant.PartitionMemoryEstimator.MemoryRequirements; import io.trino.execution.scheduler.faulttolerant.SplitAssigner.AssignmentResult; import io.trino.execution.scheduler.faulttolerant.SplitAssigner.Partition; @@ -123,6 +122,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Queue; import java.util.Set; @@ -1169,7 +1169,13 @@ private void updateStageExecutions() if (stageExecution == null) { IsReadyForExecutionResult result = isReadyForExecutionCache.computeIfAbsent(subPlan, ignored -> isReadyForExecution(subPlan)); if (result.isReadyForExecution()) { - createStageExecution(subPlan, fragmentId.equals(rootFragmentId), result.getSourceOutputSizeEstimates(), nextSchedulingPriority++, result.isEager()); + createStageExecution( + subPlan, + fragmentId.equals(rootFragmentId), + result.getSourceOutputSizeEstimates(), + nextSchedulingPriority++, + result.isSpeculative(), + result.isEager()); } } if (stageExecution != null && stageExecution.getState().equals(StageState.FINISHED) && !stageExecution.isExchangeClosed()) { @@ -1189,22 +1195,27 @@ private void updateStageExecutions() private static class IsReadyForExecutionResult { private final boolean readyForExecution; + private final boolean speculative; private final Optional> sourceOutputSizeEstimates; private final boolean eager; @CheckReturnValue - public static IsReadyForExecutionResult ready(Map sourceOutputSizeEstimates, boolean eager) + public static IsReadyForExecutionResult ready(Map sourceOutputSizeEstimates, boolean eager, boolean speculative) { - return new IsReadyForExecutionResult(true, Optional.of(sourceOutputSizeEstimates), eager); + return new IsReadyForExecutionResult(true, Optional.of(sourceOutputSizeEstimates), eager, speculative); } @CheckReturnValue public static IsReadyForExecutionResult notReady() { - return new IsReadyForExecutionResult(false, Optional.empty(), false); + return new IsReadyForExecutionResult(false, Optional.empty(), false, false); } - private IsReadyForExecutionResult(boolean readyForExecution, Optional> sourceOutputSizeEstimates, boolean eager) + private IsReadyForExecutionResult( + boolean readyForExecution, + Optional> sourceOutputSizeEstimates, + boolean eager, + boolean speculative) { requireNonNull(sourceOutputSizeEstimates, "sourceOutputSizeEstimates is null"); if (readyForExecution) { @@ -1214,6 +1225,7 @@ private IsReadyForExecutionResult(boolean readyForExecution, Optional estimateCountByKind = new HashMap<>(); ImmutableMap.Builder sourceOutputStatsEstimates = ImmutableMap.builder(); @@ -1281,7 +1296,7 @@ private IsReadyForExecutionResult isReadyForExecution(SubPlan subPlan) else { // source stage finished; no more checks needed OutputStatsEstimateResult result = sourceStageExecution.getOutputStats(stageExecutions::get, eager).orElseThrow(); - verify(result.status() == OutputStatsEstimateStatus.FINISHED, "expected FINISHED status but got %s", result.status()); + verify(Objects.equals(result.kind(), "FINISHED"), "expected FINISHED status but got %s", result.kind()); finishedSourcesCount++; sourceOutputStatsEstimates.put(sourceStageExecution.getStageId(), result.outputDataSizeEstimate()); someSourcesMadeProgress = true; @@ -1302,12 +1317,7 @@ private IsReadyForExecutionResult isReadyForExecution(SubPlan subPlan) return IsReadyForExecutionResult.notReady(); } - switch (result.orElseThrow().status()) { - case ESTIMATED_BY_PROGRESS -> estimatedByProgressSourcesCount++; - case ESTIMATED_BY_SMALL_INPUT -> estimatedBySmallInputSourcesCount++; - case ESTIMATED_FOR_EAGER_PARENT -> estimatedForEagerParent++; - default -> throw new IllegalStateException(format("unexpected status %s", result.orElseThrow().status())); // FINISHED handled above - } + estimateCountByKind.compute(result.orElseThrow().kind(), (k, v) -> v == null ? 0 : v + 1); sourceOutputStatsEstimates.put(sourceStageExecution.getStageId(), result.orElseThrow().outputDataSizeEstimate()); someSourcesMadeProgress = someSourcesMadeProgress || sourceStageExecution.isSomeProgressMade(); @@ -1318,15 +1328,13 @@ private IsReadyForExecutionResult isReadyForExecution(SubPlan subPlan) } if (speculative) { - log.debug("scheduling speculative %s/%s; sources: finished=%s; estimatedByProgress=%s; estimatedSmall=%s; estimatedForEagerParent=%s", + log.debug("scheduling speculative %s/%s; sources: finished=%s; kinds=%s", queryStateMachine.getQueryId(), subPlan.getFragment().getId(), finishedSourcesCount, - estimatedByProgressSourcesCount, - estimatedBySmallInputSourcesCount, - estimatedForEagerParent); + estimateCountByKind); } - return IsReadyForExecutionResult.ready(sourceOutputStatsEstimates.buildOrThrow(), eager); + return IsReadyForExecutionResult.ready(sourceOutputStatsEstimates.buildOrThrow(), eager, speculative); } private boolean shouldScheduleEagerly(SubPlan subPlan) @@ -1382,7 +1390,13 @@ private void closeSourceExchanges(SubPlan subPlan) } } - private void createStageExecution(SubPlan subPlan, boolean rootFragment, Map sourceOutputSizeEstimates, int schedulingPriority, boolean eager) + private void createStageExecution( + SubPlan subPlan, + boolean rootFragment, + Map sourceOutputSizeEstimates, + int schedulingPriority, + boolean speculative, + boolean eager) { Closer closer = Closer.create(); @@ -1481,6 +1495,7 @@ private void createStageExecution(SubPlan subPlan, boolean rootFragment, Map updatePartition( partition.addSplits(planNodeId, splits, noMoreSplits); if (readyForScheduling && !partition.isTaskScheduled()) { partition.setTaskScheduled(true); - return Optional.of(PrioritizedScheduledTask.createSpeculative(stage.getStageId(), taskPartitionId, schedulingPriority, eager)); + PrioritizedScheduledTask task = speculative ? + PrioritizedScheduledTask.createSpeculative(stage.getStageId(), taskPartitionId, schedulingPriority, eager) : + PrioritizedScheduledTask.create(stage.getStageId(), taskPartitionId, schedulingPriority); + return Optional.of(task); } return Optional.empty(); } @@ -2410,7 +2435,7 @@ public Optional getOutputStats(Function T accept(EventListener listener) { @@ -3047,7 +3073,8 @@ public T accept(EventListener listener) } }; - Event WAKE_UP = new Event() { + Event WAKE_UP = new Event() + { @Override public T accept(EventListener listener) { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/OutputStatsEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/OutputStatsEstimator.java index 2002d74733bfd4..15a6b6d34b571d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/OutputStatsEstimator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/OutputStatsEstimator.java @@ -29,27 +29,20 @@ Optional getEstimatedOutputStats( Function stageExecutionLookup, boolean parentEager); - enum OutputStatsEstimateStatus { - FINISHED, - ESTIMATED_BY_PROGRESS, - ESTIMATED_BY_SMALL_INPUT, - ESTIMATED_FOR_EAGER_PARENT - } - record OutputStatsEstimateResult( OutputDataSizeEstimate outputDataSizeEstimate, long outputRowCountEstimate, - OutputStatsEstimateStatus status) + String kind) { - OutputStatsEstimateResult(ImmutableLongArray partitionDataSizes, long outputRowCountEstimate, OutputStatsEstimateStatus status) + OutputStatsEstimateResult(ImmutableLongArray partitionDataSizes, long outputRowCountEstimate, String kind) { - this(new OutputDataSizeEstimate(partitionDataSizes), outputRowCountEstimate, status); + this(new OutputDataSizeEstimate(partitionDataSizes), outputRowCountEstimate, kind); } public OutputStatsEstimateResult { requireNonNull(outputDataSizeEstimate, "outputDataSizeEstimate is null"); - requireNonNull(status, "status is null"); + requireNonNull(kind, "kind is null"); } } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/PropertyUtil.java b/core/trino-main/src/main/java/io/trino/metadata/PropertyUtil.java index af206784ea5960..bc9e3ed7a6734b 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/PropertyUtil.java +++ b/core/trino-main/src/main/java/io/trino/metadata/PropertyUtil.java @@ -34,7 +34,7 @@ import java.util.Optional; import static io.trino.spi.type.TypeUtils.writeNativeValue; -import static io.trino.sql.analyzer.ExpressionInterpreter.evaluateConstantExpression; +import static io.trino.sql.analyzer.ConstantEvaluator.evaluateConstant; import static java.lang.String.format; import static java.util.Locale.ENGLISH; @@ -144,7 +144,7 @@ public static Object evaluateProperty( Object sqlObjectValue; try { Expression rewritten = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(parameters), expression); - Object value = evaluateConstantExpression(rewritten, propertyType, plannerContext, session, accessControl, parameters); + Object value = evaluateConstant(rewritten, propertyType, plannerContext, session, accessControl); // convert to object value type of SQL type Block block = writeNativeValue(propertyType, value); diff --git a/core/trino-main/src/main/java/io/trino/metadata/SessionPropertyManager.java b/core/trino-main/src/main/java/io/trino/metadata/SessionPropertyManager.java index 10634be3ba5ffd..d54e8c9dddf9f5 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SessionPropertyManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SessionPropertyManager.java @@ -56,7 +56,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; import static io.trino.spi.type.TypeUtils.writeNativeValue; -import static io.trino.sql.analyzer.ExpressionInterpreter.evaluateConstantExpression; +import static io.trino.sql.analyzer.ConstantEvaluator.evaluateConstant; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -217,7 +217,7 @@ private static T decodePropertyValue(String fullPropertyName, @Nullable Stri public static Object evaluatePropertyValue(Expression expression, Type expectedType, Session session, PlannerContext plannerContext, AccessControl accessControl, Map, Expression> parameters) { Expression rewritten = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(parameters), expression); - Object value = evaluateConstantExpression(rewritten, expectedType, plannerContext, session, accessControl, parameters); + Object value = evaluateConstant(rewritten, expectedType, plannerContext, session, accessControl); // convert to object value type of SQL type Block block = writeNativeValue(expectedType, value); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraysOverlapFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraysOverlapFunction.java index 0de29e4b1d7901..2ec2f47cd0c8b4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraysOverlapFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraysOverlapFunction.java @@ -21,170 +21,67 @@ import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; -import io.trino.spi.function.TypeParameterSpecialization; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; -import io.trino.type.BlockTypeOperators.BlockPositionComparison; -import it.unimi.dsi.fastutil.ints.IntArrays; -import it.unimi.dsi.fastutil.ints.IntComparator; -import it.unimi.dsi.fastutil.longs.LongArrays; +import io.trino.type.BlockTypeOperators; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.OperatorType.COMPARISON_UNORDERED_LAST; +import static io.trino.spi.function.OperatorType.HASH_CODE; +import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM; @ScalarFunction("arrays_overlap") @Description("Returns true if arrays have common elements") public final class ArraysOverlapFunction { - private int[] leftPositions = new int[0]; - private int[] rightPositions = new int[0]; - - private long[] leftLongArray = new long[0]; - private long[] rightLongArray = new long[0]; - - @TypeParameter("E") - public ArraysOverlapFunction(@TypeParameter("E") Type elementType) {} + private ArraysOverlapFunction() {} @SqlNullable @TypeParameter("E") - @TypeParameterSpecialization(name = "E", nativeContainerType = long.class) @SqlType(StandardTypes.BOOLEAN) - public Boolean arraysOverlapInt( - @OperatorDependency( - operator = COMPARISON_UNORDERED_LAST, - argumentTypes = {"E", "E"}, - convention = @Convention(arguments = {NEVER_NULL, NEVER_NULL}, result = FAIL_ON_NULL)) LongComparison comparisonOperator, + public static Boolean arraysOverlap( @TypeParameter("E") Type type, - @SqlType("array(E)") Block leftArray, - @SqlType("array(E)") Block rightArray) - { - int leftSize = leftArray.getPositionCount(); - int rightSize = rightArray.getPositionCount(); - - if (leftSize == 0 || rightSize == 0) { - return false; - } - - if (leftLongArray.length < leftSize) { - leftLongArray = new long[leftSize * 2]; - } - if (rightLongArray.length < rightSize) { - rightLongArray = new long[rightSize * 2]; - } - - int leftNonNullSize = sortLongArray(leftArray, leftLongArray, type, comparisonOperator); - int rightNonNullSize = sortLongArray(rightArray, rightLongArray, type, comparisonOperator); - - int leftPosition = 0; - int rightPosition = 0; - while (leftPosition < leftNonNullSize && rightPosition < rightNonNullSize) { - long compareValue = comparisonOperator.compare(leftLongArray[leftPosition], rightLongArray[rightPosition]); - if (compareValue > 0) { - rightPosition++; - } - else if (compareValue < 0) { - leftPosition++; - } - else { - return true; - } - } - return (leftNonNullSize < leftSize) || (rightNonNullSize < rightSize) ? null : false; - } - - // Assumes buffer is long enough, returns count of non-null elements. - private static int sortLongArray(Block array, long[] buffer, Type type, LongComparison comparisonOperator) - { - int arraySize = array.getPositionCount(); - int nonNullSize = 0; - for (int i = 0; i < arraySize; i++) { - if (!array.isNull(i)) { - buffer[nonNullSize++] = type.getLong(array, i); - } - } - - LongArrays.unstableSort(buffer, 0, nonNullSize, (left, right) -> (int) comparisonOperator.compare(left, right)); - - return nonNullSize; - } - - @SqlNullable - @TypeParameter("E") - @SqlType(StandardTypes.BOOLEAN) - public Boolean arraysOverlap( @OperatorDependency( - operator = COMPARISON_UNORDERED_LAST, + operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) BlockPositionComparison comparisonOperator, - @TypeParameter("E") Type type, + convention = @Convention(arguments = {BLOCK_POSITION, + BLOCK_POSITION}, result = FAIL_ON_NULL)) BlockTypeOperators.BlockPositionIsDistinctFrom elementIsDistinctFrom, + @OperatorDependency( + operator = HASH_CODE, + argumentTypes = "E", + convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockTypeOperators.BlockPositionHashCode elementHashCode, @SqlType("array(E)") Block leftArray, @SqlType("array(E)") Block rightArray) { - int leftPositionCount = leftArray.getPositionCount(); - int rightPositionCount = rightArray.getPositionCount(); - - if (leftPositionCount == 0 || rightPositionCount == 0) { - return false; + Block smaller = rightArray; + Block larger = leftArray; + if (leftArray.getPositionCount() < rightArray.getPositionCount()) { + smaller = leftArray; + larger = rightArray; } - if (leftPositions.length < leftPositionCount) { - leftPositions = new int[leftPositionCount * 2]; - } + int largerPositionCount = larger.getPositionCount(); + int smallerPositionCount = smaller.getPositionCount(); - if (rightPositions.length < rightPositionCount) { - rightPositions = new int[rightPositionCount * 2]; + if (largerPositionCount == 0 || smallerPositionCount == 0) { + return false; } - for (int i = 0; i < leftPositionCount; i++) { - leftPositions[i] = i; + BlockSet smallerSet = new BlockSet(type, elementIsDistinctFrom, elementHashCode, smallerPositionCount); + for (int position = 0; position < smallerPositionCount; position++) { + smallerSet.add(smaller, position); } - for (int i = 0; i < rightPositionCount; i++) { - rightPositions[i] = i; - } - IntArrays.quickSort(leftPositions, 0, leftPositionCount, intBlockCompare(comparisonOperator, leftArray)); - IntArrays.quickSort(rightPositions, 0, rightPositionCount, intBlockCompare(comparisonOperator, rightArray)); - int leftPosition = 0; - int rightPosition = 0; - while (leftPosition < leftPositionCount && rightPosition < rightPositionCount) { - if (leftArray.isNull(leftPositions[leftPosition]) || rightArray.isNull(rightPositions[rightPosition])) { - // Nulls are in the end of the array. Non-null elements do not overlap. - return null; - } - long compareValue = comparisonOperator.compare(leftArray, leftPositions[leftPosition], rightArray, rightPositions[rightPosition]); - if (compareValue > 0) { - rightPosition++; - } - else if (compareValue < 0) { - leftPosition++; + boolean largerContainsNull = false; + for (int position = 0; position < largerPositionCount; position++) { + if (larger.isNull(position)) { + largerContainsNull = true; + continue; } - else { + if (smallerSet.contains(larger, position)) { return true; } } - return leftArray.isNull(leftPositions[leftPositionCount - 1]) || rightArray.isNull(rightPositions[rightPositionCount - 1]) ? null : false; - } - - private static IntComparator intBlockCompare(BlockPositionComparison comparisonOperator, Block block) - { - return (left, right) -> { - if (block.isNull(left) && block.isNull(right)) { - return 0; - } - if (block.isNull(left)) { - return 1; - } - if (block.isNull(right)) { - return -1; - } - return (int) comparisonOperator.compare(block, left, block, right); - }; - } - - public interface LongComparison - { - long compare(long left, long right); + return largerContainsNull || smallerSet.containsNullElement() ? null : false; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/BlockSet.java b/core/trino-main/src/main/java/io/trino/operator/scalar/BlockSet.java index c595dabcca7cca..be49ce12fae3b2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/BlockSet.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/BlockSet.java @@ -136,6 +136,14 @@ public int size() return size; } + /** + * Returns whether this set contains a NULL element + */ + public boolean containsNullElement() + { + return containsNullElement; + } + /** * Return the position of the value within this set, or -1 if the value is not in this set. * This method can not get the position of a null value, and an exception will be thrown in that case. diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index 88eac66c51cbdf..e24251a3895de8 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -64,6 +64,7 @@ import io.trino.execution.TaskStatus; import io.trino.execution.resourcegroups.InternalResourceGroupManager; import io.trino.execution.resourcegroups.LegacyResourceGroupConfigurationManager; +import io.trino.execution.resourcegroups.ResourceGroupInfoProvider; import io.trino.execution.resourcegroups.ResourceGroupManager; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.TaskExecutionStats; @@ -198,6 +199,7 @@ protected void setup(Binder binder) binder.bind(QueryManager.class).to(SqlQueryManager.class).in(Scopes.SINGLETON); binder.bind(QueryPreparer.class).in(Scopes.SINGLETON); OptionalBinder.newOptionalBinder(binder, SessionSupplier.class).setDefault().to(QuerySessionSupplier.class).in(Scopes.SINGLETON); + binder.bind(ResourceGroupInfoProvider.class).to(ResourceGroupManager.class).in(Scopes.SINGLETON); binder.bind(InternalResourceGroupManager.class).in(Scopes.SINGLETON); newExporter(binder).export(InternalResourceGroupManager.class).withGeneratedName(); binder.bind(ResourceGroupManager.class).to(InternalResourceGroupManager.class); diff --git a/core/trino-main/src/main/java/io/trino/server/QueryStateInfoResource.java b/core/trino-main/src/main/java/io/trino/server/QueryStateInfoResource.java index df07b0bff9ee36..7c48362c393fad 100644 --- a/core/trino-main/src/main/java/io/trino/server/QueryStateInfoResource.java +++ b/core/trino-main/src/main/java/io/trino/server/QueryStateInfoResource.java @@ -15,7 +15,7 @@ import com.google.inject.Inject; import io.trino.dispatcher.DispatchManager; -import io.trino.execution.resourcegroups.ResourceGroupManager; +import io.trino.execution.resourcegroups.ResourceGroupInfoProvider; import io.trino.security.AccessControl; import io.trino.server.security.ResourceSecurity; import io.trino.spi.QueryId; @@ -53,19 +53,19 @@ public class QueryStateInfoResource { private final DispatchManager dispatchManager; - private final ResourceGroupManager resourceGroupManager; + private final ResourceGroupInfoProvider resourceGroupInfoProvider; private final AccessControl accessControl; private final HttpRequestSessionContextFactory sessionContextFactory; @Inject public QueryStateInfoResource( DispatchManager dispatchManager, - ResourceGroupManager resourceGroupManager, + ResourceGroupInfoProvider resourceGroupInfoProvider, AccessControl accessControl, HttpRequestSessionContextFactory sessionContextFactory) { this.dispatchManager = requireNonNull(dispatchManager, "dispatchManager is null"); - this.resourceGroupManager = requireNonNull(resourceGroupManager, "resourceGroupManager is null"); + this.resourceGroupInfoProvider = requireNonNull(resourceGroupInfoProvider, "resourceGroupInfoProvider is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.sessionContextFactory = requireNonNull(sessionContextFactory, "sessionContextFactory is null"); } @@ -97,7 +97,7 @@ private QueryStateInfo getQueryStateInfo(BasicQueryInfo queryInfo) return createQueuedQueryStateInfo( queryInfo, groupId, - groupId.map(group -> resourceGroupManager.tryGetPathToRoot(group) + groupId.map(group -> resourceGroupInfoProvider.tryGetPathToRoot(group) .orElseThrow(() -> new IllegalStateException("Resource group not found: " + group)))); } return createQueryStateInfo(queryInfo, groupId); diff --git a/core/trino-main/src/main/java/io/trino/server/ResourceGroupStateInfoResource.java b/core/trino-main/src/main/java/io/trino/server/ResourceGroupStateInfoResource.java index dc31935bd0584c..de7c802d6921ef 100644 --- a/core/trino-main/src/main/java/io/trino/server/ResourceGroupStateInfoResource.java +++ b/core/trino-main/src/main/java/io/trino/server/ResourceGroupStateInfoResource.java @@ -14,7 +14,7 @@ package io.trino.server; import com.google.inject.Inject; -import io.trino.execution.resourcegroups.ResourceGroupManager; +import io.trino.execution.resourcegroups.ResourceGroupInfoProvider; import io.trino.server.security.ResourceSecurity; import io.trino.spi.resourcegroups.ResourceGroupId; import jakarta.ws.rs.Encoded; @@ -38,12 +38,12 @@ @Path("/v1/resourceGroupState") public class ResourceGroupStateInfoResource { - private final ResourceGroupManager resourceGroupManager; + private final ResourceGroupInfoProvider resourceGroupInfoProvider; @Inject - public ResourceGroupStateInfoResource(ResourceGroupManager resourceGroupManager) + public ResourceGroupStateInfoResource(ResourceGroupInfoProvider resourceGroupInfoProvider) { - this.resourceGroupManager = requireNonNull(resourceGroupManager, "resourceGroupManager is null"); + this.resourceGroupInfoProvider = requireNonNull(resourceGroupInfoProvider, "resourceGroupInfoProvider is null"); } @ResourceSecurity(MANAGEMENT_READ) @@ -54,7 +54,7 @@ public ResourceGroupStateInfoResource(ResourceGroupManager resourceGroupManag public ResourceGroupInfo getQueryStateInfos(@PathParam("resourceGroupId") String resourceGroupIdString) { if (!isNullOrEmpty(resourceGroupIdString)) { - return resourceGroupManager.tryGetResourceGroupInfo( + return resourceGroupInfoProvider.tryGetResourceGroupInfo( new ResourceGroupId( Arrays.stream(resourceGroupIdString.split("/")) .map(ResourceGroupStateInfoResource::urlDecode) diff --git a/core/trino-main/src/main/java/io/trino/server/Server.java b/core/trino-main/src/main/java/io/trino/server/Server.java index a2510514c9155f..c241ae06fd9340 100644 --- a/core/trino-main/src/main/java/io/trino/server/Server.java +++ b/core/trino-main/src/main/java/io/trino/server/Server.java @@ -39,6 +39,7 @@ import io.airlift.openmetrics.JmxOpenMetricsModule; import io.airlift.tracetoken.TraceTokenModule; import io.airlift.tracing.TracingModule; +import io.airlift.units.Duration; import io.trino.client.NodeVersion; import io.trino.connector.CatalogManagerConfig; import io.trino.connector.CatalogManagerConfig.CatalogMangerKind; @@ -93,6 +94,7 @@ public final void start(String trinoVersion) private void doStart(String trinoVersion) { + long startTime = System.nanoTime(); verifyJvmRequirements(); verifySystemTimeIsReasonable(); @@ -175,6 +177,7 @@ private void doStart(String trinoVersion) injector.getInstance(StartupStatus.class).startupComplete(); + log.info("Server startup completed in %s", Duration.nanosSince(startTime).convertToMostSuccinctTimeUnit()); log.info("======== SERVER STARTED ========"); } catch (ApplicationConfigurationException e) { diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java index 318be337ea6e86..3b73dd5c1086e0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java @@ -563,11 +563,6 @@ protected Boolean visitIdentifier(Identifier node, Void context) @Override protected Boolean visitDereferenceExpression(DereferenceExpression node, Void context) { - ExpressionAnalyzer.LabelPrefixedReference labelDereference = analysis.getLabelDereference(node); - if (labelDereference != null) { - return labelDereference.getColumn().map(this::process).orElse(true); - } - if (!hasReferencesToScope(node, analysis, sourceScope)) { // reference to outer scope is group-invariant return true; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 5640d809c4df85..9d76f544feb9ab 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -55,11 +55,11 @@ import io.trino.spi.security.Identity; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.sql.analyzer.ExpressionAnalyzer.LabelPrefixedReference; import io.trino.sql.analyzer.JsonPathAnalyzer.JsonPathAnalysis; +import io.trino.sql.analyzer.PatternRecognitionAnalysis.PatternInputAnalysis; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.tree.AllColumns; -import io.trino.sql.tree.DereferenceExpression; +import io.trino.sql.tree.DataType; import io.trino.sql.tree.ExistsPredicate; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FieldReference; @@ -87,6 +87,7 @@ import io.trino.sql.tree.SampledRelation; import io.trino.sql.tree.Statement; import io.trino.sql.tree.SubqueryExpression; +import io.trino.sql.tree.SubsetDefinition; import io.trino.sql.tree.Table; import io.trino.sql.tree.TableFunctionInvocation; import io.trino.sql.tree.Unnest; @@ -156,17 +157,22 @@ public class Analysis private final Map>> tableColumnReferences = new LinkedHashMap<>(); // Record fields prefixed with labels in row pattern recognition context - private final Map, LabelPrefixedReference> labelDereferences = new LinkedHashMap<>(); - - private final Set> patternRecognitionFunctions = new LinkedHashSet<>(); - + private final Map, Optional> labels = new LinkedHashMap<>(); private final Map, Range> ranges = new LinkedHashMap<>(); - private final Map, Set> undefinedLabels = new LinkedHashMap<>(); - private final Map, MeasureDefinition> measureDefinitions = new LinkedHashMap<>(); - private final Set> patternAggregations = new LinkedHashSet<>(); + // Pattern function analysis (classifier, match_number, aggregations and prev/next/first/last) in the context of the given node + private final Map, List> patternInputsAnalysis = new LinkedHashMap<>(); + + // FunctionCall nodes corresponding to any of the special pattern recognition functions + private final Set> patternRecognitionFunctionCalls = new LinkedHashSet<>(); + + // FunctionCall nodes corresponding to any of the navigation functions (prev/next/first/last) + private final Set> patternNavigationFunctions = new LinkedHashSet<>(); + + private final Map, String> resolvedLabels = new LinkedHashMap<>(); + private final Map, Set> subsets = new LinkedHashMap<>(); // for JSON features private final Map, JsonPathAnalysis> jsonPathAnalyses = new LinkedHashMap<>(); @@ -348,6 +354,11 @@ public Map, Type> getTypes() return unmodifiableMap(types); } + public boolean isAnalyzed(Expression expression) + { + return expression instanceof DataType || types.containsKey(NodeRef.of(expression)); + } + public Type getType(Expression expression) { Type type = types.get(NodeRef.of(expression)); @@ -372,6 +383,7 @@ public Map, Type> getCoercions() public Type getCoercion(Expression expression) { + checkArgument(isAnalyzed(expression), "Expression has not been analyzed (%s): %s", expression.getClass().getName(), expression); return coercions.get(NodeRef.of(expression)); } @@ -959,24 +971,37 @@ public void addEmptyColumnReferencesForTable(AccessControl accessControl, Identi tableColumnReferences.computeIfAbsent(accessControlInfo, k -> new LinkedHashMap<>()).computeIfAbsent(table, k -> new HashSet<>()); } - public void addLabelDereferences(Map, LabelPrefixedReference> dereferences) + public void addLabels(Map, Optional> labels) + { + this.labels.putAll(labels); + } + + public void addPatternRecognitionInputs(Map, List> functions) { - labelDereferences.putAll(dereferences); + patternInputsAnalysis.putAll(functions); + + functions.values().stream() + .flatMap(List::stream) + .map(PatternInputAnalysis::expression) + .filter(FunctionCall.class::isInstance) + .map(FunctionCall.class::cast) + .map(NodeRef::of) + .forEach(patternRecognitionFunctionCalls::add); } - public LabelPrefixedReference getLabelDereference(DereferenceExpression expression) + public void addPatternNavigationFunctions(Set> functions) { - return labelDereferences.get(NodeRef.of(expression)); + patternNavigationFunctions.addAll(functions); } - public void addPatternRecognitionFunctions(Set> functions) + public Optional getLabel(Expression expression) { - patternRecognitionFunctions.addAll(functions); + return labels.get(NodeRef.of(expression)); } public boolean isPatternRecognitionFunction(FunctionCall functionCall) { - return patternRecognitionFunctions.contains(NodeRef.of(functionCall)); + return patternRecognitionFunctionCalls.contains(NodeRef.of(functionCall)); } public void setRanges(Map, Range> quantifierRanges) @@ -1018,16 +1043,6 @@ public MeasureDefinition getMeasureDefinition(WindowOperation measure) return measureDefinitions.get(NodeRef.of(measure)); } - public void setPatternAggregations(Set> aggregations) - { - patternAggregations.addAll(aggregations); - } - - public boolean isPatternAggregation(FunctionCall function) - { - return patternAggregations.contains(NodeRef.of(function)); - } - public void setJsonPathAnalyses(Map, JsonPathAnalysis> pathAnalyses) { jsonPathAnalyses.putAll(pathAnalyses); @@ -1328,6 +1343,46 @@ private boolean isInsertTarget(Table table) .orElse(FALSE); } + public List getPatternInputsAnalysis(Expression expression) + { + return patternInputsAnalysis.get(NodeRef.of(expression)); + } + + public boolean isPatternNavigationFunction(FunctionCall node) + { + return patternNavigationFunctions.contains(NodeRef.of(node)); + } + + public String getResolvedLabel(Identifier identifier) + { + return resolvedLabels.get(NodeRef.of(identifier)); + } + + public Set getLabels(SubsetDefinition subset) + { + return subsets.get(NodeRef.of(subset)); + } + + public void addSubsetLabels(SubsetDefinition subset, Set labels) + { + subsets.put(NodeRef.of(subset), labels); + } + + public void addSubsetLabels(Map, Set> subsets) + { + this.subsets.putAll(subsets); + } + + public void addResolvedLabel(Identifier label, String resolved) + { + resolvedLabels.put(NodeRef.of(label), resolved); + } + + public void addResolvedLabels(Map, String> labels) + { + resolvedLabels.putAll(labels); + } + @Immutable public static final class SelectExpression { diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ConstantEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ConstantEvaluator.java new file mode 100644 index 00000000000000..b8fbd2ff348398 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ConstantEvaluator.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.analyzer; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; +import io.trino.security.AccessControl; +import io.trino.spi.type.Type; +import io.trino.sql.PlannerContext; +import io.trino.sql.planner.IrExpressionInterpreter; +import io.trino.sql.planner.IrTypeAnalyzer; +import io.trino.sql.planner.TranslationMap; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.tree.Cast; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.NodeRef; +import io.trino.type.TypeCoercion; + +import java.util.Map; +import java.util.Optional; + +import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_CONSTANT; +import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; +import static io.trino.sql.analyzer.SemanticExceptions.semanticException; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; + +public class ConstantEvaluator +{ + private ConstantEvaluator() {} + + public static Object evaluateConstant( + Expression expression, + Type expectedType, + PlannerContext plannerContext, + Session session, + AccessControl accessControl) + { + Analysis analysis = new Analysis(null, ImmutableMap.of(), QueryType.OTHERS); + Scope scope = Scope.create(); + ExpressionAnalyzer.analyzeExpressionWithoutSubqueries( + session, + plannerContext, + accessControl, + scope, + analysis, + expression, + EXPRESSION_NOT_CONSTANT, + "Constant expression cannot contain a subquery", + WarningCollector.NOOP, + CorrelationSupport.DISALLOWED); + + TranslationMap translationMap = new TranslationMap(Optional.empty(), scope, analysis, ImmutableMap.of(), ImmutableList.of(), session, plannerContext); + Expression rewritten = translationMap.rewrite(expression); + + IrTypeAnalyzer analyzer = new IrTypeAnalyzer(plannerContext); + Map, Type> types = analyzer.getTypes(session, TypeProvider.empty(), rewritten); + + Type actualType = types.get(NodeRef.of(rewritten)); + if (!new TypeCoercion(plannerContext.getTypeManager()::getType).canCoerce(actualType, expectedType)) { + throw semanticException(TYPE_MISMATCH, expression, "Cannot cast type %s to %s", actualType.getDisplayName(), expectedType.getDisplayName()); + } + + if (!actualType.equals(expectedType)) { + rewritten = new Cast(rewritten, toSqlType(expectedType), false); + types = analyzer.getTypes(session, TypeProvider.empty(), rewritten); + } + + return new IrExpressionInterpreter(rewritten, plannerContext, session, types).evaluate(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index 0071dd1e6333b1..41e4483c13f086 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -18,10 +18,10 @@ import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.LinkedHashMultimap; import com.google.common.collect.Multimap; -import com.google.common.collect.Streams; import io.trino.Session; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionResolver; @@ -57,13 +57,20 @@ import io.trino.sql.analyzer.Analysis.Range; import io.trino.sql.analyzer.Analysis.ResolvedWindow; import io.trino.sql.analyzer.JsonPathAnalyzer.JsonPathAnalysis; -import io.trino.sql.analyzer.PatternRecognitionAnalyzer.PatternRecognitionAnalysis; +import io.trino.sql.analyzer.PatternRecognitionAnalysis.AggregationDescriptor; +import io.trino.sql.analyzer.PatternRecognitionAnalysis.ClassifierDescriptor; +import io.trino.sql.analyzer.PatternRecognitionAnalysis.MatchNumberDescriptor; +import io.trino.sql.analyzer.PatternRecognitionAnalysis.Navigation; +import io.trino.sql.analyzer.PatternRecognitionAnalysis.NavigationMode; +import io.trino.sql.analyzer.PatternRecognitionAnalysis.PatternInputAnalysis; +import io.trino.sql.analyzer.PatternRecognitionAnalysis.ScalarInputDescriptor; import io.trino.sql.planner.LiteralInterpreter; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeProvider; import io.trino.sql.tree.ArithmeticBinaryExpression; import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.Array; +import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.AtTimeZone; import io.trino.sql.tree.BetweenPredicate; import io.trino.sql.tree.BinaryLiteral; @@ -134,12 +141,13 @@ import io.trino.sql.tree.RowPattern; import io.trino.sql.tree.SearchedCaseExpression; import io.trino.sql.tree.SimpleCaseExpression; +import io.trino.sql.tree.SkipTo; import io.trino.sql.tree.SortItem; import io.trino.sql.tree.SortItem.Ordering; -import io.trino.sql.tree.StackableAstVisitor; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SubqueryExpression; import io.trino.sql.tree.SubscriptExpression; +import io.trino.sql.tree.SubsetDefinition; import io.trino.sql.tree.SymbolReference; import io.trino.sql.tree.Trim; import io.trino.sql.tree.TryExpression; @@ -156,15 +164,18 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -244,6 +255,8 @@ import static io.trino.sql.analyzer.ExpressionTreeUtils.extractExpressions; import static io.trino.sql.analyzer.ExpressionTreeUtils.extractLocation; import static io.trino.sql.analyzer.ExpressionTreeUtils.extractWindowExpressions; +import static io.trino.sql.analyzer.PatternRecognitionAnalysis.NavigationAnchor.FIRST; +import static io.trino.sql.analyzer.PatternRecognitionAnalysis.NavigationAnchor.LAST; import static io.trino.sql.analyzer.SemanticExceptions.missingAttributeException; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; @@ -331,13 +344,18 @@ public class ExpressionAnalyzer private final Multimap, Field> referencedFields = HashMultimap.create(); // Record fields prefixed with labels in row pattern recognition context - private final Map, LabelPrefixedReference> labelDereferences = new LinkedHashMap<>(); + private final Map, Optional> labels = new HashMap<>(); // Record functions specific to row pattern recognition context - private final Set> patternRecognitionFunctions = new LinkedHashSet<>(); private final Map, Range> ranges = new LinkedHashMap<>(); private final Map, Set> undefinedLabels = new LinkedHashMap<>(); + private final Map, String> resolvedLabels = new LinkedHashMap<>(); + private final Map, Set> subsets = new LinkedHashMap<>(); private final Map, MeasureDefinition> measureDefinitions = new LinkedHashMap<>(); - private final Set> patternAggregations = new LinkedHashSet<>(); + + // Pattern function analysis (classifier, match_number, aggregations and prev/next/first/last) in the context of the given node + private final Map, List> patternRecognitionInputs = new LinkedHashMap<>(); + + private final Set> patternNavigationFunctions = new LinkedHashSet<>(); // for JSON functions private final Map, JsonPathAnalysis> jsonPathAnalyses = new LinkedHashMap<>(); @@ -478,31 +496,44 @@ public Map, LambdaArgumentDeclaration> getLambdaArgumentRefe public Type analyze(Expression expression, Scope scope) { Visitor visitor = new Visitor(scope, warningCollector); - return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda(scope, CorrelationSupport.ALLOWED))); + + patternRecognitionInputs.put(NodeRef.of(expression), visitor.getPatternRecognitionInputs()); + + return visitor.process(expression, Context.notInLambda(scope, CorrelationSupport.ALLOWED)); } public Type analyze(Expression expression, Scope scope, CorrelationSupport correlationSupport) { Visitor visitor = new Visitor(scope, warningCollector); - return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda(scope, correlationSupport))); + + patternRecognitionInputs.put(NodeRef.of(expression), visitor.getPatternRecognitionInputs()); + + return visitor.process(expression, Context.notInLambda(scope, correlationSupport)); } - private Type analyze(Expression expression, Scope scope, Set labels) + private Type analyze(Expression expression, Scope scope, Set labels, boolean inWindow) { Visitor visitor = new Visitor(scope, warningCollector); - return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(Context.patternRecognition(scope, labels))); + Type type = visitor.process(expression, Context.patternRecognition(scope, labels, inWindow)); + + patternRecognitionInputs.put(NodeRef.of(expression), visitor.getPatternRecognitionInputs()); + + return type; } private Type analyze(Expression expression, Scope baseScope, Context context) { Visitor visitor = new Visitor(baseScope, warningCollector); - return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(context)); + + patternRecognitionInputs.put(NodeRef.of(expression), visitor.getPatternRecognitionInputs()); + + return visitor.process(expression, context); } private RowType analyzeJsonPathInvocation(JsonTable node, Scope scope, CorrelationSupport correlationSupport) { Visitor visitor = new Visitor(scope, warningCollector); - List inputTypes = visitor.analyzeJsonPathInvocation("JSON_TABLE", node, node.getJsonPathInvocation(), new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda(scope, correlationSupport))); + List inputTypes = visitor.analyzeJsonPathInvocation("JSON_TABLE", node, node.getJsonPathInvocation(), Context.notInLambda(scope, correlationSupport)); return (RowType) inputTypes.get(2); } @@ -519,7 +550,7 @@ private Type analyzeJsonValueExpression(ValueColumn column, JsonPathAnalysis pat column.getEmptyDefault(), column.getErrorBehavior(), column.getErrorDefault(), - new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda(scope, correlationSupport))); + Context.notInLambda(scope, correlationSupport)); } private Type analyzeJsonQueryExpression(QueryColumn column, Scope scope) @@ -538,7 +569,7 @@ private Type analyzeJsonQueryExpression(QueryColumn column, Scope scope) private void analyzeWindow(ResolvedWindow window, Scope scope, Node originalNode, CorrelationSupport correlationSupport) { Visitor visitor = new Visitor(scope, warningCollector); - visitor.analyzeWindow(window, new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda(scope, correlationSupport)), originalNode); + visitor.analyzeWindow(window, Context.inWindow(scope, correlationSupport), originalNode); } public Set> getSubqueries() @@ -571,14 +602,9 @@ public List getSourceFields() return sourceFields; } - public Map, LabelPrefixedReference> getLabelDereferences() - { - return labelDereferences; - } - - public Set> getPatternRecognitionFunctions() + public Map, Optional> getLabels() { - return patternRecognitionFunctions; + return labels; } public Map, Range> getRanges() @@ -591,14 +617,29 @@ public Map, Set> getUndefinedLabels() return undefinedLabels; } + public Map, String> getResolvedLabels() + { + return resolvedLabels; + } + + public Map, Set> getSubsetLabels() + { + return subsets; + } + public Map, MeasureDefinition> getMeasureDefinitions() { return measureDefinitions; } - public Set> getPatternAggregations() + public Map, List> getPatternRecognitionInputs() + { + return patternRecognitionInputs; + } + + public Set> getPatternNavigationFunctions() { - return patternAggregations; + return patternNavigationFunctions; } public Map, JsonPathAnalysis> getJsonPathAnalyses() @@ -617,20 +658,27 @@ public Map, ResolvedFunction> getJsonOutputFunctions() } private class Visitor - extends StackableAstVisitor + extends AstVisitor { // Used to resolve FieldReferences (e.g. during local execution planning) private final Scope baseScope; private final WarningCollector warningCollector; + private final List patternRecognitionInputs = new ArrayList<>(); + public Visitor(Scope baseScope, WarningCollector warningCollector) { this.baseScope = requireNonNull(baseScope, "baseScope is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); } + public List getPatternRecognitionInputs() + { + return patternRecognitionInputs; + } + @Override - public Type process(Node node, @Nullable StackableAstVisitorContext context) + public Type process(Node node, @Nullable Context context) { if (node instanceof Expression) { // don't double process a node @@ -643,7 +691,7 @@ public Type process(Node node, @Nullable StackableAstVisitorContext con } @Override - protected Type visitRow(Row node, StackableAstVisitorContext context) + protected Type visitRow(Row node, Context context) { List types = node.getItems().stream() .map(child -> process(child, context)) @@ -654,13 +702,13 @@ protected Type visitRow(Row node, StackableAstVisitorContext context) } @Override - protected Type visitCurrentDate(CurrentDate node, StackableAstVisitorContext context) + protected Type visitCurrentDate(CurrentDate node, Context context) { return setExpressionType(node, DATE); } @Override - protected Type visitCurrentTime(CurrentTime node, StackableAstVisitorContext context) + protected Type visitCurrentTime(CurrentTime node, Context context) { return setExpressionType( node, @@ -670,7 +718,7 @@ protected Type visitCurrentTime(CurrentTime node, StackableAstVisitorContext context) + protected Type visitCurrentTimestamp(CurrentTimestamp node, Context context) { return setExpressionType( node, @@ -680,7 +728,7 @@ protected Type visitCurrentTimestamp(CurrentTimestamp node, StackableAstVisitorC } @Override - protected Type visitLocalTime(LocalTime node, StackableAstVisitorContext context) + protected Type visitLocalTime(LocalTime node, Context context) { return setExpressionType( node, @@ -690,7 +738,7 @@ protected Type visitLocalTime(LocalTime node, StackableAstVisitorContext context) + protected Type visitLocalTimestamp(LocalTimestamp node, Context context) { return setExpressionType( node, @@ -700,11 +748,11 @@ protected Type visitLocalTimestamp(LocalTimestamp node, StackableAstVisitorConte } @Override - protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorContext context) + protected Type visitSymbolReference(SymbolReference node, Context context) { - if (context.getContext().isInLambda()) { - Optional resolvedField = context.getContext().getScope().tryResolveField(node, QualifiedName.of(node.getName())); - if (resolvedField.isPresent() && context.getContext().getFieldToLambdaArgumentDeclaration().containsKey(FieldId.from(resolvedField.get()))) { + if (context.isInLambda()) { + Optional resolvedField = context.getScope().tryResolveField(node, QualifiedName.of(node.getName())); + if (resolvedField.isPresent() && context.getFieldToLambdaArgumentDeclaration().containsKey(FieldId.from(resolvedField.get()))) { return setExpressionType(node, resolvedField.get().getType()); } } @@ -713,22 +761,30 @@ protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorCon } @Override - protected Type visitIdentifier(Identifier node, StackableAstVisitorContext context) + protected Type visitIdentifier(Identifier node, Context context) { - ResolvedField resolvedField = context.getContext().getScope().resolveField(node, QualifiedName.of(node.getValue())); + ResolvedField resolvedField = context.getScope().resolveField(node, QualifiedName.of(node.getValue())); + + if (context.isPatternRecognition()) { + labels.put(NodeRef.of(node), Optional.empty()); + patternRecognitionInputs.add(new PatternInputAnalysis( + node, + new ScalarInputDescriptor(Optional.empty(), context.getPatternRecognitionContext().navigation()))); + } + return handleResolvedField(node, resolvedField, context); } - private Type handleResolvedField(Expression node, ResolvedField resolvedField, StackableAstVisitorContext context) + private Type handleResolvedField(Expression node, ResolvedField resolvedField, Context context) { - if (!resolvedField.isLocal() && context.getContext().getCorrelationSupport() != CorrelationSupport.ALLOWED) { + if (!resolvedField.isLocal() && context.getCorrelationSupport() != CorrelationSupport.ALLOWED) { throw semanticException(NOT_SUPPORTED, node, "Reference to column '%s' from outer scope not allowed in this context", node); } FieldId fieldId = FieldId.from(resolvedField); Field field = resolvedField.getField(); - if (context.getContext().isInLambda()) { - LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getFieldToLambdaArgumentDeclaration().get(fieldId); + if (context.isInLambda()) { + LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getFieldToLambdaArgumentDeclaration().get(fieldId); if (lambdaArgumentDeclaration != null) { // Lambda argument reference is not a column reference lambdaArgumentReferences.put(NodeRef.of((Identifier) node), lambdaArgumentDeclaration); @@ -753,7 +809,7 @@ private Type handleResolvedField(Expression node, ResolvedField resolvedField, S } @Override - protected Type visitDereferenceExpression(DereferenceExpression node, StackableAstVisitorContext context) + protected Type visitDereferenceExpression(DereferenceExpression node, Context context) { if (isQualifiedAllFieldsReference(node)) { throw semanticException(NOT_SUPPORTED, node, ".* not allowed in this context"); @@ -764,9 +820,9 @@ protected Type visitDereferenceExpression(DereferenceExpression node, StackableA // If this Dereference looks like column reference, try match it to column first. if (qualifiedName != null) { // In the context of row pattern matching, fields are optionally prefixed with labels. Labels are irrelevant during type analysis. - if (context.getContext().isPatternRecognition()) { - String label = label(qualifiedName.getOriginalParts().get(0)); - if (context.getContext().getLabels().contains(label)) { + if (context.isPatternRecognition()) { + String label = label(qualifiedName.getOriginalParts().getFirst()); + if (context.getPatternRecognitionContext().labels().contains(label)) { // In the context of row pattern matching, the name of row pattern input table cannot be used to qualify column names. // (it can only be accessed in PARTITION BY and ORDER BY clauses of MATCH_RECOGNIZE). Consequentially, if a dereference // expression starts with a label, the next part must be a column. @@ -775,21 +831,25 @@ protected Type visitDereferenceExpression(DereferenceExpression node, StackableA if (qualifiedName.getOriginalParts().size() > 2) { throw semanticException(COLUMN_NOT_FOUND, node, "Column %s prefixed with label %s cannot be resolved", unlabeledName, label); } - Identifier unlabeled = qualifiedName.getOriginalParts().get(1); - if (context.getContext().getScope().tryResolveField(node, unlabeledName).isEmpty()) { + Optional resolvedField = context.getScope().tryResolveField(node, unlabeledName); + if (resolvedField.isEmpty()) { throw semanticException(COLUMN_NOT_FOUND, node, "Column %s prefixed with label %s cannot be resolved", unlabeledName, label); } // Correlation is not allowed in pattern recognition context. Visitor's context for pattern recognition has CorrelationSupport.DISALLOWED, // and so the following call should fail if the field is from outer scope. - Type type = process(unlabeled, new StackableAstVisitorContext<>(context.getContext().notExpectingLabels())); - labelDereferences.put(NodeRef.of(node), new LabelPrefixedReference(label, unlabeled)); - return setExpressionType(node, type); + + labels.put(NodeRef.of(node), Optional.of(label)); + patternRecognitionInputs.add(new PatternInputAnalysis( + node, + new ScalarInputDescriptor(Optional.of(label), context.getPatternRecognitionContext().navigation()))); + + return handleResolvedField(node, resolvedField.get(), context); } // In the context of row pattern matching, qualified column references are not allowed. throw missingAttributeException(node, qualifiedName); } - Scope scope = context.getContext().getScope(); + Scope scope = context.getScope(); Optional resolvedField = scope.tryResolveField(node, qualifiedName); if (resolvedField.isPresent()) { return handleResolvedField(node, resolvedField.get(), context); @@ -827,7 +887,7 @@ protected Type visitDereferenceExpression(DereferenceExpression node, StackableA } @Override - protected Type visitNotExpression(NotExpression node, StackableAstVisitorContext context) + protected Type visitNotExpression(NotExpression node, Context context) { coerceType(context, node.getValue(), BOOLEAN, "Value of logical NOT expression"); @@ -835,7 +895,7 @@ protected Type visitNotExpression(NotExpression node, StackableAstVisitorContext } @Override - protected Type visitLogicalExpression(LogicalExpression node, StackableAstVisitorContext context) + protected Type visitLogicalExpression(LogicalExpression node, Context context) { for (Expression term : node.getTerms()) { coerceType(context, term, BOOLEAN, "Logical expression term"); @@ -845,7 +905,7 @@ protected Type visitLogicalExpression(LogicalExpression node, StackableAstVisito } @Override - protected Type visitComparisonExpression(ComparisonExpression node, StackableAstVisitorContext context) + protected Type visitComparisonExpression(ComparisonExpression node, Context context) { OperatorType operatorType = switch (node.getOperator()) { case EQUAL, NOT_EQUAL -> OperatorType.EQUAL; @@ -858,7 +918,7 @@ protected Type visitComparisonExpression(ComparisonExpression node, StackableAst } @Override - protected Type visitIsNullPredicate(IsNullPredicate node, StackableAstVisitorContext context) + protected Type visitIsNullPredicate(IsNullPredicate node, Context context) { process(node.getValue(), context); @@ -866,7 +926,7 @@ protected Type visitIsNullPredicate(IsNullPredicate node, StackableAstVisitorCon } @Override - protected Type visitIsNotNullPredicate(IsNotNullPredicate node, StackableAstVisitorContext context) + protected Type visitIsNotNullPredicate(IsNotNullPredicate node, Context context) { process(node.getValue(), context); @@ -874,7 +934,7 @@ protected Type visitIsNotNullPredicate(IsNotNullPredicate node, StackableAstVisi } @Override - protected Type visitNullIfExpression(NullIfExpression node, StackableAstVisitorContext context) + protected Type visitNullIfExpression(NullIfExpression node, Context context) { Type firstType = process(node.getFirst(), context); Type secondType = process(node.getSecond(), context); @@ -887,7 +947,7 @@ protected Type visitNullIfExpression(NullIfExpression node, StackableAstVisitorC } @Override - protected Type visitIfExpression(IfExpression node, StackableAstVisitorContext context) + protected Type visitIfExpression(IfExpression node, Context context) { coerceType(context, node.getCondition(), BOOLEAN, "IF condition"); @@ -903,7 +963,7 @@ protected Type visitIfExpression(IfExpression node, StackableAstVisitorContext context) + protected Type visitSearchedCaseExpression(SearchedCaseExpression node, Context context) { for (WhenClause whenClause : node.getWhenClauses()) { coerceType(context, whenClause.getOperand(), BOOLEAN, "CASE WHEN clause"); @@ -923,7 +983,7 @@ protected Type visitSearchedCaseExpression(SearchedCaseExpression node, Stackabl } @Override - protected Type visitSimpleCaseExpression(SimpleCaseExpression node, StackableAstVisitorContext context) + protected Type visitSimpleCaseExpression(SimpleCaseExpression node, Context context) { coerceCaseOperandToToSingleType(node, context); @@ -940,7 +1000,7 @@ protected Type visitSimpleCaseExpression(SimpleCaseExpression node, StackableAst return type; } - private void coerceCaseOperandToToSingleType(SimpleCaseExpression node, StackableAstVisitorContext context) + private void coerceCaseOperandToToSingleType(SimpleCaseExpression node, Context context) { Type operandType = process(node.getOperand(), context); @@ -981,7 +1041,7 @@ private List getCaseResultExpressions(List whenClauses, } @Override - protected Type visitCoalesceExpression(CoalesceExpression node, StackableAstVisitorContext context) + protected Type visitCoalesceExpression(CoalesceExpression node, Context context) { Type type = coerceToSingleType(context, "All COALESCE operands", node.getOperands()); @@ -989,7 +1049,7 @@ protected Type visitCoalesceExpression(CoalesceExpression node, StackableAstVisi } @Override - protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, StackableAstVisitorContext context) + protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, Context context) { return switch (node.getSign()) { case PLUS -> { @@ -1007,13 +1067,13 @@ protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, StackableAst } @Override - protected Type visitArithmeticBinary(ArithmeticBinaryExpression node, StackableAstVisitorContext context) + protected Type visitArithmeticBinary(ArithmeticBinaryExpression node, Context context) { return getOperator(context, node, OperatorType.valueOf(node.getOperator().name()), node.getLeft(), node.getRight()); } @Override - protected Type visitLikePredicate(LikePredicate node, StackableAstVisitorContext context) + protected Type visitLikePredicate(LikePredicate node, Context context) { Type valueType = process(node.getValue(), context); if (!(valueType instanceof CharType) && !(valueType instanceof VarcharType)) { @@ -1038,7 +1098,7 @@ protected Type visitLikePredicate(LikePredicate node, StackableAstVisitorContext } @Override - protected Type visitSubscriptExpression(SubscriptExpression node, StackableAstVisitorContext context) + protected Type visitSubscriptExpression(SubscriptExpression node, Context context) { Type baseType = process(node.getBase(), context); // Subscript on Row hasn't got a dedicated operator. Its Type is resolved by hand. @@ -1066,7 +1126,7 @@ protected Type visitSubscriptExpression(SubscriptExpression node, StackableAstVi } @Override - protected Type visitArray(Array node, StackableAstVisitorContext context) + protected Type visitArray(Array node, Context context) { Type type = coerceToSingleType(context, "All ARRAY elements", node.getValues()); Type arrayType = plannerContext.getTypeManager().getParameterizedType(ARRAY.getName(), ImmutableList.of(TypeSignatureParameter.typeParameter(type.getTypeSignature()))); @@ -1074,20 +1134,20 @@ protected Type visitArray(Array node, StackableAstVisitorContext contex } @Override - protected Type visitStringLiteral(StringLiteral node, StackableAstVisitorContext context) + protected Type visitStringLiteral(StringLiteral node, Context context) { VarcharType type = VarcharType.createVarcharType(node.length()); return setExpressionType(node, type); } @Override - protected Type visitBinaryLiteral(BinaryLiteral node, StackableAstVisitorContext context) + protected Type visitBinaryLiteral(BinaryLiteral node, Context context) { return setExpressionType(node, VARBINARY); } @Override - protected Type visitLongLiteral(LongLiteral node, StackableAstVisitorContext context) + protected Type visitLongLiteral(LongLiteral node, Context context) { if (node.getParsedValue() >= Integer.MIN_VALUE && node.getParsedValue() <= Integer.MAX_VALUE) { return setExpressionType(node, INTEGER); @@ -1097,13 +1157,13 @@ protected Type visitLongLiteral(LongLiteral node, StackableAstVisitorContext context) + protected Type visitDoubleLiteral(DoubleLiteral node, Context context) { return setExpressionType(node, DOUBLE); } @Override - protected Type visitDecimalLiteral(DecimalLiteral node, StackableAstVisitorContext context) + protected Type visitDecimalLiteral(DecimalLiteral node, Context context) { DecimalParseResult parseResult; try { @@ -1116,13 +1176,13 @@ protected Type visitDecimalLiteral(DecimalLiteral node, StackableAstVisitorConte } @Override - protected Type visitBooleanLiteral(BooleanLiteral node, StackableAstVisitorContext context) + protected Type visitBooleanLiteral(BooleanLiteral node, Context context) { return setExpressionType(node, BOOLEAN); } @Override - protected Type visitGenericLiteral(GenericLiteral node, StackableAstVisitorContext context) + protected Type visitGenericLiteral(GenericLiteral node, Context context) { return setExpressionType( node, @@ -1213,7 +1273,7 @@ private Type processTimestampLiteral(GenericLiteral node) } @Override - protected Type visitIntervalLiteral(IntervalLiteral node, StackableAstVisitorContext context) + protected Type visitIntervalLiteral(IntervalLiteral node, Context context) { Type type; if (node.isYearToMonth()) { @@ -1232,16 +1292,16 @@ protected Type visitIntervalLiteral(IntervalLiteral node, StackableAstVisitorCon } @Override - protected Type visitNullLiteral(NullLiteral node, StackableAstVisitorContext context) + protected Type visitNullLiteral(NullLiteral node, Context context) { return setExpressionType(node, UNKNOWN); } @Override - protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext context) + protected Type visitFunctionCall(FunctionCall node, Context context) { boolean isAggregation = functionResolver.isAggregationFunction(session, node.getName(), accessControl); - boolean isRowPatternCount = context.getContext().isPatternRecognition() && + boolean isRowPatternCount = context.isPatternRecognition() && isAggregation && node.getName().getSuffix().equalsIgnoreCase("count"); // argument of the form `label.*` is only allowed for row pattern count function @@ -1253,16 +1313,39 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext analyzeMatchNumber(node, context); + case "CLASSIFIER" -> analyzeClassifier(node, context); + case "FIRST", "LAST" -> analyzeLogicalNavigation(node, context, name); + case "PREV", "NEXT" -> analyzePhysicalNavigation(node, context, name); + default -> throw new IllegalStateException("unexpected pattern recognition function " + name); + }); + } + else if (isAggregation) { + if (node.getWindow().isPresent()) { + throw semanticException(NESTED_WINDOW, node, "Cannot use OVER with %s aggregate function in pattern recognition context", node.getName()); + } + if (node.getFilter().isPresent()) { + throw semanticException(NOT_SUPPORTED, node, "Cannot use FILTER with %s aggregate function in pattern recognition context", node.getName()); + } + if (node.getOrderBy().isPresent()) { + throw semanticException(NOT_SUPPORTED, node, "Cannot use ORDER BY with %s aggregate function in pattern recognition context", node.getName()); + } + if (node.isDistinct()) { + throw semanticException(NOT_SUPPORTED, node, "Cannot use DISTINCT with %s aggregate function in pattern recognition context", node.getName()); + } + } } + if (node.getProcessingMode().isPresent()) { ProcessingMode processingMode = node.getProcessingMode().get(); - if (!context.getContext().isPatternRecognition()) { + if (!context.isPatternRecognition()) { throw semanticException(INVALID_PROCESSING_MODE, processingMode, "%s semantics is not supported out of pattern recognition context", processingMode.getMode()); } if (!isAggregation) { @@ -1274,7 +1357,7 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext arguments = node.getArguments(); - Expression expression = arguments.get(0); + Expression expression = arguments.getFirst(); Type expressionType = process(expression, context); if (!(expressionType instanceof VarcharType)) { throw semanticException(TYPE_MISMATCH, node, "Expected expression of varchar, but '%s' has %s type", expression, expressionType.getDisplayName()); } } - // must run after arguments are processed and labels are recorded - if (context.getContext().isPatternRecognition() && isAggregation) { - validateAggregationLabelConsistency(node); - } - ResolvedFunction function; try { function = functionResolver.resolveFunction(session, node.getName(), argumentTypes, accessControl); @@ -1357,7 +1435,7 @@ else if (node.getArguments().size() > 127) { } if (argumentTypes.get(i).hasDependency()) { FunctionType expectedFunctionType = (FunctionType) expectedType; - process(expression, new StackableAstVisitorContext<>(context.getContext().expectingLambda(expectedFunctionType.getArgumentTypes()))); + process(expression, context.expectingLambda(expectedFunctionType.getArgumentTypes())); } else { Type actualType = plannerContext.getTypeManager().getType(argumentTypes.get(i).getTypeSignature()); @@ -1366,11 +1444,16 @@ else if (node.getArguments().size() > 127) { } resolvedFunctions.put(NodeRef.of(node), function); + // must run after arguments are processed and labels are recorded + if (context.isPatternRecognition() && isAggregation) { + analyzePatternAggregation(node, function); + } + Type type = signature.getReturnType(); return setExpressionType(node, type); } - private void analyzeWindow(ResolvedWindow window, StackableAstVisitorContext context, Node originalNode) + private void analyzeWindow(ResolvedWindow window, Context context, Node originalNode) { // check no nested window functions ImmutableList.Builder childNodes = ImmutableList.builder(); @@ -1385,7 +1468,7 @@ private void analyzeWindow(ResolvedWindow window, StackableAstVisitorContext nestedWindowExpressions = extractWindowExpressions(childNodes.build()); if (!nestedWindowExpressions.isEmpty()) { - throw semanticException(NESTED_WINDOW, nestedWindowExpressions.get(0), "Cannot nest window functions or row pattern measures inside window specification"); + throw semanticException(NESTED_WINDOW, nestedWindowExpressions.getFirst(), "Cannot nest window functions or row pattern measures inside window specification"); } if (!window.isPartitionByInherited()) { @@ -1428,28 +1511,40 @@ private void analyzeWindow(ResolvedWindow window, StackableAstVisitorContext resolvedLabels.put(NodeRef.of(label), label.getCanonicalValue())); + + for (SubsetDefinition subset : frame.getSubsets()) { + resolvedLabels.put(NodeRef.of(subset.getName()), subset.getName().getCanonicalValue()); + subsets.put( + NodeRef.of(subset), + subset.getIdentifiers().stream() + .map(Identifier::getCanonicalValue) + .collect(Collectors.toSet())); + } + + ranges.putAll(analysis.ranges()); + undefinedLabels.put(NodeRef.of(frame.getPattern().get()), analysis.undefinedLabels()); PatternRecognitionAnalyzer.validateNoPatternAnchors(frame.getPattern().get()); // analyze expressions in MEASURES and DEFINE (with set of all labels passed as context) for (VariableDefinition variableDefinition : frame.getVariableDefinitions()) { Expression expression = variableDefinition.getExpression(); - Type type = process(expression, new StackableAstVisitorContext<>(context.getContext().patternRecognition(analysis.getAllLabels()))); + Type type = analyze(expression, context.getScope(), analysis.allLabels(), true); + resolvedLabels.put(NodeRef.of(variableDefinition.getName()), variableDefinition.getName().getCanonicalValue()); + if (!type.equals(BOOLEAN)) { throw semanticException(TYPE_MISMATCH, expression, "Expression defining a label must be boolean (actual type: %s)", type); } } for (MeasureDefinition measureDefinition : frame.getMeasures()) { Expression expression = measureDefinition.getExpression(); - process(expression, new StackableAstVisitorContext<>(context.getContext().patternRecognition(analysis.getAllLabels()))); + analyze(expression, context.getScope(), analysis.allLabels(), true); + resolvedLabels.put(NodeRef.of(measureDefinition.getName()), measureDefinition.getName().getCanonicalValue()); } - // validate pattern recognition expressions: MATCH_NUMBER() is not allowed in window - // this must run after the expressions in MEASURES and DEFINE are analyzed, and the patternRecognitionFunctions are recorded - PatternRecognitionAnalyzer.validateNoMatchNumber(frame.getMeasures(), frame.getVariableDefinitions(), patternRecognitionFunctions); - // TODO prohibited nesting: pattern recognition in frame end expression(?) } else { @@ -1463,10 +1558,10 @@ private void analyzeWindow(ResolvedWindow window, StackableAstVisitorContext context, ResolvedWindow window, Node originalNode) + private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type boundType, Context context, ResolvedWindow window, Node originalNode) { OrderBy orderBy = window.getOrderBy() .orElseThrow(() -> semanticException(MISSING_ORDER_BY, originalNode, "Window frame of type RANGE PRECEDING or FOLLOWING requires ORDER BY")); @@ -1597,7 +1692,7 @@ private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type bou throw e; } BoundSignature signature = function.getSignature(); - Type expectedSortKeyType = signature.getArgumentTypes().get(0); + Type expectedSortKeyType = signature.getArgumentTypes().getFirst(); if (!expectedSortKeyType.equals(sortKeyType)) { if (!typeCoercion.canCoerce(sortKeyType, expectedSortKeyType)) { throw semanticException(TYPE_MISMATCH, sortKey, "Sort key must evaluate to a %s (actual: %s)", expectedSortKeyType, sortKeyType); @@ -1620,7 +1715,7 @@ private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type bou } @Override - protected Type visitWindowOperation(WindowOperation node, StackableAstVisitorContext context) + protected Type visitWindowOperation(WindowOperation node, Context context) { ResolvedWindow window = getResolvedWindow.apply(node); checkState(window != null, "no resolved window for: " + node); @@ -1649,7 +1744,7 @@ protected Type visitWindowOperation(WindowOperation node, StackableAstVisitorCon return setExpressionType(node, type); } - public List getCallArgumentTypes(List arguments, StackableAstVisitorContext context) + public List getCallArgumentTypes(List arguments, Context context) { ImmutableList.Builder argumentTypesBuilder = ImmutableList.builder(); for (Expression argument : arguments) { @@ -1667,12 +1762,12 @@ public List getCallArgumentTypes(List argumen isDescribe, getPreanalyzedType, getResolvedWindow); - if (context.getContext().isInLambda()) { - for (LambdaArgumentDeclaration lambdaArgument : context.getContext().getFieldToLambdaArgumentDeclaration().values()) { + if (context.isInLambda()) { + for (LambdaArgumentDeclaration lambdaArgument : context.getFieldToLambdaArgumentDeclaration().values()) { innerExpressionAnalyzer.setExpressionType(lambdaArgument, getExpressionType(lambdaArgument)); } } - return innerExpressionAnalyzer.analyze(argument, baseScope, context.getContext().expectingLambda(types)).getTypeSignature(); + return innerExpressionAnalyzer.analyze(argument, baseScope, context.expectingLambda(types)).getTypeSignature(); })); } else { @@ -1681,10 +1776,10 @@ public List getCallArgumentTypes(List argumen // process the argument but do not include it in the list DereferenceExpression allRowsDereference = (DereferenceExpression) argument; String label = label((Identifier) allRowsDereference.getBase()); - if (!context.getContext().getLabels().contains(label)) { + if (!context.getPatternRecognitionContext().labels().contains(label)) { throw semanticException(INVALID_FUNCTION_ARGUMENT, allRowsDereference.getBase(), "%s is not a primary pattern variable or subset name", label); } - labelDereferences.put(NodeRef.of(allRowsDereference), new LabelPrefixedReference(label)); + labels.put(NodeRef.of(allRowsDereference), Optional.of(label)); } else { argumentTypesBuilder.add(new TypeSignatureProvider(process(argument, context).getTypeSignature())); @@ -1695,7 +1790,129 @@ public List getCallArgumentTypes(List argumen return argumentTypesBuilder.build(); } - private Type analyzePatternRecognitionFunction(FunctionCall node, StackableAstVisitorContext context) + private Type analyzeMatchNumber(FunctionCall node, Context context) + { + if (context.isInWindow()) { + throw semanticException(INVALID_PATTERN_RECOGNITION_FUNCTION, node, "MATCH_NUMBER function is not supported in window"); + } + + if (!node.getArguments().isEmpty()) { + throw semanticException(INVALID_FUNCTION_ARGUMENT, node, "MATCH_NUMBER pattern recognition function takes no arguments"); + } + + patternRecognitionInputs.add(new PatternInputAnalysis(node, new MatchNumberDescriptor())); + + return BIGINT; + } + + private Type analyzeClassifier(FunctionCall node, Context context) + { + if (node.getArguments().size() > 1) { + throw semanticException(INVALID_FUNCTION_ARGUMENT, node, "CLASSIFIER pattern recognition function takes no arguments or 1 argument"); + } + + Optional label = Optional.empty(); + if (node.getArguments().size() == 1) { + Node argument = node.getArguments().getFirst(); + if (!(argument instanceof Identifier identifier)) { + throw semanticException(TYPE_MISMATCH, argument, "CLASSIFIER function argument should be primary pattern variable or subset name. Actual: %s", argument.getClass().getSimpleName()); + } + label = Optional.of(label(identifier)); + if (!context.getPatternRecognitionContext().labels().contains(label.get())) { + throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "%s is not a primary pattern variable or subset name", identifier.getValue()); + } + } + + patternRecognitionInputs.add(new PatternInputAnalysis( + node, + new ClassifierDescriptor(label, context.getPatternRecognitionContext().navigation()))); + + return VARCHAR; + } + + private Type analyzePhysicalNavigation(FunctionCall node, Context context, String name) + { + validateNavigationFunctionArguments(node); + + // TODO: this should only be done at the root of a pattern recognition function call tree + checkNoNestedAggregations(node); + validateNavigationNesting(node); + + int offset = getNavigationOffset(node, 1); + if (name.equals("PREV")) { + offset = -offset; + } + + Navigation navigation = context.getPatternRecognitionContext().navigation(); + Type type = process( + node.getArguments().getFirst(), + context.withNavigation(new Navigation( + navigation.anchor(), + navigation.mode(), + navigation.logicalOffset(), + offset))); + + // TODO: this should only be done at the root of a pattern recognition function call tree + if (!validateLabelConsistency(node, 0).hasLabel()) { + throw semanticException(INVALID_ARGUMENTS, node, "Pattern navigation function '%s' must contain at least one column reference or CLASSIFIER()", name); + } + + patternNavigationFunctions.add(NodeRef.of(node)); + + return type; + } + + private Type analyzeLogicalNavigation(FunctionCall node, Context context, String name) + { + validateNavigationFunctionArguments(node); + + // TODO: this should only be done at the root of a pattern recognition function call tree + checkNoNestedAggregations(node); + validateNavigationNesting(node); + + PatternRecognitionAnalysis.NavigationAnchor anchor = switch (name) { + case "FIRST" -> FIRST; + case "LAST" -> LAST; + default -> throw new IllegalStateException("Unexpected navigation anchor: " + name); + }; + + Type type = process( + node.getArguments().getFirst(), + context.withNavigation(new Navigation( + anchor, + mapProcessingMode(node.getProcessingMode()), + getNavigationOffset(node, 0), + context.getPatternRecognitionContext().navigation().physicalOffset()))); + + // TODO: this should only be done at the root of a pattern recognition function call tree + if (!validateLabelConsistency(node, 0).hasLabel()) { + throw semanticException(INVALID_ARGUMENTS, node, "Pattern navigation function '%s' must contain at least one column reference or CLASSIFIER()", name); + } + + patternNavigationFunctions.add(NodeRef.of(node)); + + return type; + } + + private static NavigationMode mapProcessingMode(Optional processingMode) + { + return processingMode.map(mode -> switch (mode.getMode()) { + case FINAL -> NavigationMode.FINAL; + case RUNNING -> NavigationMode.RUNNING; + }) + .orElse(NavigationMode.RUNNING); + } + + private static int getNavigationOffset(FunctionCall node, int defaultOffset) + { + int offset = defaultOffset; + if (node.getArguments().size() == 2) { + offset = (int) ((LongLiteral) node.getArguments().get(1)).getParsedValue(); + } + return offset; + } + + private static void validatePatternRecognitionFunction(FunctionCall node) { if (node.getWindow().isPresent()) { throw semanticException(INVALID_PATTERN_RECOGNITION_FUNCTION, node, "Cannot use OVER with %s pattern recognition function", node.getName()); @@ -1716,60 +1933,26 @@ private Type analyzePatternRecognitionFunction(FunctionCall node, StackableAstVi throw semanticException(INVALID_PROCESSING_MODE, processingMode, "%s semantics is not supported with %s pattern recognition function", processingMode.getMode(), node.getName()); } } + } - patternRecognitionFunctions.add(NodeRef.of(node)); - - return switch (name.toUpperCase(ENGLISH)) { - case "FIRST", "LAST", "PREV", "NEXT" -> { - if (node.getArguments().size() != 1 && node.getArguments().size() != 2) { - throw semanticException(INVALID_FUNCTION_ARGUMENT, node, "%s pattern recognition function requires 1 or 2 arguments", node.getName()); - } - Type resultType = process(node.getArguments().get(0), context); - if (node.getArguments().size() == 2) { - process(node.getArguments().get(1), context); - // TODO the offset argument must be effectively constant, not necessarily a number. This could be extended with the use of ConstantAnalyzer. - if (!(node.getArguments().get(1) instanceof LongLiteral)) { - throw semanticException(INVALID_FUNCTION_ARGUMENT, node, "%s pattern recognition navigation function requires a number as the second argument", node.getName()); - } - long offset = ((LongLiteral) node.getArguments().get(1)).getParsedValue(); - if (offset < 0) { - throw semanticException(NUMERIC_VALUE_OUT_OF_RANGE, node, "%s pattern recognition navigation function requires a non-negative number as the second argument (actual: %s)", node.getName(), offset); - } - if (offset > Integer.MAX_VALUE) { - throw semanticException(NUMERIC_VALUE_OUT_OF_RANGE, node, "The second argument of %s pattern recognition navigation function must not exceed %s (actual: %s)", node.getName(), Integer.MAX_VALUE, offset); - } - } - validateNavigationNesting(node); - checkNoNestedAggregations(node); - - // must run after the argument is processed and labels in the argument are recorded - validateNavigationLabelConsistency(node); - yield setExpressionType(node, resultType); + private static void validateNavigationFunctionArguments(FunctionCall node) + { + if (node.getArguments().size() != 1 && node.getArguments().size() != 2) { + throw semanticException(INVALID_FUNCTION_ARGUMENT, node, "%s pattern recognition function requires 1 or 2 arguments", node.getName()); + } + if (node.getArguments().size() == 2) { + // TODO the offset argument must be effectively constant, not necessarily a number. This could be extended with the use of ConstantAnalyzer. + if (!(node.getArguments().get(1) instanceof LongLiteral)) { + throw semanticException(INVALID_FUNCTION_ARGUMENT, node, "%s pattern recognition navigation function requires a number as the second argument", node.getName()); } - case "MATCH_NUMBER" -> { - if (!node.getArguments().isEmpty()) { - throw semanticException(INVALID_FUNCTION_ARGUMENT, node, "MATCH_NUMBER pattern recognition function takes no arguments"); - } - yield setExpressionType(node, BIGINT); + long offset = ((LongLiteral) node.getArguments().get(1)).getParsedValue(); + if (offset < 0) { + throw semanticException(NUMERIC_VALUE_OUT_OF_RANGE, node, "%s pattern recognition navigation function requires a non-negative number as the second argument (actual: %s)", node.getName(), offset); } - case "CLASSIFIER" -> { - if (node.getArguments().size() > 1) { - throw semanticException(INVALID_FUNCTION_ARGUMENT, node, "CLASSIFIER pattern recognition function takes no arguments or 1 argument"); - } - if (node.getArguments().size() == 1) { - Node argument = node.getArguments().get(0); - if (!(argument instanceof Identifier identifier)) { - throw semanticException(TYPE_MISMATCH, argument, "CLASSIFIER function argument should be primary pattern variable or subset name. Actual: %s", argument.getClass().getSimpleName()); - } - String label = label(identifier); - if (!context.getContext().getLabels().contains(label)) { - throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "%s is not a primary pattern variable or subset name", identifier.getValue()); - } - } - yield setExpressionType(node, VARCHAR); + if (offset > Integer.MAX_VALUE) { + throw semanticException(NUMERIC_VALUE_OUT_OF_RANGE, node, "The second argument of %s pattern recognition navigation function must not exceed %s (actual: %s)", node.getName(), Integer.MAX_VALUE, offset); } - default -> throw new IllegalStateException("unexpected pattern recognition function " + node.getName()); - }; + } } private void validateNavigationNesting(FunctionCall node) @@ -1778,15 +1961,15 @@ private void validateNavigationNesting(FunctionCall node) String name = node.getName().getSuffix(); // It is allowed to nest FIRST and LAST functions within PREV and NEXT functions. Only immediate nesting is supported - List nestedNavigationFunctions = extractExpressions(ImmutableList.of(node.getArguments().get(0)), FunctionCall.class).stream() + List nestedNavigationFunctions = extractExpressions(ImmutableList.of(node.getArguments().getFirst()), FunctionCall.class).stream() .filter(this::isPatternNavigationFunction) .collect(toImmutableList()); if (!nestedNavigationFunctions.isEmpty()) { if (name.equalsIgnoreCase("FIRST") || name.equalsIgnoreCase("LAST")) { throw semanticException( INVALID_NAVIGATION_NESTING, - nestedNavigationFunctions.get(0), - "Cannot nest %s pattern navigation function inside %s pattern navigation function", nestedNavigationFunctions.get(0).getName(), name); + nestedNavigationFunctions.getFirst(), + "Cannot nest %s pattern navigation function inside %s pattern navigation function", nestedNavigationFunctions.getFirst().getName(), name); } if (nestedNavigationFunctions.size() > 1) { throw semanticException( @@ -1802,7 +1985,7 @@ private void validateNavigationNesting(FunctionCall node) nested, "Cannot nest %s pattern navigation function inside %s pattern navigation function", nestedName, name); } - if (nested != node.getArguments().get(0)) { + if (nested != node.getArguments().getFirst()) { throw semanticException( INVALID_NAVIGATION_NESTING, nested, @@ -1811,18 +1994,15 @@ private void validateNavigationNesting(FunctionCall node) } } - /** - * Check that all aggregation arguments refer consistently to the same label. - */ - private void validateAggregationLabelConsistency(FunctionCall node) + private Set analyzeAggregationLabels(FunctionCall node) { if (node.getArguments().isEmpty()) { - return; + return ImmutableSet.of(); } Set> argumentLabels = new HashSet<>(); for (int i = 0; i < node.getArguments().size(); i++) { - ArgumentLabel argumentLabel = validateLabelConsistency(node, false, i); + ArgumentLabel argumentLabel = validateLabelConsistency(node, i); if (argumentLabel.hasLabel()) { argumentLabels.add(argumentLabel.getLabel()); } @@ -1830,105 +2010,47 @@ private void validateAggregationLabelConsistency(FunctionCall node) if (argumentLabels.size() > 1) { throw semanticException(INVALID_ARGUMENTS, node, "All aggregate function arguments must apply to rows matched with the same label"); } - } - /** - * Check that the navigated expression refers consistently to the same label. - */ - private void validateNavigationLabelConsistency(FunctionCall node) - { - checkArgument(isPatternNavigationFunction(node)); - validateLabelConsistency(node, true, 0); + return argumentLabels.stream() + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toSet()); } - private ArgumentLabel validateLabelConsistency(FunctionCall node, boolean labelRequired, int argumentIndex) + private ArgumentLabel validateLabelConsistency(FunctionCall node, int argumentIndex) { - String name = node.getName().getSuffix(); + Set> referenceLabels = extractExpressions(node.getArguments(), Expression.class).stream() + .map(child -> labels.get(NodeRef.of(child))) + .filter(Objects::nonNull) + .collect(toImmutableSet()); - List unlabeledInputColumns = Streams.concat( - extractExpressions(ImmutableList.of(node.getArguments().get(argumentIndex)), Identifier.class).stream(), - extractExpressions(ImmutableList.of(node.getArguments().get(argumentIndex)), DereferenceExpression.class).stream()) - .filter(expression -> columnReferences.containsKey(NodeRef.of(expression))) - .collect(toImmutableList()); - List labeledInputColumns = extractExpressions(ImmutableList.of(node.getArguments().get(argumentIndex)), DereferenceExpression.class).stream() - .filter(expression -> labelDereferences.containsKey(NodeRef.of(expression))) - .collect(toImmutableList()); - List classifiers = extractExpressions(ImmutableList.of(node.getArguments().get(argumentIndex)), FunctionCall.class).stream() + Set> classifierLabels = extractExpressions(ImmutableList.of(node.getArguments().get(argumentIndex)), FunctionCall.class).stream() .filter(this::isClassifierFunction) - .collect(toImmutableList()); + .map(functionCall -> functionCall.getArguments().stream() + .findFirst() + .map(argument -> label((Identifier) argument))) + .collect(toImmutableSet()); - // Pattern navigation function must contain at least one column reference or CLASSIFIER() function. There is no such requirement for the argument of an aggregate function. - if (unlabeledInputColumns.isEmpty() && labeledInputColumns.isEmpty() && classifiers.isEmpty()) { - if (labelRequired) { - throw semanticException(INVALID_ARGUMENTS, node, "Pattern navigation function %s must contain at least one column reference or CLASSIFIER()", name); - } + Set> allLabels = ImmutableSet.>builder() + .addAll(referenceLabels) + .addAll(classifierLabels) + .build(); + + if (allLabels.isEmpty()) { return ArgumentLabel.noLabel(); } - // Label consistency rules: - // All column references must be prefixed with the same label. - // Alternatively, all column references can have no label. In such case they are considered as prefixed with universal row pattern variable. - // All CLASSIFIER() calls must have the same label or no label, respectively, as their argument. - if (!unlabeledInputColumns.isEmpty() && !labeledInputColumns.isEmpty()) { - throw semanticException( - INVALID_ARGUMENTS, - labeledInputColumns.get(0), - "Column references inside argument of function %s must all either be prefixed with the same label or be not prefixed", name); - } - Set inputColumnLabels = labeledInputColumns.stream() - .map(expression -> labelDereferences.get(NodeRef.of(expression))) - .map(LabelPrefixedReference::getLabel) - .collect(toImmutableSet()); - if (inputColumnLabels.size() > 1) { - throw semanticException( - INVALID_ARGUMENTS, - labeledInputColumns.get(0), - "Column references inside argument of function %s must all either be prefixed with the same label or be not prefixed", name); - } - Set> classifierLabels = classifiers.stream() - .map(functionCall -> { - if (functionCall.getArguments().isEmpty()) { - return Optional.empty(); - } - return Optional.of(label((Identifier) functionCall.getArguments().get(0))); - }) - .collect(toImmutableSet()); - if (classifierLabels.size() > 1) { + if (allLabels.size() > 1) { + String name = node.getName().getSuffix(); throw semanticException( INVALID_ARGUMENTS, node, - "CLASSIFIER() calls inside argument of function %s must all either have the same label as the argument or have no arguments", name); - } - if (!unlabeledInputColumns.isEmpty() && !classifiers.isEmpty()) { - if (!getOnlyElement(classifierLabels).equals(Optional.empty())) { - throw semanticException( - INVALID_ARGUMENTS, - node, - "Column references inside argument of function %s must all be prefixed with the same label that all CLASSIFIER() calls have as the argument", name); - } - } - if (!labeledInputColumns.isEmpty() && !classifiers.isEmpty()) { - if (!getOnlyElement(classifierLabels).equals(Optional.of(getOnlyElement(inputColumnLabels)))) { - throw semanticException( - INVALID_ARGUMENTS, - node, - "Column references inside argument of function %s must all be prefixed with the same label that all CLASSIFIER() calls have as the argument", name); - } + "All labels and classifiers inside the call to '%s' must match", name); } - // For aggregate functions: return the label for the current argument to check if all arguments apply to the same set of rows. - if (!inputColumnLabels.isEmpty()) { - return ArgumentLabel.explicitLabel(getOnlyElement(inputColumnLabels)); - } - if (!classifierLabels.isEmpty()) { - return getOnlyElement(classifierLabels) - .map(ArgumentLabel::explicitLabel) - .orElse(ArgumentLabel.universalLabel()); - } - if (!unlabeledInputColumns.isEmpty()) { - return ArgumentLabel.universalLabel(); - } - return ArgumentLabel.noLabel(); + Optional label = Iterables.getOnlyElement(allLabels); + return label.map(ArgumentLabel::explicitLabel) + .orElseGet(ArgumentLabel::universalLabel); } private boolean isPatternNavigationFunction(FunctionCall node) @@ -1951,28 +2073,42 @@ private boolean isClassifierFunction(FunctionCall node) return node.getName().getSuffix().toUpperCase(ENGLISH).equals("CLASSIFIER"); } + private boolean isMatchNumberFunction(FunctionCall node) + { + if (!isPatternRecognitionFunction(node)) { + return false; + } + return node.getName().getSuffix().toUpperCase(ENGLISH).equals("MATCH_NUMBER"); + } + private String label(Identifier identifier) { return identifier.getCanonicalValue(); } - private void analyzePatternAggregation(FunctionCall node) + private void analyzePatternAggregation(FunctionCall node, ResolvedFunction function) { - if (node.getWindow().isPresent()) { - throw semanticException(NESTED_WINDOW, node, "Cannot use OVER with %s aggregate function in pattern recognition context", node.getName()); - } - if (node.getFilter().isPresent()) { - throw semanticException(NOT_SUPPORTED, node, "Cannot use FILTER with %s aggregate function in pattern recognition context", node.getName()); - } - if (node.getOrderBy().isPresent()) { - throw semanticException(NOT_SUPPORTED, node, "Cannot use ORDER BY with %s aggregate function in pattern recognition context", node.getName()); - } - if (node.isDistinct()) { - throw semanticException(NOT_SUPPORTED, node, "Cannot use DISTINCT with %s aggregate function in pattern recognition context", node.getName()); - } - checkNoNestedAggregations(node); checkNoNestedNavigations(node); + Set labels = analyzeAggregationLabels(node); + + List matchNumberCalls = extractExpressions(node.getArguments(), FunctionCall.class).stream() + .filter(this::isMatchNumberFunction) + .collect(toImmutableList()); + + List classifierCalls = extractExpressions(node.getArguments(), FunctionCall.class).stream() + .filter(this::isClassifierFunction) + .collect(toImmutableList()); + + patternRecognitionInputs.add(new PatternInputAnalysis( + node, + new AggregationDescriptor( + function, + node.getArguments(), + mapProcessingMode(node.getProcessingMode()), + labels, + matchNumberCalls, + classifierCalls))); } private void checkNoNestedAggregations(FunctionCall node) @@ -2006,7 +2142,7 @@ private void checkNoNestedNavigations(FunctionCall node) } @Override - protected Type visitAtTimeZone(AtTimeZone node, StackableAstVisitorContext context) + protected Type visitAtTimeZone(AtTimeZone node, Context context) { Type valueType = process(node.getValue(), context); process(node.getTimeZone(), context); @@ -2025,31 +2161,31 @@ else if (valueType instanceof TimestampType) { } @Override - protected Type visitCurrentCatalog(CurrentCatalog node, StackableAstVisitorContext context) + protected Type visitCurrentCatalog(CurrentCatalog node, Context context) { return setExpressionType(node, VARCHAR); } @Override - protected Type visitCurrentSchema(CurrentSchema node, StackableAstVisitorContext context) + protected Type visitCurrentSchema(CurrentSchema node, Context context) { return setExpressionType(node, VARCHAR); } @Override - protected Type visitCurrentUser(CurrentUser node, StackableAstVisitorContext context) + protected Type visitCurrentUser(CurrentUser node, Context context) { return setExpressionType(node, VARCHAR); } @Override - protected Type visitCurrentPath(CurrentPath node, StackableAstVisitorContext context) + protected Type visitCurrentPath(CurrentPath node, Context context) { return setExpressionType(node, VARCHAR); } @Override - protected Type visitTrim(Trim node, StackableAstVisitorContext context) + protected Type visitTrim(Trim node, Context context) { ImmutableList.Builder argumentTypes = ImmutableList.builder(); @@ -2063,8 +2199,8 @@ protected Type visitTrim(Trim node, StackableAstVisitorContext context) List expectedTypes = function.getSignature().getArgumentTypes(); checkState(expectedTypes.size() == actualTypes.size(), "wrong argument number in the resolved signature"); - Type actualTrimSourceType = actualTypes.get(0); - Type expectedTrimSourceType = expectedTypes.get(0); + Type actualTrimSourceType = actualTypes.getFirst(); + Type expectedTrimSourceType = expectedTypes.getFirst(); coerceType(node.getTrimSource(), actualTrimSourceType, expectedTrimSourceType, "source argument of trim function"); if (node.getTrimCharacter().isPresent()) { @@ -2078,19 +2214,19 @@ protected Type visitTrim(Trim node, StackableAstVisitorContext context) } @Override - protected Type visitFormat(Format node, StackableAstVisitorContext context) + protected Type visitFormat(Format node, Context context) { List arguments = node.getArguments().stream() .map(expression -> process(expression, context)) .collect(toImmutableList()); - if (!(arguments.get(0) instanceof VarcharType)) { - throw semanticException(TYPE_MISMATCH, node.getArguments().get(0), "Type of first argument to format() must be VARCHAR (actual: %s)", arguments.get(0)); + if (!(arguments.getFirst() instanceof VarcharType)) { + throw semanticException(TYPE_MISMATCH, node.getArguments().getFirst(), "Type of first argument to format() must be VARCHAR (actual: %s)", arguments.getFirst()); } for (int i = 1; i < arguments.size(); i++) { try { - plannerContext.getMetadata().resolveBuiltinFunction(FormatFunction.NAME, fromTypes(arguments.get(0), RowType.anonymous(arguments.subList(1, arguments.size())))); + plannerContext.getMetadata().resolveBuiltinFunction(FormatFunction.NAME, fromTypes(arguments.getFirst(), RowType.anonymous(arguments.subList(1, arguments.size())))); } catch (TrinoException e) { ErrorCode errorCode = e.getErrorCode(); @@ -2105,12 +2241,12 @@ protected Type visitFormat(Format node, StackableAstVisitorContext cont } @Override - protected Type visitParameter(Parameter node, StackableAstVisitorContext context) + protected Type visitParameter(Parameter node, Context context) { if (isDescribe) { return setExpressionType(node, UNKNOWN); } - if (parameters.size() == 0) { + if (parameters.isEmpty()) { throw semanticException(INVALID_PARAMETER_USAGE, node, "Query takes no parameters"); } if (node.getId() >= parameters.size()) { @@ -2126,7 +2262,7 @@ protected Type visitParameter(Parameter node, StackableAstVisitorContext context) + protected Type visitExtract(Extract node, Context context) { Type type = process(node.getExpression(), context); Extract.Field field = node.getField(); @@ -2188,7 +2324,7 @@ private boolean isDateTimeType(Type type) } @Override - protected Type visitBetweenPredicate(BetweenPredicate node, StackableAstVisitorContext context) + protected Type visitBetweenPredicate(BetweenPredicate node, Context context) { Type valueType = process(node.getValue(), context); Type minType = process(node.getMin(), context); @@ -2219,10 +2355,10 @@ protected Type visitBetweenPredicate(BetweenPredicate node, StackableAstVisitorC } @Override - public Type visitTryExpression(TryExpression node, StackableAstVisitorContext context) + public Type visitTryExpression(TryExpression node, Context context) { // TRY is rewritten to lambda, and lambda is not supported in pattern recognition - if (context.getContext().isPatternRecognition()) { + if (context.isPatternRecognition()) { throw semanticException(NOT_SUPPORTED, node, "TRY expression in pattern recognition context is not yet supported"); } @@ -2231,7 +2367,7 @@ public Type visitTryExpression(TryExpression node, StackableAstVisitorContext context) + public Type visitCast(Cast node, Context context) { Type type; try { @@ -2259,7 +2395,7 @@ public Type visitCast(Cast node, StackableAstVisitorContext context) } @Override - protected Type visitInPredicate(InPredicate node, StackableAstVisitorContext context) + protected Type visitInPredicate(InPredicate node, Context context) { Expression value = node.getValue(); Expression valueList = node.getValueList(); @@ -2275,7 +2411,7 @@ protected Type visitInPredicate(InPredicate node, StackableAstVisitorContext { QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(dereference); if (qualifiedName != null) { - String label = label(qualifiedName.getOriginalParts().get(0)); - if (context.getContext().getLabels().contains(label)) { + String label = label(qualifiedName.getOriginalParts().getFirst()); + if (context.getPatternRecognitionContext().labels().contains(label)) { throw semanticException(NOT_SUPPORTED, dereference, "IN-PREDICATE with labeled column reference is not yet supported"); } } }); + + patternRecognitionInputs.add(new PatternInputAnalysis( + node, + new ScalarInputDescriptor(Optional.empty(), context.getPatternRecognitionContext().navigation()))); } if (valueList instanceof InListExpression inListExpression) { @@ -2312,24 +2452,31 @@ else if (valueList instanceof SubqueryExpression) { } @Override - protected Type visitSubqueryExpression(SubqueryExpression node, StackableAstVisitorContext context) + protected Type visitSubqueryExpression(SubqueryExpression node, Context context) { Type type = analyzeSubquery(node, context); // the implied type of a scalar subquery is that of the unique field in the single-column row if (type instanceof RowType && ((RowType) type).getFields().size() == 1) { - type = type.getTypeParameters().get(0); + type = type.getTypeParameters().getFirst(); } setExpressionType(node, type); subqueries.add(NodeRef.of(node)); + + if (context.isPatternRecognition()) { + patternRecognitionInputs.add(new PatternInputAnalysis( + node, + new ScalarInputDescriptor(Optional.empty(), context.getPatternRecognitionContext().navigation()))); + } + return type; } /** * @return the common supertype between the value type and subquery type */ - private Type analyzePredicateWithSubquery(Expression node, Type declaredValueType, SubqueryExpression subquery, StackableAstVisitorContext context) + private Type analyzePredicateWithSubquery(Expression node, Type declaredValueType, SubqueryExpression subquery, Context context) { Type valueRowType = declaredValueType; if (!(declaredValueType instanceof RowType) && !(declaredValueType instanceof UnknownType)) { @@ -2360,14 +2507,14 @@ private Type analyzePredicateWithSubquery(Expression node, Type declaredValueTyp return commonType.get(); } - private Type analyzeSubquery(SubqueryExpression node, StackableAstVisitorContext context) + private Type analyzeSubquery(SubqueryExpression node, Context context) { - if (context.getContext().isInLambda()) { + if (context.isInLambda()) { throw semanticException(NOT_SUPPORTED, node, "Lambda expression cannot contain subqueries"); } - StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node, context.getContext().getCorrelationSupport()); + StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node, context.getCorrelationSupport()); Scope subqueryScope = Scope.builder() - .withParent(context.getContext().getScope()) + .withParent(context.getScope()) .build(); Scope queryScope = analyzer.analyze(node.getQuery(), subqueryScope); @@ -2389,11 +2536,11 @@ private Type analyzeSubquery(SubqueryExpression node, StackableAstVisitorContext } @Override - protected Type visitExists(ExistsPredicate node, StackableAstVisitorContext context) + protected Type visitExists(ExistsPredicate node, Context context) { - StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node, context.getContext().getCorrelationSupport()); + StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node, context.getCorrelationSupport()); Scope subqueryScope = Scope.builder() - .withParent(context.getContext().getScope()) + .withParent(context.getScope()) .build(); List fields = analyzer.analyze(node.getSubquery(), subqueryScope) @@ -2413,11 +2560,17 @@ protected Type visitExists(ExistsPredicate node, StackableAstVisitorContext context) + protected Type visitQuantifiedComparisonExpression(QuantifiedComparisonExpression node, Context context) { quantifiedComparisons.add(NodeRef.of(node)); @@ -2443,25 +2596,25 @@ protected Type visitQuantifiedComparisonExpression(QuantifiedComparisonExpressio } @Override - public Type visitFieldReference(FieldReference node, StackableAstVisitorContext context) + public Type visitFieldReference(FieldReference node, Context context) { ResolvedField field = baseScope.getField(node.getFieldIndex()); return handleResolvedField(node, field, context); } @Override - protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorContext context) + protected Type visitLambdaExpression(LambdaExpression node, Context context) { - if (context.getContext().isPatternRecognition()) { + if (context.isPatternRecognition()) { throw semanticException(NOT_SUPPORTED, node, "Lambda expression in pattern recognition context is not yet supported"); } verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, node.getBody(), "Lambda expression"); - if (!context.getContext().isExpectingLambda()) { + if (!context.isExpectingLambda()) { throw semanticException(TYPE_MISMATCH, node, "Lambda expression should always be used inside a function"); } - List types = context.getContext().getFunctionInputTypes(); + List types = context.getFunctionInputTypes(); List lambdaArguments = node.getArguments(); if (types.size() != lambdaArguments.size()) { @@ -2473,43 +2626,43 @@ protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorC for (int i = 0; i < lambdaArguments.size(); i++) { LambdaArgumentDeclaration lambdaArgument = lambdaArguments.get(i); Type type = types.get(i); - fields.add(io.trino.sql.analyzer.Field.newUnqualified(lambdaArgument.getName().getValue(), type)); + fields.add(Field.newUnqualified(lambdaArgument.getName().getValue(), type)); setExpressionType(lambdaArgument, type); } Scope lambdaScope = Scope.builder() - .withParent(context.getContext().getScope()) + .withParent(context.getScope()) .withRelationType(RelationId.of(node), new RelationType(fields.build())) .build(); ImmutableMap.Builder fieldToLambdaArgumentDeclaration = ImmutableMap.builder(); - if (context.getContext().isInLambda()) { - fieldToLambdaArgumentDeclaration.putAll(context.getContext().getFieldToLambdaArgumentDeclaration()); + if (context.isInLambda()) { + fieldToLambdaArgumentDeclaration.putAll(context.getFieldToLambdaArgumentDeclaration()); } for (LambdaArgumentDeclaration lambdaArgument : lambdaArguments) { ResolvedField resolvedField = lambdaScope.resolveField(lambdaArgument, QualifiedName.of(lambdaArgument.getName().getValue())); fieldToLambdaArgumentDeclaration.put(FieldId.from(resolvedField), lambdaArgument); } - Type returnType = process(node.getBody(), new StackableAstVisitorContext<>(context.getContext().inLambda(lambdaScope, fieldToLambdaArgumentDeclaration.buildOrThrow()))); + Type returnType = process(node.getBody(), context.inLambda(lambdaScope, fieldToLambdaArgumentDeclaration.buildOrThrow())); FunctionType functionType = new FunctionType(types, returnType); return setExpressionType(node, functionType); } @Override - protected Type visitBindExpression(BindExpression node, StackableAstVisitorContext context) + protected Type visitBindExpression(BindExpression node, Context context) { - verify(context.getContext().isExpectingLambda(), "bind expression found when lambda is not expected"); + verify(context.isExpectingLambda(), "bind expression found when lambda is not expected"); - StackableAstVisitorContext innerContext = new StackableAstVisitorContext<>(context.getContext().notExpectingLambda()); + Context innerContext = context.notExpectingLambda(); ImmutableList.Builder functionInputTypesBuilder = ImmutableList.builder(); for (Expression value : node.getValues()) { functionInputTypesBuilder.add(process(value, innerContext)); } - functionInputTypesBuilder.addAll(context.getContext().getFunctionInputTypes()); + functionInputTypesBuilder.addAll(context.getFunctionInputTypes()); List functionInputTypes = functionInputTypesBuilder.build(); - FunctionType functionType = (FunctionType) process(node.getFunction(), new StackableAstVisitorContext<>(context.getContext().expectingLambda(functionInputTypes))); + FunctionType functionType = (FunctionType) process(node.getFunction(), context.expectingLambda(functionInputTypes)); List argumentTypes = functionType.getArgumentTypes(); int numCapturedValues = node.getValues().size(); @@ -2523,19 +2676,19 @@ protected Type visitBindExpression(BindExpression node, StackableAstVisitorConte } @Override - protected Type visitExpression(Expression node, StackableAstVisitorContext context) + protected Type visitExpression(Expression node, Context context) { throw semanticException(NOT_SUPPORTED, node, "not yet implemented: %s", node.getClass().getName()); } @Override - protected Type visitNode(Node node, StackableAstVisitorContext context) + protected Type visitNode(Node node, Context context) { throw semanticException(NOT_SUPPORTED, node, "not yet implemented: %s", node.getClass().getName()); } @Override - public Type visitGroupingOperation(GroupingOperation node, StackableAstVisitorContext context) + public Type visitGroupingOperation(GroupingOperation node, Context context) { if (node.getGroupingColumns().size() > MAX_NUMBER_GROUPING_ARGUMENTS_BIGINT) { throw semanticException(TOO_MANY_ARGUMENTS, node, "GROUPING supports up to %d column arguments", MAX_NUMBER_GROUPING_ARGUMENTS_BIGINT); @@ -2552,7 +2705,7 @@ public Type visitGroupingOperation(GroupingOperation node, StackableAstVisitorCo } @Override - public Type visitJsonExists(JsonExists node, StackableAstVisitorContext context) + public Type visitJsonExists(JsonExists node, Context context) { List pathInvocationArgumentTypes = analyzeJsonPathInvocation("JSON_EXISTS", node, node.getJsonPathInvocation(), context); @@ -2580,7 +2733,7 @@ public Type visitJsonExists(JsonExists node, StackableAstVisitorContext } @Override - public Type visitJsonValue(JsonValue node, StackableAstVisitorContext context) + public Type visitJsonValue(JsonValue node, Context context) { List pathInvocationArgumentTypes = analyzeJsonPathInvocation("JSON_VALUE", node, node.getJsonPathInvocation(), context); Type returnedType = analyzeJsonValueExpression( @@ -2605,7 +2758,7 @@ private Type analyzeJsonValueExpression( Optional declaredEmptyDefault, Optional errorBehavior, Optional declaredErrorDefault, - StackableAstVisitorContext context) + Context context) { // validate returned type Type returnedType = VARCHAR; // default @@ -2687,7 +2840,7 @@ private Type analyzeJsonValueExpression( } @Override - public Type visitJsonQuery(JsonQuery node, StackableAstVisitorContext context) + public Type visitJsonQuery(JsonQuery node, Context context) { List pathInvocationArgumentTypes = analyzeJsonPathInvocation("JSON_QUERY", node, node.getJsonPathInvocation(), context); Type returnedType = analyzeJsonQueryExpression( @@ -2765,7 +2918,7 @@ private Type analyzeJsonQueryExpression( return returnedType; } - private List analyzeJsonPathInvocation(String functionName, Node node, JsonPathInvocation jsonPathInvocation, StackableAstVisitorContext context) + private List analyzeJsonPathInvocation(String functionName, Node node, JsonPathInvocation jsonPathInvocation, Context context) { jsonPathInvocation.getPathName().ifPresent(pathName -> { if (!(node instanceof JsonTable)) { @@ -2946,7 +3099,7 @@ private ResolvedFunction getOutputFunction(Type type, JsonFormat format, Node no } @Override - protected Type visitJsonObject(JsonObject node, StackableAstVisitorContext context) + protected Type visitJsonObject(JsonObject node, Context context) { // TODO verify parameter count? Is there a limit on Row size? @@ -3072,7 +3225,7 @@ protected Type visitJsonObject(JsonObject node, StackableAstVisitorContext context) + protected Type visitJsonArray(JsonArray node, Context context) { // TODO verify parameter count? Is there a limit on Row size? @@ -3181,7 +3334,7 @@ protected Type visitJsonArray(JsonArray node, StackableAstVisitorContext context, Expression node, OperatorType operatorType, Expression... arguments) + private Type getOperator(Context context, Expression node, OperatorType operatorType, Expression... arguments) { ImmutableList.Builder argumentTypes = ImmutableList.builder(); for (Expression expression : arguments) { @@ -3216,13 +3369,13 @@ private void coerceType(Expression expression, Type actualType, Type expectedTyp } } - private void coerceType(StackableAstVisitorContext context, Expression expression, Type expectedType, String message) + private void coerceType(Context context, Expression expression, Type expectedType, String message) { Type actualType = process(expression, context); coerceType(expression, actualType, expectedType, message); } - private Type coerceToSingleType(StackableAstVisitorContext context, Node node, String message, Expression first, Expression second) + private Type coerceToSingleType(Context context, Node node, String message, Expression first, Expression second) { Type firstType = UNKNOWN; if (first != null) { @@ -3251,7 +3404,7 @@ private Type coerceToSingleType(StackableAstVisitorContext context, Nod throw semanticException(TYPE_MISMATCH, node, "%s: %s vs %s", message, firstType, secondType); } - private Type coerceToSingleType(StackableAstVisitorContext context, String description, List expressions) + private Type coerceToSingleType(Context context, String description, List expressions) { // determine super type Type superType = UNKNOWN; @@ -3323,59 +3476,73 @@ private static class Context // Empty map means that the all lambda expressions surrounding the current node has no arguments. private final Map fieldToLambdaArgumentDeclaration; - // Primary row pattern variables and named unions (subsets) of variables - // necessary for the analysis of expressions in the context of row pattern recognition - private final Set labels; + private final Optional patternRecognitionContext; private final CorrelationSupport correlationSupport; + private final boolean inWindow; + private Context( Scope scope, List functionInputTypes, Map fieldToLambdaArgumentDeclaration, - Set labels, - CorrelationSupport correlationSupport) + Optional patternRecognitionContext, + CorrelationSupport correlationSupport, + boolean inWindow) { this.scope = requireNonNull(scope, "scope is null"); this.functionInputTypes = functionInputTypes; this.fieldToLambdaArgumentDeclaration = fieldToLambdaArgumentDeclaration; - this.labels = labels; + this.patternRecognitionContext = requireNonNull(patternRecognitionContext, "patternRecognitionContext is null"); this.correlationSupport = requireNonNull(correlationSupport, "correlationSupport is null"); + this.inWindow = inWindow; } public static Context notInLambda(Scope scope, CorrelationSupport correlationSupport) { - return new Context(scope, null, null, null, correlationSupport); + return new Context(scope, null, null, Optional.empty(), correlationSupport, false); } public Context inLambda(Scope scope, Map fieldToLambdaArgumentDeclaration) { - return new Context(scope, null, requireNonNull(fieldToLambdaArgumentDeclaration, "fieldToLambdaArgumentDeclaration is null"), labels, correlationSupport); + return new Context(scope, null, requireNonNull(fieldToLambdaArgumentDeclaration, "fieldToLambdaArgumentDeclaration is null"), patternRecognitionContext, correlationSupport, inWindow); } - public Context expectingLambda(List functionInputTypes) + public static Context inWindow(Scope scope, CorrelationSupport correlationSupport) { - return new Context(scope, requireNonNull(functionInputTypes, "functionInputTypes is null"), this.fieldToLambdaArgumentDeclaration, labels, correlationSupport); + return new Context(scope, null, null, Optional.empty(), correlationSupport, true); } - public Context notExpectingLambda() + public Context inWindow() + { + return new Context(scope, functionInputTypes, fieldToLambdaArgumentDeclaration, patternRecognitionContext, correlationSupport, true); + } + + public Context expectingLambda(List functionInputTypes) { - return new Context(scope, null, this.fieldToLambdaArgumentDeclaration, labels, correlationSupport); + return new Context(scope, requireNonNull(functionInputTypes, "functionInputTypes is null"), this.fieldToLambdaArgumentDeclaration, patternRecognitionContext, correlationSupport, inWindow); } - public static Context patternRecognition(Scope scope, Set labels) + public Context notExpectingLambda() { - return new Context(scope, null, null, requireNonNull(labels, "labels is null"), CorrelationSupport.DISALLOWED); + return new Context(scope, null, this.fieldToLambdaArgumentDeclaration, patternRecognitionContext, correlationSupport, inWindow); } - public Context patternRecognition(Set labels) + public static Context patternRecognition(Scope scope, Set labels, boolean inWindow) { - return new Context(scope, functionInputTypes, fieldToLambdaArgumentDeclaration, requireNonNull(labels, "labels is null"), CorrelationSupport.DISALLOWED); + return new Context(scope, null, null, Optional.of(new PatternRecognitionContext(labels, Navigation.DEFAULT)), CorrelationSupport.DISALLOWED, inWindow); } - public Context notExpectingLabels() + public Context withNavigation(Navigation navigation) { - return new Context(scope, functionInputTypes, fieldToLambdaArgumentDeclaration, null, correlationSupport); + PatternRecognitionContext patternRecognitionContext = new PatternRecognitionContext(this.patternRecognitionContext.get().labels, navigation); + return new Context( + scope, + functionInputTypes, + fieldToLambdaArgumentDeclaration, + Optional.of(patternRecognitionContext), + correlationSupport, + inWindow); } Scope getScope() @@ -3388,6 +3555,11 @@ public boolean isInLambda() return fieldToLambdaArgumentDeclaration != null; } + public boolean isInWindow() + { + return inWindow; + } + public boolean isExpectingLambda() { return functionInputTypes != null; @@ -3395,7 +3567,7 @@ public boolean isExpectingLambda() public boolean isPatternRecognition() { - return labels != null; + return patternRecognitionContext.isPresent(); } public Map getFieldToLambdaArgumentDeclaration() @@ -3410,16 +3582,17 @@ public List getFunctionInputTypes() return functionInputTypes; } - public Set getLabels() + public PatternRecognitionContext getPatternRecognitionContext() { - checkState(isPatternRecognition()); - return labels; + return patternRecognitionContext.get(); } public CorrelationSupport getCorrelationSupport() { return correlationSupport; } + + record PatternRecognitionContext(Set labels, Navigation navigation) {} } public static boolean isPatternRecognitionFunction(FunctionCall node) @@ -3428,7 +3601,7 @@ public static boolean isPatternRecognitionFunction(FunctionCall node) if (qualifiedName.getParts().size() > 1) { return false; } - Identifier identifier = qualifiedName.getOriginalParts().get(0); + Identifier identifier = qualifiedName.getOriginalParts().getFirst(); if (identifier.isDelimited()) { return false; } @@ -3453,7 +3626,7 @@ public static ExpressionAnalysis analyzePatternRecognitionExpression( Set labels) { ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, TypeProvider.empty(), warningCollector); - analyzer.analyze(expression, scope, labels); + analyzer.analyze(expression, scope, labels, false); updateAnalysis(analysis, analyzer, session, accessControl); @@ -3670,12 +3843,14 @@ private static void updateAnalysis(Analysis analysis, ExpressionAnalyzer analyze analysis.addColumnReferences(analyzer.getColumnReferences()); analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences()); analysis.addTableColumnReferences(accessControl, session.getIdentity(), analyzer.getTableColumnReferences()); - analysis.addLabelDereferences(analyzer.getLabelDereferences()); - analysis.addPatternRecognitionFunctions(analyzer.getPatternRecognitionFunctions()); + analysis.addLabels(analyzer.getLabels()); + analysis.addPatternRecognitionInputs(analyzer.getPatternRecognitionInputs()); + analysis.addPatternNavigationFunctions(analyzer.getPatternNavigationFunctions()); analysis.setRanges(analyzer.getRanges()); analysis.setUndefinedLabels(analyzer.getUndefinedLabels()); + analysis.addResolvedLabels(analyzer.getResolvedLabels()); + analysis.addSubsetLabels(analyzer.getSubsetLabels()); analysis.setMeasureDefinitions(analyzer.getMeasureDefinitions()); - analysis.setPatternAggregations(analyzer.getPatternAggregations()); analysis.setJsonPathAnalyses(analyzer.getJsonPathAnalyses()); analysis.setJsonInputFunctions(analyzer.getJsonInputFunctions()); analysis.setJsonOutputFunctions(analyzer.getJsonOutputFunctions()); @@ -3799,38 +3974,6 @@ public static boolean isCharacterStringType(Type type) return type instanceof VarcharType || type instanceof CharType; } - public static class LabelPrefixedReference - { - private final String label; - private final Optional column; - - public LabelPrefixedReference(String label, Identifier column) - { - this(label, Optional.of(requireNonNull(column, "column is null"))); - } - - public LabelPrefixedReference(String label) - { - this(label, Optional.empty()); - } - - private LabelPrefixedReference(String label, Optional column) - { - this.label = requireNonNull(label, "label is null"); - this.column = requireNonNull(column, "column is null"); - } - - public String getLabel() - { - return label; - } - - public Optional getColumn() - { - return column; - } - } - private static class ArgumentLabel { private final boolean hasLabel; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionInterpreter.java deleted file mode 100644 index 0b43271ccc7348..00000000000000 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionInterpreter.java +++ /dev/null @@ -1,1241 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.analyzer; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.primitives.Primitives; -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; -import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; -import io.trino.metadata.Metadata; -import io.trino.metadata.ResolvedFunction; -import io.trino.operator.scalar.ArraySubscriptOperator; -import io.trino.operator.scalar.FormatFunction; -import io.trino.security.AccessControl; -import io.trino.spi.TrinoException; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.SqlRow; -import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.function.FunctionNullability; -import io.trino.spi.function.InvocationConvention; -import io.trino.spi.function.OperatorType; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.CharType; -import io.trino.spi.type.RowType; -import io.trino.spi.type.RowType.Field; -import io.trino.spi.type.TimeType; -import io.trino.spi.type.TimeWithTimeZoneType; -import io.trino.spi.type.TimestampType; -import io.trino.spi.type.TimestampWithTimeZoneType; -import io.trino.spi.type.Type; -import io.trino.spi.type.VarcharType; -import io.trino.sql.InterpretedFunctionInvoker; -import io.trino.sql.PlannerContext; -import io.trino.sql.planner.BuiltinFunctionCallBuilder; -import io.trino.sql.planner.Coercer; -import io.trino.sql.planner.LiteralInterpreter; -import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.SymbolResolver; -import io.trino.sql.planner.TranslationMap; -import io.trino.sql.tree.ArithmeticBinaryExpression; -import io.trino.sql.tree.ArithmeticUnaryExpression; -import io.trino.sql.tree.Array; -import io.trino.sql.tree.AstVisitor; -import io.trino.sql.tree.AtTimeZone; -import io.trino.sql.tree.BetweenPredicate; -import io.trino.sql.tree.BindExpression; -import io.trino.sql.tree.BooleanLiteral; -import io.trino.sql.tree.Cast; -import io.trino.sql.tree.CoalesceExpression; -import io.trino.sql.tree.ComparisonExpression; -import io.trino.sql.tree.ComparisonExpression.Operator; -import io.trino.sql.tree.CurrentCatalog; -import io.trino.sql.tree.CurrentDate; -import io.trino.sql.tree.CurrentPath; -import io.trino.sql.tree.CurrentSchema; -import io.trino.sql.tree.CurrentTime; -import io.trino.sql.tree.CurrentTimestamp; -import io.trino.sql.tree.CurrentUser; -import io.trino.sql.tree.DereferenceExpression; -import io.trino.sql.tree.ExistsPredicate; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.Extract; -import io.trino.sql.tree.FieldReference; -import io.trino.sql.tree.Format; -import io.trino.sql.tree.FunctionCall; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.IfExpression; -import io.trino.sql.tree.InListExpression; -import io.trino.sql.tree.InPredicate; -import io.trino.sql.tree.IsNotNullPredicate; -import io.trino.sql.tree.IsNullPredicate; -import io.trino.sql.tree.LambdaArgumentDeclaration; -import io.trino.sql.tree.LambdaExpression; -import io.trino.sql.tree.LikePredicate; -import io.trino.sql.tree.Literal; -import io.trino.sql.tree.LocalTime; -import io.trino.sql.tree.LocalTimestamp; -import io.trino.sql.tree.LogicalExpression; -import io.trino.sql.tree.Node; -import io.trino.sql.tree.NodeRef; -import io.trino.sql.tree.NotExpression; -import io.trino.sql.tree.NullIfExpression; -import io.trino.sql.tree.NullLiteral; -import io.trino.sql.tree.Parameter; -import io.trino.sql.tree.QuantifiedComparisonExpression; -import io.trino.sql.tree.Row; -import io.trino.sql.tree.SearchedCaseExpression; -import io.trino.sql.tree.SimpleCaseExpression; -import io.trino.sql.tree.StringLiteral; -import io.trino.sql.tree.SubqueryExpression; -import io.trino.sql.tree.SubscriptExpression; -import io.trino.sql.tree.SymbolReference; -import io.trino.sql.tree.WhenClause; -import io.trino.type.FunctionType; -import io.trino.type.LikeFunctions; -import io.trino.type.LikePattern; -import io.trino.type.TypeCoercion; -import io.trino.util.FastutilSetHelper; - -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.IdentityHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Throwables.throwIfInstanceOf; -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_CONSTANT; -import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; -import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; -import static io.trino.spi.block.RowValueBuilder.buildRowValue; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.spi.function.InvocationConvention.simpleConvention; -import static io.trino.spi.function.OperatorType.EQUAL; -import static io.trino.spi.function.OperatorType.HASH_CODE; -import static io.trino.spi.type.RowType.anonymous; -import static io.trino.spi.type.TimeWithTimeZoneType.createTimeWithTimeZoneType; -import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType; -import static io.trino.spi.type.TypeUtils.readNativeValue; -import static io.trino.spi.type.TypeUtils.writeNativeValue; -import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer; -import static io.trino.sql.analyzer.SemanticExceptions.semanticException; -import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; -import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; -import static io.trino.sql.gen.VarArgsToMapAdapterGenerator.generateVarArgsToMapAdapter; -import static io.trino.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression; -import static io.trino.sql.tree.DereferenceExpression.isQualifiedAllFieldsReference; -import static java.lang.Math.toIntExact; -import static java.util.Collections.singletonList; -import static java.util.Objects.requireNonNull; - -public class ExpressionInterpreter -{ - private final Expression expression; - private final PlannerContext plannerContext; - private final Metadata metadata; - private final Map, ResolvedFunction> resolvedFunctions; - private final LiteralInterpreter literalInterpreter; - private final ConnectorSession connectorSession; - private final Map, Type> expressionTypes; - private final InterpretedFunctionInvoker functionInvoker; - private final TypeCoercion typeCoercion; - - // identity-based cache for LIKE expressions with constant pattern and escape char - private final IdentityHashMap likePatternCache = new IdentityHashMap<>(); - private final IdentityHashMap> inListCache = new IdentityHashMap<>(); - - public ExpressionInterpreter(Expression expression, PlannerContext plannerContext, Session session, Map, Type> expressionTypes, Map, ResolvedFunction> resolvedFunctions) - { - this.expression = requireNonNull(expression, "expression is null"); - this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); - this.metadata = plannerContext.getMetadata(); - this.resolvedFunctions = requireNonNull(resolvedFunctions, "resolvedFunctions is null"); - this.literalInterpreter = new LiteralInterpreter(plannerContext, session); - this.connectorSession = session.toConnectorSession(); - this.expressionTypes = ImmutableMap.copyOf(requireNonNull(expressionTypes, "expressionTypes is null")); - verify(expressionTypes.containsKey(NodeRef.of(expression))); - this.functionInvoker = new InterpretedFunctionInvoker(plannerContext.getFunctionManager()); - this.typeCoercion = new TypeCoercion(plannerContext.getTypeManager()::getType); - } - - public static Object evaluateConstantExpression( - Expression expression, - Type expectedType, - PlannerContext plannerContext, - Session session, - AccessControl accessControl, - Map, Expression> parameters) - { - Analysis analysis = new Analysis(null, ImmutableMap.of(), QueryType.OTHERS); - Scope scope = Scope.create(); - ExpressionAnalyzer.analyzeExpressionWithoutSubqueries( - session, - plannerContext, - accessControl, - scope, - analysis, - expression, - EXPRESSION_NOT_CONSTANT, - "Constant expression cannot contain a subquery", - WarningCollector.NOOP, - CorrelationSupport.DISALLOWED); - - // Apply casts, desugar expression, and preform other rewrites - TranslationMap translationMap = new TranslationMap(Optional.empty(), scope, analysis, ImmutableMap.of(), ImmutableList.of(), session, plannerContext); - expression = coerceIfNecessary(analysis, expression, translationMap.rewrite(expression)); - - // The expression tree has been rewritten which breaks all the identity maps, so redo the analysis - // to re-analyze coercions that might be necessary - ExpressionAnalyzer analyzer = createConstantAnalyzer(plannerContext, accessControl, session, parameters, WarningCollector.NOOP); - analyzer.analyze(expression, scope); - - Type actualType = analyzer.getExpressionTypes().get(NodeRef.of(expression)); - if (!new TypeCoercion(plannerContext.getTypeManager()::getType).canCoerce(actualType, expectedType)) { - throw semanticException(TYPE_MISMATCH, expression, "Cannot cast type %s to %s", actualType.getDisplayName(), expectedType.getDisplayName()); - } - - Map, Type> coercions = ImmutableMap., Type>builder() - .putAll(analyzer.getExpressionCoercions()) - .put(NodeRef.of(expression), expectedType) - .buildOrThrow(); - - // add coercions - Expression rewrite = Coercer.addCoercions(expression, coercions); - - // redo the analysis since above expression rewriter might create new expressions which do not have entries in the type map - analyzer = createConstantAnalyzer(plannerContext, accessControl, session, parameters, WarningCollector.NOOP); - analyzer.analyze(rewrite, Scope.create()); - - // The optimization above may have rewritten the expression tree which breaks all the identity maps, so redo the analysis - // to re-analyze coercions that might be necessary - analyzer = createConstantAnalyzer(plannerContext, accessControl, session, parameters, WarningCollector.NOOP); - analyzer.analyze(rewrite, Scope.create()); - - // expressionInterpreter/optimizer only understands a subset of expression types - // TODO: remove this when the new expression tree is implemented - Expression canonicalized = canonicalizeExpression(rewrite, analyzer.getExpressionTypes(), plannerContext, session); - - // The optimization above may have rewritten the expression tree which breaks all the identity maps, so redo the analysis - // to re-analyze coercions that might be necessary - analyzer = createConstantAnalyzer(plannerContext, accessControl, session, parameters, WarningCollector.NOOP); - analyzer.analyze(canonicalized, Scope.create()); - - // evaluate the expression - return new ExpressionInterpreter(canonicalized, plannerContext, session, analyzer.getExpressionTypes(), analyzer.getResolvedFunctions()).evaluate(); - } - - private static Expression coerceIfNecessary(Analysis analysis, Expression original, Expression rewritten) - { - Type coercion = analysis.getCoercion(original); - if (coercion == null) { - return rewritten; - } - - return new Cast(rewritten, toSqlType(coercion), false); - } - - private Object evaluate() - { - Object result = new Visitor().processWithExceptionHandling(expression, null); - verify(!(result instanceof Expression), "Expression interpreter returned an unresolved expression"); - return result; - } - - private class Visitor - extends AstVisitor - { - private Object processWithExceptionHandling(Expression expression, Object context) - { - if (expression == null) { - return null; - } - - Object result = process(expression, context); - verify(!(result instanceof Expression), "Expression interpreter returned an unresolved expression"); - return result; - } - - @Override - public Object visitFieldReference(FieldReference node, Object context) - { - throw new UnsupportedOperationException("Field references not supported in interpreter"); - } - - @Override - protected Object visitDereferenceExpression(DereferenceExpression node, Object context) - { - checkArgument(!isQualifiedAllFieldsReference(node), "unexpected expression: all fields labeled reference %s", node); - Identifier fieldIdentifier = node.getField().orElseThrow(); - - // Row dereference: process dereference base eagerly, and only then pick the expected field - Object base = processWithExceptionHandling(node.getBase(), context); - // if the base part is evaluated to be null, the dereference expression should also be null - if (base == null) { - return null; - } - - Type type = type(node.getBase()); - RowType rowType = (RowType) type; - SqlRow row = (SqlRow) base; - Type returnType = type(node); - String fieldName = fieldIdentifier.getValue(); - List fields = rowType.getFields(); - int index = -1; - for (int i = 0; i < fields.size(); i++) { - Field field = fields.get(i); - if (field.getName().isPresent() && field.getName().get().equalsIgnoreCase(fieldName)) { - checkArgument(index < 0, "Ambiguous field %s in type %s", field, rowType.getDisplayName()); - index = i; - } - } - - checkState(index >= 0, "could not find field name: %s", fieldName); - return readNativeValue(returnType, row.getRawFieldBlock(index), row.getRawIndex()); - } - - @Override - protected Object visitIdentifier(Identifier node, Object context) - { - // Identifier only exists before planning. - // ExpressionInterpreter should only be invoked after planning. - // As a result, this method should be unreachable. - // However, RelationPlanner.visitUnnest and visitValues invokes evaluateConstantExpression. - return ((SymbolResolver) context).getValue(new Symbol(node.getValue())); - } - - @Override - protected Object visitSymbolReference(SymbolReference node, Object context) - { - return ((SymbolResolver) context).getValue(Symbol.from(node)); - } - - @Override - protected Object visitLiteral(Literal node, Object context) - { - return literalInterpreter.evaluate(node, type(node)); - } - - @Override - protected Object visitIsNullPredicate(IsNullPredicate node, Object context) - { - Object value = processWithExceptionHandling(node.getValue(), context); - return value == null; - } - - @Override - protected Object visitIsNotNullPredicate(IsNotNullPredicate node, Object context) - { - Object value = processWithExceptionHandling(node.getValue(), context); - return value != null; - } - - @Override - protected Object visitSearchedCaseExpression(SearchedCaseExpression node, Object context) - { - Object newDefault = null; - boolean foundNewDefault = false; - - for (WhenClause whenClause : node.getWhenClauses()) { - Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), context); - - if (Boolean.TRUE.equals(whenOperand)) { - // condition is true, use this as default - foundNewDefault = true; - newDefault = processWithExceptionHandling(whenClause.getResult(), context); - break; - } - } - - Object defaultResult; - if (foundNewDefault) { - defaultResult = newDefault; - } - else { - defaultResult = processWithExceptionHandling(node.getDefaultValue().orElse(null), context); - } - - return defaultResult; - } - - @Override - protected Object visitIfExpression(IfExpression node, Object context) - { - Object condition = processWithExceptionHandling(node.getCondition(), context); - - if (Boolean.TRUE.equals(condition)) { - return processWithExceptionHandling(node.getTrueValue(), context); - } - return processWithExceptionHandling(node.getFalseValue().orElse(null), context); - } - - @Override - protected Object visitSimpleCaseExpression(SimpleCaseExpression node, Object context) - { - Object operand = processWithExceptionHandling(node.getOperand(), context); - Type operandType = type(node.getOperand()); - - // if operand is null, return defaultValue - if (operand == null) { - return processWithExceptionHandling(node.getDefaultValue().orElse(null), context); - } - - Object newDefault = null; - boolean foundNewDefault = false; - - for (WhenClause whenClause : node.getWhenClauses()) { - Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), context); - - if (whenOperand != null && isEqual(operand, operandType, whenOperand, type(whenClause.getOperand()))) { - // condition is true, use this as default - foundNewDefault = true; - newDefault = processWithExceptionHandling(whenClause.getResult(), context); - break; - } - } - - Object defaultResult; - if (foundNewDefault) { - defaultResult = newDefault; - } - else { - defaultResult = processWithExceptionHandling(node.getDefaultValue().orElse(null), context); - } - - return defaultResult; - } - - private boolean isEqual(Object operand1, Type type1, Object operand2, Type type2) - { - return Boolean.TRUE.equals(invokeOperator(OperatorType.EQUAL, ImmutableList.of(type1, type2), ImmutableList.of(operand1, operand2))); - } - - private Type type(Expression expression) - { - return expressionTypes.get(NodeRef.of(expression)); - } - - @Override - protected Object visitCoalesceExpression(CoalesceExpression node, Object context) - { - for (Expression operand : node.getOperands()) { - Object value = processWithExceptionHandling(operand, context); - if (value != null) { - return value; - } - } - - return null; - } - - @Override - protected Object visitInPredicate(InPredicate node, Object context) - { - Object value = processWithExceptionHandling(node.getValue(), context); - - Expression valueListExpression = node.getValueList(); - if (!(valueListExpression instanceof InListExpression valueList)) { - throw new UnsupportedOperationException("IN predicate value list type not yet implemented: " + valueListExpression.getClass().getName()); - } - // `NULL IN ()` would be false, but InListExpression cannot be empty by construction - if (value == null) { - return null; - } - - Set set = inListCache.get(valueList); - - // We use the presence of the node in the map to indicate that we've already done - // the analysis below. If the value is null, it means that we can't apply the HashSet - // optimization - if (!inListCache.containsKey(valueList)) { - if (valueList.getValues().stream().allMatch(Literal.class::isInstance) && - valueList.getValues().stream().noneMatch(NullLiteral.class::isInstance)) { - Set objectSet = valueList.getValues().stream().map(expression -> processWithExceptionHandling(expression, context)).collect(Collectors.toSet()); - Type type = type(node.getValue()); - set = FastutilSetHelper.toFastutilHashSet( - objectSet, - type, - plannerContext.getFunctionManager().getScalarFunctionImplementation(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type)), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(), - plannerContext.getFunctionManager().getScalarFunctionImplementation(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type)), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle()); - } - inListCache.put(valueList, set); - } - - if (set != null) { - return set.contains(value); - } - - boolean hasNullValue = false; - boolean found = false; - - ResolvedFunction equalsOperator = metadata.resolveOperator(OperatorType.EQUAL, types(node.getValue(), valueList)); - for (Expression expression : valueList.getValues()) { - // Use process() instead of processWithExceptionHandling() for processing in-list items. - // Do not handle exceptions thrown while processing a single in-list expression, - // but fail the whole in-predicate evaluation. - // According to in-predicate semantics, all in-list items must be successfully evaluated - // before a check for the match is performed. - Object inValue = process(expression, context); - if (inValue == null) { - hasNullValue = true; - } - else { - Boolean result = (Boolean) functionInvoker.invoke(equalsOperator, connectorSession, ImmutableList.of(value, inValue)); - if (result == null) { - hasNullValue = true; - } - else if (!found && result) { - // in does not short-circuit so we must evaluate all value in the list - found = true; - } - } - } - if (found) { - return true; - } - - if (hasNullValue) { - return null; - } - return false; - } - - @Override - protected Object visitExists(ExistsPredicate node, Object context) - { - throw new UnsupportedOperationException("Exists subquery not yet implemented"); - } - - @Override - protected Object visitSubqueryExpression(SubqueryExpression node, Object context) - { - throw new UnsupportedOperationException("Subquery not yet implemented"); - } - - @Override - protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object context) - { - Object value = processWithExceptionHandling(node.getValue(), context); - if (value == null) { - return null; - } - - return switch (node.getSign()) { - case PLUS -> value; - case MINUS -> { - ResolvedFunction resolvedOperator = metadata.resolveOperator(OperatorType.NEGATION, types(node.getValue())); - InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NEVER_NULL), FAIL_ON_NULL, true, false); - MethodHandle handle = plannerContext.getFunctionManager().getScalarFunctionImplementation(resolvedOperator, invocationConvention).getMethodHandle(); - - if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) { - handle = handle.bindTo(connectorSession); - } - try { - yield handle.invokeWithArguments(value); - } - catch (Throwable throwable) { - throwIfInstanceOf(throwable, RuntimeException.class); - throwIfInstanceOf(throwable, Error.class); - throw new RuntimeException(throwable.getMessage(), throwable); - } - } - }; - } - - @Override - protected Object visitArithmeticBinary(ArithmeticBinaryExpression node, Object context) - { - Object left = processWithExceptionHandling(node.getLeft(), context); - if (left == null) { - return null; - } - Object right = processWithExceptionHandling(node.getRight(), context); - if (right == null) { - return null; - } - - return invokeOperator(OperatorType.valueOf(node.getOperator().name()), types(node.getLeft(), node.getRight()), ImmutableList.of(left, right)); - } - - @Override - protected Object visitComparisonExpression(ComparisonExpression node, Object context) - { - Operator operator = node.getOperator(); - Expression left = node.getLeft(); - Expression right = node.getRight(); - - if (operator == Operator.IS_DISTINCT_FROM) { - return processIsDistinctFrom(context, left, right); - } - // Execution engine does not have not equal and greater than operators, so interpret with - // equal or less than, but do not flip operator in result, as many optimizers depend on - // operators not flipping - if (node.getOperator() == Operator.NOT_EQUAL) { - Object result = visitComparisonExpression(flipComparison(node), context); - if (result == null) { - return null; - } - return !(Boolean) result; - } - if (node.getOperator() == Operator.GREATER_THAN || node.getOperator() == Operator.GREATER_THAN_OR_EQUAL) { - return visitComparisonExpression(flipComparison(node), context); - } - - return processComparisonExpression(context, operator, left, right); - } - - private Object processIsDistinctFrom(Object context, Expression leftExpression, Expression rightExpression) - { - Object left = processWithExceptionHandling(leftExpression, context); - Object right = processWithExceptionHandling(rightExpression, context); - - return invokeOperator(OperatorType.valueOf(Operator.IS_DISTINCT_FROM.name()), types(leftExpression, rightExpression), Arrays.asList(left, right)); - } - - private Object processComparisonExpression(Object context, Operator operator, Expression leftExpression, Expression rightExpression) - { - Object left = processWithExceptionHandling(leftExpression, context); - if (left == null) { - return null; - } - - Object right = processWithExceptionHandling(rightExpression, context); - if (right == null) { - return null; - } - - return invokeOperator(OperatorType.valueOf(operator.name()), types(leftExpression, rightExpression), ImmutableList.of(left, right)); - } - - // TODO define method contract or split into separate methods, as flip(EQUAL) is a negation, while flip(LESS_THAN) is just flipping sides - private ComparisonExpression flipComparison(ComparisonExpression comparisonExpression) - { - return switch (comparisonExpression.getOperator()) { - case EQUAL -> new ComparisonExpression(Operator.NOT_EQUAL, comparisonExpression.getLeft(), comparisonExpression.getRight()); - case NOT_EQUAL -> new ComparisonExpression(Operator.EQUAL, comparisonExpression.getLeft(), comparisonExpression.getRight()); - case LESS_THAN -> new ComparisonExpression(Operator.GREATER_THAN, comparisonExpression.getRight(), comparisonExpression.getLeft()); - case LESS_THAN_OR_EQUAL -> new ComparisonExpression(Operator.GREATER_THAN_OR_EQUAL, comparisonExpression.getRight(), comparisonExpression.getLeft()); - case GREATER_THAN -> new ComparisonExpression(Operator.LESS_THAN, comparisonExpression.getRight(), comparisonExpression.getLeft()); - case GREATER_THAN_OR_EQUAL -> new ComparisonExpression(Operator.LESS_THAN_OR_EQUAL, comparisonExpression.getRight(), comparisonExpression.getLeft()); - default -> throw new IllegalStateException("Unexpected value: " + comparisonExpression.getOperator()); - }; - } - - @Override - protected Object visitBetweenPredicate(BetweenPredicate node, Object context) - { - Object value = processWithExceptionHandling(node.getValue(), context); - if (value == null) { - return null; - } - Object min = processWithExceptionHandling(node.getMin(), context); - Object max = processWithExceptionHandling(node.getMax(), context); - - Boolean greaterOrEqualToMin = null; - if (min != null) { - greaterOrEqualToMin = (Boolean) invokeOperator(OperatorType.LESS_THAN_OR_EQUAL, types(node.getMin(), node.getValue()), ImmutableList.of(min, value)); - } - Boolean lessThanOrEqualToMax = null; - if (max != null) { - lessThanOrEqualToMax = (Boolean) invokeOperator(OperatorType.LESS_THAN_OR_EQUAL, types(node.getValue(), node.getMax()), ImmutableList.of(value, max)); - } - - if (greaterOrEqualToMin == null) { - return Objects.equals(lessThanOrEqualToMax, Boolean.FALSE) ? false : null; - } - if (lessThanOrEqualToMax == null) { - return Objects.equals(greaterOrEqualToMin, Boolean.FALSE) ? false : null; - } - return greaterOrEqualToMin && lessThanOrEqualToMax; - } - - @Override - protected Object visitNullIfExpression(NullIfExpression node, Object context) - { - Object first = processWithExceptionHandling(node.getFirst(), context); - if (first == null) { - return null; - } - Object second = processWithExceptionHandling(node.getSecond(), context); - if (second == null) { - return first; - } - - Type firstType = type(node.getFirst()); - Type secondType = type(node.getSecond()); - - Type commonType = typeCoercion.getCommonSuperType(firstType, secondType).get(); - - ResolvedFunction firstCast = metadata.getCoercion(firstType, commonType); - ResolvedFunction secondCast = metadata.getCoercion(secondType, commonType); - - // cast(first as ) == cast(second as ) - boolean equal = Boolean.TRUE.equals(invokeOperator( - OperatorType.EQUAL, - ImmutableList.of(commonType, commonType), - ImmutableList.of( - functionInvoker.invoke(firstCast, connectorSession, ImmutableList.of(first)), - functionInvoker.invoke(secondCast, connectorSession, ImmutableList.of(second))))); - - if (equal) { - return null; - } - return first; - } - - @Override - protected Object visitNotExpression(NotExpression node, Object context) - { - Object value = processWithExceptionHandling(node.getValue(), context); - if (value == null) { - return null; - } - - return !(Boolean) value; - } - - @Override - protected Object visitLogicalExpression(LogicalExpression node, Object context) - { - boolean hasNull = false; - for (Expression term : node.getTerms()) { - Object processed = processWithExceptionHandling(term, context); - - if (processed == null) { - hasNull = true; - } - else { - switch (node.getOperator()) { - case AND -> { - if (Boolean.FALSE.equals(processed)) { - return false; - } - } - case OR -> { - if (Boolean.TRUE.equals(processed)) { - return true; - } - } - } - } - } - - if (hasNull) { - return null; - } - - return switch (node.getOperator()) { - case AND -> true; - case OR -> false; - }; - } - - @Override - protected Object visitBooleanLiteral(BooleanLiteral node, Object context) - { - return node.equals(BooleanLiteral.TRUE_LITERAL); - } - - @Override - protected Object visitFunctionCall(FunctionCall node, Object context) - { - List argumentValues = new ArrayList<>(); - for (Expression expression : node.getArguments()) { - Object value = processWithExceptionHandling(expression, context); - argumentValues.add(value); - } - - ResolvedFunction resolvedFunction = resolvedFunctions.get(NodeRef.of(node)); - FunctionNullability functionNullability = resolvedFunction.getFunctionNullability(); - for (int i = 0; i < argumentValues.size(); i++) { - Object value = argumentValues.get(i); - if (value == null && !functionNullability.isArgumentNullable(i)) { - return null; - } - } - - // do not optimize non-deterministic functions - return functionInvoker.invoke(resolvedFunction, connectorSession, argumentValues); - } - - @Override - protected Object visitLambdaExpression(LambdaExpression node, Object context) - { - Expression body = node.getBody(); - List argumentNames = node.getArguments().stream() - .map(LambdaArgumentDeclaration::getName) - .map(Identifier::getValue) - .collect(toImmutableList()); - FunctionType functionType = (FunctionType) expressionTypes.get(NodeRef.of(node)); - checkArgument(argumentNames.size() == functionType.getArgumentTypes().size()); - - return generateVarArgsToMapAdapter( - Primitives.wrap(functionType.getReturnType().getJavaType()), - functionType.getArgumentTypes().stream() - .map(Type::getJavaType) - .map(Primitives::wrap) - .collect(toImmutableList()), - argumentNames, - map -> processWithExceptionHandling(body, new LambdaSymbolResolver(map))); - } - - @Override - protected Object visitBindExpression(BindExpression node, Object context) - { - Object[] values = node.getValues().stream() - .map(value -> processWithExceptionHandling(value, context)) - .toArray(); // values are nullable - Object function = processWithExceptionHandling(node.getFunction(), context); - - return MethodHandles.insertArguments((MethodHandle) function, 0, values); - } - - @Override - protected Object visitLikePredicate(LikePredicate node, Object context) - { - Slice value = (Slice) processWithExceptionHandling(node.getValue(), context); - - if (value == null) { - return null; - } - - if (node.getPattern() instanceof StringLiteral && (node.getEscape().isEmpty() || node.getEscape().get() instanceof StringLiteral)) { - // fast path when we know the pattern and escape are constant - LikePattern pattern = getConstantPattern(node); - if (type(node.getValue()) instanceof VarcharType) { - return LikeFunctions.likeVarchar(value, pattern); - } - - Type type = type(node.getValue()); - checkState(type instanceof CharType, "LIKE value is neither VARCHAR or CHAR"); - return LikeFunctions.likeChar((long) ((CharType) type).getLength(), value, pattern); - } - - Slice pattern = (Slice) processWithExceptionHandling(node.getPattern(), context); - - if (pattern == null) { - return null; - } - - Slice escape = null; - if (node.getEscape().isPresent()) { - escape = (Slice) processWithExceptionHandling(node.getEscape().get(), context); - - if (escape == null) { - return null; - } - } - - LikePattern likePattern; - if (escape == null) { - likePattern = LikePattern.compile(pattern.toStringUtf8(), Optional.empty()); - } - else { - likePattern = LikeFunctions.likePattern(pattern, escape); - } - - if (type(node.getValue()) instanceof VarcharType) { - return LikeFunctions.likeVarchar(value, likePattern); - } - - Type type = type(node.getValue()); - checkState(type instanceof CharType, "LIKE value is neither VARCHAR or CHAR"); - return LikeFunctions.likeChar((long) ((CharType) type).getLength(), value, likePattern); - } - - private LikePattern getConstantPattern(LikePredicate node) - { - LikePattern result = likePatternCache.get(node); - - if (result == null) { - StringLiteral pattern = (StringLiteral) node.getPattern(); - - if (node.getEscape().isPresent()) { - Slice escape = Slices.utf8Slice(((StringLiteral) node.getEscape().get()).getValue()); - result = LikeFunctions.likePattern(Slices.utf8Slice(pattern.getValue()), escape); - } - else { - result = LikePattern.compile(pattern.getValue(), Optional.empty()); - } - - likePatternCache.put(node, result); - } - - return result; - } - - @Override - public Object visitCast(Cast node, Object context) - { - Object value = processWithExceptionHandling(node.getExpression(), context); - Type targetType = plannerContext.getTypeManager().getType(toTypeSignature(node.getType())); - Type sourceType = type(node.getExpression()); - - if (value == null) { - return null; - } - - ResolvedFunction operator = metadata.getCoercion(sourceType, targetType); - - try { - return functionInvoker.invoke(operator, connectorSession, ImmutableList.of(value)); - } - catch (RuntimeException e) { - if (node.isSafe()) { - return null; - } - throw e; - } - } - - @Override - protected Object visitArray(Array node, Object context) - { - Type elementType = ((ArrayType) type(node)).getElementType(); - BlockBuilder arrayBlockBuilder = elementType.createBlockBuilder(null, node.getValues().size()); - - for (Expression expression : node.getValues()) { - Object value = processWithExceptionHandling(expression, context); - writeNativeValue(elementType, arrayBlockBuilder, value); - } - - return arrayBlockBuilder.build(); - } - - @Override - protected Object visitCurrentCatalog(CurrentCatalog node, Object context) - { - FunctionCall function = BuiltinFunctionCallBuilder.resolve(metadata) - .setName("$current_catalog") - .build(); - - return visitFunctionCall(function, context); - } - - @Override - protected Object visitCurrentSchema(CurrentSchema node, Object context) - { - FunctionCall function = BuiltinFunctionCallBuilder.resolve(metadata) - .setName("$current_schema") - .build(); - - return visitFunctionCall(function, context); - } - - @Override - protected Object visitCurrentUser(CurrentUser node, Object context) - { - FunctionCall function = BuiltinFunctionCallBuilder.resolve(metadata) - .setName("$current_user") - .build(); - - return visitFunctionCall(function, context); - } - - @Override - protected Object visitCurrentPath(CurrentPath node, Object context) - { - FunctionCall function = BuiltinFunctionCallBuilder.resolve(metadata) - .setName("$current_path") - .build(); - - return visitFunctionCall(function, context); - } - - @Override - protected Object visitAtTimeZone(AtTimeZone node, Object context) - { - Object value = processWithExceptionHandling(node.getValue(), context); - if (value == null) { - return null; - } - - Object timeZone = processWithExceptionHandling(node.getTimeZone(), context); - if (timeZone == null) { - return null; - } - - Type valueType = type(node.getValue()); - Type timeZoneType = type(node.getTimeZone()); - - if (valueType instanceof TimeType type) { - //