/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.FastExample2SparseTransform;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.functions.FastMarginModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeList;
import com.rapidminer.parameter.ParameterTypeSingle;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import liblinear.FeatureNode;
import liblinear.Linear;
import liblinear.Parameter;
import liblinear.Problem;
import liblinear.SolverType;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class FastLargeMargin
extends AbstractLearner {
    public static final String PARAMETER_SOLVER = "solver";
    public static final String PARAMETER_C = "C";
    public static final String PARAMETER_EPSILON = "epsilon";
    public static final String PARAMETER_CLASS_WEIGHTS = "class_weights";
    public static final String PARAMETER_USE_BIAS = "use_bias";
    public static final String[] SOLVER = new String[]{"L2 SVM Dual", "L2 SVM Primal", "L2 Logistic Regression", "L1 SVM Dual"};
    public static final int SOLVER_L2_SVM_DUAL = 0;
    public static final int SOLVER_L2_SVM_PRIMAL = 1;
    public static final int SOLVER_L2_LR = 2;
    public static final int SOLVER_L1_SVM_DUAL = 3;

    public FastLargeMargin(OperatorDescription description) {
        super(description);
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        return lc == LearnerCapability.BINOMINAL_CLASS;
    }

    public static FeatureNode[] makeNodes(Example e, FastExample2SparseTransform ripper, boolean useBias) {
        int[] nonDefaultIndices = ripper.getNonDefaultAttributeIndices(e);
        double[] nonDefaultValues = ripper.getNonDefaultAttributeValues(e, nonDefaultIndices);
        int offset = 0;
        if (useBias) {
            offset = 1;
        }
        FeatureNode[] nodeArray = new FeatureNode[nonDefaultIndices.length + offset];
        int a = 0;
        while (a < nonDefaultIndices.length) {
            FeatureNode node;
            nodeArray[a] = node = new FeatureNode(nonDefaultIndices[a] + 1, nonDefaultValues[a]);
            ++a;
        }
        if (useBias) {
            nodeArray[nodeArray.length - 1] = new FeatureNode(nodeArray.length, 1.0);
        }
        return nodeArray;
    }

    private Problem getProblem(ExampleSet exampleSet) throws UserError {
        this.log("Creating LibLinear problem.");
        FastExample2SparseTransform ripper = new FastExample2SparseTransform(exampleSet);
        int nodeCount = 0;
        Problem problem = new Problem();
        problem.l = exampleSet.size();
        boolean useBias = this.getParameterAsBoolean(PARAMETER_USE_BIAS);
        problem.n = useBias ? exampleSet.getAttributes().size() + 1 : exampleSet.getAttributes().size();
        problem.y = new int[exampleSet.size()];
        problem.x = new FeatureNode[exampleSet.size()][];
        Iterator i = exampleSet.iterator();
        Attribute label = exampleSet.getAttributes().getLabel();
        int j = 0;
        int firstIndex = label.getMapping().getNegativeIndex();
        while (i.hasNext()) {
            Example e = (Example)i.next();
            problem.x[j] = FastLargeMargin.makeNodes(e, ripper, useBias);
            problem.y[j] = (int)e.getValue(label) == firstIndex ? 0 : 1;
            nodeCount += problem.x[j].length;
            ++j;
        }
        this.log("Created " + nodeCount + " nodes for " + j + " examples.");
        return problem;
    }

    private Parameter getParameters(ExampleSet exampleSet) throws OperatorException {
        SolverType solverType = null;
        int solverTypeParameter = this.getParameterAsInt(PARAMETER_SOLVER);
        switch (solverTypeParameter) {
            case 0: {
                solverType = SolverType.L2LOSS_SVM_DUAL;
                break;
            }
            case 1: {
                solverType = SolverType.L2LOSS_SVM;
                break;
            }
            case 2: {
                solverType = SolverType.L2_LR;
                break;
            }
            case 3: {
                solverType = SolverType.L1LOSS_SVM_DUAL;
                break;
            }
            default: {
                solverType = SolverType.L2LOSS_SVM_DUAL;
            }
        }
        double c = this.getParameterAsDouble(PARAMETER_C);
        double epsilon = this.getParameterAsDouble(PARAMETER_EPSILON);
        Parameter parameter = new Parameter(solverType, c, epsilon);
        if (this.isParameterSet(PARAMETER_CLASS_WEIGHTS)) {
            double[] weights = new double[2];
            int[] weightLabelIndices = new int[2];
            int i = 0;
            while (i < weights.length) {
                weights[i] = 1.0;
                weightLabelIndices[i] = i;
                ++i;
            }
            List<String[]> classWeights = this.getParameterList(PARAMETER_CLASS_WEIGHTS);
            Iterator<String[]> i2 = classWeights.iterator();
            Attribute label = exampleSet.getAttributes().getLabel();
            while (i2.hasNext()) {
                String[] classWeightArray = i2.next();
                String className = classWeightArray[0];
                double classWeight = Double.valueOf(classWeightArray[1]);
                int index = label.getMapping().getIndex(className);
                if (index < 0 || index >= weights.length) continue;
                weights[index] = classWeight;
            }
            LinkedList<Double> weightList = new LinkedList<Double>();
            double[] dArray = weights;
            int n = weights.length;
            int n2 = 0;
            while (n2 < n) {
                double d = dArray[n2];
                weightList.add(d);
                ++n2;
            }
            this.log(String.valueOf(this.getName()) + ": used class weights --> " + weightList);
            parameter.setWeights(weights, weightLabelIndices);
        }
        return parameter;
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Parameter params = this.getParameters(exampleSet);
        if (exampleSet.size() < 2) {
            throw new UserError((Operator)this, 110, 2);
        }
        Linear.resetRandom();
        Linear.disableDebugOutput();
        Problem problem = this.getProblem(exampleSet);
        liblinear.Model model = Linear.train(problem, params);
        return new FastMarginModel(exampleSet, model, this.getParameterAsBoolean(PARAMETER_USE_BIAS));
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeSingle type = new ParameterTypeCategory(PARAMETER_SOLVER, "The solver type for this fast margin method.", SOLVER, 0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_C, "The cost parameter C for c_svc, epsilon_svr, and nu_svr.", 0.0, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeDouble(PARAMETER_EPSILON, "Tolerance of termination criterion.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.01));
        types.add(new ParameterTypeList(PARAMETER_CLASS_WEIGHTS, "The weights w for all classes (first column: class name, second column: weight), i.e. set the parameters C of each class w * C (empty: using 1 for all classes where the weight was not defined).", new ParameterTypeDouble("weight", "The weight for the specified class.", 0.0, Double.POSITIVE_INFINITY, 1.0)));
        types.add(new ParameterTypeBoolean(PARAMETER_USE_BIAS, "Indicates if an intercept value should be calculated.", true));
        return types;
    }
}

