package com.rapidminer.operator.learner.functions;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.math.optimization.ec.es.ESOptimization;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/learner/functions/LogisticRegression.class */
public class LogisticRegression extends AbstractLearner {
    public static final String PARAMETER_ADD_INTERCEPT = "add_intercept";
    public static final String PARAMETER_RETURN_PERFORMANCE = "return_model_performance";
    public static final String PARAMETER_START_POPULATION_TYPE = "start_population_type";
    public static final String PARAMETER_MAX_GENERATIONS = "max_generations";
    public static final String PARAMETER_GENERATIONS_WITHOUT_IMPROVAL = "generations_without_improval";
    public static final String PARAMETER_POPULATION_SIZE = "population_size";
    public static final String PARAMETER_TOURNAMENT_FRACTION = "tournament_fraction";
    public static final String PARAMETER_KEEP_BEST = "keep_best";
    public static final String PARAMETER_MUTATION_TYPE = "mutation_type";
    public static final String PARAMETER_SELECTION_TYPE = "selection_type";
    public static final String PARAMETER_CROSSOVER_PROB = "crossover_prob";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";
    public static final String PARAMETER_SHOW_CONVERGENCE_PLOT = "show_convergence_plot";
    private PerformanceVector estimatedPerformance;

    public LogisticRegression(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        LogisticRegressionOptimization logisticRegressionOptimization = new LogisticRegressionOptimization(exampleSet, getParameterAsBoolean(PARAMETER_ADD_INTERCEPT), getParameterAsInt("start_population_type"), getParameterAsInt("max_generations"), getParameterAsInt("generations_without_improval"), getParameterAsInt("population_size"), getParameterAsInt("selection_type"), getParameterAsDouble("tournament_fraction"), getParameterAsBoolean("keep_best"), getParameterAsInt("mutation_type"), getParameterAsDouble("crossover_prob"), getParameterAsBoolean("show_convergence_plot"), RandomGenerator.getRandomGenerator(getParameterAsInt("local_random_seed")), this);
        LogisticRegressionModel train = logisticRegressionOptimization.train();
        this.estimatedPerformance = logisticRegressionOptimization.getPerformance();
        return train;
    }

    @Override // com.rapidminer.operator.learner.AbstractLearner, com.rapidminer.operator.learner.Learner
    public boolean shouldEstimatePerformance() {
        return getParameterAsBoolean(PARAMETER_RETURN_PERFORMANCE);
    }

    @Override // com.rapidminer.operator.learner.AbstractLearner, com.rapidminer.operator.learner.Learner
    public PerformanceVector getEstimatedPerformance() throws OperatorException {
        if (!getParameterAsBoolean(PARAMETER_RETURN_PERFORMANCE) || this.estimatedPerformance == null) {
            throw new UserError(this, 912, getName(), "could not deliver optimization performance.");
        }
        return this.estimatedPerformance;
    }

    @Override // com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        return learnerCapability == LearnerCapability.NUMERICAL_ATTRIBUTES || learnerCapability == LearnerCapability.BINOMINAL_CLASS || learnerCapability == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_ADD_INTERCEPT, "Determines whether to include an intercept.", true));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_RETURN_PERFORMANCE, "Determines whether to return the performance.", false));
        parameterTypes.add(new ParameterTypeCategory("start_population_type", "The type of start population initialization.", ESOptimization.POPULATION_INIT_TYPES, 0));
        parameterTypes.add(new ParameterTypeInt("max_generations", "Stop after this many evaluations", 1, Integer.MAX_VALUE, 10000));
        parameterTypes.add(new ParameterTypeInt("generations_without_improval", "Stop after this number of generations without improvement (-1: optimize until max_iterations).", -1, Integer.MAX_VALUE, 300));
        parameterTypes.add(new ParameterTypeInt("population_size", "The population size (-1: number of examples)", -1, Integer.MAX_VALUE, 3));
        parameterTypes.add(new ParameterTypeDouble("tournament_fraction", "The fraction of the population used for tournament selection.", 0.0d, Double.POSITIVE_INFINITY, 0.75d));
        parameterTypes.add(new ParameterTypeBoolean("keep_best", "Indicates if the best individual should survive (elititst selection).", true));
        parameterTypes.add(new ParameterTypeCategory("mutation_type", "The type of the mutation operator.", ESOptimization.MUTATION_TYPES, 1));
        parameterTypes.add(new ParameterTypeCategory("selection_type", "The type of the selection operator.", ESOptimization.SELECTION_TYPES, 6));
        parameterTypes.add(new ParameterTypeDouble("crossover_prob", "The probability for crossovers.", 0.0d, 1.0d, 1.0d));
        parameterTypes.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global).", -1, Integer.MAX_VALUE, -1));
        parameterTypes.add(new ParameterTypeBoolean("show_convergence_plot", "Indicates if a dialog with a convergence plot should be drawn.", false));
        return parameterTypes;
    }
}
