package marytts.machinelearning;

import com.rapidminer.example.Example;
import java.io.IOException;
import marytts.util.MaryUtils;
import marytts.util.io.FileUtils;
import marytts.util.io.MaryRandomAccessFile;
import marytts.util.math.DoubleMatrix;
import marytts.util.math.MathUtils;
import marytts.util.string.StringUtils;
import org.eclipse.persistence.internal.helper.Helper;

/* loaded from: input_file:WEB-INF/lib/marytts-d4science-5.0.0.jar:marytts/machinelearning/GMMTrainer.class */
public class GMMTrainer {
    public double[] logLikelihoods = null;
    static final /* synthetic */ boolean $assertionsDisabled;

    public GMM train(double[][] dArr, GMMTrainerParams gMMTrainerParams) {
        long currentTimeMillis = System.currentTimeMillis();
        GMM gmm = null;
        if (dArr != null && gMMTrainerParams.totalComponents > 0) {
            if (!MaryUtils.isWindows()) {
                gMMTrainerParams.useNativeCLibTrainer = false;
            }
            if (gMMTrainerParams.useNativeCLibTrainer) {
                String randomFileName = StringUtils.getRandomFileName("d:/gmmTemp_", 8, ".dat");
                new DoubleMatrix(dArr).write(randomFileName);
                String modifyExtension = StringUtils.modifyExtension(randomFileName, ".gmm");
                int shellExecute = MaryUtils.shellExecute("GMMTrainer.exe \"" + randomFileName + "\" " + Helper.DEFAULT_DATABASE_DELIMITER + modifyExtension + "\" " + String.valueOf(gMMTrainerParams.totalComponents) + Example.SEPARATOR + "1" + Example.SEPARATOR + String.valueOf(gMMTrainerParams.isDiagonalCovariance ? 1 : 0) + Example.SEPARATOR + String.valueOf(gMMTrainerParams.kmeansMaxIterations) + Example.SEPARATOR + String.valueOf(gMMTrainerParams.kmeansMinClusterChangePercent) + Example.SEPARATOR + String.valueOf(gMMTrainerParams.kmeansMinSamplesInOneCluster) + Example.SEPARATOR + String.valueOf(gMMTrainerParams.emMinIterations) + Example.SEPARATOR + String.valueOf(gMMTrainerParams.emMaxIterations) + Example.SEPARATOR + String.valueOf(gMMTrainerParams.isUpdateCovariances ? 1 : 0) + Example.SEPARATOR + String.valueOf(gMMTrainerParams.tinyLogLikelihoodChangePercent) + Example.SEPARATOR + String.valueOf(gMMTrainerParams.minCovarianceAllowed) + Example.SEPARATOR + Helper.DEFAULT_DATABASE_DELIMITER + StringUtils.modifyExtension(randomFileName, ".log") + Helper.DEFAULT_DATABASE_DELIMITER, true);
                if (shellExecute == 0) {
                    System.out.println("GMM training with native C library done...");
                    gmm = new GMM(modifyExtension);
                    FileUtils.delete(modifyExtension);
                } else {
                    System.out.println("Error executing native C library with exit code " + shellExecute);
                }
                FileUtils.delete(randomFileName);
            } else {
                int length = dArr[0].length;
                for (int i = 1; i < dArr.length; i++) {
                    if (!$assertionsDisabled && dArr[i].length != length) {
                        throw new AssertionError();
                    }
                }
                KMeansClusteringTrainerParams kMeansClusteringTrainerParams = new KMeansClusteringTrainerParams(gMMTrainerParams);
                KMeansClusteringTrainer kMeansClusteringTrainer = new KMeansClusteringTrainer();
                kMeansClusteringTrainer.train(dArr, kMeansClusteringTrainerParams);
                gmm = expectationMaximization(dArr, new GMM(kMeansClusteringTrainer), gMMTrainerParams.emMinIterations, gMMTrainerParams.emMaxIterations, gMMTrainerParams.isUpdateCovariances, gMMTrainerParams.tinyLogLikelihoodChangePercent, gMMTrainerParams.minCovarianceAllowed);
            }
        }
        System.out.println("GMM training took " + String.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " seconds...");
        return gmm;
    }

    public GMM expectationMaximization(double[][] dArr, GMM gmm, int i, int i2, boolean z, double d, double d2) {
        int length = dArr.length;
        GMM gmm2 = new GMM(gmm);
        for (double[] dArr2 : dArr) {
            if (!$assertionsDisabled && dArr2.length != gmm2.featureDimension) {
                throw new AssertionError();
            }
        }
        int i3 = 1;
        double d3 = 0.0d;
        for (int i4 = 0; i4 < gmm2.totalComponents; i4++) {
            gmm2.weights[i4] = 1.0f / gmm2.totalComponents;
        }
        double[] dArr3 = new double[length];
        double[][] dArr4 = new double[length][gmm2.totalComponents];
        double[][] dArr5 = new double[length][gmm2.totalComponents];
        double[] dArr6 = new double[gmm2.featureDimension];
        double[] dArr7 = new double[gmm2.featureDimension];
        double[][] dArr8 = new double[gmm2.featureDimension][gmm2.featureDimension];
        this.logLikelihoods = new double[i2];
        while (1 != 0) {
            long currentTimeMillis = System.currentTimeMillis();
            for (int i5 = 0; i5 < length; i5++) {
                dArr3[i5] = 0.0d;
                for (int i6 = 0; i6 < gmm2.totalComponents; i6++) {
                    dArr4[i5][i6] = gmm2.weights[i6] * (gmm2.isDiagonalCovariance ? MathUtils.getGaussianPdfValue(dArr[i5], gmm2.components[i6].meanVector, gmm2.components[i6].getCovMatrixDiagonal(), gmm2.components[i6].getConstantTerm()) : MathUtils.getGaussianPdfValue(dArr[i5], gmm2.components[i6].meanVector, gmm2.components[i6].getInvCovMatrix(), gmm2.components[i6].getConstantTerm()));
                    dArr3[i5] = dArr3[i5] + dArr4[i5][i6];
                }
            }
            for (int i7 = 0; i7 < length; i7++) {
                for (int i8 = 0; i8 < gmm2.totalComponents; i8++) {
                    dArr5[i7][i8] = dArr4[i7][i8] / dArr3[i7];
                }
            }
            for (int i9 = 0; i9 < gmm2.totalComponents; i9++) {
                double d4 = 0.0d;
                for (int i10 = 0; i10 < length; i10++) {
                    d4 += dArr5[i10][i9];
                }
                gmm2.weights[i9] = d4 / length;
            }
            double d5 = 0.0d;
            for (int i11 = 0; i11 < gmm2.totalComponents; i11++) {
                for (int i12 = 0; i12 < gmm2.featureDimension; i12++) {
                    dArr6[i12] = 0.0d;
                    for (int i13 = 0; i13 < gmm2.featureDimension; i13++) {
                        dArr8[i12][i13] = 0.0d;
                    }
                }
                double d6 = 0.0d;
                for (int i14 = 0; i14 < length; i14++) {
                    d6 += dArr5[i14][i11];
                    for (int i15 = 0; i15 < gmm2.featureDimension; i15++) {
                        int i16 = i15;
                        dArr6[i16] = dArr6[i16] + (dArr[i14][i15] * dArr5[i14][i11]);
                        double d7 = dArr[i14][i15] - gmm2.components[i11].meanVector[i15];
                        for (int i17 = 0; i17 < gmm2.featureDimension; i17++) {
                            double[] dArr9 = dArr8[i15];
                            int i18 = i17;
                            dArr9[i18] = dArr9[i18] + (dArr5[i14][i11] * d7 * (dArr[i14][i17] - gmm2.components[i11].meanVector[i17]));
                        }
                    }
                }
                for (int i19 = 0; i19 < gmm2.featureDimension; i19++) {
                    dArr7[i19] = dArr6[i19] / d6;
                }
                double d8 = 0.0d;
                for (int i20 = 0; i20 < gmm2.featureDimension; i20++) {
                    double d9 = dArr7[i20] - gmm2.components[i11].meanVector[i20];
                    d8 += d9 * d9;
                }
                d5 += Math.sqrt(d8);
                for (int i21 = 0; i21 < gmm2.featureDimension; i21++) {
                    gmm2.components[i11].meanVector[i21] = dArr7[i21];
                }
                if (z) {
                    if (gmm2.isDiagonalCovariance) {
                        for (int i22 = 0; i22 < gmm2.featureDimension; i22++) {
                            gmm2.components[i11].covMatrix[0][i22] = Math.max(dArr8[i22][i22] / d6, d2);
                        }
                    } else {
                        for (int i23 = 0; i23 < gmm2.featureDimension; i23++) {
                            for (int i24 = 0; i24 < gmm2.featureDimension; i24++) {
                                gmm2.components[i11].covMatrix[i23][i24] = Math.max(dArr8[i23][i24] / d6, d2);
                            }
                        }
                    }
                    gmm2.components[i11].setDerivedValues();
                }
            }
            d3 = i3 == 1 ? d5 : d5;
            this.logLikelihoods[i3 - 1] = 0.0d;
            if (gmm2.isDiagonalCovariance) {
                for (double[] dArr10 : dArr) {
                    double d10 = 0.0d;
                    for (int i25 = 0; i25 < gmm2.totalComponents; i25++) {
                        d10 += gmm2.weights[i25] * MathUtils.getGaussianPdfValue(dArr10, gmm2.components[i25].meanVector, gmm2.components[i25].getCovMatrixDiagonal(), gmm2.components[i25].getConstantTerm());
                    }
                    double[] dArr11 = this.logLikelihoods;
                    int i26 = i3 - 1;
                    dArr11[i26] = dArr11[i26] + Math.log(d10);
                }
            } else {
                for (double[] dArr12 : dArr) {
                    double d11 = 0.0d;
                    for (int i27 = 0; i27 < gmm2.totalComponents; i27++) {
                        d11 += gmm2.weights[i27] * MathUtils.getGaussianPdfValue(dArr12, gmm2.components[i27].meanVector, gmm2.components[i27].getInvCovMatrix(), gmm2.components[i27].getConstantTerm());
                    }
                    double[] dArr13 = this.logLikelihoods;
                    int i28 = i3 - 1;
                    dArr13[i28] = dArr13[i28] + Math.log(d11);
                }
            }
            System.out.println("For " + String.valueOf(gmm2.totalComponents) + " mixes - EM iteration no: " + String.valueOf(i3) + " with avg. difference in means " + String.valueOf(d3) + " log-likelihood=" + String.valueOf(this.logLikelihoods[i3 - 1]) + " in " + String.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " sec");
            if (i3 + 1 > i2 || (i3 > i && this.logLikelihoods[i3 - 1] - this.logLikelihoods[i3 - 2] < Math.abs((this.logLikelihoods[i3 - 1] / 100.0d) * d))) {
                break;
            }
            i3++;
        }
        double[] dArr14 = new double[i3 - 1];
        System.arraycopy(this.logLikelihoods, 0, dArr14, 0, i3 - 1);
        this.logLikelihoods = new double[i3 - 1];
        System.arraycopy(dArr14, 0, this.logLikelihoods, 0, i3 - 1);
        System.out.println("GMM training completed...");
        return gmm2;
    }

    public static void testEndianFileIO() throws IOException {
        MaryRandomAccessFile maryRandomAccessFile = new MaryRandomAccessFile("d:/endianJava.tmp", "rw");
        if (maryRandomAccessFile != null) {
            maryRandomAccessFile.writeBooleanEndian(true);
            maryRandomAccessFile.writeCharEndian('c');
            maryRandomAccessFile.writeShortEndian((short) 111);
            maryRandomAccessFile.writeIntEndian(222);
            maryRandomAccessFile.writeDoubleEndian(33.3d);
            maryRandomAccessFile.writeFloatEndian(44.4f);
            maryRandomAccessFile.writeLongEndian(555L);
            maryRandomAccessFile.close();
        }
        if (!FileUtils.exists("d:/endianC.tmp")) {
            System.out.println("C generated file not found...\n");
            return;
        }
        MaryRandomAccessFile maryRandomAccessFile2 = new MaryRandomAccessFile("d:/endianC.tmp", "r");
        if (maryRandomAccessFile2 == null) {
            System.out.println("C generated file cannot be opened...\n");
            return;
        }
        boolean readBooleanEndian = maryRandomAccessFile2.readBooleanEndian();
        char readCharEndian = maryRandomAccessFile2.readCharEndian();
        short readShortEndian = maryRandomAccessFile2.readShortEndian();
        int readIntEndian = maryRandomAccessFile2.readIntEndian();
        double readDoubleEndian = maryRandomAccessFile2.readDoubleEndian();
        float readFloatEndian = maryRandomAccessFile2.readFloatEndian();
        long readLongEndian = maryRandomAccessFile2.readLongEndian();
        maryRandomAccessFile2.close();
        if (true != readBooleanEndian) {
            System.out.println("Error in bool!\n");
        }
        if ('c' != readCharEndian) {
            System.out.println("Error in char!\n");
        }
        if (111 != readShortEndian) {
            System.out.println("Error in short!\n");
        }
        if (222 != readIntEndian) {
            System.out.println("Error in int!\n");
        }
        if (33.3d != readDoubleEndian) {
            System.out.println("Error in double!\n");
        }
        if (44.4f != readFloatEndian) {
            System.out.println("Error in float!\n");
        }
        if (555 != readLongEndian) {
            System.out.println("Error in long!\n");
        }
    }

    public static void main(String[] strArr) {
        double[] dArr = {0.01d};
        ClusteredDataGenerator[] clusteredDataGeneratorArr = new ClusteredDataGenerator[10];
        for (int i = 0; i < 10; i++) {
            if (i < dArr.length) {
                clusteredDataGeneratorArr[i] = new ClusteredDataGenerator(20, 2000, 10.0d * (i + 1), dArr[i]);
            } else {
                clusteredDataGeneratorArr[i] = new ClusteredDataGenerator(20, 2000, 10.0d * (i + 1), dArr[0]);
            }
        }
        double[][] dArr2 = new double[clusteredDataGeneratorArr[0].data.length][10];
        for (int i2 = 0; i2 < clusteredDataGeneratorArr.length; i2++) {
            for (int i3 = 0; i3 < clusteredDataGeneratorArr[i2].data.length; i3++) {
                dArr2[i3][i2] = clusteredDataGeneratorArr[i2].data[i3];
            }
        }
        double[][] randomSort = MathUtils.randomSort(dArr2);
        double[] mean = MathUtils.mean(randomSort);
        System.out.println(String.valueOf(mean[0]) + Example.SEPARATOR + String.valueOf(MathUtils.variance(randomSort, mean)[0]));
        GMMTrainerParams gMMTrainerParams = new GMMTrainerParams();
        gMMTrainerParams.totalComponents = 20;
        gMMTrainerParams.isDiagonalCovariance = true;
        gMMTrainerParams.kmeansMaxIterations = 100;
        gMMTrainerParams.kmeansMinClusterChangePercent = 0.01d;
        gMMTrainerParams.kmeansMinSamplesInOneCluster = 10;
        gMMTrainerParams.emMinIterations = 100;
        gMMTrainerParams.emMaxIterations = 2000;
        gMMTrainerParams.isUpdateCovariances = true;
        gMMTrainerParams.tinyLogLikelihoodChangePercent = 0.001d;
        gMMTrainerParams.minCovarianceAllowed = 1.0E-5d;
        gMMTrainerParams.useNativeCLibTrainer = true;
        GMM train = new GMMTrainer().train(randomSort, gMMTrainerParams);
        if (train != null) {
            for (int i4 = 0; i4 < train.totalComponents; i4++) {
                System.out.println("Gaussian #" + String.valueOf(i4 + 1) + " mean=" + String.valueOf(train.components[i4].meanVector[0]) + " variance=" + String.valueOf(train.components[i4].covMatrix[0][0]) + " prior=" + train.weights[i4]);
            }
        }
    }

    static {
        $assertionsDisabled = !GMMTrainer.class.desiredAssertionStatus();
    }
}
