package com.rapidminer.operator.learner.functions;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.UndefinedParameterError;
import java.util.List;
import org.hibernate.id.enhanced.OptimizerFactory;

/* loaded from: input_file:com/rapidminer/operator/learner/functions/LinearRegression.class */
public class LinearRegression extends AbstractLearner {
    public static final String PARAMETER_FEATURE_SELECTION = "feature_selection";
    public static final String PARAMETER_ELIMINATE_COLINEAR_FEATURES = "eliminate_colinear_features";
    public static final String PARAMETER_USE_BIAS = "use_bias";
    public static final String PARAMETER_MIN_STANDARDIZED_COEFFICIENT = "min_standardized_coefficient";
    public static final String PARAMETER_RIDGE = "ridge";
    public static final String[] FEATURE_SELECTION_METHODS = {OptimizerFactory.NONE, "M5 prime", "greedy"};
    public static final int NO_SELECTION = 0;
    public static final int M5_PRIME = 1;
    public static final int GREEDY = 2;

    public LinearRegression(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        double[] performRegression;
        boolean z;
        boolean z2;
        Attribute label = exampleSet.getAttributes().getLabel();
        Attribute attribute = label;
        boolean z3 = false;
        String str = null;
        String str2 = null;
        boolean parameterAsBoolean = getParameterAsBoolean("use_bias");
        if (label.isNominal() && label.getMapping().size() == 2) {
            str = label.getMapping().getNegativeString();
            str2 = label.getMapping().getPositiveString();
            int negativeIndex = label.getMapping().getNegativeIndex();
            attribute = AttributeFactory.createAttribute("regression_label", 4);
            exampleSet.getExampleTable().addAttribute(attribute);
            for (Example example : exampleSet) {
                if (example.getValue(label) == negativeIndex) {
                    example.setValue(attribute, 0.0d);
                } else {
                    example.setValue(attribute, 1.0d);
                }
            }
            exampleSet.getAttributes().setLabel(attribute);
            z3 = true;
        }
        int size = exampleSet.getAttributes().size();
        boolean[] zArr = new boolean[size];
        int i = 0;
        String[] strArr = new String[size];
        for (Attribute attribute2 : exampleSet.getAttributes()) {
            zArr[i] = attribute2.isNumerical();
            strArr[i] = attribute2.getName();
            i++;
        }
        exampleSet.recalculateAllAttributeStatistics();
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        int i2 = 0;
        for (Attribute attribute3 : exampleSet.getAttributes()) {
            if (zArr[i2]) {
                dArr[i2] = exampleSet.getStatistics(attribute3, Statistics.AVERAGE_WEIGHTED);
                dArr2[i2] = Math.sqrt(exampleSet.getStatistics(attribute3, Statistics.VARIANCE_WEIGHTED));
                if (dArr2[i2] == 0.0d) {
                    zArr[i2] = false;
                }
            }
            i2++;
        }
        double statistics = exampleSet.getStatistics(attribute, Statistics.AVERAGE_WEIGHTED);
        double sqrt = Math.sqrt(exampleSet.getStatistics(attribute, Statistics.VARIANCE_WEIGHTED));
        int size2 = exampleSet.size();
        double[] dArr3 = new double[size + 1];
        do {
            performRegression = performRegression(exampleSet, zArr, dArr, statistics);
            if (!getParameterAsBoolean(PARAMETER_ELIMINATE_COLINEAR_FEATURES)) {
                break;
            }
        } while (deselectAttributeWithHighestCoefficient(zArr, performRegression, dArr2, sqrt));
        int i3 = 1;
        for (boolean z4 : zArr) {
            if (z4) {
                i3++;
            }
        }
        double squaredError = getSquaredError(exampleSet, zArr, performRegression, parameterAsBoolean);
        double d = (size2 - i3) + (2 * i3);
        int i4 = i3;
        switch (getParameterAsInt(PARAMETER_FEATURE_SELECTION)) {
            case 1:
                do {
                    z = false;
                    i4--;
                    double d2 = 0.0d;
                    int i5 = -1;
                    int i6 = 0;
                    for (int i7 = 0; i7 < zArr.length; i7++) {
                        if (zArr[i7]) {
                            double abs = Math.abs((performRegression[i6] * dArr2[i7]) / sqrt);
                            if (i6 == 0 || abs < d2) {
                                d2 = abs;
                                i5 = i7;
                            }
                            i6++;
                        }
                    }
                    if (i5 >= 0) {
                        zArr[i5] = false;
                        double[] performRegression2 = performRegression(exampleSet, zArr, dArr, statistics);
                        double squaredError2 = ((getSquaredError(exampleSet, zArr, performRegression2, parameterAsBoolean) / squaredError) * (size2 - i3)) + (2 * i4);
                        if (squaredError2 < d) {
                            z = true;
                            d = squaredError2;
                            performRegression = performRegression2;
                        } else {
                            zArr[i5] = true;
                        }
                    }
                } while (z);
                break;
            case 2:
                do {
                    boolean[] zArr2 = (boolean[]) zArr.clone();
                    z2 = false;
                    i4--;
                    for (int i8 = 0; i8 < zArr.length; i8++) {
                        if (zArr2[i8]) {
                            zArr2[i8] = false;
                            double[] performRegression3 = performRegression(exampleSet, zArr2, dArr, statistics);
                            double squaredError3 = ((getSquaredError(exampleSet, zArr2, performRegression3, parameterAsBoolean) / squaredError) * (size2 - i3)) + (2 * i4);
                            if (squaredError3 < d) {
                                z2 = true;
                                d = squaredError3;
                                System.arraycopy(zArr2, 0, zArr, 0, zArr.length);
                                performRegression = performRegression3;
                            }
                            zArr2[i8] = true;
                        }
                    }
                } while (z2);
        }
        if (z3) {
            exampleSet.getAttributes().remove(attribute);
            exampleSet.getExampleTable().removeAttribute(attribute);
            exampleSet.getAttributes().setLabel(label);
        }
        return new LinearRegressionModel(exampleSet, zArr, performRegression, parameterAsBoolean, str, str2);
    }

    private boolean deselectAttributeWithHighestCoefficient(boolean[] zArr, double[] dArr, double[] dArr2, double d) throws UndefinedParameterError {
        double parameterAsDouble = getParameterAsDouble(PARAMETER_MIN_STANDARDIZED_COEFFICIENT);
        int i = -1;
        int i2 = 0;
        for (int i3 = 0; i3 < zArr.length; i3++) {
            if (zArr[i3]) {
                double abs = Math.abs((dArr[i2] * dArr2[i3]) / d);
                if (abs > parameterAsDouble) {
                    parameterAsDouble = abs;
                    i = i3;
                }
                i2++;
            }
        }
        if (i < 0) {
            return false;
        }
        zArr[i] = false;
        return true;
    }

    private double getSquaredError(ExampleSet exampleSet, boolean[] zArr, double[] dArr, boolean z) {
        double d = 0.0d;
        for (Example example : exampleSet) {
            double regressionPrediction = regressionPrediction(example, zArr, dArr, z) - example.getLabel();
            d += regressionPrediction * regressionPrediction;
        }
        return d;
    }

    private double regressionPrediction(Example example, boolean[] zArr, double[] dArr, boolean z) {
        double d = 0.0d;
        int i = 0;
        int i2 = 0;
        for (Attribute attribute : example.getAttributes()) {
            int i3 = i2;
            i2++;
            if (zArr[i3]) {
                d += dArr[i] * example.getValue(attribute);
                i++;
            }
        }
        if (z) {
            d += dArr[i];
        }
        return d;
    }

    private double[] performRegression(ExampleSet exampleSet, boolean[] zArr, double[] dArr, double d) throws UndefinedParameterError {
        int i = 0;
        for (boolean z : zArr) {
            if (z) {
                i++;
            }
        }
        Matrix matrix = null;
        Matrix matrix2 = null;
        double[] dArr2 = (double[]) null;
        if (i > 0) {
            matrix = new Matrix(exampleSet.size(), i);
            matrix2 = new Matrix(exampleSet.size(), 1);
            int i2 = 0;
            dArr2 = new double[exampleSet.size()];
            Attribute weight = exampleSet.getAttributes().getWeight();
            for (Example example : exampleSet) {
                int i3 = 0;
                matrix2.set(i2, 0, example.getLabel());
                int i4 = 0;
                for (Attribute attribute : exampleSet.getAttributes()) {
                    if (zArr[i4]) {
                        matrix.set(i2, i3, example.getValue(attribute) - dArr[i4]);
                        i3++;
                    }
                    i4++;
                }
                if (weight != null) {
                    dArr2[i2] = example.getValue(weight);
                } else {
                    dArr2[i2] = 1.0d;
                }
                i2++;
            }
        }
        double[] dArr3 = new double[i + 1];
        if (i > 0) {
            System.arraycopy(com.rapidminer.tools.math.LinearRegression.performRegression(matrix, matrix2, dArr2, getParameterAsDouble("ridge")), 0, dArr3, 0, i);
        }
        dArr3[i] = d;
        int i5 = 0;
        for (int i6 = 0; i6 < zArr.length; i6++) {
            if (zArr[i6]) {
                int length = dArr3.length - 1;
                dArr3[length] = dArr3[length] - (dArr3[i5] * dArr[i6]);
                i5++;
            }
        }
        return dArr3;
    }

    @Override // com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        return learnerCapability.equals(LearnerCapability.NUMERICAL_ATTRIBUTES) || learnerCapability.equals(LearnerCapability.NUMERICAL_CLASS) || learnerCapability.equals(LearnerCapability.BINOMINAL_CLASS) || learnerCapability == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeCategory(PARAMETER_FEATURE_SELECTION, "The feature selection method used during regression.", FEATURE_SELECTION_METHODS, 1));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_ELIMINATE_COLINEAR_FEATURES, "Indicates if the algorithm should try to delete colinear features during the regression.", true));
        parameterTypes.add(new ParameterTypeBoolean("use_bias", "Indicates if an intercept value should be calculated.", true));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_MIN_STANDARDIZED_COEFFICIENT, "The minimum standardized coefficient for the removal of colinear feature elimination.", 0.0d, Double.POSITIVE_INFINITY, 1.5d));
        parameterTypes.add(new ParameterTypeDouble("ridge", "The ridge parameter used during ridge regression.", 0.0d, Double.POSITIVE_INFINITY, 1.0E-8d));
        return parameterTypes;
    }
}
