Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flink-27826] Support training very high dimensional logistic regression #237

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/content/docs/operators/classification/knn.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Below are the parameters required by `KnnModel`.
```java
import org.apache.flink.ml.classification.knn.Knn;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/classification/linearsvc.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ Below are the parameters required by `LinearSVCModel`.
```java
import org.apache.flink.ml.classification.linearsvc.LinearSVC;
import org.apache.flink.ml.classification.linearsvc.LinearSVCModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Below are the parameters required by `LogisticRegressionModel`.
```java
import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down Expand Up @@ -251,9 +251,9 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
import org.apache.flink.ml.examples.util.PeriodicSourceFunction;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
Expand Down Expand Up @@ -323,7 +323,7 @@ public class OnlineLogisticRegressionExample {

// Creates an online LogisticRegression object and initializes its parameters and initial
// model data.
Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L);
Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L, 2L, 0L);
Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData));
OnlineLogisticRegression olr =
new OnlineLogisticRegression()
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/classification/naivebayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Below are parameters required by `NaiveBayesModel`.
```java
import org.apache.flink.ml.classification.naivebayes.NaiveBayes;
import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ format of the merging information is
import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering;
import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams;
import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
6 changes: 3 additions & 3 deletions docs/content/docs/operators/clustering/kmeans.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Below are the parameters required by `KMeansModel`.
```java
import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down Expand Up @@ -228,9 +228,9 @@ import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
import org.apache.flink.ml.examples.util.PeriodicSourceFunction;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/countvectorizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Below are the parameters required by `CountVectorizerModel`.
```java
import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.SparseIntDoubleVector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/dct.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ that the transform matrix is unitary (aka scaled DCT-II).

```java
import org.apache.flink.ml.feature.dct.DCT;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/elementwiseproduct.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ scaling vector, the transformer will throw an IllegalArgumentException.

```java
import org.apache.flink.ml.feature.elementwiseproduct.ElementwiseProduct;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/featurehasher.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for det

```java
import org.apache.flink.ml.feature.featurehasher.FeatureHasher;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/hashingtf.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ the output values are accumulated by default.
```java

import org.apache.flink.ml.feature.hashingtf.HashingTF;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.SparseIntDoubleVector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/idf.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Below are the parameters required by `IDFModel`.
```java
import org.apache.flink.ml.feature.idf.IDF;
import org.apache.flink.ml.feature.idf.IDFModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/interaction.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ be Vector(3, 6, 4, 8).

```java
import org.apache.flink.ml.feature.interaction.Interaction;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/kbinsdiscretizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Below are the parameters required by `KBinsDiscretizerModel`.
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/maxabsscaler.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ It does not shift/center the data and thus does not destroy any sparsity.
```java
import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler;
import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
6 changes: 3 additions & 3 deletions docs/content/docs/operators/feature/minhashlsh.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.ml.feature.lsh.MinHashLSH;
import org.apache.flink.ml.feature.lsh.MinHashLSHModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.SparseIntDoubleVector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/minmaxscaler.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ MinMaxScaler is an algorithm that rescales feature values to a common range
```java
import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler;
import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/normalizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ A Transformer that normalizes a vector to have unit norm using the given p-norm.

```java
import org.apache.flink.ml.feature.normalizer.Normalizer;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/onehotencoder.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ vector column for each input column.
```java
import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder;
import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.SparseIntDoubleVector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
Expand Down
10 changes: 5 additions & 5 deletions docs/content/docs/operators/feature/onlinestandardscaler.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ import org.apache.flink.api.common.time.Time;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
Expand Down Expand Up @@ -202,13 +202,13 @@ t_env = StreamTableEnvironment.create(env)

# Generates input data.
dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType(
get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(),
get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer()
get_gateway().jvm.org.apache.flink.ml.linalg.DenseIntDoubleVector(0).getClass(),
get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer()
).getSerializerString()

schema = Schema.new_builder()
.column("ts", "TIMESTAMP_LTZ(3)")
.column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')"
.column("input", "RAW('org.apache.flink.ml.linalg.DenseIntDoubleVector', '{serializer}')"
.format(serializer=dense_vector_serializer))
.watermark("ts", "ts - INTERVAL '1' SECOND")
.build()
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/polynomialexpansion.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ http://en.wikipedia.org/wiki/Polynomial_expansion.

```java
import org.apache.flink.ml.feature.polynomialexpansion.PolynomialExpansion;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/robustscaler.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Below are the parameters required by `RobustScalerModel`.
```java
import org.apache.flink.ml.feature.robustscaler.RobustScaler;
import org.apache.flink.ml.feature.robustscaler.RobustScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/standardscaler.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ the mean and scaling each dimension to unit variance.
```java
import org.apache.flink.ml.feature.standardscaler.StandardScaler;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ Below are the parameters required by `UnivariateFeatureSelectorModel`.
```java
import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Below are the parameters required by `VarianceThresholdSelectorModel`.
```java
import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector;
import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/vectorassembler.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ the strategy specified by the {@link HasHandleInvalid} parameter as follows:

```java
import org.apache.flink.ml.feature.vectorassembler.VectorAssembler;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/feature/vectorslicer.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ it throws an IllegalArgumentException.

```java
import org.apache.flink.ml.feature.vectorslicer.VectorSlicer;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
4 changes: 2 additions & 2 deletions docs/content/docs/operators/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ of double arrays.

{{< tab "Java">}}
```java
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down Expand Up @@ -145,7 +145,7 @@ DenseVector instances.

{{< tab "Java">}}
```java
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.IntDoubleVector;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/operators/regression/linearregression.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Below are the parameters required by `LinearRegressionModel`.
{{< tab "Java">}}

```java
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.regression.linearregression.LinearRegression;
import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ package myflinkml;

import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorArrayGenerator;
import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize;
import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.DenseIntDoubleVector;
import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.table.api.DataTypes;
Expand Down Expand Up @@ -77,8 +77,8 @@ public Table[] getData(StreamTableEnvironment tEnv) {
* information.
*/
public static class GenerateWeightsFunction extends ScalarFunction {
public DenseVector eval(DenseVector[] centroids) {
return new DenseVector(centroids.length);
public DenseIntDoubleVector eval(DenseIntDoubleVector[] centroids) {
return new DenseIntDoubleVector(centroids.length);
}

@Override
Expand All @@ -87,7 +87,7 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) {
.outputTypeStrategy(
callContext ->
Optional.of(
DataTypes.of(DenseVectorTypeInfo.INSTANCE)
DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE)
.toDataType(typeFactory)))
.build();
}
Expand Down
Loading