Skip to content

[FLINK-37780][5/N] predict sql function type inference and validation #26583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@

package org.apache.flink.table.planner.functions.sql;

import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;

import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList;

import org.apache.calcite.sql.SqlCallBinding;
@@ -51,14 +53,16 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata {

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
if (!checkTableAndDescriptorOperands(callBinding, 1)) {
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) {
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}
if (!checkIntervalOperands(callBinding, 2)) {
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}
// check time attribute
return throwExceptionOrReturnFalse(
return SqlValidatorUtils.throwExceptionOrReturnFalse(
checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure);
}

Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@

package org.apache.flink.table.planner.functions.sql;

import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;

import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList;

import org.apache.calcite.sql.SqlCallBinding;
@@ -51,14 +53,16 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata {

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
if (!checkTableAndDescriptorOperands(callBinding, 1)) {
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) {
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}
if (!checkIntervalOperands(callBinding, 2)) {
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}
// check time attribute
return throwExceptionOrReturnFalse(
return SqlValidatorUtils.throwExceptionOrReturnFalse(
checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure);
}

Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@

package org.apache.flink.table.planner.functions.sql;

import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;

import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList;
import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableMap;

@@ -73,18 +75,20 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata {

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
if (!checkTableAndDescriptorOperands(callBinding, 1)) {
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) {
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}

final SqlValidator validator = callBinding.getValidator();
final SqlNode operand2 = callBinding.operand(2);
final RelDataType type2 = validator.getValidatedNodeType(operand2);
if (!SqlTypeUtil.isInterval(type2)) {
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}

return throwExceptionOrReturnFalse(
return SqlValidatorUtils.throwExceptionOrReturnFalse(
checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure);
}

Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@

package org.apache.flink.table.planner.functions.sql;

import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;

import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList;

import org.apache.calcite.sql.SqlCallBinding;
@@ -49,14 +51,16 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata {
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
// There should only be three operands, and number of operands are checked before
// this call.
if (!checkTableAndDescriptorOperands(callBinding, 1)) {
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) {
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}
if (!checkIntervalOperands(callBinding, 2)) {
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}
// check time attribute
return throwExceptionOrReturnFalse(
return SqlValidatorUtils.throwExceptionOrReturnFalse(
checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure);
}

Original file line number Diff line number Diff line change
@@ -34,7 +34,6 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
@@ -48,7 +47,6 @@
import java.util.List;
import java.util.Optional;

import static org.apache.calcite.util.Static.RESOURCE;
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.canBeTimeAttributeType;

/**
@@ -194,56 +192,6 @@ public boolean isOptional(int i) {
return i > getOperandCountRange().getMin() && i <= getOperandCountRange().getMax();
}

boolean throwValidationSignatureErrorOrReturnFalse(
SqlCallBinding callBinding, boolean throwOnFailure) {
if (throwOnFailure) {
throw callBinding.newValidationSignatureError();
} else {
return false;
}
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
boolean throwExceptionOrReturnFalse(Optional<RuntimeException> e, boolean throwOnFailure) {
if (e.isPresent()) {
if (throwOnFailure) {
throw e.get();
} else {
return false;
}
} else {
return true;
}
}

/**
* Checks whether the heading operands are in the form {@code (ROW, DESCRIPTOR, DESCRIPTOR
* ..., other params)}, returning whether successful, and throwing if any columns are not
* found.
*
* @param callBinding The call binding
* @param descriptorCount The number of descriptors following the first operand (e.g. the
* table)
* @return true if validation passes; throws if any columns are not found
*/
boolean checkTableAndDescriptorOperands(SqlCallBinding callBinding, int descriptorCount) {
final SqlNode operand0 = callBinding.operand(0);
final SqlValidator validator = callBinding.getValidator();
final RelDataType type = validator.getValidatedNodeType(operand0);
if (type.getSqlTypeName() != SqlTypeName.ROW) {
return false;
}
for (int i = 1; i < descriptorCount + 1; i++) {
final SqlNode operand = callBinding.operand(i);
if (operand.getKind() != SqlKind.DESCRIPTOR) {
return false;
}
validateColumnNames(
validator, type.getFieldNames(), ((SqlCall) operand).getOperandList());
}
return true;
}

/**
* Checks whether the type that the operand of time col descriptor refers to is valid.
*
@@ -310,17 +258,5 @@ boolean checkIntervalOperands(SqlCallBinding callBinding, int startPos) {
}
return true;
}

void validateColumnNames(
SqlValidator validator, List<String> fieldNames, List<SqlNode> columnNames) {
final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher();
for (SqlNode columnName : columnNames) {
final String name = ((SqlIdentifier) columnName).getSimple();
if (matcher.indexOf(fieldNames, name) < 0) {
throw SqlUtil.newContextException(
columnName.getParserPosition(), RESOURCE.unknownIdentifier(name));
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -17,18 +17,35 @@

package org.apache.flink.table.planner.functions.sql.ml;

import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlCharStringLiteral;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlModelCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlNameMatcher;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.util.Util;

import java.util.Collections;
import java.util.List;
import java.util.Optional;

import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;

/**
* SqlMlPredictTableFunction implements an operator for prediction.
@@ -51,8 +68,8 @@ public SqlMLPredictTableFunction() {
/**
* {@inheritDoc}
*
* <p>Overrides because the first parameter of table-value function windowing is an explicit
* TABLE parameter, which is not scalar.
* <p>Overrides because the first parameter of ML table-value function is an explicit TABLE
* parameter, which is not scalar.
*/
@Override
public boolean argumentMustBeScalar(int ordinal) {
@@ -61,9 +78,18 @@ public boolean argumentMustBeScalar(int ordinal) {

@Override
protected RelDataType inferRowType(SqlOperatorBinding opBinding) {
// TODO: FLINK-37780 output type based on table schema and model output schema
// model output schema to be available after integrated with SqlExplicitModelCall
return opBinding.getOperandType(1);
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
final RelDataType inputRowType = opBinding.getOperandType(0);
final RelDataType modelOutputRowType = opBinding.getOperandType(1);

return typeFactory
.builder()
.kind(inputRowType.getStructKind())
.addAll(inputRowType.getFieldList())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a look at SystemOutputStrategy#inferType.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean we need to make field names unique? I'm following SqlWindowTableFunction which doesn't check if input table column has window_start etc. I'm on the fence here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Otherwise, you will get an error here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.addAll(
SqlValidatorUtils.makeOutputUnique(
inputRowType.getFieldList(), modelOutputRowType.getFieldList()))
.build();
}

private static class PredictOperandMetadata implements SqlOperandMetadata {
@@ -87,21 +113,25 @@ public List<String> paramNames() {

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
// TODO: FLINK-37780 Check operand types after integrated with SqlExplicitModelCall in
// validator
return false;
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 2)) {
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}

if (!SqlValidatorUtils.throwExceptionOrReturnFalse(
checkModelSignature(callBinding), throwOnFailure)) {
return false;
}

return SqlValidatorUtils.throwExceptionOrReturnFalse(
checkConfig(callBinding), throwOnFailure);
}

@Override
public SqlOperandCountRange getOperandCountRange() {
return SqlOperandCountRanges.between(MANDATORY_PARAM_NAMES.size(), PARAM_NAMES.size());
}

@Override
public Consistency getConsistency() {
return Consistency.NONE;
}

@Override
public boolean isOptional(int i) {
return i > getOperandCountRange().getMin() && i <= getOperandCountRange().getMax();
@@ -110,7 +140,94 @@ public boolean isOptional(int i) {
@Override
public String getAllowedSignatures(SqlOperator op, String opName) {
return opName
+ "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]]";
+ "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]])";
}

private static Optional<RuntimeException> checkModelSignature(SqlCallBinding callBinding) {
SqlValidator validator = callBinding.getValidator();

// Check second operand is SqlModelCall
if (!(callBinding.operand(1) instanceof SqlModelCall)) {
return Optional.of(
new ValidationException("Second operand must be a model identifier."));
}

// Get descriptor columns
SqlCall descriptorCall = (SqlCall) callBinding.operand(2);
List<SqlNode> descriptCols = descriptorCall.getOperandList();

// Get model input size
SqlModelCall modelCall = (SqlModelCall) callBinding.operand(1);
RelDataType modelInputType = modelCall.getInputType(validator);

// Check sizes match
if (descriptCols.size() != modelInputType.getFieldCount()) {
return Optional.of(
new ValidationException(
String.format(
"Number of descriptor input columns (%d) does not match model input size (%d)",
descriptCols.size(), modelInputType.getFieldCount())));
}

// Check types match
final RelDataType tableType = validator.getValidatedNodeType(callBinding.operand(0));
final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher();
for (int i = 0; i < descriptCols.size(); i++) {
SqlIdentifier columnName = (SqlIdentifier) descriptCols.get(i);
String descriptColName =
columnName.isSimple()
? columnName.getSimple()
: Util.last(columnName.names);
int index = matcher.indexOf(tableType.getFieldNames(), descriptColName);
RelDataType sourceType = tableType.getFieldList().get(index).getType();
RelDataType targetType = modelInputType.getFieldList().get(i).getType();

LogicalType sourceLogicalType = toLogicalType(sourceType);
LogicalType targetLogicalType = toLogicalType(targetType);

if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType, targetLogicalType)) {
return Optional.of(
new ValidationException(
String.format(
"Descriptor column type %s cannot be assigned to model input type %s at position %d",
sourceLogicalType, targetLogicalType, i)));
}
}

return Optional.empty();
}

private static Optional<RuntimeException> checkConfig(SqlCallBinding callBinding) {
if (callBinding.getOperandCount() < PARAM_NAMES.size()) {
return Optional.empty();
}

SqlNode configNode = callBinding.operand(3);
if (!configNode.getKind().equals(SqlKind.MAP_VALUE_CONSTRUCTOR)) {
return Optional.of(new ValidationException("Config param should be a MAP."));
}

// Map operands can only be SqlCharStringLiteral or cast of SqlCharStringLiteral
SqlCall mapCall = (SqlCall) configNode;
for (int i = 0; i < mapCall.operandCount(); i++) {
SqlNode operand = mapCall.operand(i);
if (operand instanceof SqlCharStringLiteral) {
continue;
}
if (operand.getKind().equals(SqlKind.CAST)) {
SqlCall castCall = (SqlCall) operand;
if (castCall.operand(0) instanceof SqlCharStringLiteral) {
continue;
}
}
return Optional.of(
new ValidationException(
String.format(
"ML_PREDICT config param can only be a MAP of string literals. The item at position %d is %s.",
i, operand)));
}

return Optional.empty();
}
}
}
Original file line number Diff line number Diff line change
@@ -26,7 +26,6 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlTableFunction;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandMetadata;
@@ -70,18 +69,14 @@ public void validateCall(
final List<SqlNode> operandList = call.getOperandList();

// ML table function should take only one table as input and use descriptor to reference
// columns in the table. The scope for descriptor validation should be the input table.
// Since the input table will be rewritten as select query. We get the select query's scope
// and use it for descriptor validation.
SqlValidatorScope selectScope = null;
// columns in the table. The scope for descriptor validation should be the input table which
// is also an operand of the call. We defer the validation of the descriptor since
// validation here will quality the descriptor columns to be NOT simple name which
// complicates checks in later stages. We validate the descriptor columns appear in table
// column in SqlOperandMetadata.
boolean foundSelect = false;
for (SqlNode operand : operandList) {
if (operand.getKind().equals(SqlKind.DESCRIPTOR)) {
if (selectScope == null) {
throw new ValidationException(TABLE_INPUT_ERROR);
}
// Set scope to table when validating descriptor columns
operand.validate(validator, selectScope);
continue;
}
if (operand.getKind().equals(SqlKind.SET_SEMANTICS_TABLE)) {
@@ -90,15 +85,13 @@ public void validateCall(
throw new ValidationException(TABLE_INPUT_ERROR);
}
foundSelect = true;
selectScope = validator.getSelectScope((SqlSelect) operand);
}

if (operand.getKind().equals(SqlKind.SELECT)) {
if (foundSelect) {
throw new ValidationException(TABLE_INPUT_ERROR);
}
foundSelect = true;
selectScope = validator.getSelectScope((SqlSelect) operand);
}
operand.validate(validator, scope);
}
Original file line number Diff line number Diff line change
@@ -19,16 +19,30 @@
package org.apache.flink.table.planner.functions.utils;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlNameMatcher;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.util.Pair;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static org.apache.calcite.util.Static.RESOURCE;

/** Utility methods related to SQL validation. */
public class SqlValidatorUtils {
@@ -49,6 +63,74 @@ public static void adjustTypeForMapConstructor(
}
}

public static boolean throwValidationSignatureErrorOrReturnFalse(
SqlCallBinding callBinding, boolean throwOnFailure) {
if (throwOnFailure) {
throw callBinding.newValidationSignatureError();
} else {
return false;
}
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public static boolean throwExceptionOrReturnFalse(
Optional<RuntimeException> e, boolean throwOnFailure) {
if (e.isPresent()) {
if (throwOnFailure) {
throw e.get();
} else {
return false;
}
} else {
return true;
}
}

/**
* Checks whether the heading operands are in the form {@code (ROW, DESCRIPTOR, DESCRIPTOR ...,
* other params)}, returning whether successful, and throwing if any columns are not found.
*
* @param callBinding The call binding
* @param descriptorLocations position of the descriptor operands
* @return true if validation passes; throws if any columns are not found
*/
public static boolean checkTableAndDescriptorOperands(
SqlCallBinding callBinding, Integer... descriptorLocations) {
final SqlNode operand0 = callBinding.operand(0);
final SqlValidator validator = callBinding.getValidator();
final RelDataType type = validator.getValidatedNodeType(operand0);
if (type.getSqlTypeName() != SqlTypeName.ROW) {
return false;
}
for (Integer location : descriptorLocations) {
final SqlNode operand = callBinding.operand(location);
if (operand.getKind() != SqlKind.DESCRIPTOR) {
return false;
}
validateColumnNames(
validator, type.getFieldNames(), ((SqlCall) operand).getOperandList());
}
return true;
}

private static void validateColumnNames(
SqlValidator validator, List<String> fieldNames, List<SqlNode> columnNames) {
final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher();
for (SqlNode columnName : columnNames) {
SqlIdentifier columnIdentifier = (SqlIdentifier) columnName;
if (!columnIdentifier.isSimple()) {
throw SqlUtil.newContextException(
columnName.getParserPosition(), RESOURCE.aliasMustBeSimpleIdentifier());
}

final String name = columnIdentifier.getSimple();
if (matcher.indexOf(fieldNames, name) < 0) {
throw SqlUtil.newContextException(
columnName.getParserPosition(), RESOURCE.unknownIdentifier(name));
}
}
}

/**
* When the element element does not equal with the component type, making explicit casting.
*
@@ -75,6 +157,34 @@ private static void adjustTypeForMultisetConstructor(
}
}

/**
* Make output field names unique from input field names by appending index. For example, Input
* has field names {@code a, b, c} and output has field names {@code b, c, d}. After calling
* this function, new output field names will be {@code b0, c0, d}. Duplicate names are not
* checked inside input and output itself.
*
* @param input Input fields
* @param output Output fields
* @return
*/
public static List<RelDataTypeField> makeOutputUnique(
List<RelDataTypeField> input, List<RelDataTypeField> output) {
final Set<String> inputFieldNames = new HashSet<>();
for (RelDataTypeField field : input) {
inputFieldNames.add(field.getName());
}

List<RelDataTypeField> result = new ArrayList<>();
for (RelDataTypeField field : output) {
String fieldName = field.getName();
if (inputFieldNames.contains(fieldName)) {
fieldName += "0"; // Append index to make it unique
}
result.add(new RelDataTypeFieldImpl(fieldName, field.getIndex(), field.getType()));
}
return result;
}

private static SqlNode castTo(SqlNode node, RelDataType type) {
return SqlStdOperatorTable.CAST.createCall(
SqlParserPos.ZERO,

Large diffs are not rendered by default.

This file was deleted.