Skip to content

Commit c5b378f

Browse files
committed
[FLINK-37780] predict sql function type inference and validation
1 parent cc11668 commit c5b378f

File tree

9 files changed

+600
-220
lines changed

9 files changed

+600
-220
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlCumulateTableFunction.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.flink.table.planner.functions.sql;
1919

20+
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
21+
2022
import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList;
2123

2224
import org.apache.calcite.sql.SqlCallBinding;
@@ -51,14 +53,16 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata {
5153

5254
@Override
5355
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
54-
if (!checkTableAndDescriptorOperands(callBinding, 1)) {
55-
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
56+
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1, 1)) {
57+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
58+
callBinding, throwOnFailure);
5659
}
5760
if (!checkIntervalOperands(callBinding, 2)) {
58-
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
61+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
62+
callBinding, throwOnFailure);
5963
}
6064
// check time attribute
61-
return throwExceptionOrReturnFalse(
65+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
6266
checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure);
6367
}
6468

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlHopTableFunction.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.flink.table.planner.functions.sql;
1919

20+
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
21+
2022
import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList;
2123

2224
import org.apache.calcite.sql.SqlCallBinding;
@@ -51,14 +53,16 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata {
5153

5254
@Override
5355
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
54-
if (!checkTableAndDescriptorOperands(callBinding, 1)) {
55-
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
56+
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1, 1)) {
57+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
58+
callBinding, throwOnFailure);
5659
}
5760
if (!checkIntervalOperands(callBinding, 2)) {
58-
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
61+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
62+
callBinding, throwOnFailure);
5963
}
6064
// check time attribute
61-
return throwExceptionOrReturnFalse(
65+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
6266
checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure);
6367
}
6468

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlSessionTableFunction.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.flink.table.planner.functions.sql;
1919

20+
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
21+
2022
import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList;
2123
import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableMap;
2224

@@ -73,18 +75,20 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata {
7375

7476
@Override
7577
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
76-
if (!checkTableAndDescriptorOperands(callBinding, 1)) {
77-
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
78+
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1, 1)) {
79+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
80+
callBinding, throwOnFailure);
7881
}
7982

8083
final SqlValidator validator = callBinding.getValidator();
8184
final SqlNode operand2 = callBinding.operand(2);
8285
final RelDataType type2 = validator.getValidatedNodeType(operand2);
8386
if (!SqlTypeUtil.isInterval(type2)) {
84-
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
87+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
88+
callBinding, throwOnFailure);
8589
}
8690

87-
return throwExceptionOrReturnFalse(
91+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
8892
checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure);
8993
}
9094

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlTumbleTableFunction.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.flink.table.planner.functions.sql;
1919

20+
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
21+
2022
import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList;
2123

2224
import org.apache.calcite.sql.SqlCallBinding;
@@ -49,14 +51,16 @@ private static class OperandMetadataImpl extends AbstractOperandMetadata {
4951
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
5052
// There should only be three operands, and number of operands are checked before
5153
// this call.
52-
if (!checkTableAndDescriptorOperands(callBinding, 1)) {
53-
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
54+
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1, 1)) {
55+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
56+
callBinding, throwOnFailure);
5457
}
5558
if (!checkIntervalOperands(callBinding, 2)) {
56-
return throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
59+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
60+
callBinding, throwOnFailure);
5761
}
5862
// check time attribute
59-
return throwExceptionOrReturnFalse(
63+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
6064
checkTimeColumnDescriptorOperand(callBinding, 1), throwOnFailure);
6165
}
6266

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlWindowTableFunction.java

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import org.apache.calcite.sql.SqlNode;
3535
import org.apache.calcite.sql.SqlOperandCountRange;
3636
import org.apache.calcite.sql.SqlOperatorBinding;
37-
import org.apache.calcite.sql.SqlUtil;
3837
import org.apache.calcite.sql.type.SqlOperandCountRanges;
3938
import org.apache.calcite.sql.type.SqlOperandMetadata;
4039
import org.apache.calcite.sql.type.SqlReturnTypeInference;
@@ -48,7 +47,6 @@
4847
import java.util.List;
4948
import java.util.Optional;
5049

51-
import static org.apache.calcite.util.Static.RESOURCE;
5250
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.canBeTimeAttributeType;
5351

5452
/**
@@ -194,56 +192,6 @@ public boolean isOptional(int i) {
194192
return i > getOperandCountRange().getMin() && i <= getOperandCountRange().getMax();
195193
}
196194

197-
boolean throwValidationSignatureErrorOrReturnFalse(
198-
SqlCallBinding callBinding, boolean throwOnFailure) {
199-
if (throwOnFailure) {
200-
throw callBinding.newValidationSignatureError();
201-
} else {
202-
return false;
203-
}
204-
}
205-
206-
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
207-
boolean throwExceptionOrReturnFalse(Optional<RuntimeException> e, boolean throwOnFailure) {
208-
if (e.isPresent()) {
209-
if (throwOnFailure) {
210-
throw e.get();
211-
} else {
212-
return false;
213-
}
214-
} else {
215-
return true;
216-
}
217-
}
218-
219-
/**
220-
* Checks whether the heading operands are in the form {@code (ROW, DESCRIPTOR, DESCRIPTOR
221-
* ..., other params)}, returning whether successful, and throwing if any columns are not
222-
* found.
223-
*
224-
* @param callBinding The call binding
225-
* @param descriptorCount The number of descriptors following the first operand (e.g. the
226-
* table)
227-
* @return true if validation passes; throws if any columns are not found
228-
*/
229-
boolean checkTableAndDescriptorOperands(SqlCallBinding callBinding, int descriptorCount) {
230-
final SqlNode operand0 = callBinding.operand(0);
231-
final SqlValidator validator = callBinding.getValidator();
232-
final RelDataType type = validator.getValidatedNodeType(operand0);
233-
if (type.getSqlTypeName() != SqlTypeName.ROW) {
234-
return false;
235-
}
236-
for (int i = 1; i < descriptorCount + 1; i++) {
237-
final SqlNode operand = callBinding.operand(i);
238-
if (operand.getKind() != SqlKind.DESCRIPTOR) {
239-
return false;
240-
}
241-
validateColumnNames(
242-
validator, type.getFieldNames(), ((SqlCall) operand).getOperandList());
243-
}
244-
return true;
245-
}
246-
247195
/**
248196
* Checks whether the type that the operand of time col descriptor refers to is valid.
249197
*
@@ -310,17 +258,5 @@ boolean checkIntervalOperands(SqlCallBinding callBinding, int startPos) {
310258
}
311259
return true;
312260
}
313-
314-
void validateColumnNames(
315-
SqlValidator validator, List<String> fieldNames, List<SqlNode> columnNames) {
316-
final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher();
317-
for (SqlNode columnName : columnNames) {
318-
final String name = ((SqlIdentifier) columnName).getSimple();
319-
if (matcher.indexOf(fieldNames, name) < 0) {
320-
throw SqlUtil.newContextException(
321-
columnName.getParserPosition(), RESOURCE.unknownIdentifier(name));
322-
}
323-
}
324-
}
325261
}
326262
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLPredictTableFunction.java

Lines changed: 126 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,35 @@
1717

1818
package org.apache.flink.table.planner.functions.sql.ml;
1919

20+
import org.apache.flink.table.api.ValidationException;
21+
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
22+
import org.apache.flink.table.types.logical.LogicalType;
23+
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
24+
2025
import org.apache.calcite.rel.type.RelDataType;
2126
import org.apache.calcite.rel.type.RelDataTypeFactory;
27+
import org.apache.calcite.sql.SqlCall;
2228
import org.apache.calcite.sql.SqlCallBinding;
29+
import org.apache.calcite.sql.SqlCharStringLiteral;
30+
import org.apache.calcite.sql.SqlIdentifier;
31+
import org.apache.calcite.sql.SqlKind;
32+
import org.apache.calcite.sql.SqlModelCall;
33+
import org.apache.calcite.sql.SqlNode;
2334
import org.apache.calcite.sql.SqlOperandCountRange;
2435
import org.apache.calcite.sql.SqlOperator;
2536
import org.apache.calcite.sql.SqlOperatorBinding;
2637
import org.apache.calcite.sql.type.SqlOperandCountRanges;
2738
import org.apache.calcite.sql.type.SqlOperandMetadata;
2839
import org.apache.calcite.sql.type.SqlTypeName;
40+
import org.apache.calcite.sql.validate.SqlNameMatcher;
41+
import org.apache.calcite.sql.validate.SqlValidator;
42+
import org.apache.calcite.util.Util;
2943

3044
import java.util.Collections;
3145
import java.util.List;
46+
import java.util.Optional;
47+
48+
import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
3249

3350
/**
3451
* SqlMlPredictTableFunction implements an operator for prediction.
@@ -61,9 +78,16 @@ public boolean argumentMustBeScalar(int ordinal) {
6178

6279
@Override
6380
protected RelDataType inferRowType(SqlOperatorBinding opBinding) {
64-
// TODO: FLINK-37780 output type based on table schema and model output schema
65-
// model output schema to be available after integrated with SqlExplicitModelCall
66-
return opBinding.getOperandType(1);
81+
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
82+
final RelDataType inputRowType = opBinding.getOperandType(0);
83+
final RelDataType modelOutputRowType = opBinding.getOperandType(1);
84+
85+
return typeFactory
86+
.builder()
87+
.kind(inputRowType.getStructKind())
88+
.addAll(inputRowType.getFieldList())
89+
.addAll(modelOutputRowType.getFieldList())
90+
.build();
6791
}
6892

6993
private static class PredictOperandMetadata implements SqlOperandMetadata {
@@ -87,21 +111,25 @@ public List<String> paramNames() {
87111

88112
@Override
89113
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
90-
// TODO: FLINK-37780 Check operand types after integrated with SqlExplicitModelCall in
91-
// validator
92-
return false;
114+
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 2, 1)) {
115+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
116+
callBinding, throwOnFailure);
117+
}
118+
119+
if (!SqlValidatorUtils.throwExceptionOrReturnFalse(
120+
checkModelSignature(callBinding), throwOnFailure)) {
121+
return false;
122+
}
123+
124+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
125+
checkConfig(callBinding), throwOnFailure);
93126
}
94127

95128
@Override
96129
public SqlOperandCountRange getOperandCountRange() {
97130
return SqlOperandCountRanges.between(MANDATORY_PARAM_NAMES.size(), PARAM_NAMES.size());
98131
}
99132

100-
@Override
101-
public Consistency getConsistency() {
102-
return Consistency.NONE;
103-
}
104-
105133
@Override
106134
public boolean isOptional(int i) {
107135
return i > getOperandCountRange().getMin() && i <= getOperandCountRange().getMax();
@@ -112,5 +140,92 @@ public String getAllowedSignatures(SqlOperator op, String opName) {
112140
return opName
113141
+ "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]]";
114142
}
143+
144+
private static Optional<RuntimeException> checkModelSignature(SqlCallBinding callBinding) {
145+
SqlValidator validator = callBinding.getValidator();
146+
147+
// Check second operand is SqlModelCall
148+
if (!(callBinding.operand(1) instanceof SqlModelCall)) {
149+
return Optional.of(
150+
new ValidationException("Second operand must be a model identifier."));
151+
}
152+
153+
// Get descriptor columns
154+
SqlCall descriptorCall = (SqlCall) callBinding.operand(2);
155+
List<SqlNode> descriptCols = descriptorCall.getOperandList();
156+
157+
// Get model input size
158+
SqlModelCall modelCall = (SqlModelCall) callBinding.operand(1);
159+
RelDataType modelInputType = modelCall.getInputType(validator);
160+
161+
// Check sizes match
162+
if (descriptCols.size() != modelInputType.getFieldCount()) {
163+
return Optional.of(
164+
new ValidationException(
165+
String.format(
166+
"Number of descriptor input columns (%d) does not match model input size (%d)",
167+
descriptCols.size(), modelInputType.getFieldCount())));
168+
}
169+
170+
// Check types match
171+
final RelDataType tableType = validator.getValidatedNodeType(callBinding.operand(0));
172+
final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher();
173+
for (int i = 0; i < descriptCols.size(); i++) {
174+
SqlIdentifier columnName = (SqlIdentifier) descriptCols.get(i);
175+
String descriptColName =
176+
columnName.isSimple()
177+
? columnName.getSimple()
178+
: Util.last(columnName.names);
179+
int index = matcher.indexOf(tableType.getFieldNames(), descriptColName);
180+
RelDataType sourceType = tableType.getFieldList().get(index).getType();
181+
RelDataType targetType = modelInputType.getFieldList().get(i).getType();
182+
183+
LogicalType sourceLogicalType = toLogicalType(sourceType);
184+
LogicalType targetLogicalType = toLogicalType(targetType);
185+
186+
if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType, targetLogicalType)) {
187+
return Optional.of(
188+
new ValidationException(
189+
String.format(
190+
"Descriptor column type %s cannot be assigned to model input type %s at position %d",
191+
sourceLogicalType, targetLogicalType, i)));
192+
}
193+
}
194+
195+
return Optional.empty();
196+
}
197+
198+
private static Optional<RuntimeException> checkConfig(SqlCallBinding callBinding) {
199+
if (callBinding.getOperandCount() < PARAM_NAMES.size()) {
200+
return Optional.empty();
201+
}
202+
203+
SqlNode configNode = callBinding.operand(3);
204+
if (!configNode.getKind().equals(SqlKind.MAP_VALUE_CONSTRUCTOR)) {
205+
return Optional.of(new ValidationException("Config param should be a MAP."));
206+
}
207+
208+
// Map operands can only be SqlCharStringLiteral or cast of SqlCharStringLiteral
209+
SqlCall mapCall = (SqlCall) configNode;
210+
for (int i = 0; i < mapCall.operandCount(); i++) {
211+
SqlNode operand = mapCall.operand(i);
212+
if (operand instanceof SqlCharStringLiteral) {
213+
continue;
214+
}
215+
if (operand.getKind().equals(SqlKind.CAST)) {
216+
SqlCall castCall = (SqlCall) operand;
217+
if (castCall.operand(0) instanceof SqlCharStringLiteral) {
218+
continue;
219+
}
220+
}
221+
return Optional.of(
222+
new ValidationException(
223+
String.format(
224+
"ML_PREDICT config param can only be a MAP of string literals. The item at position %d is %s.",
225+
i, operand)));
226+
}
227+
228+
return Optional.empty();
229+
}
115230
}
116231
}

0 commit comments

Comments
 (0)