/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.functions.HyperplaneModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.math.kernels.DotKernel;
import com.rapidminer.tools.math.kernels.Kernel;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class Perceptron
extends AbstractLearner {
    public static final String PARAMETER_ROUNDS = "rounds";
    public static final String PARAMETER_LEARNING_RATE = "learning_rate";

    public Perceptron(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Kernel kernel = this.getKernel();
        kernel.init(exampleSet);
        double initLearnRate = this.getParameterAsDouble(PARAMETER_LEARNING_RATE);
        NominalMapping labelMapping = exampleSet.getAttributes().getLabel().getMapping();
        String classNeg = labelMapping.getNegativeString();
        String classPos = labelMapping.getPositiveString();
        double classValueNeg = labelMapping.getNegativeIndex();
        int numberOfAttributes = exampleSet.getAttributes().size();
        HyperplaneModel model = new HyperplaneModel(exampleSet, classNeg, classPos, kernel);
        model.init(new double[numberOfAttributes], 0.0);
        int round = 0;
        while (round <= this.getParameterAsInt(PARAMETER_ROUNDS)) {
            double learnRate = this.getLearnRate(round, this.getParameterAsInt(PARAMETER_ROUNDS), initLearnRate);
            Attributes attributes = exampleSet.getAttributes();
            for (Example example : exampleSet) {
                double prediction = model.predict(example);
                if (prediction == example.getLabel()) continue;
                double direction = example.getLabel() == classValueNeg ? -1 : 1;
                model.setIntercept(model.getIntercept() + learnRate * direction);
                double[] coefficients = model.getCoefficients();
                int i = 0;
                for (Attribute attribute : attributes) {
                    int n = i++;
                    coefficients[n] = coefficients[n] + learnRate * direction * example.getValue(attribute);
                }
            }
            ++round;
        }
        return model;
    }

    protected Kernel getKernel() throws UndefinedParameterError {
        return new DotKernel();
    }

    public double getLearnRate(int time, int maxtime, double initLearnRate) {
        return initLearnRate * Math.pow(initLearnRate * 0.1 / initLearnRate, (double)time / (double)maxtime);
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_CLASS) {
            return true;
        }
        return lc == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeInt(PARAMETER_ROUNDS, "The number of datascans used to adapt the hyperplane.", 0, Integer.MAX_VALUE, 3));
        types.add(new ParameterTypeDouble(PARAMETER_LEARNING_RATE, "The hyperplane will adapt with this rate to each example.", 0.0, 1.0, 0.05));
        return types;
    }
}

