/*
 * Decompiled with CFR 0.152.
 */
package be.ac.ulg.montefiore.run.jahmm.learn;

import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.Observation;
import be.ac.ulg.montefiore.run.jahmm.Opdf;
import be.ac.ulg.montefiore.run.jahmm.OpdfFactory;
import be.ac.ulg.montefiore.run.jahmm.ViterbiCalculator;
import be.ac.ulg.montefiore.run.jahmm.learn.Clusters;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class KMeansLearner<O extends Observation> {
    private Clusters<O> clusters;
    private int nbStates;
    private List<? extends List<? extends O>> obsSeqs;
    private OpdfFactory<? extends Opdf<O>> opdfFactory;
    private boolean terminated;

    public KMeansLearner(int nbStates, OpdfFactory<? extends Opdf<O>> opdfFactory, List<? extends List<? extends O>> sequences) {
        this.obsSeqs = sequences;
        this.opdfFactory = opdfFactory;
        this.nbStates = nbStates;
        List<O> observations = KMeansLearner.flat(sequences);
        this.clusters = new Clusters<O>(nbStates, observations);
        this.terminated = false;
    }

    public Hmm<O> iterate() {
        Hmm hmm = new Hmm(this.nbStates, this.opdfFactory);
        this.learnPi(hmm);
        this.learnAij(hmm);
        this.learnOpdf(hmm);
        this.terminated = this.optimizeCluster(hmm);
        return hmm;
    }

    public boolean isTerminated() {
        return this.terminated;
    }

    public Hmm<O> learn() {
        Hmm<O> hmm;
        do {
            hmm = this.iterate();
        } while (!this.isTerminated());
        return hmm;
    }

    private void learnPi(Hmm<?> hmm) {
        double[] pi = new double[this.nbStates];
        int i = 0;
        while (i < this.nbStates) {
            pi[i] = 0.0;
            ++i;
        }
        for (List<O> sequence : this.obsSeqs) {
            int n = this.clusters.clusterNb((Observation)sequence.get(0));
            pi[n] = pi[n] + 1.0;
        }
        i = 0;
        while (i < this.nbStates) {
            hmm.setPi(i, pi[i] / (double)this.obsSeqs.size());
            ++i;
        }
    }

    private void learnAij(Hmm<O> hmm) {
        int i = 0;
        while (i < hmm.nbStates()) {
            int j = 0;
            while (j < hmm.nbStates()) {
                hmm.setAij(i, j, 0.0);
                ++j;
            }
            ++i;
        }
        for (List<O> obsSeq : this.obsSeqs) {
            if (obsSeq.size() < 2) continue;
            int second_state = this.clusters.clusterNb((Observation)obsSeq.get(0));
            int i2 = 1;
            while (i2 < obsSeq.size()) {
                int first_state = second_state;
                second_state = this.clusters.clusterNb((Observation)obsSeq.get(i2));
                hmm.setAij(first_state, second_state, hmm.getAij(first_state, second_state) + 1.0);
                ++i2;
            }
        }
        i = 0;
        while (i < hmm.nbStates()) {
            double sum = 0.0;
            int j = 0;
            while (j < hmm.nbStates()) {
                sum += hmm.getAij(i, j);
                ++j;
            }
            if (sum == 0.0) {
                j = 0;
                while (j < hmm.nbStates()) {
                    hmm.setAij(i, j, 1.0 / (double)hmm.nbStates());
                    ++j;
                }
            } else {
                j = 0;
                while (j < hmm.nbStates()) {
                    hmm.setAij(i, j, hmm.getAij(i, j) / sum);
                    ++j;
                }
            }
            ++i;
        }
    }

    private void learnOpdf(Hmm<O> hmm) {
        int i = 0;
        while (i < hmm.nbStates()) {
            Collection<O> clusterObservations = this.clusters.cluster(i);
            if (clusterObservations.isEmpty()) {
                hmm.setOpdf(i, this.opdfFactory.factor());
            } else {
                hmm.getOpdf(i).fit(clusterObservations);
            }
            ++i;
        }
    }

    private boolean optimizeCluster(Hmm<O> hmm) {
        boolean modif = false;
        for (List<O> obsSeq : this.obsSeqs) {
            ViterbiCalculator vc = new ViterbiCalculator(obsSeq, hmm);
            int[] states = vc.stateSequence();
            int i = 0;
            while (i < states.length) {
                Observation o = (Observation)obsSeq.get(i);
                if (this.clusters.clusterNb(o) != states[i]) {
                    modif = true;
                    this.clusters.remove(o, this.clusters.clusterNb(o));
                    this.clusters.put(o, states[i]);
                }
                ++i;
            }
        }
        return !modif;
    }

    static <T> List<T> flat(List<? extends List<? extends T>> lists) {
        ArrayList<T> v = new ArrayList<T>();
        for (List<T> list : lists) {
            v.addAll(list);
        }
        return v;
    }
}

