/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.bayes;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.learner.bayes.DiscriminantModel;
import com.rapidminer.operator.learner.bayes.LinearDiscriminantAnalysis;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.math.matrix.CovarianceMatrix;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RegularizedDiscriminantAnalysis
extends LinearDiscriminantAnalysis {
    public static final String PARAMETER_ALPHA = "alpha";

    public RegularizedDiscriminantAnalysis(OperatorDescription description) {
        super(description);
    }

    @Override
    protected Matrix[] getInverseCovarianceMatrices(ExampleSet exampleSet, String[] labels) throws UndefinedParameterError {
        double alpha = this.getParameterAsDouble(PARAMETER_ALPHA);
        Matrix[] globalInverseCovariances = super.getInverseCovarianceMatrices(exampleSet, labels);
        Matrix[] classInverseCovariances = new Matrix[labels.length];
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        SplittedExampleSet labelSet = SplittedExampleSet.splitByAttribute(exampleSet, exampleSet.getAttributes().getLabel());
        int labelIndex = 0;
        String[] stringArray = labels;
        int n = labels.length;
        int n2 = 0;
        while (n2 < n) {
            Matrix inverse;
            String label = stringArray[n2];
            int i = 0;
            while (i < labels.length) {
                labelSet.selectSingleSubset(i);
                if (labelSet.getExample(0).getNominalValue(labelAttribute).equals(label)) break;
                ++i;
            }
            classInverseCovariances[labelIndex] = inverse = CovarianceMatrix.getCovarianceMatrix(labelSet).inverse();
            ++labelIndex;
            ++n2;
        }
        Matrix[] regularizedMatrices = new Matrix[classInverseCovariances.length];
        int i = 0;
        while (i < labels.length) {
            regularizedMatrices[i] = globalInverseCovariances[i].times(alpha).plus(classInverseCovariances[i].times(1.0 - alpha));
            ++i;
        }
        return classInverseCovariances;
    }

    @Override
    protected DiscriminantModel getModel(ExampleSet exampleSet, String[] labels, Matrix[] meanVectors, Matrix[] inverseCovariances, double[] aprioriProbabilities) throws UndefinedParameterError {
        return new DiscriminantModel(exampleSet, labels, meanVectors, inverseCovariances, aprioriProbabilities, this.getParameterAsDouble(PARAMETER_ALPHA));
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> list = super.getParameterTypes();
        list.add(new ParameterTypeDouble(PARAMETER_ALPHA, "This is the strength of regularization. 1: Only global covariance is used, 0: Only per class covariance is used.", 0.0, 1.0, 0.5));
        return list;
    }
}

