RandomForest Classification Example using Spark MLlib

RandomForest Classification Example using Spark MLlib – In this tutorial, we shall see how to train and generate a model using RandomForest classifier. And use this generated model on test to predict the categories and calculate Test Error and Accuracy of the model.

Training using Random Forest classifier

Spark MLlib understands only numbers. So, the training data should be prepared in a way that MLlib understands. Preparing the training data is the most important step that decides the accuracy a model. And this includes the following

  1. Identify the categories. And index the categories.
  2. Identify the features. And index the features.
  3. Transform the experiments/observations/examples using indexes of categories and features

Note: Feature values could be discrete or continuous. Comments have been provided in the program to make some of the features discrete and others as continuous. With this as reference, features could be configured as per your requirement.

Download the source code of the ongoing example here, RandomForestExampleAttachment. For setting up java project to work with spark MLlib , please refer Create Java Project with Apache Spark.

ADVERTISEMENT

Sample Training Data for Random Forest

Below is the sample of transformed and ready to be fed, to the RandomForest, to train on. Each row represents an experiment/observation/example. The format of each row is [category feature1:value feature2:value ..]

Training data: trainingValues.txt

0 1:1 2:1 3:1 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:2 2:1 3:5 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:1 2:3 3:1 4:1 5:1 6:1

Below is the java class, RandomForestTrainerExample.java, that trains a model and saves it to local.

Trainer Class: RandomForestTrainerExample.java

package com.tut;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;

import java.util.HashMap;

import org.apache.commons.io.FileUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;

/** RandomForest Classification Example using Spark MLlib
 * @author tutorialkart.com
 */
public class RandomForestTrainerExample {
	
public static void main(String[] args) {
	// hadoop home dir [path to bin folder containing winutils.exe]
	System.setProperty("hadoop.home.dir", "D:\\Arjun\\ml\\hadoop\\");
	  
    // Configuring spark
    SparkConf sparkConf = new SparkConf().setAppName("RandomForestExample")
    		.setMaster("local[2]")
    		.set("spark.executor.memory","3g")
    		.set("spark.driver.memory", "3g");
    
    // initializing the spark context
    JavaSparkContext jsc = new JavaSparkContext(sparkConf);
    
    // Load and parse the data file.
    String datapath = "data"+File.separator+"trainingValues.txt";
    JavaRDD trainingData;
	try {
		trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
	} catch (Exception e1) {
		System.out.println("No training data available.");
		e1.printStackTrace();
		return;
	}
	
    // Configuration/Hyper parameters to train random forest model
	Integer numClasses = 3;
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    HashMap<Integer, Integer> categoricalFeaturesInfo =new HashMap<Integer, Integer>(){{
    		put(0,3);	// feature 0 is considered discrete, with values from 0 to 9
    		put(1,7);  // feature 1 is considered discrete, with values from 0 to 6
    		put(2,10);  // feature 2 is considered discrete, with values from 0 to 9
    		// feature 3 is considered continuous valued
    		put(4,10);  // feature 4 is considered discrete, with values from 0 to 9
    		// feature 5 is considered continuous valued
    }};
    Integer numTrees = 3; // number of decision trees to be included in the Random Forest
    String featureSubsetStrategy = "auto"; // Let the algorithm choose, which set of features to be made as subsets
    String impurity = "gini";	// adds impurity to the experiments/samples in the training data : gini is a choice
    Integer maxDepth = 30;	// maximum depth of a decision tree that can grow
    Integer maxBins = 10;	// classifier first splits the training data into number of bins, and this parameter decides the maximum number of bins
    Integer seed = 12345;	// classifier introduces some randomization, and for this randomization to be same across iterations, same seed is used in all the iterations inside classifier 

    // training the classifier with all the hyper-parameters defined above
    final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
      categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
      seed);

    System.out.print("");
    // Delete if model already present, and Save the new model
	try {
		FileUtils.forceDelete(new File("RandForestClsfrMdl"));
		System.out.println("\nDeleting old model completed.");
	} catch (FileNotFoundException e1) {
	} catch (IOException e) {
	}
	// saving the random forest model that is generated
    model.save(jsc.sc(), "RandForestClsfrMdl"+File.separator+"model");
    System.out.println("\nRandForestClsfrMdl/model has been created and successfully saved.");
    
    // printing the random forest model (collection of decision trees)
    System.out.println(model.toDebugString());
    
    jsc.stop();
    
  }
}

When the above java class is run, a model is generated, with three decision trees which are shown in the below output.

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".                
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
                                                                                
RandForestClsfrMdl/model has been created and successfully saved.
TreeEnsembleModel classifier with 3 trees

  Tree 0:
    If (feature 5 <= 6.0)
     If (feature 0 in {1.0})
      If (feature 1 in {3.0})
       Predict: 1.0
      Else (feature 1 not in {3.0})
       If (feature 5 <= 2.0) If (feature 2 in {1.0}) Predict: 0.0 Else (feature 2 not in {1.0}) Predict: 1.0 Else (feature 5 > 2.0) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 5 > 6.0)
     Predict: 2.0
  Tree 1:
    If (feature 5 <= 6.0) If (feature 0 in {1.0}) If (feature 2 in {6.0}) Predict: 1.0 Else (feature 2 not in {6.0}) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 5 > 6.0)
     If (feature 3 <= 1.0) Predict: 0.0 Else (feature 3 > 1.0)
      Predict: 2.0
  Tree 2:
    If (feature 3 <= 1.0) If (feature 2 in {5.0,6.0}) Predict: 1.0 Else (feature 2 not in {5.0,6.0}) If (feature 0 in {1.0}) If (feature 1 in {1.0}) Predict: 0.0 Else (feature 1 not in {1.0}) If (feature 1 in {3.0}) Predict: 1.0 Else (feature 1 not in {3.0}) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 3 > 1.0)
     Predict: 2.0

From the above random forest, following observation could be made:. features : 0,1,2,4 are considered discrete as [feature 2 not in {5.0,6.0}]. features : 3,5 are considered continuous as [feature 5 > 6.0]

Possible exceptions during training:

One might come across some of the exceptions below, which has to be taken care of

java.lang.IllegalArgumentException – requirement failed – DecisionTree requires maxBins

When  maxBins = 2   andmaximum number of discrete values for a feature in our training data is : 10Exception in thread “main” java.lang.IllegalArgumentException: requirement failed: DecisionTree requires maxBins (=2) to be at least as large as the number of values in each categorical feature, but categorical feature 2 has 10 values. Considering remove this and other categorical features with a large number of values, or add more training examples.

Solution : Provide maxBins with value >= max(maximum discrete value + 1) among all the features with discrete values.

java.lang.IllegalArgumentException: GiniAggregator given label

When numClasses = 2  andtraining data has three categories [0,1,2]Caused by: java.lang.IllegalArgumentException: GiniAggregator given label 2.0 but requires label < numClasses (= 2).

Solution : Provide numClasses with value >= number of categories in the training data.

Prediction using the saved model from the above Random Forest Classification Example using Spark MLlib – Training part:

Sample of the test data is shown below. Little observation reveals that the format of the test data is same as that of training data.

0 1:1 2:4 3:1 4:1 5:1 6:3
0 1:1 2:1 3:1 4:1 5:1 6:6
1 1:2 2:1 3:5 4:1 5:1 6:6
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:2 2:3 3:1 4:1 5:1 6:1
2 1:2 2:6 3:9 4:6 5:1 6:8
2 1:2 2:6 3:9 4:4 5:1 6:8

Prediction using the model generated during training :

Predictor Class : RandomForestPredictor.java

package com.tut;

import scala.Tuple2;

import java.io.File;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;

/** RandomForest Classification Example using Spark MLlib
 * @author tutorialkart.com
 */
public class RandomForestPredictor {
	static RandomForestModel model;

	public static void main(String[] args) {
		// hadoop home dir [path to bin folder containing winutils.exe]
		System.setProperty("hadoop.home.dir", "D:\\Arjun\\ml\\hadoop\\");

		// Configuring spark
		SparkConf sparkConf1 = new SparkConf().setAppName("RandomForestExample")
				.setMaster("local[2]")
				.set("spark.executor.memory","3g")
				.set("spark.driver.memory", "3g");
		
		// initializing the spark context
		JavaSparkContext jsc = new JavaSparkContext(sparkConf1);
		
		// loading the model, that is generated during training
		model = RandomForestModel.load(jsc.sc(),"RandForestClsfrMdl"+File.separator+"model");
		
		// Load and parse the test data file.
		String datapath = "data"+File.separator+"testValues.txt";
		JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
		
		System.out.println("\nPredicted : Expected");
		
		// Evaluate model on test instances and compute test error
		JavaPairRDD<Double, Double> predictionAndLabel =
				data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
					@Override
					public Tuple2<Double, Double> call(LabeledPoint p) {
						System.out.println(model.predict(p.features())+" : "+p.label());
						return new Tuple2<>(model.predict(p.features()), p.label());
					}
				});
		
		// compute error of the model to predict the categories for test samples/experiments 
		Double testErr =
				1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
					@Override
					public Boolean call(Tuple2<Double, Double> pl) {
						return !pl._1().equals(pl._2());
					}
				}).count() / data.count();
		System.out.println("Test Error: " + testErr);

		jsc.stop();
	}

	private static PairFunction<LabeledPoint, Double, Double> pf =  new PairFunction<LabeledPoint, Double, Double>() {
		@Override
		public Tuple2<Double, Double> call(LabeledPoint p) {
			Double prediction= null;
			try {
				prediction = model.predict(p.features());
			} catch (Exception e) {
				//logger.error(ExceptionUtils.getStackTrace(e));
				e.printStackTrace();
			}
			System.out.println(prediction+" : "+p.label());
			return new Tuple2<>(prediction, p.label());
		}
	};
	
	private static Function<Tuple2<Double, Double>, Boolean> f = new Function<Tuple2<Double, Double>, Boolean>() {
		@Override
		public Boolean call(Tuple2<Double, Double> pl) {
			return !pl._1().equals(pl._2());
		}
	};
}

Output

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
                                                                                
Predicted : Expected
1.0 : 1.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
1.0 : 1.0
1.0 : 1.0
2.0 : 2.0
0.0 : 0.0
2.0 : 1.0
2.0 : 2.0
0.0 : 0.0
2.0 : 2.0
0.0 : 0.0
2.0 : 2.0
1.0 : 1.0
0.0 : 0.0
Test Error: 0.047619047619047616

For the test data, we provided, the model is able to predict 100-0.047 = 95.3% accurately. Since test error = 0.047 = 4.7% inaccurate.

Conclusion

In this Apache Spark Tutorial – RandomForest Classification Example using Spark MLlib, we have learned how to train and predict for a classification problem using RandomForest Classification Example in Apache Spark MLlib.