Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Merge pull request #224 from myui/hotfix/fix_fm_predict_bug
Browse files Browse the repository at this point in the history
Hotfix/fix fm predict bug
  • Loading branch information
myui committed Oct 30, 2015
2 parents d0dc744 + d1b054d commit 2b1995c
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 13 deletions.
2 changes: 1 addition & 1 deletion build.properties
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ test.result.dir=${target.dir}/test-results
user.name=myui
#java.version=1.6

project.version=0.4.0
project.version=0.4.0-1
project.name=Hivemall
project.groupId=hivemall
project.organization.name=Treasure Data, Inc.
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<groupId>io.github.myui</groupId>
<artifactId>hivemall</artifactId>
<version>0.4.0</version>
<version>0.4.0-1</version>

<name>Hivemall</name>
<description>Scalable Machine Learning Library for Apache Hive</description>
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/hivemall/HivemallConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

public final class HivemallConstants {

public static final String VERSION = "0.4.0";
public static final String VERSION = "0.4.0-1";

public static final String BIAS_CLAUSE = "0";
public static final String CONFKEY_RAND_AMPLIFY_SEED = "hivemall.amplify.seed";
Expand Down
12 changes: 10 additions & 2 deletions src/main/java/hivemall/fm/FMPredictUDAF.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,18 @@ void iterate(@Nullable DoubleWritable Wj, @Nullable List<FloatWritable> Vjf, @Nu
if(Xj == null) {
throw new HiveException("Xj should not be null");
}
final int factor = Vjf.size();
if(factor == 0) {// workaround for TD
return;
}

if(sumVjXj == null) {
int factors = Vjf.size();
this.sumVjXj = Arrays.asList(MutableDouble.initArray(factors, 0.d));
this.sumV2X2 = Arrays.asList(MutableDouble.initArray(factors, 0.d));
}

final double x = Xj.get();
final int factor = Vjf.size();
if(factor < 1) {
throw new HiveException("# of Factor should be more than 0: " + Vjf.toString());
}
Expand All @@ -141,6 +145,7 @@ void merge(PartialResult other) {
this.sumV2X2 = other.sumV2X2;
} else {
add(other.sumVjXj, sumVjXj);
assert (sumV2X2 != null);
add(other.sumV2X2, sumV2X2);
}
}
Expand All @@ -163,7 +168,10 @@ void merge(PartialResult other) {
return ret;
}

private static void add(@Nonnull final List<MutableDouble> src, @Nonnull final List<MutableDouble> dst) {
private static void add(@Nullable final List<MutableDouble> src, @Nonnull final List<MutableDouble> dst) {
if(src == null) {
return;
}
for(int i = 0, size = src.size(); i < size; i++) {
MutableDouble s = src.get(i);
assert (s != null);
Expand Down
26 changes: 18 additions & 8 deletions src/main/java/hivemall/mf/MFPredictionUDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,17 @@ public FloatWritable evaluate(List<Float> Pu, List<Float> Qi, double mu) throws
if(Pu == null || Qi == null) {
return null; //throw new HiveException("Pu should not be NULL");
}
final int factor = Pu.size();
if(Qi.size() != factor) {
throw new HiveException("|Pu| " + factor + " was not equal to |Qi| " + Qi.size());
final int PuSize = Pu.size();
final int QiSize = Qi.size();
if(QiSize != PuSize) {
throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
}
if(PuSize == 0) {// workaround for TD
return null;
}

float ret = (float) mu;
for(int k = 0; k < factor; k++) {
for(int k = 0; k < PuSize; k++) {
ret += Pu.get(k) * Qi.get(k);
}
return new FloatWritable(ret);
Expand All @@ -68,12 +73,17 @@ public FloatWritable evaluate(List<Float> Pu, List<Float> Qi, double Bu, double
return new FloatWritable(ret);
}

final int factor = Pu.size();
if(Qi.size() != factor) {
throw new HiveException("|Pu| " + factor + " was not equal to |Qi| " + Qi.size());
final int PuSize = Pu.size();
final int QiSize = Qi.size();
if(QiSize != PuSize) {
throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
}
if(PuSize == 0) {// workaround for TD
return null;
}

float ret = (float) (mu + Bu + Bi);
for(int k = 0; k < factor; k++) {
for(int k = 0; k < PuSize; k++) {
ret += Pu.get(k) * Qi.get(k);
}
return new FloatWritable(ret);
Expand Down

0 comments on commit 2b1995c

Please sign in to comment.