package com.rapidminer.operator.validation;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.math.AverageVector;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/validation/FixedSplitValidationChain.class */
public class FixedSplitValidationChain extends ValidationChain {
    public static final String PARAMETER_TRAINING_SET_SIZE = "training_set_size";
    public static final String PARAMETER_TEST_SET_SIZE = "test_set_size";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";

    public FixedSplitValidationChain(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.validation.ValidationChain
    public IOObject[] estimatePerformance(ExampleSet exampleSet) throws OperatorException {
        int parameterAsInt = getParameterAsInt(PARAMETER_TRAINING_SET_SIZE);
        int parameterAsInt2 = getParameterAsInt(PARAMETER_TEST_SET_SIZE);
        int size = exampleSet.size();
        if (size < parameterAsInt + parameterAsInt2) {
            throw new UserError(this, 110, String.valueOf(parameterAsInt + parameterAsInt2) + " (" + parameterAsInt + " for training, " + parameterAsInt2 + " for testing)");
        }
        int i = size - (parameterAsInt + parameterAsInt2);
        if (parameterAsInt < 1 && parameterAsInt2 < 1) {
            throw new UserError(this, 116, "training_set_size / test_set_size", "either training_set_size or test_set_size or both must be greater than 1.");
        }
        if (parameterAsInt2 < 1) {
            i = 0;
            parameterAsInt2 = size - parameterAsInt;
        } else if (parameterAsInt < 1) {
            i = 0;
            parameterAsInt = size - parameterAsInt2;
        }
        log("Using " + parameterAsInt + " examples for learning and " + parameterAsInt2 + " examples for testing. " + i + " examples are not used.");
        SplittedExampleSet splittedExampleSet = new SplittedExampleSet(exampleSet, new double[]{parameterAsInt / size, parameterAsInt2 / size, i / size}, getParameterAsInt("sampling_type"), getParameterAsInt("local_random_seed"));
        splittedExampleSet.selectSingleSubset(0);
        learn(splittedExampleSet);
        splittedExampleSet.selectSingleSubset(1);
        IOContainer evaluate = evaluate(splittedExampleSet);
        LinkedList linkedList = new LinkedList();
        Tools.handleAverages(evaluate, linkedList);
        PerformanceVector performanceVector = Tools.getPerformanceVector(linkedList);
        if (performanceVector != null) {
            setResult(performanceVector);
        }
        AverageVector[] averageVectorArr = new AverageVector[linkedList.size()];
        linkedList.toArray(averageVectorArr);
        return averageVectorArr;
    }

    @Override // com.rapidminer.operator.validation.ValidationChain, com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt(PARAMETER_TRAINING_SET_SIZE, "Absolute size required for the training set (-1: use rest for training)", -1, Integer.MAX_VALUE, 100);
        parameterTypeInt.setExpert(false);
        parameterTypes.add(parameterTypeInt);
        ParameterTypeInt parameterTypeInt2 = new ParameterTypeInt(PARAMETER_TEST_SET_SIZE, "Absolute size required for the test set (-1: use rest for testing)", -1, Integer.MAX_VALUE, -1);
        parameterTypeInt2.setExpert(false);
        parameterTypes.add(parameterTypeInt2);
        parameterTypes.add(new ParameterTypeCategory("sampling_type", "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 1));
        parameterTypes.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        return parameterTypes;
    }
}
