/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import java.util.HashMap;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Assert;
import org.junit.Test;

public class JavaRandomForestClassifierSuite
extends SharedSparkSession {
    @Test
    public void runDT() {
        String[] invalidStrategies;
        String[] integerStrategies;
        String[] realStrategies;
        int nPoints = 20;
        double A = 2.0;
        double B = -1.5;
        JavaRDD data = this.jsc.parallelize(LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
        HashMap<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
        Dataset<Row> dataFrame = TreeTests.setMetadata((JavaRDD<LabeledPoint>)data, categoricalFeatures, 2);
        RandomForestClassifier rf = new RandomForestClassifier().setMaxDepth(2).setMaxBins(10).setMinInstancesPerNode(5).setMinInfoGain(0.0).setMaxMemoryInMB(256).setCacheNodeIds(false).setCheckpointInterval(10).setSubsamplingRate(1.0).setSeed(1234L).setNumTrees(3).setMaxDepth(2);
        for (String impurity : RandomForestClassifier.supportedImpurities()) {
            rf.setImpurity(impurity);
        }
        for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) {
            rf.setFeatureSubsetStrategy(featureSubsetStrategy);
        }
        for (String strategy : realStrategies = new String[]{".1", ".10", "0.10", "0.1", "0.9", "1.0"}) {
            rf.setFeatureSubsetStrategy(strategy);
        }
        for (String strategy : integerStrategies = new String[]{"1", "10", "100", "1000", "10000"}) {
            rf.setFeatureSubsetStrategy(strategy);
        }
        for (String strategy : invalidStrategies = new String[]{"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}) {
            Assert.assertThrows(IllegalArgumentException.class, () -> rf.setFeatureSubsetStrategy(strategy));
        }
        RandomForestClassificationModel model = (RandomForestClassificationModel)rf.fit(dataFrame);
        model.transform(dataFrame);
        model.totalNumNodes();
        model.toDebugString();
        model.trees();
        model.treeWeights();
        Vector importances = model.featureImportances();
    }
}

