package com.rapidminer.operator.features.weighting;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.generator.MultipleLabelGenerator;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.Tools;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;

/* loaded from: input_file:WEB-INF/lib/rapidMiner-1.0.0.jar:com/rapidminer/operator/features/weighting/ReliefWeighting.class */
public class ReliefWeighting extends AbstractWeighting {
    public static final String PARAMETER_NUMBER_OF_NEIGHBORS = "number_of_neighbors";
    public static final String PARAMETER_SAMPLE_RATIO = "sample_ratio";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";
    private double differentLabelWeight;
    private double[] differentAttributesWeights;
    private double[] differentLabelAndAttributesWeights;
    private double[] classProbabilities;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/rapidMiner-1.0.0.jar:com/rapidminer/operator/features/weighting/ReliefWeighting$IndexDistance.class */
    public static class IndexDistance implements Comparable<IndexDistance> {
        private int exampleIndex;
        private double distance;

        public IndexDistance(int i, double d) {
            this.exampleIndex = i;
            this.distance = d;
        }

        public int getIndex() {
            return this.exampleIndex;
        }

        public double getDistance() {
            return this.distance;
        }

        public int hashCode() {
            return Double.valueOf(this.distance).hashCode();
        }

        public boolean equals(Object obj) {
            return (obj instanceof IndexDistance) && this.distance == ((IndexDistance) obj).distance;
        }

        @Override // java.lang.Comparable
        public int compareTo(IndexDistance indexDistance) {
            return Double.compare(this.distance, indexDistance.distance);
        }

        public String toString() {
            return String.valueOf(this.exampleIndex) + " (d: " + Tools.formatNumber(this.distance) + ")";
        }
    }

    public ReliefWeighting(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.features.weighting.AbstractWeighting
    public AttributeWeights calculateWeights(ExampleSet exampleSet) throws OperatorException {
        exampleSet.recalculateAllAttributeStatistics();
        Attribute label = exampleSet.getAttributes().getLabel();
        if (label == null) {
            throw new UserError(this, 105);
        }
        AttributeWeights attributeWeights = new AttributeWeights(exampleSet);
        Iterator<Attribute> it2 = exampleSet.getAttributes().iterator();
        while (it2.hasNext()) {
            attributeWeights.setWeight(it2.next().getName(), 0.0d);
        }
        this.differentLabelWeight = 0.0d;
        this.differentAttributesWeights = new double[exampleSet.getAttributes().size()];
        this.differentLabelAndAttributesWeights = new double[exampleSet.getAttributes().size()];
        this.classProbabilities = null;
        if (label.isNominal()) {
            this.classProbabilities = new double[label.getMapping().size()];
            int i = 0;
            Iterator<String> it3 = label.getMapping().getValues().iterator();
            while (it3.hasNext()) {
                int i2 = i;
                i++;
                this.classProbabilities[i2] = exampleSet.getStatistics(label, "count", it3.next()) / exampleSet.size();
            }
        }
        int parameterAsInt = getParameterAsInt("number_of_neighbors");
        double parameterAsDouble = getParameterAsDouble("sample_ratio");
        ExampleSet exampleSet2 = exampleSet;
        if (parameterAsDouble < 1.0d) {
            exampleSet2 = new SplittedExampleSet(exampleSet, parameterAsDouble, 2, getParameterAsInt("local_random_seed"));
            ((SplittedExampleSet) exampleSet2).selectSingleSubset(0);
        }
        int i3 = 0;
        for (Example example : exampleSet2) {
            Map<String, SortedSet<IndexDistance>> searchNeighbors = searchNeighbors(exampleSet2, example, i3, label, parameterAsInt);
            if (label.isNominal()) {
                updateWeightsClassification(searchNeighbors, exampleSet2, example, attributeWeights, label);
            } else {
                updateWeightsRegression(searchNeighbors, exampleSet2, example, attributeWeights, label, parameterAsInt);
            }
            i3++;
        }
        if (!label.isNominal()) {
            int i4 = 0;
            Iterator<Attribute> it4 = exampleSet2.getAttributes().iterator();
            while (it4.hasNext()) {
                attributeWeights.setWeight(it4.next().getName(), (this.differentLabelAndAttributesWeights[i4] / this.differentLabelWeight) - ((this.differentAttributesWeights[i4] - this.differentLabelAndAttributesWeights[i4]) / (exampleSet2.size() - this.differentLabelWeight)));
                i4++;
            }
        }
        return attributeWeights;
    }

    private void updateWeightsRegression(Map<String, SortedSet<IndexDistance>> map, ExampleSet exampleSet, Example example, AttributeWeights attributeWeights, Attribute attribute, int i) {
        Iterator<IndexDistance> it2 = map.get(MultipleLabelGenerator.PARAMETER_REGRESSION).iterator();
        while (it2.hasNext()) {
            Example example2 = exampleSet.getExample(it2.next().getIndex());
            double normedDifference = normedDifference(example, example2, exampleSet, attribute);
            if (!Double.isNaN(normedDifference)) {
                this.differentLabelWeight += normedDifference / i;
                int i2 = 0;
                for (Attribute attribute2 : exampleSet.getAttributes()) {
                    if (((int) exampleSet.getStatistics(attribute2, "unknown")) < exampleSet.size()) {
                        double normedDifference2 = normedDifference(example, example2, exampleSet, attribute2);
                        if (!Double.isNaN(normedDifference2)) {
                            double[] dArr = this.differentAttributesWeights;
                            int i3 = i2;
                            dArr[i3] = dArr[i3] + (normedDifference2 / i);
                            double[] dArr2 = this.differentLabelAndAttributesWeights;
                            int i4 = i2;
                            dArr2[i4] = dArr2[i4] + ((normedDifference * normedDifference2) / i);
                            i2++;
                        }
                    }
                }
            }
        }
    }

    private void updateWeightsClassification(Map<String, SortedSet<IndexDistance>> map, ExampleSet exampleSet, Example example, AttributeWeights attributeWeights, Attribute attribute) {
        double d = 1.0d - this.classProbabilities[(int) example.getValue(attribute)];
        int i = 0;
        for (String str : attribute.getMapping().getValues()) {
            Iterator<IndexDistance> it2 = map.get(str).iterator();
            while (it2.hasNext()) {
                Example example2 = exampleSet.getExample(it2.next().getIndex());
                int i2 = 0;
                for (Attribute attribute2 : exampleSet.getAttributes()) {
                    double weight = attributeWeights.getWeight(attribute2.getName());
                    if (((int) exampleSet.getStatistics(attribute2, "unknown")) < exampleSet.size()) {
                        double normedDifference = normedDifference(example, example2, exampleSet, attribute2);
                        if (!Double.isNaN(normedDifference)) {
                            weight = str.equals(example.getValueAsString(attribute)) ? weight - (normedDifference / (exampleSet.size() - r0)) : weight + (((this.classProbabilities[i] / d) * normedDifference) / (exampleSet.size() - r0));
                        }
                    }
                    attributeWeights.setWeight(attribute2.getName(), weight);
                    i2++;
                }
            }
            i++;
        }
    }

    private double normedDifference(Example example, Example example2, ExampleSet exampleSet, Attribute attribute) {
        double abs = Math.abs(example.getValue(attribute) - example2.getValue(attribute));
        if (Double.isNaN(abs)) {
            return Double.NaN;
        }
        if (attribute.isNominal()) {
            return abs == 0.0d ? 0.0d : 1.0d;
        }
        double statistics = exampleSet.getStatistics(attribute, Statistics.MINIMUM);
        return (abs - statistics) / (exampleSet.getStatistics(attribute, Statistics.MAXIMUM) - statistics);
    }

    private Map<String, SortedSet<IndexDistance>> searchNeighbors(ExampleSet exampleSet, Example example, int i, Attribute attribute, int i2) {
        HashMap hashMap = new HashMap();
        if (attribute.isNominal()) {
            Iterator<String> it2 = attribute.getMapping().getValues().iterator();
            while (it2.hasNext()) {
                hashMap.put(it2.next(), new TreeSet());
            }
        } else {
            hashMap.put(MultipleLabelGenerator.PARAMETER_REGRESSION, new TreeSet());
        }
        int i3 = 0;
        for (Example example2 : exampleSet) {
            if (i != i3) {
                double calculateDistance = calculateDistance(example, example2);
                SortedSet sortedSet = attribute.isNominal() ? (SortedSet) hashMap.get(example2.getValueAsString(attribute)) : (SortedSet) hashMap.get(MultipleLabelGenerator.PARAMETER_REGRESSION);
                sortedSet.add(new IndexDistance(i3, calculateDistance));
                if (sortedSet.size() > i2) {
                    sortedSet.remove(sortedSet.last());
                }
            }
            i3++;
        }
        return hashMap;
    }

    private double calculateDistance(Example example, Example example2) {
        double d = 0.0d;
        for (Attribute attribute : example.getAttributes()) {
            double value = example.getValue(attribute) - example2.getValue(attribute);
            d += value * value;
        }
        return Math.sqrt(d);
    }

    @Override // com.rapidminer.operator.features.weighting.AbstractWeighting, com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt("number_of_neighbors", "Number of nearest neigbors for relevance calculation.", 1, Integer.MAX_VALUE, 10);
        parameterTypeInt.setExpert(false);
        parameterTypes.add(parameterTypeInt);
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble("sample_ratio", "Number of examples used for determining the weights.", 0.0d, 1.0d, 1.0d);
        parameterTypeDouble.setExpert(false);
        parameterTypes.add(parameterTypeDouble);
        parameterTypes.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        return parameterTypes;
    }
}
