/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.validation;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.validation.Tools;
import com.rapidminer.operator.validation.ValidationChain;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.math.AverageVector;
import java.util.LinkedList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class FixedSplitValidationChain
extends ValidationChain {
    public static final String PARAMETER_TRAINING_SET_SIZE = "training_set_size";
    public static final String PARAMETER_TEST_SET_SIZE = "test_set_size";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";

    public FixedSplitValidationChain(OperatorDescription description) {
        super(description);
    }

    @Override
    public IOObject[] estimatePerformance(ExampleSet inputSet) throws OperatorException {
        int trainingSetSize = this.getParameterAsInt(PARAMETER_TRAINING_SET_SIZE);
        int testSetSize = this.getParameterAsInt(PARAMETER_TEST_SET_SIZE);
        int inputSetSize = inputSet.size();
        if (inputSetSize < trainingSetSize + testSetSize) {
            throw new UserError((Operator)this, 110, String.valueOf(trainingSetSize + testSetSize) + " (" + trainingSetSize + " for training, " + testSetSize + " for testing)");
        }
        int rest = inputSetSize - (trainingSetSize + testSetSize);
        if (trainingSetSize < 1 && testSetSize < 1) {
            throw new UserError((Operator)this, 116, "training_set_size / test_set_size", "either training_set_size or test_set_size or both must be greater than 1.");
        }
        if (testSetSize < 1) {
            rest = 0;
            testSetSize = inputSetSize - trainingSetSize;
        } else if (trainingSetSize < 1) {
            rest = 0;
            trainingSetSize = inputSetSize - testSetSize;
        }
        this.log("Using " + trainingSetSize + " examples for learning and " + testSetSize + " examples for testing. " + rest + " examples are not used.");
        double[] ratios = new double[]{(double)trainingSetSize / (double)inputSetSize, (double)testSetSize / (double)inputSetSize, (double)rest / (double)inputSetSize};
        SplittedExampleSet eSet = new SplittedExampleSet(inputSet, ratios, this.getParameterAsInt(PARAMETER_SAMPLING_TYPE), this.getParameterAsInt(PARAMETER_LOCAL_RANDOM_SEED));
        eSet.selectSingleSubset(0);
        this.learn(eSet);
        eSet.selectSingleSubset(1);
        IOContainer evalRes = this.evaluate(eSet);
        LinkedList<AverageVector> averageVectors = new LinkedList<AverageVector>();
        Tools.handleAverages(evalRes, averageVectors);
        PerformanceVector performanceVector = Tools.getPerformanceVector(averageVectors);
        if (performanceVector != null) {
            this.setResult(performanceVector);
        }
        IOObject[] result = new AverageVector[averageVectors.size()];
        averageVectors.toArray(result);
        return result;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeInt type = new ParameterTypeInt(PARAMETER_TRAINING_SET_SIZE, "Absolute size required for the training set (-1: use rest for training)", -1, Integer.MAX_VALUE, 100);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeInt(PARAMETER_TEST_SET_SIZE, "Absolute size required for the test set (-1: use rest for testing)", -1, Integer.MAX_VALUE, -1);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 1));
        types.add(new ParameterTypeInt(PARAMETER_LOCAL_RANDOM_SEED, "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        return types;
    }
}

