Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
language: java
jdk:
- oraclejdk8
- oraclejdk9
os:
- linux
- osx
Expand Down
9 changes: 8 additions & 1 deletion src/main/cpp/fasttext_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,14 @@ namespace FastTextWrapper {

std::vector<real> FastTextApi::getVector(const std::string& word) {
Vector vec(privateMembers->args_->dim);
fastText.getVector(vec, word);
fastText.getWordVector(vec, word);
return std::vector<real>(vec.data(), vec.data() + vec.size());
}

std::vector<real> FastTextApi::getSentenceVector(const std::string& sentence) {
Vector vec(privateMembers->args_->dim);
std::istringstream in(sentence);
fastText.getSentenceVector(in, vec);
return std::vector<real>(vec.data(), vec.data() + vec.size());
}

Expand Down
1 change: 1 addition & 0 deletions src/main/cpp/fasttext_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace FastTextWrapper {
std::vector<std::string> predict(const std::string&, int32_t);
std::vector<std::pair<real,std::string>> predictProba(const std::string&, int32_t);
std::vector<real> getVector(const std::string&);
std::vector<real> getSentenceVector(const std::string&);
std::vector<std::string> getWords();
std::vector<std::string> getLabels();
std::string getWord(int32_t);
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/github/jfasttext/FastTextWrapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ public DoubleIntPair put(float firstValue, int secondValue) {
public native @ByVal FloatStringPairVector predictProba(@StdString String arg0, int arg1);
public native @ByVal RealVector getVector(@StdString BytePointer arg0);
public native @ByVal RealVector getVector(@StdString String arg0);
public native @ByVal RealVector getSentenceVector(@StdString BytePointer arg0);
public native @ByVal RealVector getSentenceVector(@StdString String arg0);
public native @ByVal StringVector getWords();
public native @ByVal StringVector getLabels();
public native @StdString BytePointer getWord(int arg0);
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/com/github/jfasttext/JFastText.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ public List<Float> getVector(String word) {
return wordVec;
}

public List<Float> getSentenceVector(String sentence) {
if (!sentence.endsWith("\n")) {
sentence += "\n";
}
FastTextWrapper.RealVector rv = fta.getSentenceVector(sentence);
List<Float> wordVec = new ArrayList<>();
for (int i = 0; i < rv.size(); i++) {
wordVec.add(rv.get(i));
}
return wordVec;
}

public int getNWords() {
return fta.getNWords();
}
Expand Down
17 changes: 14 additions & 3 deletions src/test/java/com/github/jfasttext/JFastTextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ public void test01TrainSupervisedCmd() {
jft.runCmd(new String[] {
"supervised",
"-input", "src/test/resources/data/labeled_data.txt",
"-output", "src/test/resources/models/supervised.model"
"-output", "src/test/resources/models/supervised.model",
"-wordNgrams", "3",
"-bucket", "100"
});
}

Expand Down Expand Up @@ -86,11 +88,20 @@ public void test07GetVector() throws Exception {
System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec);
}

@Test
public void test08GetSentenceVector() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
String word = "soccers";
List<Float> vec = jft.getSentenceVector(word);
System.out.printf("\nSentence embedding vector of '%s': %s\n", word, vec);
}

/**
* Test retrieving model's information: words, labels, learning rate, etc.
*/
@Test
public void test08ModelInfo() throws Exception {
public void test09ModelInfo() throws Exception {
System.out.printf("\nSupervised model information:\n");
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
Expand All @@ -113,7 +124,7 @@ public void test08ModelInfo() throws Exception {
* allocated by native function calls).
*/
@Test
public void test09ModelUnloading() throws Exception {
public void test10ModelUnloading() throws Exception {
JFastText jft = new JFastText();
System.out.println("\nLoading model ...");
jft.loadModel("src/test/resources/models/supervised.model.bin");
Expand Down