Skip to content

Commit

Permalink
Fixed bug in k-state for hmm break 'no transition'.
Browse files Browse the repository at this point in the history
Change-Id: Id6a3b7a7642f1552c7ea7ef9fd3c833a9f9a79fd
  • Loading branch information
Sebastian Mattheis committed Jan 3, 2018
1 parent eb16240 commit f2a3824
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 58 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.bmw-carit</groupId>
<artifactId>barefoot</artifactId>
<version>0.1.1</version>
<version>0.1.2</version>
<build>
<plugins>
<plugin>
Expand Down
42 changes: 22 additions & 20 deletions src/main/java/com/bmwcarit/barefoot/markov/KState.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.json.JSONException;
import org.json.JSONObject;

import com.bmwcarit.barefoot.util.Triple;
import com.bmwcarit.barefoot.util.Tuple;

/**
Expand All @@ -40,7 +41,7 @@ public class KState<C extends StateCandidate<C, T, S>, T extends StateTransition
extends StateMemory<C, T, S> {
private final int k;
private final long t;
private final LinkedList<Tuple<Set<C>, S>> sequence;
private final LinkedList<Triple<Set<C>, S, C>> sequence;
private final Map<C, Integer> counters;

/**
Expand Down Expand Up @@ -100,8 +101,10 @@ public KState(JSONObject json, Factory<C, T, S> factory) throws JSONException {
}

S sample = factory.sample(jsonseqelement.getJSONObject("sample"));
String kestid = jsonseqelement.getString("kestid");
C kestimate = candidates.get(kestid);

sequence.add(new Tuple<>(vector, sample));
sequence.add(new Triple<>(vector, sample, kestimate));
}

Collections.sort(sequence, new Comparator<Tuple<Set<C>, S>>() {
Expand Down Expand Up @@ -183,6 +186,7 @@ public void update(Set<C> vector, S sample) {
throw new RuntimeException("out-of-order state update is prohibited");
}

C kestimate = null;
for (C candidate : vector) {
counters.put(candidate, 0);
if (candidate.predecessor() != null) {
Expand All @@ -192,16 +196,16 @@ public void update(Set<C> vector, S sample) {
}
counters.put(candidate.predecessor(), counters.get(candidate.predecessor()) + 1);
}
if (kestimate == null || candidate.seqprob() > kestimate.seqprob()) {
kestimate = candidate;
}
}

if (!sequence.isEmpty()) {
Triple<Set<C>, S, C> last = sequence.peekLast();
Set<C> deletes = new HashSet<>();
C estimate = null;

for (C candidate : sequence.peekLast().one()) {
if (estimate == null || candidate.seqprob() > estimate.seqprob()) {
estimate = candidate;
}
for (C candidate : last.one()) {
if (counters.get(candidate) == 0) {
deletes.add(candidate);
}
Expand All @@ -210,13 +214,13 @@ public void update(Set<C> vector, S sample) {
int size = sequence.peekLast().one().size();

for (C candidate : deletes) {
if (deletes.size() != size || candidate != estimate) {
if (deletes.size() != size || candidate != last.three()) {
remove(candidate, sequence.size() - 1);
}
}
}

sequence.add(new Tuple<>(vector, sample));
sequence.add(new Triple<>(vector, sample, kestimate));

while ((t > 0 && sample.time() - sequence.peekFirst().two().time() > t)
|| (k >= 0 && sequence.size() > k + 1)) {
Expand All @@ -234,6 +238,10 @@ public void update(Set<C> vector, S sample) {
}

protected void remove(C candidate, int index) {
if (sequence.get(index).three() == candidate) {
return;
}

Set<C> vector = sequence.get(index).one();
counters.remove(candidate);
vector.remove(candidate);
Expand Down Expand Up @@ -282,23 +290,16 @@ public List<C> sequence() {
return null;
}

C kestimate = null;

for (C candidate : sequence.peekLast().one()) {
if (kestimate == null || candidate.seqprob() > kestimate.seqprob()) {
kestimate = candidate;
}
}

C kestimate = sequence.peekLast().three();
LinkedList<C> ksequence = new LinkedList<>();

for (int i = sequence.size() - 1; i >= 0; --i) {
if (kestimate != null) {
ksequence.push(kestimate);
kestimate = kestimate.predecessor();
} else {
ksequence.push(sequence.get(i).one().iterator().next());
assert (sequence.get(i).one().size() == 1);
ksequence.push(sequence.get(i).three());
kestimate = sequence.get(i).three().predecessor();
}
}

Expand All @@ -309,7 +310,7 @@ public List<C> sequence() {
public JSONObject toJSON() throws JSONException {
JSONObject json = new JSONObject();
JSONArray jsonsequence = new JSONArray();
for (Tuple<Set<C>, S> element : sequence) {
for (Triple<Set<C>, S, C> element : sequence) {
JSONObject jsonseqelement = new JSONObject();
JSONArray jsonvector = new JSONArray();
for (C candidate : element.one()) {
Expand All @@ -321,6 +322,7 @@ public JSONObject toJSON() throws JSONException {
}
jsonseqelement.put("vector", jsonvector);
jsonseqelement.put("sample", element.two().toJSON());
jsonseqelement.put("kestid", element.three().id());
jsonsequence.put(jsonseqelement);
}

Expand Down
97 changes: 61 additions & 36 deletions src/test/java/com/bmwcarit/barefoot/markov/KStateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,10 @@ public void TestKStateUnbound() {
elements.put(1, new MockElem(1, Math.log10(0.2), 0.2, null));
elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null));

KState<MockElem, StateTransition, Sample> state =
new KState<>();
KState<MockElem, StateTransition, Sample> state = new KState<>();
{
Set<MockElem> vector = new HashSet<>(
Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));
Set<MockElem> vector =
new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));

state.update(vector, new Sample(0));

Expand All @@ -90,8 +89,8 @@ public void TestKStateUnbound() {
elements.put(6, new MockElem(6, Math.log10(0.1), 0.1, elements.get(2)));

{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(3),
elements.get(4), elements.get(5), elements.get(6)));
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(3), elements.get(4),
elements.get(5), elements.get(6)));

state.update(vector, new Sample(1));

Expand All @@ -110,8 +109,8 @@ public void TestKStateUnbound() {
elements.put(10, new MockElem(10, Math.log10(0.1), 0.1, elements.get(6)));

{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(7),
elements.get(8), elements.get(9), elements.get(10)));
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(7), elements.get(8),
elements.get(9), elements.get(10)));

state.update(vector, new Sample(2));

Expand All @@ -130,12 +129,12 @@ public void TestKStateUnbound() {
elements.put(14, new MockElem(14, Math.log10(0.1), 0.1, null));

{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(11),
elements.get(12), elements.get(13), elements.get(14)));
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(11), elements.get(12),
elements.get(13), elements.get(14)));

state.update(vector, new Sample(3));

assertEquals(7, state.size());
assertEquals(8, state.size());
assertEquals(13, state.estimate().numid());

List<Integer> sequence = new LinkedList<>(Arrays.asList(2, 6, 9, 13));
Expand All @@ -148,7 +147,7 @@ public void TestKStateUnbound() {

state.update(vector, new Sample(4));

assertEquals(7, state.size());
assertEquals(8, state.size());
assertEquals(13, state.estimate().numid());

List<Integer> sequence = new LinkedList<>(Arrays.asList(2, 6, 9, 13));
Expand All @@ -158,18 +157,47 @@ public void TestKStateUnbound() {
}
}

@Test
public void TestBreak() {
// Test k-state in case of HMM break 'no transition' as reported in barefoot issue #83.
// Tests only 'no transitions', no emissions is empty vector and, hence, input to update
// operation.

KState<MockElem, StateTransition, Sample> state = new KState<>();
Map<Integer, MockElem> elements = new HashMap<>();
elements.put(0, new MockElem(0, Math.log10(0.4), 0.4, null));
{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(0)));
state.update(vector, new Sample(0));
}
elements.put(1, new MockElem(1, Math.log(0.7), 0.6, null));
elements.put(2, new MockElem(2, Math.log(0.3), 0.4, elements.get(0)));
{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(1), elements.get(2)));
state.update(vector, new Sample(1));
}
elements.put(3, new MockElem(3, Math.log(0.5), 0.6, null));
{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(3)));
state.update(vector, new Sample(2));
}
List<MockElem> seq = state.sequence();
assertEquals(seq.get(0).numid(), 0);
assertEquals(seq.get(1).numid(), 1);
assertEquals(seq.get(2).numid(), 3);
}

@Test
public void TestKState() {
Map<Integer, MockElem> elements = new HashMap<>();
elements.put(0, new MockElem(0, Math.log10(0.3), 0.3, null));
elements.put(1, new MockElem(1, Math.log10(0.2), 0.2, null));
elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null));

KState<MockElem, StateTransition, Sample> state =
new KState<>(1, -1);
KState<MockElem, StateTransition, Sample> state = new KState<>(1, -1);
{
Set<MockElem> vector = new HashSet<>(
Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));
Set<MockElem> vector =
new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));

state.update(vector, new Sample(0));

Expand All @@ -183,8 +211,8 @@ public void TestKState() {
elements.put(6, new MockElem(6, Math.log10(0.1), 0.1, elements.get(2)));

{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(3),
elements.get(4), elements.get(5), elements.get(6)));
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(3), elements.get(4),
elements.get(5), elements.get(6)));

state.update(vector, new Sample(1));

Expand All @@ -203,8 +231,8 @@ public void TestKState() {
elements.put(10, new MockElem(10, Math.log10(0.1), 0.1, elements.get(6)));

{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(7),
elements.get(8), elements.get(9), elements.get(10)));
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(7), elements.get(8),
elements.get(9), elements.get(10)));

state.update(vector, new Sample(2));

Expand All @@ -223,8 +251,8 @@ public void TestKState() {
elements.put(14, new MockElem(14, Math.log10(0.1), 0.1, null));

{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(11),
elements.get(12), elements.get(13), elements.get(14)));
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(11), elements.get(12),
elements.get(13), elements.get(14)));

state.update(vector, new Sample(3));

Expand Down Expand Up @@ -258,11 +286,10 @@ public void TestTState() {
elements.put(1, new MockElem(1, Math.log10(0.2), 0.2, null));
elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null));

KState<MockElem, StateTransition, Sample> state =
new KState<>(-1, 1);
KState<MockElem, StateTransition, Sample> state = new KState<>(-1, 1);
{
Set<MockElem> vector = new HashSet<>(
Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));
Set<MockElem> vector =
new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2)));

state.update(vector, new Sample(0));

Expand All @@ -276,8 +303,8 @@ public void TestTState() {
elements.put(6, new MockElem(6, Math.log10(0.1), 0.1, elements.get(2)));

{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(3),
elements.get(4), elements.get(5), elements.get(6)));
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(3), elements.get(4),
elements.get(5), elements.get(6)));

state.update(vector, new Sample(1));

Expand All @@ -296,8 +323,8 @@ public void TestTState() {
elements.put(10, new MockElem(10, Math.log10(0.1), 0.1, elements.get(6)));

{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(7),
elements.get(8), elements.get(9), elements.get(10)));
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(7), elements.get(8),
elements.get(9), elements.get(10)));

state.update(vector, new Sample(2));

Expand All @@ -316,8 +343,8 @@ public void TestTState() {
elements.put(14, new MockElem(14, Math.log10(0.1), 0.1, null));

{
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(11),
elements.get(12), elements.get(13), elements.get(14)));
Set<MockElem> vector = new HashSet<>(Arrays.asList(elements.get(11), elements.get(12),
elements.get(13), elements.get(14)));

state.update(vector, new Sample(3));

Expand Down Expand Up @@ -348,8 +375,7 @@ public void TestTState() {
public void TestKStateJSON() throws JSONException {
Map<Integer, MockElem> elements = new HashMap<>();

KState<MockElem, StateTransition, Sample> state =
new KState<>(1, -1);
KState<MockElem, StateTransition, Sample> state = new KState<>(1, -1);

{
JSONObject json = state.toJSON();
Expand All @@ -361,8 +387,7 @@ public void TestKStateJSON() throws JSONException {
elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null));

state.update(
new HashSet<>(
Arrays.asList(elements.get(0), elements.get(1), elements.get(2))),
new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2))),
new Sample(0));

{
Expand Down
2 changes: 1 addition & 1 deletion util/submit/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

tmp = "batch-%s" % random.randint(0, sys.maxint)
file = open(tmp, "w")
file.write("{\"format\": \"%s\", \"request\": %s}" % (options.format, json.dumps(samples)))
file.write("{\"format\": \"%s\", \"request\": %s}\n" % (options.format, json.dumps(samples)))
file.close()

subprocess.call("cat %s | netcat %s %s" % (tmp, options.host, options.port), shell=True)
Expand Down

0 comments on commit f2a3824

Please sign in to comment.