/*
 * Decompiled with CFR 0.152.
 */
package be.ac.ulg.montefiore.run.jahmm.learn;

import be.ac.ulg.montefiore.run.jahmm.ForwardBackwardCalculator;
import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.Observation;
import be.ac.ulg.montefiore.run.jahmm.Opdf;
import be.ac.ulg.montefiore.run.jahmm.learn.KMeansLearner;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class BaumWelchLearner {
    private int nbIterations = 9;

    public <O extends Observation> Hmm<O> iterate(Hmm<O> hmm, List<? extends List<? extends O>> sequences) {
        Object nhmm;
        try {
            nhmm = hmm.clone();
        }
        catch (CloneNotSupportedException e) {
            throw new InternalError();
        }
        double[][][] allGamma = new double[sequences.size()][][];
        double[][] aijNum = new double[hmm.nbStates()][hmm.nbStates()];
        double[] aijDen = new double[hmm.nbStates()];
        Arrays.fill(aijDen, 0.0);
        int i = 0;
        while (i < hmm.nbStates()) {
            Arrays.fill(aijNum[i], 0.0);
            ++i;
        }
        int g = 0;
        for (List<O> obsSeq : sequences) {
            ForwardBackwardCalculator fbc = this.generateForwardBackwardCalculator(obsSeq, hmm);
            double[][][] xi = this.estimateXi(obsSeq, fbc, hmm);
            int n = g++;
            double[][] dArray = this.estimateGamma(xi, fbc);
            allGamma[n] = dArray;
            double[][] gamma = dArray;
            int i2 = 0;
            while (i2 < hmm.nbStates()) {
                int t = 0;
                while (t < obsSeq.size() - 1) {
                    int n2 = i2;
                    aijDen[n2] = aijDen[n2] + gamma[t][i2];
                    int j = 0;
                    while (j < hmm.nbStates()) {
                        double[] dArray2 = aijNum[i2];
                        int n3 = j;
                        dArray2[n3] = dArray2[n3] + xi[t][i2][j];
                        ++j;
                    }
                    ++t;
                }
                ++i2;
            }
        }
        int i3 = 0;
        while (i3 < hmm.nbStates()) {
            int j;
            if (aijDen[i3] == 0.0) {
                j = 0;
                while (j < hmm.nbStates()) {
                    ((Hmm)nhmm).setAij(i3, j, hmm.getAij(i3, j));
                    ++j;
                }
            } else {
                j = 0;
                while (j < hmm.nbStates()) {
                    ((Hmm)nhmm).setAij(i3, j, aijNum[i3][j] / aijDen[i3]);
                    ++j;
                }
            }
            ++i3;
        }
        i3 = 0;
        while (i3 < hmm.nbStates()) {
            ((Hmm)nhmm).setPi(i3, 0.0);
            ++i3;
        }
        int o = 0;
        while (o < sequences.size()) {
            int i4 = 0;
            while (i4 < hmm.nbStates()) {
                ((Hmm)nhmm).setPi(i4, ((Hmm)nhmm).getPi(i4) + allGamma[o][0][i4] / (double)sequences.size());
                ++i4;
            }
            ++o;
        }
        i3 = 0;
        while (i3 < hmm.nbStates()) {
            List observations = KMeansLearner.flat(sequences);
            double[] weights = new double[observations.size()];
            double sum = 0.0;
            int j = 0;
            int o2 = 0;
            for (List<O> obsSeq : sequences) {
                int t = 0;
                while (t < obsSeq.size()) {
                    weights[j] = allGamma[o2][t][i3];
                    sum += weights[j];
                    ++t;
                    ++j;
                }
                ++o2;
            }
            --j;
            while (j >= 0) {
                int n = j--;
                weights[n] = weights[n] / sum;
            }
            Opdf opdf = ((Hmm)nhmm).getOpdf(i3);
            opdf.fit(observations, weights);
            ++i3;
        }
        return nhmm;
    }

    protected <O extends Observation> ForwardBackwardCalculator generateForwardBackwardCalculator(List<? extends O> sequence, Hmm<O> hmm) {
        return new ForwardBackwardCalculator(sequence, hmm, EnumSet.allOf(ForwardBackwardCalculator.Computation.class));
    }

    public <O extends Observation> Hmm<O> learn(Hmm<O> initialHmm, List<? extends List<? extends O>> sequences) {
        Hmm<O> hmm = initialHmm;
        int i = 0;
        while (i < this.nbIterations) {
            hmm = this.iterate(hmm, sequences);
            ++i;
        }
        return hmm;
    }

    protected <O extends Observation> double[][][] estimateXi(List<? extends O> sequence, ForwardBackwardCalculator fbc, Hmm<O> hmm) {
        if (sequence.size() <= 1) {
            throw new IllegalArgumentException("Observation sequence too short");
        }
        double[][][] xi = new double[sequence.size() - 1][hmm.nbStates()][hmm.nbStates()];
        double probability = fbc.probability();
        Iterator<O> seqIterator = sequence.iterator();
        seqIterator.next();
        int t = 0;
        while (t < sequence.size() - 1) {
            Observation o = (Observation)seqIterator.next();
            int i = 0;
            while (i < hmm.nbStates()) {
                int j = 0;
                while (j < hmm.nbStates()) {
                    xi[t][i][j] = fbc.alphaElement(t, i) * hmm.getAij(i, j) * hmm.getOpdf(j).probability(o) * fbc.betaElement(t + 1, j) / probability;
                    ++j;
                }
                ++i;
            }
            ++t;
        }
        return xi;
    }

    protected double[][] estimateGamma(double[][][] xi, ForwardBackwardCalculator fbc) {
        int i;
        double[][] gamma = new double[xi.length + 1][xi[0].length];
        int t = 0;
        while (t < xi.length + 1) {
            Arrays.fill(gamma[t], 0.0);
            ++t;
        }
        t = 0;
        while (t < xi.length) {
            i = 0;
            while (i < xi[0].length) {
                int j = 0;
                while (j < xi[0].length) {
                    double[] dArray = gamma[t];
                    int n = i;
                    dArray[n] = dArray[n] + xi[t][i][j];
                    ++j;
                }
                ++i;
            }
            ++t;
        }
        int j = 0;
        while (j < xi[0].length) {
            i = 0;
            while (i < xi[0].length) {
                double[] dArray = gamma[xi.length];
                int n = j;
                dArray[n] = dArray[n] + xi[xi.length - 1][i][j];
                ++i;
            }
            ++j;
        }
        return gamma;
    }

    public int getNbIterations() {
        return this.nbIterations;
    }

    public void setNbIterations(int nb) {
        if (nb < 0) {
            throw new IllegalArgumentException("Positive number expected");
        }
        this.nbIterations = nb;
    }
}

