简介
一个系统可以处于N个不同的、不可观测的状态,(也就是说,我们永远不知道系统的实际状态)。该系统还具有有限的可能的可观测“输出”,这些输出取决于系统的实际(不可观测的)状态。
Viterbi算法的输入是一个时间序列中的观测列表,该算法计算每个时间帧对应的最可能的状态。
除意见清单外,还提供了以下内容:
有关更多细节,请参见例如维基百科。
设计的几点看法
除了正确地实现算法之外,我还想到了以下几个目标:
nextStep、getProbabilitiesForObservations和getPreviousStatesObservations方法)。calculate方法完成的。本综述的
虽然任何建议/意见都是受欢迎的,但以下是我最感兴趣的几点。
注:我只包括有关的部分。在GitHub上可以找到一个完整的工作版本。作为外部库,代码使用番石榴。(以及用于测试的JUnit / 汉克雷斯特。)
public static class ViterbiModel<S extends Enum<S>, T extends Enum<T>> {
public final ImmutableMap<S, Double> initialDistributions;
public final ImmutableTable<S, S, Double> transitionProbabilities;
public final ImmutableTable<S, T, Double> emissionProbabilities;
private ViterbiModel(ImmutableMap<S, Double> initialDistributions,
ImmutableTable<S, S, Double> transitionProbabilities,
ImmutableTable<S, T, Double> emissionProbabilities) {
this.initialDistributions = checkNotNull(initialDistributions);
this.transitionProbabilities = checkNotNull(transitionProbabilities);
this.emissionProbabilities = checkNotNull(emissionProbabilities);
}
public static <S extends Enum<S>, T extends Enum<T>> Builder<S, T> builder() {
return new Builder<>();
}
public static class Builder<S extends Enum<S>, T extends Enum<T>> {
private ImmutableMap<S, Double> initialDistributions;
private ImmutableTable.Builder<S, S, Double> transitionProbabilities = ImmutableTable.builder();
private ImmutableTable.Builder<S, T, Double> emissionProbabilities = ImmutableTable.builder();
public ViterbiModel<S, T> build() {
return new ViterbiModel<S, T>(immutableEnumMap(initialDistributions), transitionProbabilities.build(), emissionProbabilities.build());
}
public Builder<S, T> withInitialDistributions(ImmutableMap<S, Double> initialDistributions) {
this.initialDistributions = initialDistributions;
return this;
}
public Builder<S, T> withTransitionProbability(S src, S dest, Double prob) {
transitionProbabilities.put(src, dest, prob);
return this;
}
public Builder<S, T> withEmissionProbability(S state, T emission, Double prob) {
emissionProbabilities.put(state, emission, prob);
return this;
}
}
}
public static class ViterbiMachine<S extends Enum<S>, T extends Enum<T>> {
private final List<S> possibleStates;
private final List<T> possibleObservations;
private final ViterbiModel<S, T> model;
private final ImmutableList<T> observations;
private Table<S, Integer, Double> stateProbsForObservations = HashBasedTable.create();
private Table<S, Integer, Optional<S>> previousStatesForObservations = HashBasedTable.create();
private int step;
public ViterbiMachine(ViterbiModel<S, T> model, ImmutableList<T> observations) {
this.model = checkNotNull(model);
this.observations = checkNotNull(observations);
try {
possibleStates = ImmutableList.copyOf(getPossibleStates());
} catch (IllegalStateException ise) {
throw new IllegalArgumentException("empty states enum, or no explicit initial distribution provided", ise);
}
try {
possibleObservations = ImmutableList.copyOf(getPossibleObservations());
} catch (IllegalStateException ise) {
throw new IllegalArgumentException("empty observations enum, or no explicit observations provided", ise);
}
validate();
initialize();
}
private void validate() {
if (model.initialDistributions.size() != possibleStates.size()) {
throw new IllegalArgumentException("model.initialDistributions.size() = " + model.initialDistributions.size());
}
double sumInitProbs = 0.0;
for (double prob: model.initialDistributions.values()) {
sumInitProbs += prob;
}
if (!doublesEqual(sumInitProbs, 1.0)) {
throw new IllegalArgumentException("the sum of initial distributions should be 1.0, was " + sumInitProbs);
}
if (observations.size() < 1) {
// should not happen (observations size already checked when retrieving possible enum values),
// only added for the sake of completeness
throw new IllegalArgumentException("at least one observation should be provided, " + observations.size() + " given");
}
if (model.transitionProbabilities.size() < 1) {
throw new IllegalArgumentException("at least one transition probability should be provided, " + model.transitionProbabilities.size() + " given");
}
for (S row : possibleStates) {
double sumRowProbs = 0.0;
for (double prob : rowOrDefault(model.transitionProbabilities, row, ImmutableMap.<S, Double>of()).values()) {
sumRowProbs += prob;
}
if (!doublesEqual(sumRowProbs, 1.0)) {
throw new IllegalArgumentException("sum of transition probabilities for each state should be one, was " + sumRowProbs + " for state " + row);
}
}
if (model.emissionProbabilities.size() < 1) {
throw new IllegalArgumentException("at least one emission probability should be provided, 0 given " + model.emissionProbabilities.size() + " given");
}
for (S row : possibleStates) {
double sumRowProbs = 0.0;
for (double prob : rowOrDefault(model.emissionProbabilities, row, ImmutableMap.<T, Double>of()).values()) {
sumRowProbs += prob;
}
if (!doublesEqual(sumRowProbs, 1.0)) {
throw new IllegalArgumentException("sum of emission probabilities for each state should be one, was " + sumRowProbs + " for state " + row);
}
}
}
private static <S, T, V> V getOrDefault(Table<S, T, V> table, S key1, T key2, V defaultValue) {
V ret = table.get(key1, key2);
if (ret == null) {
ret = defaultValue;
}
return ret;
}
private static <S, T, V> Map<T, V> rowOrDefault(Table<S, T, V> table, S key, Map<T, V> defaultValue) {
Map<T, V> ret = table.row(key);
if (ret == null) {
ret = defaultValue;
}
return ret;
}
private void initialize() {
final T firstObservation = observations.get(0);
for (S state : possibleStates) {
stateProbsForObservations.put(state, 0, model.initialDistributions.getOrDefault(state, 0.0) * getOrDefault(model.emissionProbabilities, state, firstObservation, 0.0));
previousStatesForObservations.put(state, 0, Optional.<S>empty());
}
step = 1;
}
public void nextStep() {
if (step >= observations.size()) {
throw new IllegalStateException("already finished last step");
}
for (S state : possibleStates) {
double maxProb = 0.0;
Optional<S> prevStateWithMaxProb = Optional.empty();
for (S state2 : possibleStates) {
double prob = getOrDefault(stateProbsForObservations, state2, step - 1, 0.0) * getOrDefault(model.transitionProbabilities, state2, state, 0.0);
if (prob > maxProb) {
maxProb = prob;
prevStateWithMaxProb = Optional.of(state2);
}
}
stateProbsForObservations.put(state, step, maxProb * getOrDefault(model.emissionProbabilities, state, observations.get(step), 0.0));
previousStatesForObservations.put(state, step, prevStateWithMaxProb);
}
++step;
}
public ImmutableTable<S, Integer, Double> getProbabilitiesForObservations() {
return ImmutableTable.copyOf(stateProbsForObservations);
}
public ImmutableTable<S, Integer, Optional<S>> getPreviousStatesObservations() {
return ImmutableTable.copyOf(previousStatesForObservations);
}
public List<S> finish() {
if (step != observations.size()) {
throw new IllegalStateException("step = " + step);
}
S stateWithMaxProb = possibleStates.get(0);
double maxProb = stateProbsForObservations.get(stateWithMaxProb, observations.size() - 1);
for (S state : possibleStates) {
double prob = stateProbsForObservations.get(state, observations.size() - 1);
if (prob > maxProb) {
maxProb = prob;
stateWithMaxProb = state;
}
}
List<S> result = new ArrayList<>();
for (int i = observations.size() - 1; i >= 0; --i) {
result.add(stateWithMaxProb);
stateWithMaxProb = previousStatesForObservations.get(stateWithMaxProb, i).orElse(null);
}
return Lists.reverse(result);
}
public List<S> calculate() {
for (int i = 0; i < observations.size() - 1; ++i) {
nextStep();
}
return finish();
}
private S[] getPossibleStates() {
return getEnumsFromIterator(model.initialDistributions.keySet().iterator());
}
private T[] getPossibleObservations() {
return getEnumsFromIterator(observations.iterator());
}
private static <X extends Enum<X>> X[] getEnumsFromIterator(Iterator<X> it) {
if (!it.hasNext()) {
throw new IllegalStateException("iterator should have at least one element");
}
Enum<X> val1 = it.next();
return val1.getDeclaringClass().getEnumConstants();
}
private static boolean doublesEqual(double d1, double d2) {
return Math.abs(d1 - d2) < 0.0000001;
}
}public class ViterbiTest {
@Rule
public ExpectedException thrown = ExpectedException.none();
enum ZeroStatesZeroObservationsState { };
enum ZeroStatesZeroObservationsObservation { };
@Test
public void zeroStatesZeroObservationsIsNotOk() {
ViterbiModel<ZeroStatesZeroObservationsState, ZeroStatesZeroObservationsObservation> model = ViterbiModel.<ZeroStatesZeroObservationsState, ZeroStatesZeroObservationsObservation>builder()
.withInitialDistributions(ImmutableMap.<ZeroStatesZeroObservationsState, Double>builder()
.build())
.build();
ImmutableList<ZeroStatesZeroObservationsObservation> observations = ImmutableList.of();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty states enum, or no explicit initial distribution provided");
new ViterbiMachine<>(model, observations);
}
enum ZeroStatesOneObservationState { };
enum ZeroStatesOneObservationObservation { OBSERVATION0 };
@Test
public void zeroStatesOneObservationIsNotOk() {
ViterbiModel<ZeroStatesOneObservationState, ZeroStatesOneObservationObservation> model = ViterbiModel.<ZeroStatesOneObservationState, ZeroStatesOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<ZeroStatesOneObservationState, Double>builder()
.build())
.build();
ImmutableList<ZeroStatesOneObservationObservation> observations = ImmutableList.of();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty states enum, or no explicit initial distribution provided");
new ViterbiMachine<>(model, observations);
}
enum OneStateZeroObservationsState { STATE0 };
enum OneStateZeroObservationsObservation { };
@Test
public void oneStateZeroObservationsIsNotOk() {
ViterbiModel<OneStateZeroObservationsState, OneStateZeroObservationsObservation> model = ViterbiModel.<OneStateZeroObservationsState, OneStateZeroObservationsObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateZeroObservationsState, Double>builder()
.put(OneStateZeroObservationsState.STATE0, 1.0)
.build())
.build();
ImmutableList<OneStateZeroObservationsObservation> observations = ImmutableList.of();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty observations enum, or no explicit observations provided");
new ViterbiMachine<>(model, observations);
}
enum OneStateOneObservationState { STATE0 };
enum OneStateOneObservationObservation { OBSERVATION0 };
@Test
public void oneStateOneObservationIsOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
ViterbiMachine<OneStateOneObservationState, OneStateOneObservationObservation> machine = new ViterbiMachine<>(model, observations);
List<OneStateOneObservationState> states = machine.calculate();
final List<OneStateOneObservationState> expected = ImmutableList.of(OneStateOneObservationState.STATE0);
assertThat(states, is(expected));
}
@Test
public void oneStateOneObservationMissingInitialDistributionIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty states enum, or no explicit initial distribution provided");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationMissingObservationsIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty observations enum, or no explicit observations provided");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationSumInitialDistribNotOneIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.1)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("the sum of initial distributions should be 1.0, was 1.1");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationNoTransitionProbabilitiesIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("at least one transition probability should be provided, 0 given");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationSumTransitionProbabilitiesNotOneIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.1)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("sum of transition probabilities for each state should be one, was 1.1 for state STATE0");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationZeroEmissionProbabilitiesIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("at least one emission probability should be provided, 0 given");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationSumEmissionProbabilitiesNotOneIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.1)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("sum of emission probabilities for each state should be one, was 1.1 for state STATE0");
new ViterbiMachine<>(model, observations);
}
enum OneStateTwoObservationsState { STATE0 };
enum OneStateTwoObservationsObservation { OBSERVATION0, OBSERVATION1 };
@Test
public void oneStateTwoObservationsIsOk() {
ViterbiModel<OneStateTwoObservationsState, OneStateTwoObservationsObservation> model = ViterbiModel.<OneStateTwoObservationsState, OneStateTwoObservationsObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateTwoObservationsState, Double>builder()
.put(OneStateTwoObservationsState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsState.STATE0, 1.0)
.withEmissionProbability(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsObservation.OBSERVATION0, 0.4)
.withEmissionProbability(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsObservation.OBSERVATION1, 0.6)
.build();
ImmutableList<OneStateTwoObservationsObservation> observations = ImmutableList.of(OneStateTwoObservationsObservation.OBSERVATION1, OneStateTwoObservationsObservation.OBSERVATION1);
ViterbiMachine<OneStateTwoObservationsState, OneStateTwoObservationsObservation> machine = new ViterbiMachine<>(model, observations);
List<OneStateTwoObservationsState> states = machine.calculate();
final List<OneStateTwoObservationsState> expected = ImmutableList.of(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsState.STATE0);
assertThat(states, is(expected));
}
enum TwoStatesOneObservationState { STATE0, STATE1 };
enum TwoStatesOneObservationObservation { OBSERVATION0 };
@Test
public void twoStatesOneObservationIsOk() {
ViterbiModel<TwoStatesOneObservationState, TwoStatesOneObservationObservation> model = ViterbiModel.<TwoStatesOneObservationState, TwoStatesOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<TwoStatesOneObservationState, Double>builder()
.put(TwoStatesOneObservationState.STATE0, 0.6)
.put(TwoStatesOneObservationState.STATE1, 0.4)
.build())
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0, 0.7)
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE1, 0.3)
.withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE0, 0.4)
.withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE1, 0.6)
.withEmissionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.withEmissionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<TwoStatesOneObservationObservation> observations = ImmutableList.of(TwoStatesOneObservationObservation.OBSERVATION0, TwoStatesOneObservationObservation.OBSERVATION0);
ViterbiMachine<TwoStatesOneObservationState, TwoStatesOneObservationObservation> machine = new ViterbiMachine<>(model, observations);
List<TwoStatesOneObservationState> states = machine.calculate();
final List<TwoStatesOneObservationState> expected = ImmutableList.of(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0);
assertThat(states, is(expected));
}
@Test
public void twoStatesOneObservationTransitionsOmittedForOneStateIsNotOk() {
ViterbiModel<TwoStatesOneObservationState, TwoStatesOneObservationObservation> model = ViterbiModel.<TwoStatesOneObservationState, TwoStatesOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<TwoStatesOneObservationState, Double>builder()
.put(TwoStatesOneObservationState.STATE0, 0.6)
.put(TwoStatesOneObservationState.STATE1, 0.4)
.build())
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0, 0.7)
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE1, 0.3)
.withEmissionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.withEmissionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<TwoStatesOneObservationObservation> observations = ImmutableList.of(TwoStatesOneObservationObservation.OBSERVATION0, TwoStatesOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("sum of transition probabilities for each state should be one, was 0.0 for state STATE1");
new ViterbiMachine<>(model, observations);
}
@Test
public void twoStatesOneObservationEmissionsOmittedForOneStateIsNotOk() {
ViterbiModel<TwoStatesOneObservationState, TwoStatesOneObservationObservation> model = ViterbiModel.<TwoStatesOneObservationState, TwoStatesOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<TwoStatesOneObservationState, Double>builder()
.put(TwoStatesOneObservationState.STATE0, 0.6)
.put(TwoStatesOneObservationState.STATE1, 0.4)
.build())
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0, 0.7)
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE1, 0.3)
.withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE0, 0.4)
.withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE1, 0.6)
.withEmissionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<TwoStatesOneObservationObservation> observations = ImmutableList.of(TwoStatesOneObservationObservation.OBSERVATION0, TwoStatesOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("sum of emission probabilities for each state should be one, was 0.0 for state STATE1");
new ViterbiMachine<>(model, observations);
}
enum TwoStatesTwoObservationsState { STATE0, STATE1 };
enum TwoStatesTwoObservationsObservation { OBSERVATION0, OBSERVATION1 };
@Test
public void twoStatesTwoObservationsIsOk() {
ViterbiModel<TwoStatesTwoObservationsState, TwoStatesTwoObservationsObservation> model = ViterbiModel.<TwoStatesTwoObservationsState, TwoStatesTwoObservationsObservation>builder()
.withInitialDistributions(ImmutableMap.<TwoStatesTwoObservationsState, Double>builder()
.put(TwoStatesTwoObservationsState.STATE0, 0.6)
.put(TwoStatesTwoObservationsState.STATE1, 0.4)
.build())
.withTransitionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsState.STATE0, 0.7)
.withTransitionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsState.STATE1, 0.3)
.withTransitionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsState.STATE0, 0.4)
.withTransitionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsState.STATE1, 0.6)
.withEmissionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsObservation.OBSERVATION0, 0.6)
.withEmissionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsObservation.OBSERVATION1, 0.4)
.withEmissionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsObservation.OBSERVATION0, 0.6)
.withEmissionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsObservation.OBSERVATION1, 0.4)
.build();
ImmutableList<TwoStatesTwoObservationsObservation> observations = ImmutableList.of(TwoStatesTwoObservationsObservation.OBSERVATION0, TwoStatesTwoObservationsObservation.OBSERVATION0);
ViterbiMachine<TwoStatesTwoObservationsState, TwoStatesTwoObservationsObservation> machine = new ViterbiMachine<>(model, observations);
List<TwoStatesTwoObservationsState> states = machine.calculate();
final List<TwoStatesTwoObservationsState> expected = ImmutableList.of(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsState.STATE0);
assertThat(states, is(expected));
}
enum WikipediaState { HEALTHY, FEVER };
enum WikipediaObservation { OK, COLD, DIZZY };
@Test
public void wikipediaSample() {
ViterbiModel<WikipediaState, WikipediaObservation> model = ViterbiModel.<WikipediaState, WikipediaObservation>builder()
.withInitialDistributions(ImmutableMap.<WikipediaState, Double>builder()
.put(WikipediaState.HEALTHY, 0.6)
.put(WikipediaState.FEVER, 0.4)
.build())
.withTransitionProbability(WikipediaState.HEALTHY, WikipediaState.HEALTHY, 0.7)
.withTransitionProbability(WikipediaState.HEALTHY, WikipediaState.FEVER, 0.3)
.withTransitionProbability(WikipediaState.FEVER, WikipediaState.HEALTHY, 0.4)
.withTransitionProbability(WikipediaState.FEVER, WikipediaState.FEVER, 0.6)
.withEmissionProbability(WikipediaState.HEALTHY, WikipediaObservation.OK, 0.5)
.withEmissionProbability(WikipediaState.HEALTHY, WikipediaObservation.COLD, 0.4)
.withEmissionProbability(WikipediaState.HEALTHY, WikipediaObservation.DIZZY, 0.1)
.withEmissionProbability(WikipediaState.FEVER, WikipediaObservation.OK, 0.1)
.withEmissionProbability(WikipediaState.FEVER, WikipediaObservation.COLD, 0.3)
.withEmissionProbability(WikipediaState.FEVER, WikipediaObservation.DIZZY, 0.6)
.build();
ImmutableList<WikipediaObservation> observations = ImmutableList.of(WikipediaObservation.OK, WikipediaObservation.COLD, WikipediaObservation.DIZZY);
ViterbiMachine<WikipediaState, WikipediaObservation> machine = new ViterbiMachine<>(model, observations);
List<WikipediaState> states = machine.calculate();
final List<WikipediaState> expected = ImmutableList.of(WikipediaState.HEALTHY, WikipediaState.HEALTHY, WikipediaState.FEVER);
assertThat(states, is(expected));
}
// ... SNIP
}的评论
这个API可能看起来很冗长,但它是迄今为止我能想到的最好的API。我以前尝试过更简洁的方法,但它们更容易出错,也更难管理大量(4-5以上)的状态/观察。
作为参考,下面是API的前面尝试:
public static int [] viterbi(int numStates, int numObservations,
double [] initialDistrib,
double [][] transitionProbs, double [][] emissionProbs,
int [] observations) // --> causes huge/unmenegeable arrays
public static List<String> viterbi(Set<String> states,
Set<String> emissions,
Map<Key<String>, Double> transitionProbs,
Map<Key<String>, Double> emissionProbs,
Map<String, Double> initProbs,
List<String> observations) // --> a bit better, but not type safe发布于 2019-04-03 17:02:16
如果您将repo更新为包含make/ant/maven/graven构建文件,我将能够轻松地更改和运行您的代码。不能够重现您的构建环境,我可以做一些一般性的评论。
考虑使用谷歌的CallBuilder库来保存大量样板代码。这个库简单地通过注释构造函数就可以轻松地创建一个构建器。您可能需要实现一个自定义的“样式”类来复制您在自定义构建器中的确切行为;但是,我认为这是值得的。使用代码生成可以使构建器节省大量重复的、容易出错的代码,并有助于在整个项目中强制执行一致的构建器接口。
实际上,为所有Gauva数据结构编写CallBuilder样式类将是一个非常酷和有用的项目。但这超出了这个算法的范围。
ViterbiModel的构造函数更易于接受类似于:
private ViterbiModel(Map<? extends S, Double> initialDistributions,
Table<? extends S, ? extends S, Double> transitionProbabilities,
Table<? extends S, ? extends T, Double> emissionProbabilities)然后,在构造函数内部,使用ImmutableMap.copyOf和ImmutableTable.copyOf方法创建和存储不可变的副本。这些相同的更改需要适当地扩展到构建器。
ViterbiObservations类它应载有意见清单。它应该提供一个生成器。这是为了一致性,与VirterbiModel类匹配。
中执行验证
分别在适当的构造函数中验证ViterbiModel和VirterbiObservations对象。在这种情况下早期失败是与用户通信的一种重要方式。如果他们能够在不抛出任何异常的情况下创建VirterbiModel,那么它应该是有效的。
你应该
ViterbiMachine(ViterbiModel<S, ? extends T> model, ImmutableList<T> observations)因为可以在由父类型组成的模型中发出一系列子观察。
ImmutableTable您编写的getOrDefault和rowOrDefault方法很不错。但是,它们应该属于表类本身。因此,将ImmutableTable扩展到具有这些方法的类。
initialize()方法不清楚为什么这不是构造函数的一部分。
您的一些较小的函数与VirterbiMachines没有什么关系。把他们转移到另一个班去。
S, T为枚举类型()
我不明白为什么要把这些当作墓穴。有人会想要创建一个VirterbiMachine,比如状态是整数,输出是字符串吗?当然,您的代码可以允许这样做。
发布于 2019-06-08 12:56:54
在本杰明的评论已经很好了上有一些挑剔的地方:
Objects.requireNonNull‘s checkNotNullvalues()代替。java的泛型如此脆弱,使得不可能“只”调用X.values() (这是保证存在的),这有点烦人。ImmutableMap依赖项(比较这就是答案)。这将允许您使用EnumMap来获得更好的性能。stateProbsForObservations和previousStatesForObservations分别替换为Map<S, Double[]>和Map<S, Optional<S>[]>类型的映射。它可以由EnumMap再次填充,从而进一步减少内存占用和提高性能。同样,对于大多数用途来说,这只是微不足道的。ViterbiMachine的构造函数中使用异常作为流控制和验证。为了避免这种情况的发生,您可以检查您正在显式执行的操作的先决条件,而不是依赖下游方法在某个异常情况下失败。YMMV :)getOrDefault和rowOrDefault,但这不是您可以修复的东西:/nextStep()一次之后,calculate就会抛出一个IllegalStateException。IG,我会尽量避免使获得的结果容易出现非法的州例外。calculate(),但这是因为我喜欢缓存和智能和懒惰的计算器类。我只是喜欢实施这些..。https://codereview.stackexchange.com/questions/184896
复制相似问题