RandomForest Classification Example using Spark MLlib
RandomForest Classification Example using Spark MLlib – In this tutorial, we shall see how to train and generate a model using RandomForest classifier. And use this generated model on test to predict the categories and calculate Test Error and Accuracy of the model.
Training using Random Forest classifier
Spark MLlib understands only numbers. So, the training data should be prepared in a way that MLlib understands. Preparing the training data is the most important step that decides the accuracy a model. And this includes the following
- Identify the categories. And index the categories.
- Identify the features. And index the features.
- Transform the experiments/observations/examples using indexes of categories and features
Note: Feature values could be discrete or continuous. Comments have been provided in the program to make some of the features discrete and others as continuous. With this as reference, features could be configured as per your requirement.
Download the source code of the ongoing example here, RandomForestExampleAttachment. For setting up java project to work with spark MLlib , please refer Create Java Project with Apache Spark.
Sample Training Data for Random Forest
Below is the sample of transformed and ready to be fed, to the RandomForest, to train on. Each row represents an experiment/observation/example. The format of each row is [category feature1:value feature2:value ..]
Training data: trainingValues.txt
0 1:1 2:1 3:1 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:2 2:1 3:5 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:1 2:3 3:1 4:1 5:1 6:1
Below is the java class, RandomForestTrainerExample.java, that trains a model and saves it to local.
Trainer Class: RandomForestTrainerExample.java
package com.tut;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.HashMap;
import org.apache.commons.io.FileUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;
/** RandomForest Classification Example using Spark MLlib
* @author tutorialkart.com
*/
public class RandomForestTrainerExample {
public static void main(String[] args) {
// hadoop home dir [path to bin folder containing winutils.exe]
System.setProperty("hadoop.home.dir", "D:\\Arjun\\ml\\hadoop\\");
// Configuring spark
SparkConf sparkConf = new SparkConf().setAppName("RandomForestExample")
.setMaster("local[2]")
.set("spark.executor.memory","3g")
.set("spark.driver.memory", "3g");
// initializing the spark context
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "data"+File.separator+"trainingValues.txt";
JavaRDD trainingData;
try {
trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
} catch (Exception e1) {
System.out.println("No training data available.");
e1.printStackTrace();
return;
}
// Configuration/Hyper parameters to train random forest model
Integer numClasses = 3;
// Empty categoricalFeaturesInfo indicates all features are continuous.
HashMap<Integer, Integer> categoricalFeaturesInfo =new HashMap<Integer, Integer>(){{
put(0,3); // feature 0 is considered discrete, with values from 0 to 9
put(1,7); // feature 1 is considered discrete, with values from 0 to 6
put(2,10); // feature 2 is considered discrete, with values from 0 to 9
// feature 3 is considered continuous valued
put(4,10); // feature 4 is considered discrete, with values from 0 to 9
// feature 5 is considered continuous valued
}};
Integer numTrees = 3; // number of decision trees to be included in the Random Forest
String featureSubsetStrategy = "auto"; // Let the algorithm choose, which set of features to be made as subsets
String impurity = "gini"; // adds impurity to the experiments/samples in the training data : gini is a choice
Integer maxDepth = 30; // maximum depth of a decision tree that can grow
Integer maxBins = 10; // classifier first splits the training data into number of bins, and this parameter decides the maximum number of bins
Integer seed = 12345; // classifier introduces some randomization, and for this randomization to be same across iterations, same seed is used in all the iterations inside classifier
// training the classifier with all the hyper-parameters defined above
final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
seed);
System.out.print("");
// Delete if model already present, and Save the new model
try {
FileUtils.forceDelete(new File("RandForestClsfrMdl"));
System.out.println("\nDeleting old model completed.");
} catch (FileNotFoundException e1) {
} catch (IOException e) {
}
// saving the random forest model that is generated
model.save(jsc.sc(), "RandForestClsfrMdl"+File.separator+"model");
System.out.println("\nRandForestClsfrMdl/model has been created and successfully saved.");
// printing the random forest model (collection of decision trees)
System.out.println(model.toDebugString());
jsc.stop();
}
}
When the above java class is run, a model is generated, with three decision trees which are shown in the below output.
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
RandForestClsfrMdl/model has been created and successfully saved.
TreeEnsembleModel classifier with 3 trees
Tree 0:
If (feature 5 <= 6.0)
If (feature 0 in {1.0})
If (feature 1 in {3.0})
Predict: 1.0
Else (feature 1 not in {3.0})
If (feature 5 <= 2.0) If (feature 2 in {1.0}) Predict: 0.0 Else (feature 2 not in {1.0}) Predict: 1.0 Else (feature 5 > 2.0) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 5 > 6.0)
Predict: 2.0
Tree 1:
If (feature 5 <= 6.0) If (feature 0 in {1.0}) If (feature 2 in {6.0}) Predict: 1.0 Else (feature 2 not in {6.0}) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 5 > 6.0)
If (feature 3 <= 1.0) Predict: 0.0 Else (feature 3 > 1.0)
Predict: 2.0
Tree 2:
If (feature 3 <= 1.0) If (feature 2 in {5.0,6.0}) Predict: 1.0 Else (feature 2 not in {5.0,6.0}) If (feature 0 in {1.0}) If (feature 1 in {1.0}) Predict: 0.0 Else (feature 1 not in {1.0}) If (feature 1 in {3.0}) Predict: 1.0 Else (feature 1 not in {3.0}) Predict: 0.0 Else (feature 0 not in {1.0}) Predict: 1.0 Else (feature 3 > 1.0)
Predict: 2.0
From the above random forest, following observation could be made:
. features : 0,1,2,4 are considered discrete as [feature 2 not in {5.0,6.0}]
. features : 3,5 are considered continuous as [feature 5 > 6.0]
Possible exceptions during training:
One might come across some of the exceptions below, which has to be taken care of
java.lang.IllegalArgumentException – requirement failed – DecisionTree requires maxBins
When maxBins = 2 and
maximum number of discrete values for a feature in our training data is : 10
Exception in thread “main” java.lang.IllegalArgumentException: requirement failed: DecisionTree requires maxBins (=2) to be at least as large as the number of values in each categorical feature, but categorical feature 2 has 10 values. Considering remove this and other categorical features with a large number of values, or add more training examples.
Solution : Provide maxBins with value >= max(maximum discrete value + 1) among all the features with discrete values.
java.lang.IllegalArgumentException: GiniAggregator given label
When numClasses = 2 and
training data has three categories [0,1,2]
Caused by: java.lang.IllegalArgumentException: GiniAggregator given label 2.0 but requires label < numClasses (= 2).
Solution : Provide numClasses with value >= number of categories in the training data.
Prediction using the saved model from the above Random Forest Classification Example using Spark MLlib – Training part:
Sample of the test data is shown below. Little observation reveals that the format of the test data is same as that of training data.
0 1:1 2:4 3:1 4:1 5:1 6:3
0 1:1 2:1 3:1 4:1 5:1 6:6
1 1:2 2:1 3:5 4:1 5:1 6:6
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:2 2:3 3:1 4:1 5:1 6:1
2 1:2 2:6 3:9 4:6 5:1 6:8
2 1:2 2:6 3:9 4:4 5:1 6:8
Prediction using the model generated during training :
Predictor Class : RandomForestPredictor.java
package com.tut;
import scala.Tuple2;
import java.io.File;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;
/** RandomForest Classification Example using Spark MLlib
* @author tutorialkart.com
*/
public class RandomForestPredictor {
static RandomForestModel model;
public static void main(String[] args) {
// hadoop home dir [path to bin folder containing winutils.exe]
System.setProperty("hadoop.home.dir", "D:\\Arjun\\ml\\hadoop\\");
// Configuring spark
SparkConf sparkConf1 = new SparkConf().setAppName("RandomForestExample")
.setMaster("local[2]")
.set("spark.executor.memory","3g")
.set("spark.driver.memory", "3g");
// initializing the spark context
JavaSparkContext jsc = new JavaSparkContext(sparkConf1);
// loading the model, that is generated during training
model = RandomForestModel.load(jsc.sc(),"RandForestClsfrMdl"+File.separator+"model");
// Load and parse the test data file.
String datapath = "data"+File.separator+"testValues.txt";
JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
System.out.println("\nPredicted : Expected");
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
System.out.println(model.predict(p.features())+" : "+p.label());
return new Tuple2<>(model.predict(p.features()), p.label());
}
});
// compute error of the model to predict the categories for test samples/experiments
Double testErr =
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / data.count();
System.out.println("Test Error: " + testErr);
jsc.stop();
}
private static PairFunction<LabeledPoint, Double, Double> pf = new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
Double prediction= null;
try {
prediction = model.predict(p.features());
} catch (Exception e) {
//logger.error(ExceptionUtils.getStackTrace(e));
e.printStackTrace();
}
System.out.println(prediction+" : "+p.label());
return new Tuple2<>(prediction, p.label());
}
};
private static Function<Tuple2<Double, Double>, Boolean> f = new Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
};
}
Output
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
Predicted : Expected
1.0 : 1.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
1.0 : 1.0
1.0 : 1.0
2.0 : 2.0
0.0 : 0.0
2.0 : 1.0
2.0 : 2.0
0.0 : 0.0
2.0 : 2.0
0.0 : 0.0
2.0 : 2.0
1.0 : 1.0
0.0 : 0.0
Test Error: 0.047619047619047616
For the test data, we provided, the model is able to predict 100-0.047 = 95.3% accurately. Since test error = 0.047 = 4.7% inaccurate.
Conclusion
In this Apache Spark Tutorial – RandomForest Classification Example using Spark MLlib, we have learned how to train and predict for a classification problem using RandomForest Classification Example in Apache Spark MLlib.