/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.features.weighting;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.features.weighting.AbstractWeighting;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeNumber;
import com.rapidminer.tools.Tools;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ReliefWeighting
extends AbstractWeighting {
    public static final String PARAMETER_NUMBER_OF_NEIGHBORS = "number_of_neighbors";
    public static final String PARAMETER_SAMPLE_RATIO = "sample_ratio";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";
    private double differentLabelWeight;
    private double[] differentAttributesWeights;
    private double[] differentLabelAndAttributesWeights;
    private double[] classProbabilities;

    public ReliefWeighting(OperatorDescription description) {
        super(description);
    }

    @Override
    public AttributeWeights calculateWeights(ExampleSet inputSet) throws OperatorException {
        inputSet.recalculateAllAttributeStatistics();
        Attribute label = inputSet.getAttributes().getLabel();
        if (label == null) {
            throw new UserError(this, 105);
        }
        AttributeWeights weights = new AttributeWeights(inputSet);
        for (Attribute attribute : inputSet.getAttributes()) {
            weights.setWeight(attribute.getName(), 0.0);
        }
        this.differentLabelWeight = 0.0;
        this.differentAttributesWeights = new double[inputSet.getAttributes().size()];
        this.differentLabelAndAttributesWeights = new double[inputSet.getAttributes().size()];
        this.classProbabilities = null;
        if (label.isNominal()) {
            this.classProbabilities = new double[label.getMapping().size()];
            int counter = 0;
            for (String value : label.getMapping().getValues()) {
                this.classProbabilities[counter++] = inputSet.getStatistics(label, "count", value) / (double)inputSet.size();
            }
        }
        int numberOfNeighbors = this.getParameterAsInt(PARAMETER_NUMBER_OF_NEIGHBORS);
        double sampleRatio = this.getParameterAsDouble(PARAMETER_SAMPLE_RATIO);
        ExampleSet exampleSet = inputSet;
        if (sampleRatio < 1.0) {
            exampleSet = new SplittedExampleSet(inputSet, sampleRatio, 2, this.getParameterAsInt(PARAMETER_LOCAL_RANDOM_SEED));
            ((SplittedExampleSet)exampleSet).selectSingleSubset(0);
        }
        int exampleCounter = 0;
        for (Example example : exampleSet) {
            Map<String, SortedSet<IndexDistance>> neighborSets = this.searchNeighbors(exampleSet, example, exampleCounter, label, numberOfNeighbors);
            if (label.isNominal()) {
                this.updateWeightsClassification(neighborSets, exampleSet, example, weights, label);
            } else {
                this.updateWeightsRegression(neighborSets, exampleSet, example, weights, label, numberOfNeighbors);
            }
            ++exampleCounter;
        }
        if (!label.isNominal()) {
            int attributeCounter = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                double weight = this.differentLabelAndAttributesWeights[attributeCounter] / this.differentLabelWeight - (this.differentAttributesWeights[attributeCounter] - this.differentLabelAndAttributesWeights[attributeCounter]) / ((double)exampleSet.size() - this.differentLabelWeight);
                weights.setWeight(attribute.getName(), weight);
                ++attributeCounter;
            }
        }
        return weights;
    }

    private void updateWeightsRegression(Map<String, SortedSet<IndexDistance>> neighborSets, ExampleSet exampleSet, Example example, AttributeWeights weights, Attribute label, int numberOfNeighbors) {
        for (IndexDistance indexDistance : neighborSets.get("regression")) {
            Example neighbor = exampleSet.getExample(indexDistance.getIndex());
            double labelDiff = this.normedDifference(example, neighbor, exampleSet, label);
            if (Double.isNaN(labelDiff)) continue;
            this.differentLabelWeight += labelDiff / (double)numberOfNeighbors;
            int attributeCounter = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                double diff;
                int unknownCount = (int)exampleSet.getStatistics(attribute, "unknown");
                if (unknownCount >= exampleSet.size() || Double.isNaN(diff = this.normedDifference(example, neighbor, exampleSet, attribute))) continue;
                int n = attributeCounter;
                this.differentAttributesWeights[n] = this.differentAttributesWeights[n] + diff / (double)numberOfNeighbors;
                int n2 = attributeCounter++;
                this.differentLabelAndAttributesWeights[n2] = this.differentLabelAndAttributesWeights[n2] + labelDiff * diff / (double)numberOfNeighbors;
            }
        }
    }

    private void updateWeightsClassification(Map<String, SortedSet<IndexDistance>> neighborSets, ExampleSet exampleSet, Example example, AttributeWeights weights, Attribute label) {
        double classProbabilityNormalization = 1.0 - this.classProbabilities[(int)example.getValue(label)];
        int classCounter = 0;
        for (String classValue : label.getMapping().getValues()) {
            for (IndexDistance indexDistance : neighborSets.get(classValue)) {
                Example neighbor = exampleSet.getExample(indexDistance.getIndex());
                int attributeCounter = 0;
                for (Attribute attribute : exampleSet.getAttributes()) {
                    double diff;
                    double weight = weights.getWeight(attribute.getName());
                    int unknownCount = (int)exampleSet.getStatistics(attribute, "unknown");
                    if (unknownCount < exampleSet.size() && !Double.isNaN(diff = this.normedDifference(example, neighbor, exampleSet, attribute))) {
                        weight = classValue.equals(example.getValueAsString(label)) ? (weight -= diff / (double)(exampleSet.size() - unknownCount)) : (weight += this.classProbabilities[classCounter] / classProbabilityNormalization * diff / (double)(exampleSet.size() - unknownCount));
                    }
                    weights.setWeight(attribute.getName(), weight);
                    ++attributeCounter;
                }
            }
            ++classCounter;
        }
    }

    private double normedDifference(Example first, Example second, ExampleSet exampleSet, Attribute attribute) {
        double diff = Math.abs(first.getValue(attribute) - second.getValue(attribute));
        if (Double.isNaN(diff)) {
            return Double.NaN;
        }
        if (attribute.isNominal()) {
            if (diff == 0.0) {
                return 0.0;
            }
            return 1.0;
        }
        double min = exampleSet.getStatistics(attribute, "minimum");
        double max = exampleSet.getStatistics(attribute, "maximum");
        return (diff - min) / (max - min);
    }

    private Map<String, SortedSet<IndexDistance>> searchNeighbors(ExampleSet exampleSet, Example example, int exampleIndex, Attribute label, int numberOfNeighbors) {
        HashMap<String, SortedSet<IndexDistance>> neighborSets = new HashMap<String, SortedSet<IndexDistance>>();
        if (label.isNominal()) {
            for (String value : label.getMapping().getValues()) {
                neighborSets.put(value, new TreeSet());
            }
        } else {
            neighborSets.put("regression", new TreeSet());
        }
        int exampleCounter = 0;
        for (Example candidate : exampleSet) {
            if (exampleIndex != exampleCounter) {
                double distance = this.calculateDistance(example, candidate);
                SortedSet currentSet = null;
                if (label.isNominal()) {
                    String classValue = candidate.getValueAsString(label);
                    currentSet = (SortedSet)neighborSets.get(classValue);
                } else {
                    currentSet = (SortedSet)neighborSets.get("regression");
                }
                currentSet.add(new IndexDistance(exampleCounter, distance));
                if (currentSet.size() > numberOfNeighbors) {
                    currentSet.remove(currentSet.last());
                }
            }
            ++exampleCounter;
        }
        return neighborSets;
    }

    private double calculateDistance(Example first, Example second) {
        double distance = 0.0;
        for (Attribute attribute : first.getAttributes()) {
            double diff = first.getValue(attribute) - second.getValue(attribute);
            distance += diff * diff;
        }
        return Math.sqrt(distance);
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeNumber type = new ParameterTypeInt(PARAMETER_NUMBER_OF_NEIGHBORS, "Number of nearest neigbors for relevance calculation.", 1, Integer.MAX_VALUE, 10);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_SAMPLE_RATIO, "Number of examples used for determining the weights.", 0.0, 1.0, 1.0);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeInt(PARAMETER_LOCAL_RANDOM_SEED, "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        return types;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static class IndexDistance
    implements Comparable<IndexDistance> {
        private int exampleIndex;
        private double distance;

        public IndexDistance(int index, double distance) {
            this.exampleIndex = index;
            this.distance = distance;
        }

        public int getIndex() {
            return this.exampleIndex;
        }

        public double getDistance() {
            return this.distance;
        }

        public int hashCode() {
            return Double.valueOf(this.distance).hashCode();
        }

        public boolean equals(Object other) {
            if (!(other instanceof IndexDistance)) {
                return false;
            }
            IndexDistance o = (IndexDistance)other;
            return this.distance == o.distance;
        }

        @Override
        public int compareTo(IndexDistance o) {
            return Double.compare(this.distance, o.distance);
        }

        public String toString() {
            return String.valueOf(this.exampleIndex) + " (d: " + Tools.formatNumber(this.distance) + ")";
        }
    }
}

