diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlCumulateTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlCumulateTableFunction.java index 8467747ed59ee..1c73f48c2812f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlCumulateTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlCumulateTableFunction.java @@ -17,6 +17,8 @@ package org.apache.flink.table.planner.functions.sql; +import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils; + import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList; import org.apache.calcite.sql.SqlCallBinding; @@ -51,14 +53,16 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata { @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { - if (!checkTableAndDescriptorOperands(callBinding, 1)) { - return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure); + if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) { + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); } if (!checkIntervalOperands(callBinding, 2)) { - return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure); + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); } // check time attribute - return throwExceptionOrReturnFalse( + return SqlValidatorUtils.throwExceptionOrReturnFalse( checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlHopTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlHopTableFunction.java index b957e94d54556..fa16f037f2007 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlHopTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlHopTableFunction.java @@ -17,6 +17,8 @@ package org.apache.flink.table.planner.functions.sql; +import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils; + import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList; import org.apache.calcite.sql.SqlCallBinding; @@ -51,14 +53,16 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata { @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { - if (!checkTableAndDescriptorOperands(callBinding, 1)) { - return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure); + if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) { + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); } if (!checkIntervalOperands(callBinding, 2)) { - return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure); + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); } // check time attribute - return throwExceptionOrReturnFalse( + return SqlValidatorUtils.throwExceptionOrReturnFalse( checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlSessionTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlSessionTableFunction.java index c120811eef499..ab1267f765704 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlSessionTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlSessionTableFunction.java @@ -17,6 +17,8 @@ package org.apache.flink.table.planner.functions.sql; +import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils; + import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList; import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableMap; @@ -73,18 +75,20 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata { @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { - if (!checkTableAndDescriptorOperands(callBinding, 1)) { - return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure); + if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) { + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); } final SqlValidator validator = callBinding.getValidator(); final SqlNode operand2 = callBinding.operand(2); final RelDataType type2 = validator.getValidatedNodeType(operand2); if (!SqlTypeUtil.isInterval(type2)) { - return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure); + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); } - return throwExceptionOrReturnFalse( + return SqlValidatorUtils.throwExceptionOrReturnFalse( checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlTumbleTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlTumbleTableFunction.java index 7e054d8f76c5f..563bae81c3a62 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlTumbleTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlTumbleTableFunction.java @@ -17,6 +17,8 @@ package org.apache.flink.table.planner.functions.sql; +import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils; + import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList; import org.apache.calcite.sql.SqlCallBinding; @@ -49,14 +51,16 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata { public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { // There should only be three operands, and number of operands are checked before // this call. - if (!checkTableAndDescriptorOperands(callBinding, 1)) { - return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure); + if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) { + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); } if (!checkIntervalOperands(callBinding, 2)) { - return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure); + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); } // check time attribute - return throwExceptionOrReturnFalse( + return SqlValidatorUtils.throwExceptionOrReturnFalse( checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlWindowTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlWindowTableFunction.java index a4e55dc7a1595..d9571d599a0b1 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlWindowTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlWindowTableFunction.java @@ -34,7 +34,6 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperatorBinding; -import org.apache.calcite.sql.SqlUtil; import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandMetadata; import org.apache.calcite.sql.type.SqlReturnTypeInference; @@ -48,7 +47,6 @@ import java.util.List; import java.util.Optional; -import static org.apache.calcite.util.Static.RESOURCE; import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.canBeTimeAttributeType; /** @@ -194,56 +192,6 @@ public boolean isOptional(int i) { return i > getOperandCountRange().getMin() && i <= getOperandCountRange().getMax(); } - boolean throwValidationSignatureErrorOrReturnFalse( - SqlCallBinding callBinding, boolean throwOnFailure) { - if (throwOnFailure) { - throw callBinding.newValidationSignatureError(); - } else { - return false; - } - } - - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - boolean throwExceptionOrReturnFalse(Optional e, boolean throwOnFailure) { - if (e.isPresent()) { - if (throwOnFailure) { - throw e.get(); - } else { - return false; - } - } else { - return true; - } - } - - /** - * Checks whether the heading operands are in the form {@code (ROW, DESCRIPTOR, DESCRIPTOR - * ..., other params)}, returning whether successful, and throwing if any columns are not - * found. - * - * @param callBinding The call binding - * @param descriptorCount The number of descriptors following the first operand (e.g. the - * table) - * @return true if validation passes; throws if any columns are not found - */ - boolean checkTableAndDescriptorOperands(SqlCallBinding callBinding, int descriptorCount) { - final SqlNode operand0 = callBinding.operand(0); - final SqlValidator validator = callBinding.getValidator(); - final RelDataType type = validator.getValidatedNodeType(operand0); - if (type.getSqlTypeName() != SqlTypeName.ROW) { - return false; - } - for (int i = 1; i < descriptorCount + 1; i++) { - final SqlNode operand = callBinding.operand(i); - if (operand.getKind() != SqlKind.DESCRIPTOR) { - return false; - } - validateColumnNames( - validator, type.getFieldNames(), ((SqlCall) operand).getOperandList()); - } - return true; - } - /** * Checks whether the type that the operand of time col descriptor refers to is valid. * @@ -310,17 +258,5 @@ boolean checkIntervalOperands(SqlCallBinding callBinding, int startPos) { } return true; } - - void validateColumnNames( - SqlValidator validator, List fieldNames, List columnNames) { - final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher(); - for (SqlNode columnName : columnNames) { - final String name = ((SqlIdentifier) columnName).getSimple(); - if (matcher.indexOf(fieldNames, name) < 0) { - throw SqlUtil.newContextException( - columnName.getParserPosition(), RESOURCE.unknownIdentifier(name)); - } - } - } } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java index 2bc3516ecefcc..96fde75762f8a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java @@ -17,18 +17,35 @@ package org.apache.flink.table.planner.functions.sql.ml; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.utils.LogicalTypeCasts; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlCharStringLiteral; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlModelCall; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandMetadata; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlNameMatcher; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.util.Util; import java.util.Collections; import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType; /** * SqlMlPredictTableFunction implements an operator for prediction. @@ -51,8 +68,8 @@ public SqlMLPredictTableFunction() { /** * {@inheritDoc} * - *

Overrides because the first parameter of table-value function windowing is an explicit - * TABLE parameter, which is not scalar. + *

Overrides because the first parameter of ML table-value function is an explicit TABLE + * parameter, which is not scalar. */ @Override public boolean argumentMustBeScalar(int ordinal) { @@ -61,9 +78,18 @@ public boolean argumentMustBeScalar(int ordinal) { @Override protected RelDataType inferRowType(SqlOperatorBinding opBinding) { - // TODO: FLINK-37780 output type based on table schema and model output schema - // model output schema to be available after integrated with SqlExplicitModelCall - return opBinding.getOperandType(1); + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + final RelDataType inputRowType = opBinding.getOperandType(0); + final RelDataType modelOutputRowType = opBinding.getOperandType(1); + + return typeFactory + .builder() + .kind(inputRowType.getStructKind()) + .addAll(inputRowType.getFieldList()) + .addAll( + SqlValidatorUtils.makeOutputUnique( + inputRowType.getFieldList(), modelOutputRowType.getFieldList())) + .build(); } private static class PredictOperandMetadata implements SqlOperandMetadata { @@ -87,9 +113,18 @@ public List paramNames() { @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { - // TODO: FLINK-37780 Check operand types after integrated with SqlExplicitModelCall in - // validator - return false; + if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 2)) { + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); + } + + if (!SqlValidatorUtils.throwExceptionOrReturnFalse( + checkModelSignature(callBinding), throwOnFailure)) { + return false; + } + + return SqlValidatorUtils.throwExceptionOrReturnFalse( + checkConfig(callBinding), throwOnFailure); } @Override @@ -97,11 +132,6 @@ public SqlOperandCountRange getOperandCountRange() { return SqlOperandCountRanges.between(MANDATORY_PARAM_NAMES.size(), PARAM_NAMES.size()); } - @Override - public Consistency getConsistency() { - return Consistency.NONE; - } - @Override public boolean isOptional(int i) { return i > getOperandCountRange().getMin() && i <= getOperandCountRange().getMax(); @@ -110,7 +140,94 @@ public boolean isOptional(int i) { @Override public String getAllowedSignatures(SqlOperator op, String opName) { return opName - + "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]]"; + + "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]])"; + } + + private static Optional checkModelSignature(SqlCallBinding callBinding) { + SqlValidator validator = callBinding.getValidator(); + + // Check second operand is SqlModelCall + if (!(callBinding.operand(1) instanceof SqlModelCall)) { + return Optional.of( + new ValidationException("Second operand must be a model identifier.")); + } + + // Get descriptor columns + SqlCall descriptorCall = (SqlCall) callBinding.operand(2); + List descriptCols = descriptorCall.getOperandList(); + + // Get model input size + SqlModelCall modelCall = (SqlModelCall) callBinding.operand(1); + RelDataType modelInputType = modelCall.getInputType(validator); + + // Check sizes match + if (descriptCols.size() != modelInputType.getFieldCount()) { + return Optional.of( + new ValidationException( + String.format( + "Number of descriptor input columns (%d) does not match model input size (%d)", + descriptCols.size(), modelInputType.getFieldCount()))); + } + + // Check types match + final RelDataType tableType = validator.getValidatedNodeType(callBinding.operand(0)); + final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher(); + for (int i = 0; i < descriptCols.size(); i++) { + SqlIdentifier columnName = (SqlIdentifier) descriptCols.get(i); + String descriptColName = + columnName.isSimple() + ? columnName.getSimple() + : Util.last(columnName.names); + int index = matcher.indexOf(tableType.getFieldNames(), descriptColName); + RelDataType sourceType = tableType.getFieldList().get(index).getType(); + RelDataType targetType = modelInputType.getFieldList().get(i).getType(); + + LogicalType sourceLogicalType = toLogicalType(sourceType); + LogicalType targetLogicalType = toLogicalType(targetType); + + if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType, targetLogicalType)) { + return Optional.of( + new ValidationException( + String.format( + "Descriptor column type %s cannot be assigned to model input type %s at position %d", + sourceLogicalType, targetLogicalType, i))); + } + } + + return Optional.empty(); + } + + private static Optional checkConfig(SqlCallBinding callBinding) { + if (callBinding.getOperandCount() < PARAM_NAMES.size()) { + return Optional.empty(); + } + + SqlNode configNode = callBinding.operand(3); + if (!configNode.getKind().equals(SqlKind.MAP_VALUE_CONSTRUCTOR)) { + return Optional.of(new ValidationException("Config param should be a MAP.")); + } + + // Map operands can only be SqlCharStringLiteral or cast of SqlCharStringLiteral + SqlCall mapCall = (SqlCall) configNode; + for (int i = 0; i < mapCall.operandCount(); i++) { + SqlNode operand = mapCall.operand(i); + if (operand instanceof SqlCharStringLiteral) { + continue; + } + if (operand.getKind().equals(SqlKind.CAST)) { + SqlCall castCall = (SqlCall) operand; + if (castCall.operand(0) instanceof SqlCharStringLiteral) { + continue; + } + } + return Optional.of( + new ValidationException( + String.format( + "ML_PREDICT config param can only be a MAP of string literals. The item at position %d is %s.", + i, operand))); + } + + return Optional.empty(); } } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLTableFunction.java index aa8c46e7ab516..986385c6c9fab 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLTableFunction.java @@ -26,7 +26,6 @@ import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperatorBinding; -import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.SqlTableFunction; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlOperandMetadata; @@ -70,18 +69,14 @@ public void validateCall( final List operandList = call.getOperandList(); // ML table function should take only one table as input and use descriptor to reference - // columns in the table. The scope for descriptor validation should be the input table. - // Since the input table will be rewritten as select query. We get the select query's scope - // and use it for descriptor validation. - SqlValidatorScope selectScope = null; + // columns in the table. The scope for descriptor validation should be the input table which + // is also an operand of the call. We defer the validation of the descriptor since + // validation here will quality the descriptor columns to be NOT simple name which + // complicates checks in later stages. We validate the descriptor columns appear in table + // column in SqlOperandMetadata. boolean foundSelect = false; for (SqlNode operand : operandList) { if (operand.getKind().equals(SqlKind.DESCRIPTOR)) { - if (selectScope == null) { - throw new ValidationException(TABLE_INPUT_ERROR); - } - // Set scope to table when validating descriptor columns - operand.validate(validator, selectScope); continue; } if (operand.getKind().equals(SqlKind.SET_SEMANTICS_TABLE)) { @@ -90,7 +85,6 @@ public void validateCall( throw new ValidationException(TABLE_INPUT_ERROR); } foundSelect = true; - selectScope = validator.getSelectScope((SqlSelect) operand); } if (operand.getKind().equals(SqlKind.SELECT)) { @@ -98,7 +92,6 @@ public void validateCall( throw new ValidationException(TABLE_INPUT_ERROR); } foundSelect = true; - selectScope = validator.getSelectScope((SqlSelect) operand); } operand.validate(validator, scope); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java index 279541eec739f..42e381a606290 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java @@ -19,16 +19,30 @@ package org.apache.flink.table.planner.functions.utils; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rel.type.RelDataTypeFieldImpl; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlUtil; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.sql.validate.SqlNameMatcher; +import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.util.Pair; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static org.apache.calcite.util.Static.RESOURCE; /** Utility methods related to SQL validation. */ public class SqlValidatorUtils { @@ -49,6 +63,74 @@ public static void adjustTypeForMapConstructor( } } + public static boolean throwValidationSignatureErrorOrReturnFalse( + SqlCallBinding callBinding, boolean throwOnFailure) { + if (throwOnFailure) { + throw callBinding.newValidationSignatureError(); + } else { + return false; + } + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + public static boolean throwExceptionOrReturnFalse( + Optional e, boolean throwOnFailure) { + if (e.isPresent()) { + if (throwOnFailure) { + throw e.get(); + } else { + return false; + } + } else { + return true; + } + } + + /** + * Checks whether the heading operands are in the form {@code (ROW, DESCRIPTOR, DESCRIPTOR ..., + * other params)}, returning whether successful, and throwing if any columns are not found. + * + * @param callBinding The call binding + * @param descriptorLocations position of the descriptor operands + * @return true if validation passes; throws if any columns are not found + */ + public static boolean checkTableAndDescriptorOperands( + SqlCallBinding callBinding, Integer... descriptorLocations) { + final SqlNode operand0 = callBinding.operand(0); + final SqlValidator validator = callBinding.getValidator(); + final RelDataType type = validator.getValidatedNodeType(operand0); + if (type.getSqlTypeName() != SqlTypeName.ROW) { + return false; + } + for (Integer location : descriptorLocations) { + final SqlNode operand = callBinding.operand(location); + if (operand.getKind() != SqlKind.DESCRIPTOR) { + return false; + } + validateColumnNames( + validator, type.getFieldNames(), ((SqlCall) operand).getOperandList()); + } + return true; + } + + private static void validateColumnNames( + SqlValidator validator, List fieldNames, List columnNames) { + final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher(); + for (SqlNode columnName : columnNames) { + SqlIdentifier columnIdentifier = (SqlIdentifier) columnName; + if (!columnIdentifier.isSimple()) { + throw SqlUtil.newContextException( + columnName.getParserPosition(), RESOURCE.aliasMustBeSimpleIdentifier()); + } + + final String name = columnIdentifier.getSimple(); + if (matcher.indexOf(fieldNames, name) < 0) { + throw SqlUtil.newContextException( + columnName.getParserPosition(), RESOURCE.unknownIdentifier(name)); + } + } + } + /** * When the element element does not equal with the component type, making explicit casting. * @@ -75,6 +157,34 @@ private static void adjustTypeForMultisetConstructor( } } + /** + * Make output field names unique from input field names by appending index. For example, Input + * has field names {@code a, b, c} and output has field names {@code b, c, d}. After calling + * this function, new output field names will be {@code b0, c0, d}. Duplicate names are not + * checked inside input and output itself. + * + * @param input Input fields + * @param output Output fields + * @return + */ + public static List makeOutputUnique( + List input, List output) { + final Set inputFieldNames = new HashSet<>(); + for (RelDataTypeField field : input) { + inputFieldNames.add(field.getName()); + } + + List result = new ArrayList<>(); + for (RelDataTypeField field : output) { + String fieldName = field.getName(); + if (inputFieldNames.contains(fieldName)) { + fieldName += "0"; // Append index to make it unique + } + result.add(new RelDataTypeFieldImpl(fieldName, field.getIndex(), field.getType())); + } + return result; + } + private static SqlNode castTo(SqlNode node, RelDataType type) { return SqlStdOperatorTable.CAST.createCall( SqlParserPos.ZERO, diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java new file mode 100644 index 0000000000000..20bf939a807dd --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java @@ -0,0 +1,388 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.stream.sql; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.planner.utils.TableTestBase; +import org.apache.flink.table.planner.utils.TableTestUtil; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for model table value function. */ +public class MLPredictTableFunctionTest extends TableTestBase { + + private TableTestUtil util; + + @BeforeEach + public void setup() { + util = streamTestUtil(TableConfig.getDefault()); + + // Create test table + util.tableEnv() + .executeSql( + "CREATE TABLE MyTable (\n" + + " a INT,\n" + + " b BIGINT,\n" + + " c STRING,\n" + + " d DECIMAL(10, 3),\n" + + " rowtime TIMESTAMP(3),\n" + + " proctime as PROCTIME(),\n" + + " WATERMARK FOR rowtime AS rowtime - INTERVAL '1' SECOND\n" + + ") with (\n" + + " 'connector' = 'values'\n" + + ")"); + + // Create test model + util.tableEnv() + .executeSql( + "CREATE MODEL MyModel\n" + + "INPUT (a INT, b BIGINT)\n" + + "OUTPUT(e STRING, f ARRAY)\n" + + "with (\n" + + " 'provider' = 'openai'\n" + + ")"); + } + + @Test + public void testNamedArguments() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(INPUT => TABLE MyTable, " + + "MODEL => MODEL MyModel, " + + "ARGS => DESCRIPTOR(a, b)))"; + assertReachesRelConverter(sql); + } + + @Test + public void testOptionalNamedArguments() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(INPUT => TABLE MyTable, " + + "MODEL => MODEL MyModel, " + + "ARGS => DESCRIPTOR(a, b)," + + "CONFIG => MAP['key', 'value']))"; + assertReachesRelConverter(sql); + } + + @Test + public void testSimple() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b)))"; + assertReachesRelConverter(sql); + } + + @Test + public void testConfigWithCast() { + // 'async' and 'timeout' in the map are both cast to VARCHAR(7) + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true', 'timeout', '100s']))"; + assertReachesRelConverter(sql); + } + + @Test + public void testTooFewArguments() { + String sql = "SELECT *\n" + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .hasMessageContaining( + "Invalid number of arguments to function 'ML_PREDICT'. Was expecting 3 arguments"); + } + + @Test + public void testTooManyArguments() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['key', 'value'], 'arg0'))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .hasMessageContaining( + "Invalid number of arguments to function 'ML_PREDICT'. Was expecting 3 arguments"); + } + + @Test + public void testNonExistModel() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL NonExistModel, DESCRIPTOR(a, b), MAP['key', 'value'], 'arg0'))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Object 'NonExistModel' not found"); + } + + @Test + public void testConflictOutputColumnName() { + util.tableEnv() + .executeSql( + "CREATE MODEL ConflictModel\n" + + "INPUT (a INT, b BIGINT)\n" + + "OUTPUT(c STRING, d ARRAY)\n" + + "with (\n" + + " 'provider' = 'openai'\n" + + ")"); + + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL ConflictModel, DESCRIPTOR(a, b)))"; + assertReachesRelConverter(sql); + } + + @Test + public void testMissingModelParam() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, DESCRIPTOR(a, b), DESCRIPTOR(a, b)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining( + "SQL validation failed. Second operand must be a model identifier."); + } + + @Test + public void testMismatchInputSize() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b, c)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining( + "SQL validation failed. Number of descriptor input columns (3) does not match model input size (2)"); + } + + @Test + public void testNonExistColumn() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(no_col)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Unknown identifier 'no_col'"); + } + + @Test + public void testNonSimpleColumn() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(MyTable.a)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Table or column alias must be a simple identifier"); + } + + @ParameterizedTest + @MethodSource("compatibleTypeProvider") + public void testCompatibleInputTypes(String tableType, String modelType) { + // Create test table with dynamic type + util.tableEnv() + .executeSql( + String.format( + "CREATE TABLE TypeTable (\n" + + " col %s\n" + + ") with (\n" + + " 'connector' = 'values'\n" + + ")", + tableType)); + + // Create test model with dynamic type + util.tableEnv() + .executeSql( + String.format( + "CREATE MODEL TypeModel\n" + + "INPUT (x %s)\n" + + "OUTPUT (res STRING)\n" + + "with (\n" + + " 'provider' = 'openai'\n" + + ")", + modelType)); + + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE TypeTable, MODEL TypeModel, DESCRIPTOR(col)))"; + assertReachesRelConverter(sql); + } + + @ParameterizedTest + @MethodSource("incompatibleTypeProvider") + public void testIncompatibleInputTypes(String tableType, String modelType) { + // Create test table with dynamic type + util.tableEnv() + .executeSql( + String.format( + "CREATE TABLE TypeTable (\n" + + " col %s\n" + + ") with (\n" + + " 'connector' = 'values'\n" + + ")", + tableType)); + + // Create test model with dynamic type + util.tableEnv() + .executeSql( + String.format( + "CREATE MODEL TypeModel\n" + + "INPUT (x %s)\n" + + "OUTPUT (res STRING)\n" + + "with (\n" + + " 'provider' = 'openai'\n" + + ")", + modelType)); + + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE TypeTable, MODEL TypeModel, DESCRIPTOR(col)))"; + + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining("cannot be assigned to model input type"); + } + + @Test + public void testWrongConfigType() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', true]))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining( + "ML_PREDICT config param can only be a MAP of string literals. The item at position 1 is TRUE."); + } + + private void assertReachesRelConverter(String sql) { + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .hasMessageContaining("while converting MODEL"); + } + + private static Stream compatibleTypeProvider() { + return Stream.of( + // NOT NULL to NULLABLE type + Arguments.of("STRING NOT NULL", "STRING"), + + // Exact matches - primitive types + Arguments.of("BOOLEAN", "BOOLEAN"), + Arguments.of("TINYINT", "TINYINT"), + Arguments.of("SMALLINT", "SMALLINT"), + Arguments.of("INT", "INT"), + Arguments.of("BIGINT", "BIGINT"), + Arguments.of("FLOAT", "FLOAT"), + Arguments.of("DOUBLE", "DOUBLE"), + Arguments.of("DECIMAL(10,2)", "DECIMAL(10,2)"), + Arguments.of("STRING", "STRING"), + Arguments.of("BINARY(10)", "BINARY(10)"), + Arguments.of("VARBINARY(10)", "VARBINARY(10)"), + Arguments.of("DATE", "DATE"), + Arguments.of("TIME(3)", "TIME(3)"), + Arguments.of("TIMESTAMP(3)", "TIMESTAMP(3)"), + Arguments.of("TIMESTAMP_LTZ(3)", "TIMESTAMP_LTZ(3)"), + + // Numeric type promotions + Arguments.of("TINYINT", "SMALLINT"), + Arguments.of("SMALLINT", "INT"), + Arguments.of("INT", "BIGINT"), + Arguments.of("FLOAT", "DOUBLE"), + Arguments.of("DECIMAL(5,2)", "DECIMAL(10,2)"), + Arguments.of( + "DECIMAL(10,2)", "DECIMAL(5,2)"), // This is also allowed, is this a bug? + + // String type compatibility + Arguments.of("CHAR(10)", "STRING"), + Arguments.of("VARCHAR(20)", "STRING"), + + // Temporal types + Arguments.of("TIMESTAMP(3)", "TIMESTAMP(3)"), + Arguments.of("DATE", "DATE"), + Arguments.of("TIME(3)", "TIME(3)"), + + // Array types + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("ARRAY", "ARRAY"), + + // Map types + Arguments.of("MAP", "MAP"), + Arguments.of("MAP", "MAP"), + Arguments.of("MAP>", "MAP>"), + + // Row types + Arguments.of("ROW", "ROW"), + Arguments.of( + "ROW", "ROW"), // Different field name + Arguments.of( + "ROW>", "ROW>"), + Arguments.of( + "ROW>", + "ROW>"), + + // Nested complex types + Arguments.of( + "ROW, b MAP>>", + "ROW, b MAP>>"), + Arguments.of( + "MAP>>", + "MAP>>")); + } + + private static Stream incompatibleTypeProvider() { + return Stream.of( + // NULLABLE to NOT NULL type + Arguments.of("STRING", "STRING NOT NULL"), + + // Incompatible primitive types + Arguments.of("BOOLEAN", "INT"), + Arguments.of("STRING", "INT"), + Arguments.of("INT", "STRING"), + Arguments.of("TIMESTAMP(3)", "INT"), + Arguments.of("DATE", "TIMESTAMP(3)"), + Arguments.of("BINARY(10)", "STRING"), + + // Incompatible numeric types (wrong direction) + Arguments.of("BIGINT", "INT"), // Cannot downcast + Arguments.of("DOUBLE", "FLOAT"), // Cannot downcast + + // Incompatible array types + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("INT", "ARRAY"), + + // Incompatible map types + Arguments.of("MAP", "MAP"), // Key type mismatch + Arguments.of("MAP", "MAP"), // Value type mismatch + Arguments.of("MAP", "MAP"), // Cannot downcast value + Arguments.of("MAP", "MAP"), // Cannot downcast key + + // Incompatible row types + Arguments.of("ROW", "ROW"), // Field type mismatch + Arguments.of("ROW", "ROW"), // Field count mismatch + + // Incompatible nested types + Arguments.of( + "ROW, b MAP>", + "ROW, b MAP>"), + Arguments.of("MAP>", "MAP>"), + Arguments.of("ARRAY>", "ARRAY>")); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ModelTableFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ModelTableFunctionTest.java deleted file mode 100644 index 74e5a284573b9..0000000000000 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ModelTableFunctionTest.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.planner.plan.stream.sql; - -import org.apache.flink.table.api.TableConfig; -import org.apache.flink.table.api.ValidationException; -import org.apache.flink.table.planner.utils.TableTestBase; -import org.apache.flink.table.planner.utils.TableTestUtil; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** Tests for model table-valued function. */ -public class ModelTableFunctionTest extends TableTestBase { - - private TableTestUtil util; - - @BeforeEach - public void setup() { - util = streamTestUtil(TableConfig.getDefault()); - - // Create test table - util.tableEnv() - .executeSql( - "CREATE TABLE MyTable (\n" - + " a INT,\n" - + " b BIGINT,\n" - + " c STRING,\n" - + " d DECIMAL(10, 3),\n" - + " rowtime TIMESTAMP(3),\n" - + " proctime as PROCTIME(),\n" - + " WATERMARK FOR rowtime AS rowtime - INTERVAL '1' SECOND\n" - + ") with (\n" - + " 'connector' = 'values'\n" - + ")"); - - // Create test model - util.tableEnv() - .executeSql( - "CREATE MODEL MyModel\n" - + "INPUT (a INT, b BIGINT)\n" - + "OUTPUT(c STRING, D ARRAY)\n" - + "with (\n" - + " 'provider' = 'openai'\n" - + ")"); - } - - @Test - public void testMLPredictTVFWithNamedArguments() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(INPUT => TABLE MyTable, " - + "MODEL => MODEL MyModel, " - + "ARGS => DESCRIPTOR(a, b)))"; - assertReachesRelConverter(sql); - } - - @Test - public void testMLPredictTVFWithOptionalNamedArguments() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(INPUT => TABLE MyTable, " - + "MODEL => MODEL MyModel, " - + "ARGS => DESCRIPTOR(a, b)," - + "CONFIG => MAP['key', 'value']))"; - assertReachesRelConverter(sql); - } - - @Test - public void testMLPredictTVF() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b)))"; - assertReachesRelConverter(sql); - } - - @Test - public void testMLPredictTVFWithTooManyArguments() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['key', 'value'], 'arg0'))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .hasMessageContaining( - "Invalid number of arguments to function 'ML_PREDICT'. Was expecting 3 arguments"); - } - - @Test - public void testMLPredictTVFWithNonExistModel() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL NonExistModel, DESCRIPTOR(a, b), MAP['key', 'value'], 'arg0'))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasMessageContaining("Object 'NonExistModel' not found"); - } - - @Test - public void testMLPredictTVFWithNonExistColumn() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(no_col)))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasMessageContaining("Column 'no_col' not found in any table"); - } - - private void assertReachesRelConverter(String sql) { - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .hasMessageContaining("while converting MODEL"); - } -}