/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.neuralnet;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.gui.tools.ExtendedJScrollPane;
import com.rapidminer.gui.tools.JRadioSelectionPanel;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.functions.neuralnet.ActivationFunction;
import com.rapidminer.operator.learner.functions.neuralnet.ImprovedNeuralNetVisualizer;
import com.rapidminer.operator.learner.functions.neuralnet.InnerNode;
import com.rapidminer.operator.learner.functions.neuralnet.InputNode;
import com.rapidminer.operator.learner.functions.neuralnet.LinearFunction;
import com.rapidminer.operator.learner.functions.neuralnet.Node;
import com.rapidminer.operator.learner.functions.neuralnet.OutputNode;
import com.rapidminer.operator.learner.functions.neuralnet.SigmoidFunction;
import com.rapidminer.tools.RandomGenerator;
import java.awt.Component;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ImprovedNeuralNetModel
extends PredictionModel {
    private static final long serialVersionUID = -2206598483097451366L;
    private static final ActivationFunction SIGMOID_FUNCTION = new SigmoidFunction();
    private static final ActivationFunction LINEAR_FUNCTION = new LinearFunction();
    private String[] attributeNames;
    private InputNode[] inputNodes = new InputNode[0];
    private InnerNode[] innerNodes = new InnerNode[0];
    private OutputNode[] outputNodes = new OutputNode[0];

    protected ImprovedNeuralNetModel(ExampleSet trainingExampleSet) {
        super(trainingExampleSet);
        this.attributeNames = Tools.getRegularAttributeNames(trainingExampleSet);
    }

    public void train(ExampleSet exampleSet, List<String[]> hiddenLayers, int maxCycles, double maxError, double learningRate, double momentum, boolean decay, boolean shuffle, boolean normalize, RandomGenerator randomGenerator) {
        Attribute label = exampleSet.getAttributes().getLabel();
        int numberOfClasses = this.getNumberOfClasses(label);
        if (normalize) {
            exampleSet.recalculateAllAttributeStatistics();
        } else {
            exampleSet.recalculateAttributeStatistics(label);
        }
        this.initInputLayer(exampleSet, normalize);
        double labelMin = exampleSet.getStatistics(label, "minimum");
        double labelMax = exampleSet.getStatistics(label, "maximum");
        this.initOutputLayer(label, numberOfClasses, labelMin, labelMax, randomGenerator);
        this.initHiddenLayers(exampleSet, label, hiddenLayers, randomGenerator);
        Attribute weightAttribute = exampleSet.getAttributes().getWeight();
        double totalWeight = 0.0;
        for (Example example : exampleSet) {
            double weight = 1.0;
            if (weightAttribute != null) {
                weight = example.getValue(weightAttribute);
            }
            totalWeight += weight;
        }
        int[] exampleIndices = null;
        if (shuffle) {
            ArrayList<Integer> indices = new ArrayList<Integer>(exampleSet.size());
            int i = 0;
            while (i < exampleSet.size()) {
                indices.add(i);
                ++i;
            }
            Collections.shuffle(indices, randomGenerator);
            exampleIndices = new int[indices.size()];
            int index = 0;
            Iterator iterator = indices.iterator();
            while (iterator.hasNext()) {
                int current = (Integer)iterator.next();
                exampleIndices[index++] = current;
            }
        }
        int cycle = 0;
        while (cycle < maxCycles) {
            double error = 0.0;
            int maxSize = exampleSet.size();
            int index = 0;
            while (index < maxSize) {
                int exampleIndex = index;
                if (exampleIndices != null) {
                    exampleIndex = exampleIndices[index];
                }
                Example example = exampleSet.getExample(exampleIndex);
                this.resetNetwork();
                this.calculateValue(example);
                double weight = 1.0;
                if (weightAttribute != null) {
                    weight = example.getValue(weightAttribute);
                }
                double tempRate = learningRate * weight;
                if (decay) {
                    tempRate /= (double)(cycle + 1);
                }
                error += this.calculateError(example) / (double)numberOfClasses * weight;
                this.update(example, tempRate, momentum);
                ++index;
            }
            if ((error /= totalWeight) < maxError) break;
            if (Double.isInfinite(error) || Double.isNaN(error)) {
                if (com.rapidminer.tools.Tools.isLessEqual(learningRate, 0.0)) {
                    throw new RuntimeException("Cannot reset network to a smaller learning rate.");
                }
                this.train(exampleSet, hiddenLayers, maxCycles, maxError, learningRate /= 2.0, momentum, decay, shuffle, normalize, randomGenerator);
            }
            ++cycle;
        }
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        for (Example example : exampleSet) {
            this.resetNetwork();
            if (predictedLabel.isNominal()) {
                int numberOfClasses = this.getNumberOfClasses(this.getLabel());
                double[] classProbabilities = new double[numberOfClasses];
                int c = 0;
                while (c < numberOfClasses) {
                    classProbabilities[c] = this.outputNodes[c].calculateValue(true, example);
                    ++c;
                }
                double total = 0.0;
                int c2 = 0;
                while (c2 < numberOfClasses) {
                    total += classProbabilities[c2];
                    ++c2;
                }
                double maxConfidence = Double.NEGATIVE_INFINITY;
                int maxIndex = 0;
                int c3 = 0;
                while (c3 < numberOfClasses) {
                    int n = c3;
                    classProbabilities[n] = classProbabilities[n] / total;
                    if (classProbabilities[c3] > maxConfidence) {
                        maxIndex = c3;
                        maxConfidence = classProbabilities[c3];
                    }
                    ++c3;
                }
                example.setValue(predictedLabel, predictedLabel.getMapping().mapString(this.getLabel().getMapping().mapIndex(maxIndex)));
                c3 = 0;
                while (c3 < numberOfClasses) {
                    example.setConfidence(this.getLabel().getMapping().mapIndex(c3), classProbabilities[c3]);
                    ++c3;
                }
                continue;
            }
            double value = this.outputNodes[0].calculateValue(true, example);
            example.setValue(predictedLabel, value);
        }
        return exampleSet;
    }

    public String[] getAttributeNames() {
        return this.attributeNames;
    }

    public InputNode[] getInputNodes() {
        return this.inputNodes;
    }

    public OutputNode[] getOutputNodes() {
        return this.outputNodes;
    }

    public InnerNode[] getInnerNodes() {
        return this.innerNodes;
    }

    private int getNumberOfClasses(Attribute label) {
        int numberOfClasses = 1;
        if (label.isNominal()) {
            numberOfClasses = label.getMapping().size();
        }
        return numberOfClasses;
    }

    private void addNode(InnerNode node) {
        InnerNode[] newInnerNodes = new InnerNode[this.innerNodes.length + 1];
        System.arraycopy(this.innerNodes, 0, newInnerNodes, 0, this.innerNodes.length);
        newInnerNodes[newInnerNodes.length - 1] = node;
        this.innerNodes = newInnerNodes;
    }

    private void resetNetwork() {
        int i = 0;
        while (i < this.outputNodes.length) {
            this.outputNodes[i].reset();
            ++i;
        }
    }

    private void update(Example example, double learningRate, double momentum) {
        int i = 0;
        while (i < this.outputNodes.length) {
            this.outputNodes[i].update(example, learningRate, momentum);
            ++i;
        }
    }

    private void calculateValue(Example example) {
        int i = 0;
        while (i < this.outputNodes.length) {
            this.outputNodes[i].calculateValue(true, example);
            ++i;
        }
    }

    private double calculateError(Example example) {
        int i = 0;
        while (i < this.inputNodes.length) {
            this.inputNodes[i].calculateError(true, example);
            ++i;
        }
        double totalError = 0.0;
        int i2 = 0;
        while (i2 < this.outputNodes.length) {
            double error = this.outputNodes[i2].calculateError(false, example);
            totalError += error * error;
            ++i2;
        }
        return totalError;
    }

    private int getDefaultLayerSize(ExampleSet exampleSet, Attribute label) {
        return (int)Math.round((double)(exampleSet.getAttributes().size() + this.getNumberOfClasses(label)) / 2.0) + 1;
    }

    private void initInputLayer(ExampleSet exampleSet, boolean normalize) {
        this.inputNodes = new InputNode[exampleSet.getAttributes().size()];
        int a = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            this.inputNodes[a] = new InputNode(attribute.getName());
            double range = 1.0;
            double offset = 0.0;
            if (normalize) {
                double min = exampleSet.getStatistics(attribute, "minimum");
                double max = exampleSet.getStatistics(attribute, "maximum");
                range = (max - min) / 2.0;
                offset = (max + min) / 2.0;
            }
            this.inputNodes[a].setAttribute(attribute, range, offset, normalize);
            ++a;
        }
    }

    private void initOutputLayer(Attribute label, int numberOfClasses, double min, double max, RandomGenerator randomGenerator) {
        double range = (max - min) / 2.0;
        double offset = (max + min) / 2.0;
        this.outputNodes = new OutputNode[numberOfClasses];
        int o = 0;
        while (o < numberOfClasses) {
            if (!label.isNominal()) {
                this.outputNodes[o] = new OutputNode(label.getName(), label, range, offset);
            } else {
                this.outputNodes[o] = new OutputNode(label.getMapping().mapIndex(o), label, range, offset);
                this.outputNodes[o].setClassIndex(o);
            }
            InnerNode actualOutput = null;
            if (label.isNominal()) {
                String classValue = label.getMapping().mapIndex(o);
                actualOutput = new InnerNode("Class '" + classValue + "'", -2, randomGenerator, SIGMOID_FUNCTION);
            } else {
                actualOutput = new InnerNode("Regression", -2, randomGenerator, LINEAR_FUNCTION);
            }
            this.addNode(actualOutput);
            Node.connect(actualOutput, this.outputNodes[o]);
            ++o;
        }
    }

    private void initHiddenLayers(ExampleSet exampleSet, Attribute label, List<String[]> hiddenLayerList, RandomGenerator randomGenerator) {
        int o;
        String[] layerNames = null;
        int[] layerSizes = null;
        if (hiddenLayerList.size() > 0) {
            layerNames = new String[hiddenLayerList.size()];
            layerSizes = new int[hiddenLayerList.size()];
            int index = 0;
            for (String[] nameSizePair : hiddenLayerList) {
                layerNames[index] = nameSizePair[0];
                int layerSize = Integer.valueOf(nameSizePair[1]);
                if (layerSize <= 0) {
                    layerSize = this.getDefaultLayerSize(exampleSet, label);
                }
                layerSizes[index] = layerSize;
                ++index;
            }
        } else {
            this.log("No hidden layers defined. Using default hidden layer.");
            layerNames = new String[]{"Hidden"};
            layerSizes = new int[]{this.getDefaultLayerSize(exampleSet, label)};
        }
        int lastLayerSize = 0;
        int layerIndex = 0;
        while (layerIndex < layerNames.length) {
            int numberOfNodes = layerSizes[layerIndex];
            int nodeIndex = 0;
            while (nodeIndex < numberOfNodes) {
                InnerNode innerNode = new InnerNode("Node " + (nodeIndex + 1), layerIndex, randomGenerator, SIGMOID_FUNCTION);
                this.addNode(innerNode);
                if (layerIndex > 0) {
                    int i = this.innerNodes.length - nodeIndex - 1 - lastLayerSize;
                    while (i < this.innerNodes.length - nodeIndex - 1) {
                        Node.connect(this.innerNodes[i], innerNode);
                        ++i;
                    }
                }
                ++nodeIndex;
            }
            lastLayerSize = numberOfNodes;
            ++layerIndex;
        }
        int firstLayerSize = layerSizes[0];
        int numberOfAttributes = exampleSet.getAttributes().size();
        int numberOfClasses = this.getNumberOfClasses(label);
        if (firstLayerSize == 0) {
            int i = 0;
            while (i < numberOfAttributes) {
                o = 0;
                while (o < numberOfClasses) {
                    Node.connect(this.inputNodes[i], this.innerNodes[o]);
                    ++o;
                }
                ++i;
            }
        } else {
            int i = 0;
            while (i < numberOfAttributes) {
                o = numberOfClasses;
                while (o < numberOfClasses + firstLayerSize) {
                    Node.connect(this.inputNodes[i], this.innerNodes[o]);
                    ++o;
                }
                ++i;
            }
            i = this.innerNodes.length - lastLayerSize;
            while (i < this.innerNodes.length) {
                o = 0;
                while (o < numberOfClasses) {
                    Node.connect(this.innerNodes[i], this.innerNodes[o]);
                    ++o;
                }
                ++i;
            }
        }
    }

    @Override
    public Component getVisualizationComponent(IOContainer ioContainer) {
        JRadioSelectionPanel mainPanel = new JRadioSelectionPanel();
        ExtendedJScrollPane graphView = new ExtendedJScrollPane(new ImprovedNeuralNetVisualizer(this, this.attributeNames));
        Component textView = super.getVisualizationComponent(ioContainer);
        mainPanel.addComponent("Graph View", graphView, "Changes to a graphical view of this model.");
        mainPanel.addComponent("Text View", textView, "Changes to a textual description of this model.");
        return mainPanel;
    }

    @Override
    public String toString() {
        int i;
        Node[] inputNodes;
        double[] weights;
        String nodeName;
        String layerName;
        int layerIndex;
        InnerNode innerNode;
        StringBuffer result = new StringBuffer();
        int lastLayerIndex = -99;
        boolean first = true;
        InnerNode[] innerNodeArray = this.innerNodes;
        int n = this.innerNodes.length;
        int n2 = 0;
        while (n2 < n) {
            innerNode = innerNodeArray[n2];
            layerIndex = innerNode.getLayerIndex();
            if (layerIndex != -2) {
                int t;
                if (lastLayerIndex == -99 || lastLayerIndex != layerIndex) {
                    if (!first) {
                        result.append(com.rapidminer.tools.Tools.getLineSeparators(2));
                    }
                    first = false;
                    layerName = "Hidden " + (layerIndex + 1);
                    result.append(String.valueOf(layerName) + com.rapidminer.tools.Tools.getLineSeparator());
                    t = 0;
                    while (t < layerName.length()) {
                        result.append("=");
                        ++t;
                    }
                    lastLayerIndex = layerIndex;
                    result.append(com.rapidminer.tools.Tools.getLineSeparator());
                }
                nodeName = String.valueOf(innerNode.getNodeName()) + " (" + innerNode.getActivationFunction().getTypeName() + ")";
                result.append(String.valueOf(com.rapidminer.tools.Tools.getLineSeparator()) + nodeName + com.rapidminer.tools.Tools.getLineSeparator());
                t = 0;
                while (t < nodeName.length()) {
                    result.append("-");
                    ++t;
                }
                result.append(com.rapidminer.tools.Tools.getLineSeparator());
                weights = innerNode.getWeights();
                inputNodes = innerNode.getInputNodes();
                i = 0;
                while (i < inputNodes.length) {
                    result.append(String.valueOf(inputNodes[i].getNodeName()) + ": " + com.rapidminer.tools.Tools.formatNumber(weights[i + 1]) + com.rapidminer.tools.Tools.getLineSeparator());
                    ++i;
                }
                result.append("Threshold: " + com.rapidminer.tools.Tools.formatNumber(weights[0]) + com.rapidminer.tools.Tools.getLineSeparator());
            }
            ++n2;
        }
        first = true;
        innerNodeArray = this.innerNodes;
        n = this.innerNodes.length;
        n2 = 0;
        while (n2 < n) {
            innerNode = innerNodeArray[n2];
            layerIndex = innerNode.getLayerIndex();
            if (layerIndex == -2) {
                if (first) {
                    result.append(com.rapidminer.tools.Tools.getLineSeparators(2));
                    layerName = "Output";
                    result.append(String.valueOf(layerName) + com.rapidminer.tools.Tools.getLineSeparator());
                    int t = 0;
                    while (t < layerName.length()) {
                        result.append("=");
                        ++t;
                    }
                    lastLayerIndex = layerIndex;
                    result.append(com.rapidminer.tools.Tools.getLineSeparator());
                    first = false;
                }
                nodeName = String.valueOf(innerNode.getNodeName()) + " (" + innerNode.getActivationFunction().getTypeName() + ")";
                result.append(String.valueOf(com.rapidminer.tools.Tools.getLineSeparator()) + nodeName + com.rapidminer.tools.Tools.getLineSeparator());
                int t = 0;
                while (t < nodeName.length()) {
                    result.append("-");
                    ++t;
                }
                result.append(com.rapidminer.tools.Tools.getLineSeparator());
                weights = innerNode.getWeights();
                inputNodes = innerNode.getInputNodes();
                i = 0;
                while (i < inputNodes.length) {
                    result.append(String.valueOf(inputNodes[i].getNodeName()) + ": " + com.rapidminer.tools.Tools.formatNumber(weights[i + 1]) + com.rapidminer.tools.Tools.getLineSeparator());
                    ++i;
                }
                result.append("Threshold: " + com.rapidminer.tools.Tools.formatNumber(weights[0]) + com.rapidminer.tools.Tools.getLineSeparator());
            }
            ++n2;
        }
        return result.toString();
    }
}

