/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.lazy;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.InputDescription;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.lazy.KNNClassificationModel;
import com.rapidminer.operator.learner.lazy.KNNRegressionModel;
import com.rapidminer.operator.similarity.SimilarityMeasure;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.math.container.LinearList;
import com.rapidminer.tools.math.similarity.DistanceMeasure;
import com.rapidminer.tools.math.similarity.DistanceMeasures;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class KNNLearner
extends AbstractLearner {
    public static final String PARAMETER_K = "k";
    public static final String PARAMETER_WEIGHTED_VOTE = "weighted_vote";

    public KNNLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        DistanceMeasure measure = DistanceMeasures.createMeasure(this, exampleSet, this.getInput());
        Attribute label = exampleSet.getAttributes().getLabel();
        if (label.isNominal()) {
            LinearList<Integer> samples = new LinearList<Integer>(measure);
            Attributes attributes = exampleSet.getAttributes();
            int valuesSize = attributes.size();
            for (Example example : exampleSet) {
                double[] values = new double[valuesSize];
                int i = 0;
                for (Attribute attribute : attributes) {
                    values[i] = example.getValue(attribute);
                    ++i;
                }
                int labelValue = (int)example.getValue(label);
                samples.add(values, labelValue);
                this.checkForStop();
            }
            return new KNNClassificationModel(exampleSet, samples, this.getParameterAsInt(PARAMETER_K), this.getParameterAsBoolean(PARAMETER_WEIGHTED_VOTE));
        }
        LinearList<Double> samples = new LinearList<Double>(measure);
        Attributes attributes = exampleSet.getAttributes();
        int valuesSize = attributes.size();
        for (Example example : exampleSet) {
            double[] values = new double[valuesSize];
            int i = 0;
            for (Attribute attribute : attributes) {
                values[i] = example.getValue(attribute);
                ++i;
            }
            double labelValue = example.getValue(label);
            samples.add(values, labelValue);
            this.checkForStop();
        }
        return new KNNRegressionModel(exampleSet, samples, this.getParameterAsInt(PARAMETER_K), this.getParameterAsBoolean(PARAMETER_WEIGHTED_VOTE));
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.POLYNOMINAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.POLYNOMINAL_CLASS) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_CLASS) {
            return true;
        }
        if (lc == LearnerCapability.NUMERICAL_CLASS) {
            return true;
        }
        return lc == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    @Override
    public InputDescription getInputDescription(Class cls) {
        if (SimilarityMeasure.class.isAssignableFrom(cls)) {
            return new InputDescription(cls, false, true);
        }
        return super.getInputDescription(cls);
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeInt type = new ParameterTypeInt(PARAMETER_K, "The used number of nearest neighbors.", 1, Integer.MAX_VALUE, 1);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeBoolean(PARAMETER_WEIGHTED_VOTE, "Indicates if the votes should be weighted by similarity.", false));
        types.addAll(DistanceMeasures.getParameterTypes(this));
        return types;
    }
}

