/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.lazy;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.container.Tupel;
import com.rapidminer.tools.math.container.GeometricDataCollection;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class KNNClassificationModel
extends PredictionModel {
    private static final long serialVersionUID = -6292869962412072573L;
    private int k;
    private GeometricDataCollection<Integer> samples;
    private ArrayList<String> sampleAttributeNames;
    private boolean weightByDistance;

    public KNNClassificationModel(ExampleSet trainingSet, GeometricDataCollection<Integer> samples, int k, boolean weightByDistance) {
        super(trainingSet);
        this.k = k;
        this.samples = samples;
        this.weightByDistance = weightByDistance;
        Attributes attributes = trainingSet.getAttributes();
        this.sampleAttributeNames = new ArrayList(attributes.size());
        for (Attribute attribute : attributes) {
            this.sampleAttributeNames.add(attribute.getName());
        }
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        ArrayList<Attribute> sampleAttributes = new ArrayList<Attribute>(this.sampleAttributeNames.size());
        Attributes attributes = exampleSet.getAttributes();
        for (String attributeName : this.sampleAttributeNames) {
            sampleAttributes.add(attributes.get(attributeName));
        }
        double[] values = new double[sampleAttributes.size()];
        for (Example example : exampleSet) {
            int i = 0;
            for (Attribute attribute : sampleAttributes) {
                values[i] = example.getValue(attribute);
                ++i;
            }
            double[] counter = new double[predictedLabel.getMapping().size()];
            double totalDistance = 0.0;
            if (!this.weightByDistance) {
                Collection<Integer> neighbourLabels = this.samples.getNearestValues(this.k, values);
                totalDistance = this.k;
                Iterator iterator = neighbourLabels.iterator();
                while (iterator.hasNext()) {
                    int n;
                    int n2 = n = ((Integer)iterator.next()).intValue();
                    counter[n2] = counter[n2] + 1.0 / totalDistance;
                }
            } else {
                Collection<Tupel<Double, Integer>> neighbours = this.samples.getNearestValueDistances(this.k, values);
                for (Tupel tupel : neighbours) {
                    totalDistance += ((Double)tupel.getFirst()).doubleValue();
                }
                double d = 0.0;
                if (totalDistance == 0.0) {
                    totalDistance = 1.0;
                    d = this.k;
                } else {
                    d = Math.max(this.k - 1, 1);
                }
                for (Tupel<Double, Integer> tupel : neighbours) {
                    int n = tupel.getSecond();
                    counter[n] = counter[n] + (1.0 - tupel.getFirst() / totalDistance) / d;
                }
            }
            int mostFrequentIndex = Integer.MIN_VALUE;
            double d = Double.NEGATIVE_INFINITY;
            int index = 0;
            while (index < counter.length) {
                if (d < counter[index]) {
                    d = counter[index];
                    mostFrequentIndex = index;
                }
                ++index;
            }
            example.setValue(predictedLabel, mostFrequentIndex);
            index = 0;
            while (index < counter.length) {
                example.setConfidence(predictedLabel.getMapping().mapIndex(index), counter[index]);
                ++index;
            }
        }
        return exampleSet;
    }
}

