Skip to content

Commit a44adc2

Browse files
authored
Support ML TF-IDF (#394)
1 parent 7cb6fea commit a44adc2

File tree

11 files changed

+923
-31
lines changed

11 files changed

+923
-31
lines changed

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Collections.Generic;
6+
using System.IO;
7+
using Microsoft.Spark.E2ETest.Utils;
68
using Microsoft.Spark.ML.Feature;
79
using Microsoft.Spark.Sql;
810
using Xunit;
@@ -47,6 +49,15 @@ public void TestBucketizer()
4749
Assert.Equal(expectedInputCol, bucketizer.GetInputCol());
4850
Assert.Equal(expectedOutputCol, bucketizer.GetOutputCol());
4951
Assert.Equal(expectedSplits, bucketizer.GetSplits());
52+
53+
using (var tempDirectory = new TemporaryDirectory())
54+
{
55+
string savePath = Path.Join(tempDirectory.Path, "bucket");
56+
bucketizer.Save(savePath);
57+
58+
Bucketizer loadedBucketizer = Bucketizer.Load(savePath);
59+
Assert.Equal(bucketizer.Uid(), loadedBucketizer.Uid());
60+
}
5061
}
5162

5263
[Fact]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.IO;
8+
using System.Linq;
9+
using Microsoft.Spark.E2ETest.Utils;
10+
using Microsoft.Spark.ML.Feature;
11+
using Microsoft.Spark.Sql;
12+
using Xunit;
13+
14+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
15+
{
16+
[Collection("Spark E2E Tests")]
17+
public class HashingTFTests
18+
{
19+
private readonly SparkSession _spark;
20+
21+
public HashingTFTests(SparkFixture fixture)
22+
{
23+
_spark = fixture.Spark;
24+
}
25+
26+
[Fact]
27+
public void TestHashingTF()
28+
{
29+
string expectedInputCol = "input_col";
30+
string expectedOutputCol = "output_col";
31+
int expectedFeatures = 10;
32+
33+
Assert.IsType<HashingTF>(new HashingTF());
34+
35+
HashingTF hashingTf = new HashingTF("my-unique-id")
36+
.SetNumFeatures(expectedFeatures)
37+
.SetInputCol(expectedInputCol)
38+
.SetOutputCol(expectedOutputCol);
39+
40+
Assert.Equal(expectedFeatures, hashingTf.GetNumFeatures());
41+
Assert.Equal(expectedInputCol, hashingTf.GetInputCol());
42+
Assert.Equal(expectedOutputCol, hashingTf.GetOutputCol());
43+
44+
DataFrame input = _spark.Sql("SELECT array('this', 'is', 'a', 'string', 'a', 'a')" +
45+
" as input_col");
46+
47+
DataFrame output = hashingTf.Transform(input);
48+
DataFrame outputVector = output.Select(expectedOutputCol);
49+
50+
Assert.Contains(expectedOutputCol, outputVector.Columns());
51+
52+
using (var tempDirectory = new TemporaryDirectory())
53+
{
54+
string savePath = Path.Join(tempDirectory.Path, "hashingTF");
55+
hashingTf.Save(savePath);
56+
57+
HashingTF loadedHashingTf = HashingTF.Load(savePath);
58+
Assert.Equal(hashingTf.Uid(), loadedHashingTf.Uid());
59+
}
60+
61+
hashingTf.SetBinary(true);
62+
Assert.True(hashingTf.GetBinary());
63+
}
64+
}
65+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.IO;
6+
using Microsoft.Spark.E2ETest.Utils;
7+
using Microsoft.Spark.ML.Feature;
8+
using Microsoft.Spark.Sql;
9+
using Xunit;
10+
11+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
12+
{
13+
[Collection("Spark E2E Tests")]
14+
public class IDFModelTests
15+
{
16+
private readonly SparkSession _spark;
17+
18+
public IDFModelTests(SparkFixture fixture)
19+
{
20+
_spark = fixture.Spark;
21+
}
22+
23+
[Fact]
24+
public void TestIDFModel()
25+
{
26+
int expectedDocFrequency = 1980;
27+
string expectedInputCol = "rawFeatures";
28+
string expectedOutputCol = "features";
29+
30+
DataFrame sentenceData =
31+
_spark.Sql("SELECT 0.0 as label, 'Hi I heard about Spark' as sentence");
32+
33+
Tokenizer tokenizer = new Tokenizer()
34+
.SetInputCol("sentence")
35+
.SetOutputCol("words");
36+
37+
DataFrame wordsData = tokenizer.Transform(sentenceData);
38+
39+
HashingTF hashingTF = new HashingTF()
40+
.SetInputCol("words")
41+
.SetOutputCol(expectedInputCol)
42+
.SetNumFeatures(20);
43+
44+
DataFrame featurizedData = hashingTF.Transform(wordsData);
45+
46+
IDF idf = new IDF()
47+
.SetInputCol(expectedInputCol)
48+
.SetOutputCol(expectedOutputCol)
49+
.SetMinDocFreq(expectedDocFrequency);
50+
51+
IDFModel idfModel = idf.Fit(featurizedData);
52+
53+
DataFrame rescaledData = idfModel.Transform(featurizedData);
54+
Assert.Contains(expectedOutputCol, rescaledData.Columns());
55+
56+
Assert.Equal(expectedInputCol, idfModel.GetInputCol());
57+
Assert.Equal(expectedOutputCol, idfModel.GetOutputCol());
58+
Assert.Equal(expectedDocFrequency, idfModel.GetMinDocFreq());
59+
60+
using (var tempDirectory = new TemporaryDirectory())
61+
{
62+
string modelPath = Path.Join(tempDirectory.Path, "idfModel");
63+
idfModel.Save(modelPath);
64+
65+
IDFModel loadedModel = IDFModel.Load(modelPath);
66+
Assert.Equal(idfModel.Uid(), loadedModel.Uid());
67+
}
68+
}
69+
}
70+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.IO;
6+
using Microsoft.Spark.E2ETest.Utils;
7+
using Microsoft.Spark.ML.Feature;
8+
using Microsoft.Spark.Sql;
9+
using Xunit;
10+
11+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
12+
{
13+
[Collection("Spark E2E Tests")]
14+
public class IDFTests
15+
{
16+
private readonly SparkSession _spark;
17+
18+
public IDFTests(SparkFixture fixture)
19+
{
20+
_spark = fixture.Spark;
21+
}
22+
23+
[Fact]
24+
public void TestIDFModel()
25+
{
26+
string expectedInputCol = "rawFeatures";
27+
string expectedOutputCol = "features";
28+
int expectedDocFrequency = 100;
29+
30+
IDF idf = new IDF()
31+
.SetInputCol(expectedInputCol)
32+
.SetOutputCol(expectedOutputCol)
33+
.SetMinDocFreq(expectedDocFrequency);
34+
35+
Assert.Equal(expectedInputCol, idf.GetInputCol());
36+
Assert.Equal(expectedOutputCol, idf.GetOutputCol());
37+
Assert.Equal(expectedDocFrequency, idf.GetMinDocFreq());
38+
39+
using (var tempDirectory = new TemporaryDirectory())
40+
{
41+
string savePath = Path.Join(tempDirectory.Path, "IDF");
42+
idf.Save(savePath);
43+
44+
IDF loadedIdf = IDF.Load(savePath);
45+
Assert.Equal(idf.Uid(), loadedIdf.Uid());
46+
}
47+
}
48+
}
49+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.IO;
6+
using Microsoft.Spark.E2ETest.Utils;
7+
using Microsoft.Spark.ML.Feature;
8+
using Microsoft.Spark.Sql;
9+
using Xunit;
10+
11+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
12+
{
13+
[Collection("Spark E2E Tests")]
14+
public class TokenizerTests
15+
{
16+
private readonly SparkSession _spark;
17+
18+
public TokenizerTests(SparkFixture fixture)
19+
{
20+
_spark = fixture.Spark;
21+
}
22+
23+
[Fact]
24+
public void TestTokenizer()
25+
{
26+
string expectedUid = "theUid";
27+
string expectedInputCol = "input_col";
28+
string expectedOutputCol = "output_col";
29+
30+
DataFrame input = _spark.Sql("SELECT 'hello I AM a string TO, TOKENIZE' as input_col" +
31+
" from range(100)");
32+
33+
Tokenizer tokenizer = new Tokenizer(expectedUid)
34+
.SetInputCol(expectedInputCol)
35+
.SetOutputCol(expectedOutputCol);
36+
37+
DataFrame output = tokenizer.Transform(input);
38+
39+
Assert.Contains(output.Schema().Fields, (f => f.Name == expectedOutputCol));
40+
Assert.Equal(expectedInputCol, tokenizer.GetInputCol());
41+
Assert.Equal(expectedOutputCol, tokenizer.GetOutputCol());
42+
43+
using (var tempDirectory = new TemporaryDirectory())
44+
{
45+
string savePath = Path.Join(tempDirectory.Path, "Tokenizer");
46+
tokenizer.Save(savePath);
47+
48+
Tokenizer loadedTokenizer = Tokenizer.Load(savePath);
49+
Assert.Equal(tokenizer.Uid(), loadedTokenizer.Uid());
50+
}
51+
52+
Assert.Equal(expectedUid, tokenizer.Uid());
53+
}
54+
}
55+
}

0 commit comments

Comments
 (0)