17
17
18
18
package org .apache .flink .table .planner .functions .sql .ml ;
19
19
20
+ import org .apache .flink .table .api .ValidationException ;
21
+ import org .apache .flink .table .planner .functions .utils .SqlValidatorUtils ;
22
+ import org .apache .flink .table .types .logical .LogicalType ;
23
+ import org .apache .flink .table .types .logical .utils .LogicalTypeCasts ;
24
+
20
25
import org .apache .calcite .rel .type .RelDataType ;
21
26
import org .apache .calcite .rel .type .RelDataTypeFactory ;
27
+ import org .apache .calcite .sql .SqlCall ;
22
28
import org .apache .calcite .sql .SqlCallBinding ;
29
+ import org .apache .calcite .sql .SqlCharStringLiteral ;
30
+ import org .apache .calcite .sql .SqlIdentifier ;
31
+ import org .apache .calcite .sql .SqlKind ;
32
+ import org .apache .calcite .sql .SqlModelCall ;
33
+ import org .apache .calcite .sql .SqlNode ;
23
34
import org .apache .calcite .sql .SqlOperandCountRange ;
24
35
import org .apache .calcite .sql .SqlOperator ;
25
36
import org .apache .calcite .sql .SqlOperatorBinding ;
26
37
import org .apache .calcite .sql .type .SqlOperandCountRanges ;
27
38
import org .apache .calcite .sql .type .SqlOperandMetadata ;
28
39
import org .apache .calcite .sql .type .SqlTypeName ;
40
+ import org .apache .calcite .sql .validate .SqlNameMatcher ;
41
+ import org .apache .calcite .sql .validate .SqlValidator ;
42
+ import org .apache .calcite .util .Util ;
29
43
30
44
import java .util .Collections ;
31
45
import java .util .List ;
46
+ import java .util .Optional ;
47
+
48
+ import static org .apache .flink .table .planner .calcite .FlinkTypeFactory .toLogicalType ;
32
49
33
50
/**
34
51
* SqlMlPredictTableFunction implements an operator for prediction.
@@ -61,9 +78,16 @@ public boolean argumentMustBeScalar(int ordinal) {
61
78
62
79
@ Override
63
80
protected RelDataType inferRowType (SqlOperatorBinding opBinding ) {
64
- // TODO: FLINK-37780 output type based on table schema and model output schema
65
- // model output schema to be available after integrated with SqlExplicitModelCall
66
- return opBinding .getOperandType (1 );
81
+ final RelDataTypeFactory typeFactory = opBinding .getTypeFactory ();
82
+ final RelDataType inputRowType = opBinding .getOperandType (0 );
83
+ final RelDataType modelOutputRowType = opBinding .getOperandType (1 );
84
+
85
+ return typeFactory
86
+ .builder ()
87
+ .kind (inputRowType .getStructKind ())
88
+ .addAll (inputRowType .getFieldList ())
89
+ .addAll (modelOutputRowType .getFieldList ())
90
+ .build ();
67
91
}
68
92
69
93
private static class PredictOperandMetadata implements SqlOperandMetadata {
@@ -87,21 +111,25 @@ public List<String> paramNames() {
87
111
88
112
@ Override
89
113
public boolean checkOperandTypes (SqlCallBinding callBinding , boolean throwOnFailure ) {
90
- // TODO: FLINK-37780 Check operand types after integrated with SqlExplicitModelCall in
91
- // validator
92
- return false ;
114
+ if (!SqlValidatorUtils .checkTableAndDescriptorOperands (callBinding , 2 , 1 )) {
115
+ return SqlValidatorUtils .throwValidationSignatureErrorOrReturnFalse (
116
+ callBinding , throwOnFailure );
117
+ }
118
+
119
+ if (!SqlValidatorUtils .throwExceptionOrReturnFalse (
120
+ checkModelSignature (callBinding ), throwOnFailure )) {
121
+ return false ;
122
+ }
123
+
124
+ return SqlValidatorUtils .throwExceptionOrReturnFalse (
125
+ checkConfig (callBinding ), throwOnFailure );
93
126
}
94
127
95
128
@ Override
96
129
public SqlOperandCountRange getOperandCountRange () {
97
130
return SqlOperandCountRanges .between (MANDATORY_PARAM_NAMES .size (), PARAM_NAMES .size ());
98
131
}
99
132
100
- @ Override
101
- public Consistency getConsistency () {
102
- return Consistency .NONE ;
103
- }
104
-
105
133
@ Override
106
134
public boolean isOptional (int i ) {
107
135
return i > getOperandCountRange ().getMin () && i <= getOperandCountRange ().getMax ();
@@ -112,5 +140,92 @@ public String getAllowedSignatures(SqlOperator op, String opName) {
112
140
return opName
113
141
+ "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]]" ;
114
142
}
143
+
144
+ private static Optional <RuntimeException > checkModelSignature (SqlCallBinding callBinding ) {
145
+ SqlValidator validator = callBinding .getValidator ();
146
+
147
+ // Check second operand is SqlModelCall
148
+ if (!(callBinding .operand (1 ) instanceof SqlModelCall )) {
149
+ return Optional .of (
150
+ new ValidationException ("Second operand must be a model identifier." ));
151
+ }
152
+
153
+ // Get descriptor columns
154
+ SqlCall descriptorCall = (SqlCall ) callBinding .operand (2 );
155
+ List <SqlNode > descriptCols = descriptorCall .getOperandList ();
156
+
157
+ // Get model input size
158
+ SqlModelCall modelCall = (SqlModelCall ) callBinding .operand (1 );
159
+ RelDataType modelInputType = modelCall .getInputType (validator );
160
+
161
+ // Check sizes match
162
+ if (descriptCols .size () != modelInputType .getFieldCount ()) {
163
+ return Optional .of (
164
+ new ValidationException (
165
+ String .format (
166
+ "Number of descriptor input columns (%d) does not match model input size (%d)" ,
167
+ descriptCols .size (), modelInputType .getFieldCount ())));
168
+ }
169
+
170
+ // Check types match
171
+ final RelDataType tableType = validator .getValidatedNodeType (callBinding .operand (0 ));
172
+ final SqlNameMatcher matcher = validator .getCatalogReader ().nameMatcher ();
173
+ for (int i = 0 ; i < descriptCols .size (); i ++) {
174
+ SqlIdentifier columnName = (SqlIdentifier ) descriptCols .get (i );
175
+ String descriptColName =
176
+ columnName .isSimple ()
177
+ ? columnName .getSimple ()
178
+ : Util .last (columnName .names );
179
+ int index = matcher .indexOf (tableType .getFieldNames (), descriptColName );
180
+ RelDataType sourceType = tableType .getFieldList ().get (index ).getType ();
181
+ RelDataType targetType = modelInputType .getFieldList ().get (i ).getType ();
182
+
183
+ LogicalType sourceLogicalType = toLogicalType (sourceType );
184
+ LogicalType targetLogicalType = toLogicalType (targetType );
185
+
186
+ if (!LogicalTypeCasts .supportsImplicitCast (sourceLogicalType , targetLogicalType )) {
187
+ return Optional .of (
188
+ new ValidationException (
189
+ String .format (
190
+ "Descriptor column type %s cannot be assigned to model input type %s at position %d" ,
191
+ sourceLogicalType , targetLogicalType , i )));
192
+ }
193
+ }
194
+
195
+ return Optional .empty ();
196
+ }
197
+
198
+ private static Optional <RuntimeException > checkConfig (SqlCallBinding callBinding ) {
199
+ if (callBinding .getOperandCount () < PARAM_NAMES .size ()) {
200
+ return Optional .empty ();
201
+ }
202
+
203
+ SqlNode configNode = callBinding .operand (3 );
204
+ if (!configNode .getKind ().equals (SqlKind .MAP_VALUE_CONSTRUCTOR )) {
205
+ return Optional .of (new ValidationException ("Config param should be a MAP." ));
206
+ }
207
+
208
+ // Map operands can only be SqlCharStringLiteral or cast of SqlCharStringLiteral
209
+ SqlCall mapCall = (SqlCall ) configNode ;
210
+ for (int i = 0 ; i < mapCall .operandCount (); i ++) {
211
+ SqlNode operand = mapCall .operand (i );
212
+ if (operand instanceof SqlCharStringLiteral ) {
213
+ continue ;
214
+ }
215
+ if (operand .getKind ().equals (SqlKind .CAST )) {
216
+ SqlCall castCall = (SqlCall ) operand ;
217
+ if (castCall .operand (0 ) instanceof SqlCharStringLiteral ) {
218
+ continue ;
219
+ }
220
+ }
221
+ return Optional .of (
222
+ new ValidationException (
223
+ String .format (
224
+ "ML_PREDICT config param can only be a MAP of string literals. The item at position %d is %s." ,
225
+ i , operand )));
226
+ }
227
+
228
+ return Optional .empty ();
229
+ }
115
230
}
116
231
}
0 commit comments