/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.visualization;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.condition.AllInnerOperatorCondition;
import com.rapidminer.operator.condition.InnerOperatorCondition;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.visualization.ROCComparison;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.math.ROCData;
import com.rapidminer.tools.math.ROCDataGenerator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ROCBasedComparisonOperator
extends OperatorChain {
    public static final String PARAMETER_NUMBER_OF_FOLDS = "number_of_folds";
    public static final String PARAMETER_SPLIT_RATIO = "split_ratio";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";
    public static final String PARAMETER_USE_EXAMPLE_WEIGHTS = "use_example_weights";

    public ROCBasedComparisonOperator(OperatorDescription description) {
        super(description);
    }

    @Override
    public IOObject[] apply() throws OperatorException {
        ExampleSet exampleSet = this.getInput(ExampleSet.class);
        if (exampleSet.getAttributes().getLabel() == null) {
            throw new UserError(this, 105);
        }
        if (!exampleSet.getAttributes().getLabel().isNominal()) {
            throw new UserError((Operator)this, 101, "ROC Comparison", exampleSet.getAttributes().getLabel());
        }
        if (exampleSet.getAttributes().getLabel().getMapping().getValues().size() != 2) {
            throw new UserError((Operator)this, 114, "ROC Comparison", exampleSet.getAttributes().getLabel());
        }
        HashMap<String, List<ROCData>> rocData = new HashMap<String, List<ROCData>>();
        int numberOfFolds = this.getParameterAsInt(PARAMETER_NUMBER_OF_FOLDS);
        if (numberOfFolds < 0) {
            double splitRatio = this.getParameterAsDouble(PARAMETER_SPLIT_RATIO);
            SplittedExampleSet eSet = new SplittedExampleSet((ExampleSet)exampleSet.clone(), splitRatio, this.getParameterAsInt(PARAMETER_SAMPLING_TYPE), this.getParameterAsInt(PARAMETER_LOCAL_RANDOM_SEED));
            PredictionModel.removePredictedLabel(eSet);
            int i = 0;
            while (i < this.getNumberOfOperators()) {
                eSet.selectSingleSubset(0);
                Operator innerOperator = this.getOperator(i);
                IOContainer result = innerOperator.apply(new IOContainer(eSet));
                Model model = result.remove(Model.class);
                eSet.selectSingleSubset(1);
                ExampleSet resultSet = model.apply(eSet);
                if (resultSet.getAttributes().getPredictedLabel() == null) {
                    throw new UserError(this, 107);
                }
                ROCDataGenerator rocDataGenerator = new ROCDataGenerator(1.0, 1.0);
                ROCData rocPoints = rocDataGenerator.createROCData(resultSet, this.getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS));
                LinkedList<ROCData> dataList = new LinkedList<ROCData>();
                dataList.add(rocPoints);
                rocData.put(innerOperator.getName(), dataList);
                PredictionModel.removePredictedLabel(resultSet);
                ++i;
            }
        } else {
            SplittedExampleSet eSet = new SplittedExampleSet((ExampleSet)exampleSet.clone(), numberOfFolds, this.getParameterAsInt(PARAMETER_SAMPLING_TYPE), this.getParameterAsInt(PARAMETER_LOCAL_RANDOM_SEED));
            PredictionModel.removePredictedLabel(eSet);
            int i = 0;
            while (i < this.getNumberOfOperators()) {
                Operator innerOperator = this.getOperator(i);
                LinkedList<ROCData> dataList = new LinkedList<ROCData>();
                int iteration = 0;
                while (iteration < numberOfFolds) {
                    eSet.selectAllSubsetsBut(iteration);
                    IOContainer result = innerOperator.apply(new IOContainer(eSet));
                    Model model = result.remove(Model.class);
                    eSet.selectSingleSubset(iteration);
                    ExampleSet resultSet = model.apply(eSet);
                    if (resultSet.getAttributes().getPredictedLabel() == null) {
                        throw new UserError(this, 107);
                    }
                    ROCDataGenerator rocDataGenerator = new ROCDataGenerator(1.0, 1.0);
                    ROCData rocPoints = rocDataGenerator.createROCData(resultSet, this.getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS));
                    dataList.add(rocPoints);
                    PredictionModel.removePredictedLabel(resultSet);
                    this.inApplyLoop();
                    ++iteration;
                }
                rocData.put(innerOperator.getName(), dataList);
                ++i;
            }
        }
        return new IOObject[]{exampleSet, new ROCComparison(rocData)};
    }

    @Override
    public Class<?>[] getInputClasses() {
        return new Class[]{ExampleSet.class};
    }

    @Override
    public Class<?>[] getOutputClasses() {
        return new Class[]{ExampleSet.class, ROCComparison.class};
    }

    @Override
    public InnerOperatorCondition getInnerOperatorCondition() {
        return new AllInnerOperatorCondition(new Class[]{ExampleSet.class}, new Class[]{Model.class});
    }

    @Override
    public int getMinNumberOfInnerOperators() {
        return 1;
    }

    @Override
    public int getMaxNumberOfInnerOperators() {
        return Integer.MAX_VALUE;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeInt type = new ParameterTypeInt(PARAMETER_NUMBER_OF_FOLDS, "The number of folds used for a cross validation evaluation (-1: use simple split ratio).", -1, Integer.MAX_VALUE, 10);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeDouble(PARAMETER_SPLIT_RATIO, "Relative size of the training set", 0.0, 1.0, 0.7));
        types.add(new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2));
        types.add(new ParameterTypeInt(PARAMETER_LOCAL_RANDOM_SEED, "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        types.add(new ParameterTypeBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS, "Indicates if example weights should be regarded (use weight 1 for each example otherwise).", true));
        return types;
    }
}

