/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.performance;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.performance.ForecastingCriterion;
import com.rapidminer.operator.performance.MeasuredPerformance;
import com.rapidminer.tools.math.Averagable;

public class PredictionTrendAccuracy
extends MeasuredPerformance
implements ForecastingCriterion {
    private static final long serialVersionUID = 4275593122138248581L;
    private double length = 1.0;
    private double correctCounter = 0.0;
    private Operator parent;

    public PredictionTrendAccuracy() {
    }

    public PredictionTrendAccuracy(PredictionTrendAccuracy pta) {
        super(pta);
        this.length = pta.length;
        this.correctCounter = pta.correctCounter;
        this.parent = pta.parent;
    }

    public void setParent(Operator parent) {
        this.parent = parent;
    }

    public String getName() {
        return "prediction_trend_accuracy";
    }

    public String getDescription() {
        return "Measures the average of times a regression prediction was able to correctly predict the trend of the regression.";
    }

    public void startCounting(ExampleSet eSet, boolean useExampleWeights) throws OperatorException {
        int horizon;
        super.startCounting(eSet, useExampleWeights);
        Attribute labelAttribute = eSet.getAttributes().getLabel();
        Attribute predictedLabelAttribute = eSet.getAttributes().getPredictedLabel();
        Attribute weightAttribute = null;
        if (useExampleWeights) {
            weightAttribute = eSet.getAttributes().getWeight();
        }
        double[] weights = new double[eSet.size()];
        double[] labels = new double[eSet.size()];
        double[] predictions = new double[eSet.size()];
        int index = 0;
        for (Example example : eSet) {
            double weight = 1.0;
            if (weightAttribute != null) {
                weight = example.getValue(weightAttribute);
            }
            weights[index] = weight;
            labels[index] = example.getValue(labelAttribute);
            predictions[index] = example.getValue(predictedLabelAttribute);
            ++index;
        }
        int i = horizon = this.parent.getParameterAsInt("horizon");
        while (i < labels.length) {
            double actualTrend = labels[i] - labels[i - horizon];
            double predictionTrend = predictions[i] - labels[i - horizon];
            if (actualTrend * predictionTrend >= 0.0) {
                this.correctCounter += weights[i - horizon];
            }
            this.length += weights[i - horizon];
            ++i;
        }
    }

    public double getExampleCount() {
        return this.length;
    }

    public void countExample(Example example) {
    }

    public double getFitness() {
        return this.getAverage();
    }

    public double getMikroAverage() {
        return this.correctCounter / this.length;
    }

    public double getMikroVariance() {
        return Double.NaN;
    }

    public void buildSingleAverage(Averagable averagable) {
        PredictionTrendAccuracy other = (PredictionTrendAccuracy)averagable;
        this.length += other.length;
        this.correctCounter += other.correctCounter;
    }
}

