/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.kernel;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.FormulaProvider;
import com.rapidminer.operator.learner.functions.kernel.KernelModel;
import com.rapidminer.operator.learner.functions.kernel.SupportVector;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExample;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel.Kernel;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel.KernelDot;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.svm.SVMInterface;
import com.rapidminer.tools.Tools;
import java.util.Iterator;
import java.util.Map;

public abstract class AbstractMySVMModel
extends KernelModel
implements FormulaProvider {
    private static final long serialVersionUID = 2812901947459843681L;
    private SVMExamples model;
    private Kernel kernel;
    private double[] weights = null;

    public AbstractMySVMModel(ExampleSet exampleSet, SVMExamples model, Kernel kernel, int kernelType) {
        super(exampleSet);
        this.model = model;
        this.kernel = kernel;
        if (this.kernel instanceof KernelDot) {
            this.weights = new double[this.getNumberOfAttributes()];
            int i = 0;
            while (i < this.getNumberOfSupportVectors()) {
                SupportVector sv = this.getSupportVector(i);
                if (sv != null) {
                    double[] x = sv.getX();
                    double alpha = sv.getAlpha();
                    double y = sv.getY();
                    int j = 0;
                    while (j < this.weights.length) {
                        int n = j;
                        this.weights[n] = this.weights[n] + y * alpha * x[j];
                        ++j;
                    }
                } else {
                    this.weights = null;
                    break;
                }
                ++i;
            }
        }
    }

    public abstract SVMInterface createSVM();

    public boolean isClassificationModel() {
        return this.getLabel().isNominal();
    }

    public double getBias() {
        return this.model.get_b();
    }

    public SupportVector getSupportVector(int index) {
        double alpha = this.model.get_alpha(index);
        double y = this.model.get_y(index);
        if (y != 0.0) {
            alpha /= y;
        }
        return new SupportVector(this.model.get_example(index).toDense(this.getNumberOfAttributes()), y, alpha);
    }

    public double getAlpha(int index) {
        return this.model.get_alpha(index);
    }

    public String getId(int index) {
        return this.model.getId(index);
    }

    public int getNumberOfSupportVectors() {
        return this.model.count_examples();
    }

    public int getNumberOfAttributes() {
        return this.model.get_dim();
    }

    public double getAttributeValue(int exampleIndex, int attributeIndex) {
        SVMExample sVMExample = this.model.get_example(exampleIndex);
        double value = 0.0;
        try {
            value = sVMExample.toDense(this.getNumberOfAttributes())[attributeIndex];
        }
        catch (ArrayIndexOutOfBoundsException arrayIndexOutOfBoundsException) {
            // empty catch block
        }
        return value;
    }

    public String getClassificationLabel(int index) {
        double y = this.model.get_y(index);
        if (y < 0.0) {
            return this.getLabel().getMapping().getNegativeString();
        }
        return this.getLabel().getMapping().getPositiveString();
    }

    public double getRegressionLabel(int index) {
        return this.model.get_y(index);
    }

    public double getFunctionValue(int index) {
        SVMInterface svm = this.createSVM();
        svm.init(this.kernel, this.model);
        return svm.predict(this.model.get_example(index));
    }

    public Kernel getKernel() {
        return this.kernel;
    }

    public SVMExamples getExampleSet() {
        return this.model;
    }

    public abstract void setPrediction(Example var1, double var2);

    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabelAttribute) throws OperatorException {
        if (this.kernel instanceof KernelDot && this.weights != null) {
            Map<Integer, SVMExamples.MeanVariance> meanVariances = this.model.getMeanVariances();
            for (Example example : exampleSet) {
                double prediction = this.getBias();
                int a = 0;
                for (Attribute attribute : exampleSet.getAttributes()) {
                    double value = example.getValue(attribute);
                    SVMExamples.MeanVariance meanVariance = meanVariances.get(a);
                    if (meanVariance != null) {
                        value = meanVariance.getVariance() == 0.0 ? 0.0 : (value - meanVariance.getMean()) / Math.sqrt(meanVariance.getVariance());
                    }
                    prediction += this.weights[a] * value;
                    ++a;
                }
                this.setPrediction(example, prediction);
            }
            return exampleSet;
        }
        SVMExamples toPredict = new SVMExamples(exampleSet, exampleSet.getAttributes().getPredictedLabel(), this.model.getMeanVariances());
        SVMInterface svm = this.createSVM();
        svm.init(this.kernel, this.model);
        svm.predict(toPredict);
        Iterator reader = exampleSet.iterator();
        int k = 0;
        while (reader.hasNext()) {
            this.setPrediction((Example)reader.next(), toPredict.get_y(k++));
        }
        return exampleSet;
    }

    public String getFormula() {
        StringBuffer result = new StringBuffer();
        Kernel kernel = this.getKernel();
        boolean first = true;
        int i = 0;
        while (i < this.getNumberOfSupportVectors()) {
            double alpha;
            SupportVector sv = this.getSupportVector(i);
            if (sv != null && !Tools.isZero(alpha = sv.getAlpha())) {
                result.append(Tools.getLineSeparator());
                double[] x = sv.getX();
                double y = sv.getY();
                double factor = y * alpha;
                if (factor < 0.0) {
                    if (first) {
                        result.append("- " + Math.abs(factor));
                    } else {
                        result.append("- " + Math.abs(factor));
                    }
                } else if (first) {
                    result.append("  " + factor);
                } else {
                    result.append("+ " + factor);
                }
                result.append(" * (" + kernel.getDistanceFormula(x, this.getAttributeConstructions()) + ")");
                first = false;
            }
            ++i;
        }
        double bias = this.getBias();
        if (!Tools.isZero(bias)) {
            result.append(Tools.getLineSeparator());
            if (bias < 0.0) {
                if (first) {
                    result.append("- " + Math.abs(bias));
                } else {
                    result.append("- " + Math.abs(bias));
                }
            } else if (first) {
                result.append(bias);
            } else {
                result.append("+ " + bias);
            }
        }
        return result.toString();
    }
}

