package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.MissingIOObjectException;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.Tools;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

/* loaded from: input_file:WEB-INF/lib/rapidMiner-1.0.0.jar:com/rapidminer/operator/learner/meta/BayesianBoosting.class */
public class BayesianBoosting extends AbstractMetaLearner {
    public static final String PARAMETER_ITERATIONS = "iterations";
    public static final String PARAMETER_USE_SUBSET_FOR_TRAINING = "use_subset_for_training";
    public static final String PARAMETER_RESCALE_LABEL_PRIORS = "rescale_label_priors";
    public static final String PARAMETER_ALLOW_MARGINAL_SKEWS = "allow_marginal_skews";
    public static final double MIN_ADVANTAGE = 0.001d;
    private Model startModel;
    protected int currentIteration;
    private double performance;
    private double[] oldWeights;

    public BayesianBoosting(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.performance = 0.0d;
        addValue(new ValueDouble("performance", "The performance.") { // from class: com.rapidminer.operator.learner.meta.BayesianBoosting.1
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return BayesianBoosting.this.performance;
            }
        });
        addValue(new ValueDouble("iteration", "The current iteration.") { // from class: com.rapidminer.operator.learner.meta.BayesianBoosting.2
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return BayesianBoosting.this.currentIteration;
            }
        });
    }

    @Override // com.rapidminer.operator.learner.meta.AbstractMetaLearner, com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        if (learnerCapability == LearnerCapability.NUMERICAL_CLASS || learnerCapability == LearnerCapability.POLYNOMINAL_CLASS) {
            return false;
        }
        if (learnerCapability == LearnerCapability.WEIGHTED_EXAMPLES) {
            return true;
        }
        return super.supportsCapability(learnerCapability);
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        readOptionalParameters();
        double[] prepareWeights = prepareWeights(exampleSet);
        double d = Double.NEGATIVE_INFINITY;
        double d2 = 0.0d;
        for (int i = 0; i < prepareWeights.length; i++) {
            if (prepareWeights[i] > d) {
                d = prepareWeights[i];
            }
            d2 += prepareWeights[i];
        }
        BayBoostModel bayBoostModel = Tools.isEqual(d2, d) ? new BayBoostModel(exampleSet, new Vector(), prepareWeights) : trainBoostingModel(exampleSet, prepareWeights);
        if (this.oldWeights != null) {
            Iterator<Example> it = exampleSet.iterator();
            int i2 = 0;
            while (it.hasNext() && i2 < this.oldWeights.length) {
                int i3 = i2;
                i2++;
                it.next().setWeight(this.oldWeights[i3]);
            }
        } else {
            Attribute weight = exampleSet.getAttributes().getWeight();
            exampleSet.getAttributes().remove(weight);
            exampleSet.getExampleTable().removeAttribute(weight);
        }
        return bayBoostModel;
    }

    protected double[] prepareWeights(ExampleSet exampleSet) {
        if (exampleSet.getAttributes().getWeight() == null) {
            this.oldWeights = null;
            this.performance = exampleSet.size();
            return createNewWeightAttribute(exampleSet);
        }
        this.oldWeights = new double[exampleSet.size()];
        double[] dArr = new double[exampleSet.getAttributes().getLabel().getMapping().size()];
        double d = 0.0d;
        Iterator<Example> it = exampleSet.iterator();
        for (int i = 0; it.hasNext() && i < this.oldWeights.length; i++) {
            Example next = it.next();
            if (next != null) {
                double weight = next.getWeight();
                this.oldWeights[i] = weight;
                int label = (int) next.getLabel();
                if (label < 0 || label >= dArr.length) {
                    next.setWeight(0.0d);
                } else {
                    dArr[label] = dArr[label] + weight;
                    d += weight;
                }
            }
        }
        this.performance = d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
        }
        return dArr;
    }

    private double[] createNewWeightAttribute(ExampleSet exampleSet) {
        com.rapidminer.example.Tools.createWeightAttribute(exampleSet);
        Iterator<Example> it = exampleSet.iterator();
        double[] dArr = new double[exampleSet.getAttributes().getLabel().getMapping().getValues().size()];
        double size = 1.0d / exampleSet.size();
        if (getParameterAsBoolean("rescale_label_priors")) {
            while (it.hasNext()) {
                int label = (int) it.next().getLabel();
                dArr[label] = dArr[label] + size;
            }
            rescaleToEqualPriors(exampleSet, dArr);
        } else {
            while (it.hasNext()) {
                Example next = it.next();
                next.setWeight(1.0d);
                int label2 = (int) next.getLabel();
                dArr[label2] = dArr[label2] + size;
            }
        }
        return dArr;
    }

    private void rescaleToEqualPriors(ExampleSet exampleSet, double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = 1.0d / (dArr2.length * dArr[i]);
        }
        for (Example example : exampleSet) {
            example.setWeight(dArr2[(int) example.getLabel()]);
        }
    }

    protected Model trainBaseModel(ExampleSet exampleSet) throws OperatorException {
        return applyInnerLearner(exampleSet);
    }

    private void readOptionalParameters() {
        try {
            this.startModel = (Model) getInput(Model.class);
        } catch (MissingIOObjectException e) {
            log(String.valueOf(getName()) + ": No model found in input.");
        }
    }

    private void applyPriorModel(ExampleSet exampleSet, List<BayBoostBaseModelInfo> list) throws OperatorException {
        if (this.startModel != null) {
            ExampleSet apply = this.startModel.apply((ExampleSet) exampleSet.clone());
            WeightedPerformanceMeasures weightedPerformanceMeasures = new WeightedPerformanceMeasures(apply);
            reweightExamples(weightedPerformanceMeasures, apply);
            list.add(new BayBoostBaseModelInfo(this.startModel, weightedPerformanceMeasures.getContingencyMatrix()));
            PredictionModel.removePredictedLabel(apply);
        }
    }

    private BayBoostModel trainBoostingModel(ExampleSet exampleSet, double[] dArr) throws OperatorException {
        Model trainBaseModel;
        ExampleSet apply;
        WeightedPerformanceMeasures weightedPerformanceMeasures;
        Vector vector = new Vector();
        applyPriorModel(exampleSet, vector);
        double parameterAsDouble = getParameterAsDouble("use_subset_for_training");
        boolean z = parameterAsDouble > 0.0d && parameterAsDouble < 1.0d;
        log(z ? "Bootstrapping enabled." : "Bootstrapping disabled.");
        boolean parameterAsBoolean = getParameterAsBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS);
        SplittedExampleSet splittedExampleSet = null;
        if (z) {
            splittedExampleSet = new SplittedExampleSet(exampleSet, parameterAsDouble, 1, -1);
        }
        int parameterAsInt = getParameterAsInt("iterations");
        int i = 0;
        while (true) {
            if (i >= parameterAsInt) {
                break;
            }
            this.currentIteration = i;
            ExampleSet exampleSet2 = (ExampleSet) exampleSet.clone();
            if (z) {
                splittedExampleSet.selectSingleSubset(0);
                trainBaseModel = trainBaseModel(splittedExampleSet);
                apply = trainBaseModel.apply(exampleSet2);
                WeightedPerformanceMeasures.reweightExamples(splittedExampleSet, new WeightedPerformanceMeasures(splittedExampleSet).getContingencyMatrix(), parameterAsBoolean);
                splittedExampleSet.selectSingleSubset(1);
                weightedPerformanceMeasures = new WeightedPerformanceMeasures(splittedExampleSet);
                this.performance = WeightedPerformanceMeasures.reweightExamples(splittedExampleSet, weightedPerformanceMeasures.getContingencyMatrix(), parameterAsBoolean);
            } else {
                trainBaseModel = trainBaseModel(exampleSet2);
                apply = trainBaseModel.apply(exampleSet2);
                weightedPerformanceMeasures = new WeightedPerformanceMeasures(apply);
                this.performance = reweightExamples(weightedPerformanceMeasures, apply);
            }
            PredictionModel.removePredictedLabel(apply);
            int length = dArr.length;
            if (weightedPerformanceMeasures.getNumberOfNonEmptyClasses() < 2) {
                vector.add(new BayBoostBaseModelInfo(trainBaseModel, weightedPerformanceMeasures.getContingencyMatrix()));
                break;
            }
            ContingencyMatrix contingencyMatrix = weightedPerformanceMeasures.getContingencyMatrix();
            vector.add(new BayBoostBaseModelInfo(trainBaseModel, contingencyMatrix));
            if (!isModelUseful(contingencyMatrix)) {
                log("Discard model because of low advantage on training data.");
                vector.remove(vector.size() - 1);
                break;
            }
            if (this.performance == 0.0d) {
                break;
            }
            inApplyLoop();
            i++;
        }
        return new BayBoostModel(exampleSet, vector, dArr);
    }

    protected double reweightExamples(WeightedPerformanceMeasures weightedPerformanceMeasures, ExampleSet exampleSet) throws OperatorException {
        return WeightedPerformanceMeasures.reweightExamples(exampleSet, weightedPerformanceMeasures.getContingencyMatrix(), getParameterAsBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS));
    }

    private boolean isModelUseful(ContingencyMatrix contingencyMatrix) {
        return true;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeBoolean("rescale_label_priors", "Specifies whether the proportion of labels should be equal by construction after first iteration .", false));
        parameterTypes.add(new ParameterTypeDouble("use_subset_for_training", "Fraction of examples used for training, remaining ones are used to estimate the confusion matrix. Set to 1 to turn off test set.", 0.0d, 1.0d, 1.0d));
        parameterTypes.add(new ParameterTypeInt("iterations", "The maximum number of iterations.", 1, Integer.MAX_VALUE, 10));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS, "Allow to skew the marginal distribution (P(x)) during learning.", true));
        return parameterTypes;
    }
}
