package com.rapidminer.operator.learner.functions.neuralnet;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.example.Tools;
import com.rapidminer.gui.renderer.AbstractGraphRenderer;
import com.rapidminer.gui.tools.ExtendedJScrollPane;
import com.rapidminer.gui.tools.JRadioSelectionPanel;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.RandomGenerator;
import java.awt.Component;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:WEB-INF/lib/rapidMiner-1.0.0.jar:com/rapidminer/operator/learner/functions/neuralnet/ImprovedNeuralNetModel.class */
public class ImprovedNeuralNetModel extends PredictionModel {
    private static final long serialVersionUID = -2206598483097451366L;
    private static final ActivationFunction SIGMOID_FUNCTION = new SigmoidFunction();
    private static final ActivationFunction LINEAR_FUNCTION = new LinearFunction();
    private String[] attributeNames;
    private InputNode[] inputNodes;
    private InnerNode[] innerNodes;
    private OutputNode[] outputNodes;

    /* JADX INFO: Access modifiers changed from: protected */
    public ImprovedNeuralNetModel(ExampleSet exampleSet) {
        super(exampleSet);
        this.inputNodes = new InputNode[0];
        this.innerNodes = new InnerNode[0];
        this.outputNodes = new OutputNode[0];
        this.attributeNames = Tools.getRegularAttributeNames(exampleSet);
    }

    public void train(ExampleSet exampleSet, List<String[]> list, int i, double d, double d2, double d3, boolean z, boolean z2, boolean z3, RandomGenerator randomGenerator) {
        Attribute label = exampleSet.getAttributes().getLabel();
        int numberOfClasses = getNumberOfClasses(label);
        if (z3) {
            exampleSet.recalculateAllAttributeStatistics();
        } else {
            exampleSet.recalculateAttributeStatistics(label);
        }
        initInputLayer(exampleSet, z3);
        initOutputLayer(label, numberOfClasses, exampleSet.getStatistics(label, Statistics.MINIMUM), exampleSet.getStatistics(label, Statistics.MAXIMUM), randomGenerator);
        initHiddenLayers(exampleSet, label, list, randomGenerator);
        Attribute weight = exampleSet.getAttributes().getWeight();
        double d4 = 0.0d;
        for (Example example : exampleSet) {
            double d5 = 1.0d;
            if (weight != null) {
                d5 = example.getValue(weight);
            }
            d4 += d5;
        }
        int[] iArr = (int[]) null;
        if (z2) {
            ArrayList arrayList = new ArrayList(exampleSet.size());
            for (int i2 = 0; i2 < exampleSet.size(); i2++) {
                arrayList.add(Integer.valueOf(i2));
            }
            Collections.shuffle(arrayList, randomGenerator);
            iArr = new int[arrayList.size()];
            int i3 = 0;
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                int i4 = i3;
                i3++;
                iArr[i4] = ((Integer) it2.next()).intValue();
            }
        }
        for (int i5 = 0; i5 < i; i5++) {
            double d6 = 0.0d;
            int size = exampleSet.size();
            for (int i6 = 0; i6 < size; i6++) {
                int i7 = i6;
                if (iArr != null) {
                    i7 = iArr[i6];
                }
                Example example2 = exampleSet.getExample(i7);
                resetNetwork();
                calculateValue(example2);
                double value = weight != null ? example2.getValue(weight) : 1.0d;
                double d7 = d2 * value;
                if (z) {
                    d7 /= i5 + 1;
                }
                d6 += (calculateError(example2) / numberOfClasses) * value;
                update(example2, d7, d3);
            }
            double d8 = d6 / d4;
            if (d8 < d) {
                return;
            }
            if (Double.isInfinite(d8) || Double.isNaN(d8)) {
                if (com.rapidminer.tools.Tools.isLessEqual(d2, 0.0d)) {
                    throw new RuntimeException("Cannot reset network to a smaller learning rate.");
                }
                d2 /= 2.0d;
                train(exampleSet, list, i, d, d2, d3, z, z2, z3, randomGenerator);
            }
        }
    }

    @Override // com.rapidminer.operator.learner.PredictionModel
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute attribute) throws OperatorException {
        for (Example example : exampleSet) {
            resetNetwork();
            if (attribute.isNominal()) {
                int numberOfClasses = getNumberOfClasses(getLabel());
                double[] dArr = new double[numberOfClasses];
                for (int i = 0; i < numberOfClasses; i++) {
                    dArr[i] = this.outputNodes[i].calculateValue(true, example);
                }
                double d = 0.0d;
                for (int i2 = 0; i2 < numberOfClasses; i2++) {
                    d += dArr[i2];
                }
                double d2 = Double.NEGATIVE_INFINITY;
                int i3 = 0;
                for (int i4 = 0; i4 < numberOfClasses; i4++) {
                    int i5 = i4;
                    dArr[i5] = dArr[i5] / d;
                    if (dArr[i4] > d2) {
                        i3 = i4;
                        d2 = dArr[i4];
                    }
                }
                example.setValue(attribute, attribute.getMapping().mapString(getLabel().getMapping().mapIndex(i3)));
                for (int i6 = 0; i6 < numberOfClasses; i6++) {
                    example.setConfidence(getLabel().getMapping().mapIndex(i6), dArr[i6]);
                }
            } else {
                example.setValue(attribute, this.outputNodes[0].calculateValue(true, example));
            }
        }
        return exampleSet;
    }

    public String[] getAttributeNames() {
        return this.attributeNames;
    }

    public InputNode[] getInputNodes() {
        return this.inputNodes;
    }

    public OutputNode[] getOutputNodes() {
        return this.outputNodes;
    }

    public InnerNode[] getInnerNodes() {
        return this.innerNodes;
    }

    private int getNumberOfClasses(Attribute attribute) {
        int i = 1;
        if (attribute.isNominal()) {
            i = attribute.getMapping().size();
        }
        return i;
    }

    private void addNode(InnerNode innerNode) {
        InnerNode[] innerNodeArr = new InnerNode[this.innerNodes.length + 1];
        System.arraycopy(this.innerNodes, 0, innerNodeArr, 0, this.innerNodes.length);
        innerNodeArr[innerNodeArr.length - 1] = innerNode;
        this.innerNodes = innerNodeArr;
    }

    private void resetNetwork() {
        for (int i = 0; i < this.outputNodes.length; i++) {
            this.outputNodes[i].reset();
        }
    }

    private void update(Example example, double d, double d2) {
        for (int i = 0; i < this.outputNodes.length; i++) {
            this.outputNodes[i].update(example, d, d2);
        }
    }

    private void calculateValue(Example example) {
        for (int i = 0; i < this.outputNodes.length; i++) {
            this.outputNodes[i].calculateValue(true, example);
        }
    }

    private double calculateError(Example example) {
        for (int i = 0; i < this.inputNodes.length; i++) {
            this.inputNodes[i].calculateError(true, example);
        }
        double d = 0.0d;
        for (int i2 = 0; i2 < this.outputNodes.length; i2++) {
            double calculateError = this.outputNodes[i2].calculateError(false, example);
            d += calculateError * calculateError;
        }
        return d;
    }

    private int getDefaultLayerSize(ExampleSet exampleSet, Attribute attribute) {
        return ((int) Math.round((exampleSet.getAttributes().size() + getNumberOfClasses(attribute)) / 2.0d)) + 1;
    }

    private void initInputLayer(ExampleSet exampleSet, boolean z) {
        this.inputNodes = new InputNode[exampleSet.getAttributes().size()];
        int i = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            this.inputNodes[i] = new InputNode(attribute.getName());
            double d = 1.0d;
            double d2 = 0.0d;
            if (z) {
                double statistics = exampleSet.getStatistics(attribute, Statistics.MINIMUM);
                double statistics2 = exampleSet.getStatistics(attribute, Statistics.MAXIMUM);
                d = (statistics2 - statistics) / 2.0d;
                d2 = (statistics2 + statistics) / 2.0d;
            }
            this.inputNodes[i].setAttribute(attribute, d, d2, z);
            i++;
        }
    }

    private void initOutputLayer(Attribute attribute, int i, double d, double d2, RandomGenerator randomGenerator) {
        double d3 = (d2 - d) / 2.0d;
        double d4 = (d2 + d) / 2.0d;
        this.outputNodes = new OutputNode[i];
        for (int i2 = 0; i2 < i; i2++) {
            if (attribute.isNominal()) {
                this.outputNodes[i2] = new OutputNode(attribute.getMapping().mapIndex(i2), attribute, d3, d4);
                this.outputNodes[i2].setClassIndex(i2);
            } else {
                this.outputNodes[i2] = new OutputNode(attribute.getName(), attribute, d3, d4);
            }
            InnerNode innerNode = attribute.isNominal() ? new InnerNode("Class '" + attribute.getMapping().mapIndex(i2) + "'", -2, randomGenerator, SIGMOID_FUNCTION) : new InnerNode("Regression", -2, randomGenerator, LINEAR_FUNCTION);
            addNode(innerNode);
            Node.connect(innerNode, this.outputNodes[i2]);
        }
    }

    private void initHiddenLayers(ExampleSet exampleSet, Attribute attribute, List<String[]> list, RandomGenerator randomGenerator) {
        String[] strArr;
        int[] iArr;
        if (list.size() > 0) {
            strArr = new String[list.size()];
            iArr = new int[list.size()];
            int i = 0;
            for (String[] strArr2 : list) {
                strArr[i] = strArr2[0];
                int intValue = Integer.valueOf(strArr2[1]).intValue();
                if (intValue <= 0) {
                    intValue = getDefaultLayerSize(exampleSet, attribute);
                }
                iArr[i] = intValue;
                i++;
            }
        } else {
            log("No hidden layers defined. Using default hidden layer.");
            strArr = new String[]{"Hidden"};
            iArr = new int[]{getDefaultLayerSize(exampleSet, attribute)};
        }
        int i2 = 0;
        for (int i3 = 0; i3 < strArr.length; i3++) {
            int i4 = iArr[i3];
            for (int i5 = 0; i5 < i4; i5++) {
                InnerNode innerNode = new InnerNode("Node " + (i5 + 1), i3, randomGenerator, SIGMOID_FUNCTION);
                addNode(innerNode);
                if (i3 > 0) {
                    for (int length = ((this.innerNodes.length - i5) - 1) - i2; length < (this.innerNodes.length - i5) - 1; length++) {
                        Node.connect(this.innerNodes[length], innerNode);
                    }
                }
            }
            i2 = i4;
        }
        int i6 = iArr[0];
        int size = exampleSet.getAttributes().size();
        int numberOfClasses = getNumberOfClasses(attribute);
        if (i6 == 0) {
            for (int i7 = 0; i7 < size; i7++) {
                for (int i8 = 0; i8 < numberOfClasses; i8++) {
                    Node.connect(this.inputNodes[i7], this.innerNodes[i8]);
                }
            }
            return;
        }
        for (int i9 = 0; i9 < size; i9++) {
            for (int i10 = numberOfClasses; i10 < numberOfClasses + i6; i10++) {
                Node.connect(this.inputNodes[i9], this.innerNodes[i10]);
            }
        }
        for (int length2 = this.innerNodes.length - i2; length2 < this.innerNodes.length; length2++) {
            for (int i11 = 0; i11 < numberOfClasses; i11++) {
                Node.connect(this.innerNodes[length2], this.innerNodes[i11]);
            }
        }
    }

    @Override // com.rapidminer.operator.ResultObjectAdapter, com.rapidminer.operator.ResultObject
    public Component getVisualizationComponent(IOContainer iOContainer) {
        JRadioSelectionPanel jRadioSelectionPanel = new JRadioSelectionPanel();
        ExtendedJScrollPane extendedJScrollPane = new ExtendedJScrollPane(new ImprovedNeuralNetVisualizer(this, this.attributeNames));
        Component visualizationComponent = super.getVisualizationComponent(iOContainer);
        jRadioSelectionPanel.addComponent(AbstractGraphRenderer.RENDERER_NAME, extendedJScrollPane, "Changes to a graphical view of this model.");
        jRadioSelectionPanel.addComponent("Text View", visualizationComponent, "Changes to a textual description of this model.");
        return jRadioSelectionPanel;
    }

    @Override // com.rapidminer.operator.learner.PredictionModel, com.rapidminer.report.Readable
    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        int i = -99;
        boolean z = true;
        for (InnerNode innerNode : this.innerNodes) {
            int layerIndex = innerNode.getLayerIndex();
            if (layerIndex != -2) {
                if (i == -99 || i != layerIndex) {
                    if (!z) {
                        stringBuffer.append(com.rapidminer.tools.Tools.getLineSeparators(2));
                    }
                    z = false;
                    String str = "Hidden " + (layerIndex + 1);
                    stringBuffer.append(String.valueOf(str) + com.rapidminer.tools.Tools.getLineSeparator());
                    for (int i2 = 0; i2 < str.length(); i2++) {
                        stringBuffer.append("=");
                    }
                    i = layerIndex;
                    stringBuffer.append(com.rapidminer.tools.Tools.getLineSeparator());
                }
                String str2 = String.valueOf(innerNode.getNodeName()) + " (" + innerNode.getActivationFunction().getTypeName() + ")";
                stringBuffer.append(String.valueOf(com.rapidminer.tools.Tools.getLineSeparator()) + str2 + com.rapidminer.tools.Tools.getLineSeparator());
                for (int i3 = 0; i3 < str2.length(); i3++) {
                    stringBuffer.append("-");
                }
                stringBuffer.append(com.rapidminer.tools.Tools.getLineSeparator());
                double[] weights = innerNode.getWeights();
                Node[] inputNodes = innerNode.getInputNodes();
                for (int i4 = 0; i4 < inputNodes.length; i4++) {
                    stringBuffer.append(String.valueOf(inputNodes[i4].getNodeName()) + ": " + com.rapidminer.tools.Tools.formatNumber(weights[i4 + 1]) + com.rapidminer.tools.Tools.getLineSeparator());
                }
                stringBuffer.append("Threshold: " + com.rapidminer.tools.Tools.formatNumber(weights[0]) + com.rapidminer.tools.Tools.getLineSeparator());
            }
        }
        boolean z2 = true;
        for (InnerNode innerNode2 : this.innerNodes) {
            if (innerNode2.getLayerIndex() == -2) {
                if (z2) {
                    stringBuffer.append(com.rapidminer.tools.Tools.getLineSeparators(2));
                    stringBuffer.append(String.valueOf("Output") + com.rapidminer.tools.Tools.getLineSeparator());
                    for (int i5 = 0; i5 < "Output".length(); i5++) {
                        stringBuffer.append("=");
                    }
                    stringBuffer.append(com.rapidminer.tools.Tools.getLineSeparator());
                    z2 = false;
                }
                String str3 = String.valueOf(innerNode2.getNodeName()) + " (" + innerNode2.getActivationFunction().getTypeName() + ")";
                stringBuffer.append(String.valueOf(com.rapidminer.tools.Tools.getLineSeparator()) + str3 + com.rapidminer.tools.Tools.getLineSeparator());
                for (int i6 = 0; i6 < str3.length(); i6++) {
                    stringBuffer.append("-");
                }
                stringBuffer.append(com.rapidminer.tools.Tools.getLineSeparator());
                double[] weights2 = innerNode2.getWeights();
                Node[] inputNodes2 = innerNode2.getInputNodes();
                for (int i7 = 0; i7 < inputNodes2.length; i7++) {
                    stringBuffer.append(String.valueOf(inputNodes2[i7].getNodeName()) + ": " + com.rapidminer.tools.Tools.formatNumber(weights2[i7 + 1]) + com.rapidminer.tools.Tools.getLineSeparator());
                }
                stringBuffer.append("Threshold: " + com.rapidminer.tools.Tools.formatNumber(weights2[0]) + com.rapidminer.tools.Tools.getLineSeparator());
            }
        }
        return stringBuffer.toString();
    }
}
