package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.example.set.Condition;
import com.rapidminer.example.set.ConditionedExampleSet;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.performance.EstimatedPerformance;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.math.RunVector;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import marytts.signalproc.adaptation.codebook.WeightedCodebookMapperParams;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/operator/learner/meta/BayBoostStream.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream.class
  input_file:builds/deps.jar:tmp-src.zip:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream.class
  input_file:com/rapidminer/operator/learner/meta/BayBoostStream.class
  input_file:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream.class
  input_file:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream.class
 */
/* loaded from: input_file:tmp-src.zip:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream.class */
public class BayBoostStream extends AbstractMetaLearner {
    public static final String PARAMETER_BATCH_SIZE = "batch_size";
    public static final String PARAMETER_RESCALE_LABEL_PRIORS = "rescale_label_priors";
    public static final String PARAMETER_FRACTION_HOLD_OUT_SET = "fraction_hold_out_set";
    public static final double MIN_ADVANTAGE = 0.02d;
    public static final String STREAM_CONTROL_ATTRIB_NAME = "BayBoostStream.StreamControl";
    public static final double MIN_LIFT_RATIO_SOFT_CLASSIFIER = 0.2d;
    private RunVector runVector;
    private int currentIteration;
    private double performance;
    private double[] oldWeights;

    /* JADX WARN: Classes with same name are omitted:
      input_file:builds/deps.jar:com/rapidminer/operator/learner/meta/BayBoostStream$BatchFilterCondition.class
      input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream$BatchFilterCondition.class
      input_file:builds/deps.jar:tmp-src.zip:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream$BatchFilterCondition.class
      input_file:com/rapidminer/operator/learner/meta/BayBoostStream$BatchFilterCondition.class
      input_file:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream$BatchFilterCondition.class
      input_file:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream$BatchFilterCondition.class
     */
    /* loaded from: input_file:tmp-src.zip:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostStream$BatchFilterCondition.class */
    public static class BatchFilterCondition implements Condition {
        private static final long serialVersionUID = 7910713773299060449L;
        private final int batchNumber;
        private final Attribute attribute;

        public BatchFilterCondition(Attribute attribute, int i) {
            this.batchNumber = i;
            this.attribute = attribute;
        }

        @Override // com.rapidminer.example.set.Condition
        public boolean conditionOk(Example example) {
            return example.getValue(this.attribute) >= ((double) this.batchNumber);
        }

        @Override // com.rapidminer.example.set.Condition
        @Deprecated
        public Condition duplicate() {
            return this;
        }
    }

    public BayBoostStream(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.performance = WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN;
        addValue(new ValueDouble("performance", "The performance.") { // from class: com.rapidminer.operator.learner.meta.BayBoostStream.1
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return BayBoostStream.this.performance;
            }
        });
        addValue(new ValueDouble("iteration", "The current iteration.") { // from class: com.rapidminer.operator.learner.meta.BayBoostStream.2
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return BayBoostStream.this.currentIteration;
            }
        });
    }

    @Override // com.rapidminer.operator.learner.meta.AbstractMetaLearner, com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        if (learnerCapability == LearnerCapability.NUMERICAL_CLASS) {
            return false;
        }
        return super.supportsCapability(learnerCapability);
    }

    protected void prepareWeights(ExampleSet exampleSet) {
        if (exampleSet.getAttributes().getWeight() == null) {
            this.oldWeights = null;
            Tools.createWeightAttribute(exampleSet);
            return;
        }
        this.oldWeights = new double[exampleSet.size()];
        Iterator<Example> it = exampleSet.iterator();
        for (int i = 0; it.hasNext() && i < this.oldWeights.length; i++) {
            Example next = it.next();
            if (next != null) {
                this.oldWeights[i] = next.getWeight();
                next.setWeight(1.0d);
            }
        }
    }

    private void restoreOldWeights(ExampleSet exampleSet) {
        if (this.oldWeights == null) {
            Attribute weight = exampleSet.getAttributes().getWeight();
            exampleSet.getAttributes().remove(weight);
            exampleSet.getExampleTable().removeAttribute(weight);
            return;
        }
        Iterator<Example> it = exampleSet.iterator();
        int i = 0;
        while (it.hasNext() && i < this.oldWeights.length) {
            int i2 = i;
            i++;
            it.next().setWeight(this.oldWeights[i2]);
        }
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Attribute attribute;
        EstimatedPerformance estimatedPerformance;
        this.runVector = new RunVector();
        BayBoostModel bayBoostModel = null;
        BayBoostModel bayBoostModel2 = null;
        Vector<BayBoostBaseModelInfo> vector = new Vector<>();
        Vector<BayBoostBaseModelInfo> vector2 = new Vector<>();
        this.currentIteration = 0;
        int i = 1;
        Attribute attribute2 = exampleSet.getAttributes().get(STREAM_CONTROL_ATTRIB_NAME);
        if (attribute2 == null) {
            attribute = Tools.createSpecialAttribute(exampleSet, STREAM_CONTROL_ATTRIB_NAME, 3);
        } else {
            attribute = attribute2;
            logWarning("Attribute with the (reserved) name of the stream control attribute exists. It is probably an old version created by this operator. Trying to recycle it... ");
            Iterator<Example> it = exampleSet.iterator();
            while (it.hasNext()) {
                it.next().setValue(attribute, WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN);
            }
        }
        if (exampleSet.getAttributes().getWeight() == null) {
            prepareWeights(exampleSet);
        }
        boolean z = true;
        Iterator<Example> it2 = exampleSet.iterator();
        while (it2.hasNext()) {
            int i2 = this.currentIteration + 1;
            this.currentIteration = i2;
            double[] prepareBatch = prepareBatch(i2, it2, attribute);
            ConditionedExampleSet conditionedExampleSet = new ConditionedExampleSet(exampleSet, new BatchFilterCondition(attribute, this.currentIteration));
            if (bayBoostModel2 != null) {
                ConditionedExampleSet conditionedExampleSet2 = (ConditionedExampleSet) bayBoostModel2.apply(conditionedExampleSet);
                this.performance = evaluatePredictions(conditionedExampleSet2);
                conditionedExampleSet = (ConditionedExampleSet) bayBoostModel.apply(conditionedExampleSet2);
                double evaluatePredictions = evaluatePredictions(conditionedExampleSet);
                estimatedPerformance = z ? new EstimatedPerformance("accuracy", this.performance, conditionedExampleSet.size(), false) : new EstimatedPerformance("accuracy", evaluatePredictions, conditionedExampleSet.size(), false);
                if (evaluatePredictions > this.performance) {
                    this.performance = evaluatePredictions;
                    i = Math.max(1, this.currentIteration - 1);
                } else {
                    vector.clear();
                    vector.addAll(vector2);
                }
            } else if (bayBoostModel != null) {
                conditionedExampleSet = (ConditionedExampleSet) bayBoostModel.apply(conditionedExampleSet);
                this.performance = evaluatePredictions(conditionedExampleSet);
                i = Math.max(1, this.currentIteration - 1);
                estimatedPerformance = new EstimatedPerformance("accuracy", this.performance, conditionedExampleSet.size(), false);
            } else {
                estimatedPerformance = null;
            }
            if (estimatedPerformance != null) {
                PerformanceVector performanceVector = new PerformanceVector();
                performanceVector.addAveragable(estimatedPerformance);
                this.runVector.addVector(performanceVector);
            }
            if (getParameterAsBoolean("rescale_label_priors")) {
                rescalePriors(conditionedExampleSet, prepareBatch);
            }
            z = true;
            if (vector.size() > 0) {
                vector2 = new Vector<>();
                Iterator<BayBoostBaseModelInfo> it3 = vector.iterator();
                while (it3.hasNext()) {
                    vector2.add(it3.next());
                }
                double parameterAsDouble = getParameterAsDouble(PARAMETER_FRACTION_HOLD_OUT_SET);
                Vector vector3 = new Vector();
                if (parameterAsDouble > WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN) {
                    RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(0);
                    Iterator<Example> it4 = conditionedExampleSet.iterator();
                    while (it4.hasNext()) {
                        Example next = it4.next();
                        if (randomGenerator.nextDoubleInRange(WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d) <= parameterAsDouble) {
                            next.setValue(attribute, WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN);
                            vector3.add(next);
                        }
                    }
                }
                if (adjustBaseModelWeights(conditionedExampleSet, vector)) {
                    trainAdditionalModel(conditionedExampleSet, vector);
                }
                bayBoostModel = new BayBoostModel(exampleSet, vector, prepareBatch);
                ConditionedExampleSet conditionedExampleSet3 = new ConditionedExampleSet(exampleSet, new BatchFilterCondition(attribute, i));
                double[] prepareExtendedBatch = prepareExtendedBatch(conditionedExampleSet3);
                if (getParameterAsBoolean("rescale_label_priors")) {
                    rescalePriors(conditionedExampleSet3, prepareExtendedBatch);
                }
                vector2.remove(vector2.size() - 1);
                if (!adjustBaseModelWeights(conditionedExampleSet3, vector2)) {
                    bayBoostModel2 = new BayBoostModel(exampleSet, vector2, prepareExtendedBatch);
                } else if (trainAdditionalModel(conditionedExampleSet3, vector2)) {
                    bayBoostModel2 = new BayBoostModel(exampleSet, vector2, prepareExtendedBatch);
                } else {
                    bayBoostModel2 = null;
                    z = false;
                }
                if (parameterAsDouble > WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN) {
                    Iterator it5 = vector3.iterator();
                    while (it5.hasNext()) {
                        ((Example) it5.next()).setValue(attribute, this.currentIteration);
                    }
                    if (bayBoostModel2 != null) {
                        ConditionedExampleSet conditionedExampleSet4 = (ConditionedExampleSet) bayBoostModel.apply(conditionedExampleSet);
                        Iterator it6 = vector3.iterator();
                        int i3 = 0;
                        while (it6.hasNext()) {
                            Example example = (Example) it6.next();
                            if (example.getPredictedLabel() != example.getLabel()) {
                                i3++;
                            }
                        }
                        double size = i3 / vector3.size();
                        ConditionedExampleSet conditionedExampleSet5 = (ConditionedExampleSet) bayBoostModel2.apply(conditionedExampleSet4);
                        Iterator it7 = vector3.iterator();
                        int i4 = 0;
                        while (it7.hasNext()) {
                            Example example2 = (Example) it7.next();
                            if (example2.getPredictedLabel() != example2.getLabel()) {
                                i4++;
                            }
                        }
                        z = ((double) i4) / ((double) vector3.size()) <= size;
                        if (z) {
                            bayBoostModel2 = retrainLastWeight(bayBoostModel2, conditionedExampleSet5, vector3);
                        } else {
                            bayBoostModel = retrainLastWeight(bayBoostModel, conditionedExampleSet5, vector3);
                        }
                    } else {
                        bayBoostModel = retrainLastWeight(bayBoostModel, conditionedExampleSet, vector3);
                    }
                }
            } else {
                trainAdditionalModel(conditionedExampleSet, vector);
                bayBoostModel = new BayBoostModel(exampleSet, vector, prepareBatch);
                bayBoostModel2 = null;
                z = false;
            }
        }
        restoreOldWeights(exampleSet);
        return bayBoostModel2 == null ? bayBoostModel : bayBoostModel2;
    }

    private BayBoostModel retrainLastWeight(BayBoostModel bayBoostModel, ExampleSet exampleSet, Vector vector) throws OperatorException {
        prepareExtendedBatch(exampleSet);
        int numberOfModels = bayBoostModel.getNumberOfModels();
        Vector vector2 = new Vector();
        double[] priors = bayBoostModel.getPriors();
        for (int i = 0; i < numberOfModels - 1; i++) {
            Model model = bayBoostModel.getModel(i);
            ContingencyMatrix contingencyMatrix = bayBoostModel.getContingencyMatrix(i);
            vector2.add(new BayBoostBaseModelInfo(model, contingencyMatrix));
            exampleSet = model.apply(exampleSet);
            WeightedPerformanceMeasures.reweightExamples(exampleSet, contingencyMatrix, false);
        }
        Model model2 = bayBoostModel.getModel(numberOfModels - 1);
        ExampleSet apply = model2.apply(exampleSet);
        double[] dArr = new double[vector.size()];
        Iterator it = vector.iterator();
        int i2 = 0;
        while (it.hasNext()) {
            int i3 = i2;
            i2++;
            dArr[i3] = ((Example) it.next()).getWeight();
        }
        Iterator<Example> it2 = apply.iterator();
        while (it2.hasNext()) {
            it2.next().setWeight(WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN);
        }
        Iterator it3 = vector.iterator();
        int i4 = 0;
        while (it3.hasNext()) {
            int i5 = i4;
            i4++;
            ((Example) it3.next()).setWeight(dArr[i5]);
        }
        vector2.add(new BayBoostBaseModelInfo(model2, new WeightedPerformanceMeasures(apply).getContingencyMatrix()));
        return new BayBoostModel(apply, vector2, priors);
    }

    @Override // com.rapidminer.operator.learner.meta.AbstractMetaLearner, com.rapidminer.operator.OperatorChain, com.rapidminer.operator.Operator
    public IOObject[] apply() throws OperatorException {
        IOObject[] iOObjectArr;
        IOObject[] apply = super.apply();
        if (apply != null) {
            iOObjectArr = new IOObject[apply.length + 1];
            for (int i = 0; i < apply.length; i++) {
                iOObjectArr[i] = apply[i];
            }
        } else {
            iOObjectArr = new IOObject[1];
        }
        iOObjectArr[iOObjectArr.length - 1] = this.runVector;
        return iOObjectArr;
    }

    @Override // com.rapidminer.operator.learner.meta.AbstractMetaLearner, com.rapidminer.operator.Operator
    public Class<?>[] getOutputClasses() {
        Class<?>[] outputClasses = super.getOutputClasses();
        Class<?>[] clsArr = new Class[outputClasses.length + 1];
        for (int i = 0; i < outputClasses.length; i++) {
            clsArr[i] = outputClasses[i];
        }
        clsArr[clsArr.length - 1] = RunVector.class;
        return clsArr;
    }

    private void rescalePriors(ExampleSet exampleSet, double[] dArr) {
        double[] dArr2 = new double[2];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = 1.0d / (dArr2.length * dArr[i]);
        }
        for (Example example : exampleSet) {
            example.setWeight(dArr2[(int) example.getLabel()]);
        }
    }

    private Model trainBaseModel(ExampleSet exampleSet) throws OperatorException {
        Model applyInnerLearner = applyInnerLearner(exampleSet);
        createOrReplacePredictedLabelFor(exampleSet, applyInnerLearner);
        return applyInnerLearner;
    }

    private double[] prepareBatch(int i, Iterator<Example> it, Attribute attribute) throws UndefinedParameterError {
        int parameterAsInt = getParameterAsInt("batch_size");
        int i2 = 0;
        int[] iArr = new int[2];
        while (true) {
            int i3 = i2;
            i2++;
            if (i3 >= parameterAsInt || !it.hasNext()) {
                break;
            }
            Example next = it.next();
            next.setValue(attribute, i);
            next.setWeight(1.0d);
            int label = (int) next.getLabel();
            iArr[label] = iArr[label] + 1;
        }
        int i4 = i2 - 1;
        return new double[]{iArr[0] / i4, iArr[1] / i4};
    }

    private double[] prepareExtendedBatch(ExampleSet exampleSet) {
        int[] iArr = new int[2];
        for (Example example : exampleSet) {
            example.setWeight(1.0d);
            int label = (int) example.getLabel();
            iArr[label] = iArr[label] + 1;
        }
        int i = iArr[0] + iArr[1];
        return new double[]{iArr[0] / i, iArr[1] / i};
    }

    private double evaluatePredictions(ExampleSet exampleSet) {
        int i = 0;
        int i2 = 0;
        for (Example example : exampleSet) {
            i++;
            if (example.getLabel() == example.getPredictedLabel()) {
                i2++;
            }
        }
        return i2 / i;
    }

    private boolean trainAdditionalModel(ExampleSet exampleSet, Vector<BayBoostBaseModelInfo> vector) throws OperatorException {
        Model trainBaseModel = trainBaseModel(exampleSet);
        WeightedPerformanceMeasures weightedPerformanceMeasures = new WeightedPerformanceMeasures(trainBaseModel.apply(exampleSet));
        if (isModelUseful(weightedPerformanceMeasures.getContingencyMatrix())) {
            vector.add(new BayBoostBaseModelInfo(trainBaseModel, weightedPerformanceMeasures.getContingencyMatrix()));
            return true;
        }
        log("Discard model because of low advantage on training data.");
        return false;
    }

    private boolean adjustBaseModelWeights(ExampleSet exampleSet, Vector<BayBoostBaseModelInfo> vector) throws OperatorException {
        int i = 0;
        while (i < vector.size()) {
            BayBoostBaseModelInfo bayBoostBaseModelInfo = vector.get(i);
            Model model = bayBoostBaseModelInfo.getModel();
            ContingencyMatrix contingencyMatrix = bayBoostBaseModelInfo.getContingencyMatrix();
            createOrReplacePredictedLabelFor(exampleSet, model);
            exampleSet = model.apply(exampleSet);
            if (!exampleSet.getAttributes().getPredictedLabel().isNominal()) {
                throw new UserError(this, 101, exampleSet.getAttributes().getLabel(), "BayBoostStream base learners");
            }
            ContingencyMatrix contingencyMatrix2 = new WeightedPerformanceMeasures(exampleSet).getContingencyMatrix();
            if (isModelUseful(contingencyMatrix)) {
                vector.set(i, new BayBoostBaseModelInfo(model, contingencyMatrix2));
                if (!(WeightedPerformanceMeasures.reweightExamples(exampleSet, contingencyMatrix2, false) > WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN)) {
                    return false;
                }
            } else {
                vector.remove(i);
                i--;
                log("Discard base model because of low advantage.");
            }
            i++;
        }
        return true;
    }

    private boolean isModelUseful(ContingencyMatrix contingencyMatrix) {
        for (int i = 0; i < contingencyMatrix.getNumberOfPredictions(); i++) {
            for (int i2 = 0; i2 < contingencyMatrix.getNumberOfClasses(); i2++) {
                if (Math.abs(contingencyMatrix.getLift(i, i2) - 1.0d) > 0.02d) {
                    return true;
                }
            }
        }
        return false;
    }

    private static void createOrReplacePredictedLabelFor(ExampleSet exampleSet, Model model) {
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        if (predictedLabel != null) {
            exampleSet.getAttributes().remove(predictedLabel);
            exampleSet.getExampleTable().removeAttribute(predictedLabel);
        }
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeBoolean("rescale_label_priors", "Specifies whether the proportion of labels should be equal by construction after first iteration .", false));
        parameterTypes.add(new ParameterTypeInt("batch_size", "Size of the batches. Minimum number of examples used to train a model.", 1, Integer.MAX_VALUE, 100));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_FRACTION_HOLD_OUT_SET, "Rel. size of hold out set for ensemble selection. Set to 0 to turn off.", WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d, WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN));
        return parameterTypes;
    }
}
