diff --git a/pom.xml b/pom.xml
index b7aa61b0..70282a76 100755
--- a/pom.xml
+++ b/pom.xml
@@ -2,7 +2,7 @@
4.0.0
com.bmw-carit
barefoot
- 0.1.1
+ 0.1.2
diff --git a/src/main/java/com/bmwcarit/barefoot/markov/KState.java b/src/main/java/com/bmwcarit/barefoot/markov/KState.java
index 594e38a7..95e323e9 100755
--- a/src/main/java/com/bmwcarit/barefoot/markov/KState.java
+++ b/src/main/java/com/bmwcarit/barefoot/markov/KState.java
@@ -27,7 +27,7 @@
import org.json.JSONException;
import org.json.JSONObject;
-import com.bmwcarit.barefoot.util.Tuple;
+import com.bmwcarit.barefoot.util.Triple;
/**
* k-State data structure for organizing state memory in HMM inference.
@@ -40,7 +40,7 @@ public class KState, T extends StateTransition
extends StateMemory {
private final int k;
private final long t;
- private final LinkedList, S>> sequence;
+ private final LinkedList, S, C>> sequence;
private final Map counters;
/**
@@ -100,13 +100,15 @@ public KState(JSONObject json, Factory factory) throws JSONException {
}
S sample = factory.sample(jsonseqelement.getJSONObject("sample"));
+ String kestid = jsonseqelement.getString("kestid");
+ C kestimate = candidates.get(kestid);
- sequence.add(new Tuple<>(vector, sample));
+ sequence.add(new Triple<>(vector, sample, kestimate));
}
- Collections.sort(sequence, new Comparator, S>>() {
+ Collections.sort(sequence, new Comparator, S, C>>() {
@Override
- public int compare(Tuple, S> left, Tuple, S> right) {
+ public int compare(Triple, S, C> left, Triple, S, C> right) {
if (left.two().time() < right.two().time()) {
return -1;
} else if (left.two().time() > right.two().time()) {
@@ -167,7 +169,7 @@ public S sample() {
*/
public List samples() {
LinkedList samples = new LinkedList<>();
- for (Tuple, S> element : sequence) {
+ for (Triple, S, C> element : sequence) {
samples.add(element.two());
}
return samples;
@@ -183,6 +185,7 @@ public void update(Set vector, S sample) {
throw new RuntimeException("out-of-order state update is prohibited");
}
+ C kestimate = null;
for (C candidate : vector) {
counters.put(candidate, 0);
if (candidate.predecessor() != null) {
@@ -192,16 +195,16 @@ public void update(Set vector, S sample) {
}
counters.put(candidate.predecessor(), counters.get(candidate.predecessor()) + 1);
}
+ if (kestimate == null || candidate.seqprob() > kestimate.seqprob()) {
+ kestimate = candidate;
+ }
}
if (!sequence.isEmpty()) {
+ Triple, S, C> last = sequence.peekLast();
Set deletes = new HashSet<>();
- C estimate = null;
- for (C candidate : sequence.peekLast().one()) {
- if (estimate == null || candidate.seqprob() > estimate.seqprob()) {
- estimate = candidate;
- }
+ for (C candidate : last.one()) {
if (counters.get(candidate) == 0) {
deletes.add(candidate);
}
@@ -210,13 +213,13 @@ public void update(Set vector, S sample) {
int size = sequence.peekLast().one().size();
for (C candidate : deletes) {
- if (deletes.size() != size || candidate != estimate) {
+ if (deletes.size() != size || candidate != last.three()) {
remove(candidate, sequence.size() - 1);
}
}
}
- sequence.add(new Tuple<>(vector, sample));
+ sequence.add(new Triple<>(vector, sample, kestimate));
while ((t > 0 && sample.time() - sequence.peekFirst().two().time() > t)
|| (k >= 0 && sequence.size() > k + 1)) {
@@ -234,6 +237,10 @@ public void update(Set vector, S sample) {
}
protected void remove(C candidate, int index) {
+ if (sequence.get(index).three() == candidate) {
+ return;
+ }
+
Set vector = sequence.get(index).one();
counters.remove(candidate);
vector.remove(candidate);
@@ -282,14 +289,7 @@ public List sequence() {
return null;
}
- C kestimate = null;
-
- for (C candidate : sequence.peekLast().one()) {
- if (kestimate == null || candidate.seqprob() > kestimate.seqprob()) {
- kestimate = candidate;
- }
- }
-
+ C kestimate = sequence.peekLast().three();
LinkedList ksequence = new LinkedList<>();
for (int i = sequence.size() - 1; i >= 0; --i) {
@@ -297,8 +297,8 @@ public List sequence() {
ksequence.push(kestimate);
kestimate = kestimate.predecessor();
} else {
- ksequence.push(sequence.get(i).one().iterator().next());
- assert (sequence.get(i).one().size() == 1);
+ ksequence.push(sequence.get(i).three());
+ kestimate = sequence.get(i).three().predecessor();
}
}
@@ -309,7 +309,7 @@ public List sequence() {
public JSONObject toJSON() throws JSONException {
JSONObject json = new JSONObject();
JSONArray jsonsequence = new JSONArray();
- for (Tuple, S> element : sequence) {
+ for (Triple, S, C> element : sequence) {
JSONObject jsonseqelement = new JSONObject();
JSONArray jsonvector = new JSONArray();
for (C candidate : element.one()) {
@@ -321,6 +321,7 @@ public JSONObject toJSON() throws JSONException {
}
jsonseqelement.put("vector", jsonvector);
jsonseqelement.put("sample", element.two().toJSON());
+ jsonseqelement.put("kestid", element.three().id());
jsonsequence.put(jsonseqelement);
}
diff --git a/src/test/java/com/bmwcarit/barefoot/markov/KStateTest.java b/src/test/java/com/bmwcarit/barefoot/markov/KStateTest.java
index 8bbf1381..35b3656d 100644
--- a/src/test/java/com/bmwcarit/barefoot/markov/KStateTest.java
+++ b/src/test/java/com/bmwcarit/barefoot/markov/KStateTest.java
@@ -72,11 +72,10 @@ public void TestKStateUnbound() {
elements.put(1, new MockElem(1, Math.log10(0.2), 0.2, null));
elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null));
- KState state =
- new KState<>();
+ KState state = new KState<>();
{
- Set vector = new HashSet<>(
- Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));
+ Set vector =
+ new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));
state.update(vector, new Sample(0));
@@ -90,8 +89,8 @@ public void TestKStateUnbound() {
elements.put(6, new MockElem(6, Math.log10(0.1), 0.1, elements.get(2)));
{
- Set vector = new HashSet<>(Arrays.asList(elements.get(3),
- elements.get(4), elements.get(5), elements.get(6)));
+ Set vector = new HashSet<>(Arrays.asList(elements.get(3), elements.get(4),
+ elements.get(5), elements.get(6)));
state.update(vector, new Sample(1));
@@ -110,8 +109,8 @@ public void TestKStateUnbound() {
elements.put(10, new MockElem(10, Math.log10(0.1), 0.1, elements.get(6)));
{
- Set vector = new HashSet<>(Arrays.asList(elements.get(7),
- elements.get(8), elements.get(9), elements.get(10)));
+ Set vector = new HashSet<>(Arrays.asList(elements.get(7), elements.get(8),
+ elements.get(9), elements.get(10)));
state.update(vector, new Sample(2));
@@ -130,12 +129,12 @@ public void TestKStateUnbound() {
elements.put(14, new MockElem(14, Math.log10(0.1), 0.1, null));
{
- Set vector = new HashSet<>(Arrays.asList(elements.get(11),
- elements.get(12), elements.get(13), elements.get(14)));
+ Set vector = new HashSet<>(Arrays.asList(elements.get(11), elements.get(12),
+ elements.get(13), elements.get(14)));
state.update(vector, new Sample(3));
- assertEquals(7, state.size());
+ assertEquals(8, state.size());
assertEquals(13, state.estimate().numid());
List sequence = new LinkedList<>(Arrays.asList(2, 6, 9, 13));
@@ -148,7 +147,7 @@ public void TestKStateUnbound() {
state.update(vector, new Sample(4));
- assertEquals(7, state.size());
+ assertEquals(8, state.size());
assertEquals(13, state.estimate().numid());
List sequence = new LinkedList<>(Arrays.asList(2, 6, 9, 13));
@@ -158,6 +157,36 @@ public void TestKStateUnbound() {
}
}
+ @Test
+ public void TestBreak() {
+ // Test k-state in case of HMM break 'no transition' as reported in barefoot issue #83.
+ // Tests only 'no transitions', no emissions is empty vector and, hence, input to update
+ // operation.
+
+ KState state = new KState<>();
+ Map elements = new HashMap<>();
+ elements.put(0, new MockElem(0, Math.log10(0.4), 0.4, null));
+ {
+ Set vector = new HashSet<>(Arrays.asList(elements.get(0)));
+ state.update(vector, new Sample(0));
+ }
+ elements.put(1, new MockElem(1, Math.log(0.7), 0.6, null));
+ elements.put(2, new MockElem(2, Math.log(0.3), 0.4, elements.get(0)));
+ {
+ Set vector = new HashSet<>(Arrays.asList(elements.get(1), elements.get(2)));
+ state.update(vector, new Sample(1));
+ }
+ elements.put(3, new MockElem(3, Math.log(0.5), 0.6, null));
+ {
+ Set vector = new HashSet<>(Arrays.asList(elements.get(3)));
+ state.update(vector, new Sample(2));
+ }
+ List seq = state.sequence();
+ assertEquals(seq.get(0).numid(), 0);
+ assertEquals(seq.get(1).numid(), 1);
+ assertEquals(seq.get(2).numid(), 3);
+ }
+
@Test
public void TestKState() {
Map elements = new HashMap<>();
@@ -165,11 +194,10 @@ public void TestKState() {
elements.put(1, new MockElem(1, Math.log10(0.2), 0.2, null));
elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null));
- KState state =
- new KState<>(1, -1);
+ KState state = new KState<>(1, -1);
{
- Set vector = new HashSet<>(
- Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));
+ Set vector =
+ new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));
state.update(vector, new Sample(0));
@@ -183,8 +211,8 @@ public void TestKState() {
elements.put(6, new MockElem(6, Math.log10(0.1), 0.1, elements.get(2)));
{
- Set vector = new HashSet<>(Arrays.asList(elements.get(3),
- elements.get(4), elements.get(5), elements.get(6)));
+ Set vector = new HashSet<>(Arrays.asList(elements.get(3), elements.get(4),
+ elements.get(5), elements.get(6)));
state.update(vector, new Sample(1));
@@ -203,8 +231,8 @@ public void TestKState() {
elements.put(10, new MockElem(10, Math.log10(0.1), 0.1, elements.get(6)));
{
- Set vector = new HashSet<>(Arrays.asList(elements.get(7),
- elements.get(8), elements.get(9), elements.get(10)));
+ Set vector = new HashSet<>(Arrays.asList(elements.get(7), elements.get(8),
+ elements.get(9), elements.get(10)));
state.update(vector, new Sample(2));
@@ -223,8 +251,8 @@ public void TestKState() {
elements.put(14, new MockElem(14, Math.log10(0.1), 0.1, null));
{
- Set vector = new HashSet<>(Arrays.asList(elements.get(11),
- elements.get(12), elements.get(13), elements.get(14)));
+ Set vector = new HashSet<>(Arrays.asList(elements.get(11), elements.get(12),
+ elements.get(13), elements.get(14)));
state.update(vector, new Sample(3));
@@ -258,11 +286,10 @@ public void TestTState() {
elements.put(1, new MockElem(1, Math.log10(0.2), 0.2, null));
elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null));
- KState state =
- new KState<>(-1, 1);
+ KState state = new KState<>(-1, 1);
{
- Set vector = new HashSet<>(
- Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));
+ Set vector =
+ new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));
state.update(vector, new Sample(0));
@@ -276,8 +303,8 @@ public void TestTState() {
elements.put(6, new MockElem(6, Math.log10(0.1), 0.1, elements.get(2)));
{
- Set vector = new HashSet<>(Arrays.asList(elements.get(3),
- elements.get(4), elements.get(5), elements.get(6)));
+ Set vector = new HashSet<>(Arrays.asList(elements.get(3), elements.get(4),
+ elements.get(5), elements.get(6)));
state.update(vector, new Sample(1));
@@ -296,8 +323,8 @@ public void TestTState() {
elements.put(10, new MockElem(10, Math.log10(0.1), 0.1, elements.get(6)));
{
- Set vector = new HashSet<>(Arrays.asList(elements.get(7),
- elements.get(8), elements.get(9), elements.get(10)));
+ Set vector = new HashSet<>(Arrays.asList(elements.get(7), elements.get(8),
+ elements.get(9), elements.get(10)));
state.update(vector, new Sample(2));
@@ -316,8 +343,8 @@ public void TestTState() {
elements.put(14, new MockElem(14, Math.log10(0.1), 0.1, null));
{
- Set vector = new HashSet<>(Arrays.asList(elements.get(11),
- elements.get(12), elements.get(13), elements.get(14)));
+ Set vector = new HashSet<>(Arrays.asList(elements.get(11), elements.get(12),
+ elements.get(13), elements.get(14)));
state.update(vector, new Sample(3));
@@ -348,8 +375,7 @@ public void TestTState() {
public void TestKStateJSON() throws JSONException {
Map elements = new HashMap<>();
- KState state =
- new KState<>(1, -1);
+ KState state = new KState<>(1, -1);
{
JSONObject json = state.toJSON();
@@ -361,8 +387,7 @@ public void TestKStateJSON() throws JSONException {
elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null));
state.update(
- new HashSet<>(
- Arrays.asList(elements.get(0), elements.get(1), elements.get(2))),
+ new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2))),
new Sample(0));
{
diff --git a/util/submit/batch.py b/util/submit/batch.py
index 504ae61c..ad06c551 100644
--- a/util/submit/batch.py
+++ b/util/submit/batch.py
@@ -54,7 +54,7 @@
tmp = "batch-%s" % random.randint(0, sys.maxint)
file = open(tmp, "w")
-file.write("{\"format\": \"%s\", \"request\": %s}" % (options.format, json.dumps(samples)))
+file.write("{\"format\": \"%s\", \"request\": %s}\n" % (options.format, json.dumps(samples)))
file.close()
subprocess.call("cat %s | netcat %s %s" % (tmp, options.host, options.port), shell=True)