/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.bayes;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.SimplePredictionModel;
import com.rapidminer.tools.Tools;

public class DiscriminantModel
extends SimplePredictionModel {
    private static final long serialVersionUID = 3793343069512113817L;
    private double alpha;
    private String[] labels;
    private Matrix[] meanVectors;
    private Matrix[] inverseCovariances;
    private double[] aprioriProbabilities;
    private double[] constClassValues;

    public DiscriminantModel(ExampleSet exampleSet, String[] labels, Matrix[] meanVectors, Matrix[] inverseCovariances, double[] aprioriProbabilities, double alpha) {
        super(exampleSet);
        this.alpha = alpha;
        this.labels = labels;
        this.meanVectors = meanVectors;
        this.inverseCovariances = inverseCovariances;
        this.aprioriProbabilities = aprioriProbabilities;
        this.constClassValues = new double[labels.length];
        int i = 0;
        while (i < labels.length) {
            this.constClassValues[i] = -0.5 * meanVectors[i].times(inverseCovariances[i]).times(meanVectors[i].transpose()).get(0, 0) + Math.log(aprioriProbabilities[i]);
            ++i;
        }
    }

    public double predict(Example example) throws OperatorException {
        int numberOfAttributes = this.meanVectors[0].getColumnDimension();
        double[] vector = new double[numberOfAttributes];
        int i = 0;
        for (Attribute attribute : example.getAttributes()) {
            if (!attribute.isNumerical()) continue;
            vector[i] = example.getValue(attribute);
            ++i;
        }
        Matrix xVector = new Matrix(vector, 1);
        double[] labelFunction = new double[this.labels.length];
        int labelIndex = 0;
        while (labelIndex < this.labels.length) {
            labelFunction[labelIndex] = xVector.times(this.inverseCovariances[labelIndex]).times(this.meanVectors[labelIndex].transpose()).get(0, 0) + this.constClassValues[labelIndex];
            ++labelIndex;
        }
        double maximalValue = Double.NEGATIVE_INFINITY;
        int bestValue = 0;
        int labelIndex2 = 0;
        while (labelIndex2 < this.labels.length) {
            if (labelFunction[labelIndex2] >= maximalValue) {
                bestValue = labelIndex2;
                maximalValue = labelFunction[labelIndex2];
            }
            ++labelIndex2;
        }
        return bestValue;
    }

    public String getName() {
        if (this.alpha == 0.0) {
            return "Quadratic Discriminant Model";
        }
        if (this.alpha == 1.0) {
            return "Linear Discriminant Model";
        }
        return "Regularized Discriminant Model";
    }

    public String toString() {
        StringBuffer buffer = new StringBuffer();
        buffer.append("Apriori probabilities:\n");
        int i = 0;
        while (i < this.labels.length) {
            buffer.append(String.valueOf(this.labels[i]) + "\t");
            buffer.append(String.valueOf(Tools.formatNumber(this.aprioriProbabilities[i], 4)) + "\n");
            ++i;
        }
        return buffer.toString();
    }
}

