Classification is a task of identifying the features of an entity and classifying the entity to one of the predefined classes/categories based on the previous knowledge.

A decision tree has a structure like tree. It has a root which denotes a decision node and also the start of classifying a problem instance. A node can branch out. Each branch represents a possible outcome from the decision block. Each branch can end up with another node or a class label terminating the classification and ending up with the result – class label.

In this Apache Spark Tutorial, we shall build a decision tree, like the one said above, from the training data using Decision Tree Algorithm in Apache Spark MLlib.

Classification using Decision Trees in Apache Spark MLlib with Java

Following is a step by step process to build a classifier using Decision Tree algorithm of MLLib :
Setup Java Project with Apache Spark

1. Configure Spark.

</>
Copy
SparkConf sparkConf = new SparkConf().setAppName("DecisionTreeExample");

2. Start a spark context.

</>
Copy
JavaSparkContext jsc = new JavaSparkContext(sparkConf);

3. Load Data and Split the data to be used for training and testing. The data file used in this example is present in the folder “data” in “apache spark“, downloaded from official website.

</>
Copy
// provide path to data transformed as [feature vectors]
String path = "data/mllib/sample_libsvm_data.txt";
JavaRDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
 
// split the data for training (60%) and testing (40%)
JavaRDD[] tmp = inputData.randomSplit(new double[]{0.6, 0.4});
JavaRDD trainingData = tmp[0]; // training set
JavaRDD testData = tmp[1]; // test set

4. Set the hyper parameters required by Decision Tree. impurity : impurity introduced into the feature values, to avoid Decision Tree model over-fitting the training data. maxDepth : Maximum number of node levels that can be created from root node by the Decision Tree algorithm during training. maxBins : Before even starting with the training a model, the training data is shuffled into bins. maxBins sets a limit on the number of data bins that could be created.

</>
Copy
int numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini";
int maxDepth = 5;
int maxBins = 32;

5. Train a Decision Tree model.

</>
Copy
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
          categoricalFeaturesInfo, impurity, maxDepth, maxBins);

6. Use the model to predict on the test data, and calculate accuracy. Decision Tree that is generated could be visualized by converting the tree to a readable string.

</>
Copy
// Predict for the test data using the model trained
JavaPairRDD<Double, Double> predictionAndLabel =
        testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
// calculate the accuracy
double accuracy =
        predictionAndLabel.filter(pl -> pl._1().equals(pl._2())).count() / (double) testData.count();
 
System.out.println("Accuracy is : "+accuracy);
System.out.println("Trained Decision Tree model:\n" + model.toDebugString());

7. Save the trained classifier model to local for future use.

</>
Copy
model.save(jsc.sc(), "myDecisionTreeClassificationModel");

8. Stop the spark context.

</>
Copy
jsc.stop();

Complete Java program is given below.

DecisionTreeClassifierExample.java

</>
Copy
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

import java.util.HashMap;
import java.util.Map;

import org.apache.spark.SparkConf;

/**
 * Classification using Decision Trees in Apache Spark MLlib with Java Example 
 */
public class DecisionTreeClassifierExample {

	public static void main(String[] args) {

		// configure spark
		SparkConf sparkConf = new SparkConf().setAppName("DecisionTreeExample")
										.setMaster("local[2]").set("spark.executor.memory","2g");
		// start a spark context
		JavaSparkContext jsc = new JavaSparkContext(sparkConf);
		
		// provide path to data transformed as [feature vectors]
		String path = "data/mllib/sample_libsvm_data.txt";
		JavaRDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
		
		// split the data for training (60%) and testing (40%)
		JavaRDD[] tmp = inputData.randomSplit(new double[]{0.6, 0.4});
		JavaRDD trainingData = tmp[0]; // training set
		JavaRDD testData = tmp[1]; // test set
		
	    // Set hyper parameters for Decision Tree algorithm
	    //  Empty categoricalFeaturesInfo indicates all features are continuous.
	    int numClasses = 2;
	    Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
	    String impurity = "gini";
	    int maxDepth = 5;
	    int maxBins = 32;
		
		// Train a Decision Tree model
	    DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
	    	      categoricalFeaturesInfo, impurity, maxDepth, maxBins);

		
		// Predict for the test data using the model trained
		JavaPairRDD<Double, Double> predictionAndLabel =
				testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
		// calculate the accuracy
		double accuracy =
				predictionAndLabel.filter(pl -> pl._1().equals(pl._2())).count() / (double) testData.count();
		
		System.out.println("Accuracy is : "+accuracy);
		System.out.println("Trained Decision Tree model:\n" + model.toDebugString());

		// Save model to local for future use
		model.save(jsc.sc(), "myDecisionTreeClassificationModel");

		// stop the spark context
		jsc.stop();
	}
}

Output

Accuracy is : 0.9787234042553191
Trained Decision Tree model:
DecisionTreeModel classifier of depth 1 with 3 nodes
  If (feature 434 <= 0.0) Predict: 0.0 Else (feature 434 > 0.0)
   Predict: 1.0

The hyper parameters set only the limits. The Decision Tree algorithm may optimize the tree by reducing the number of nodes and branches. And in this example, despite the maxDepth=5, the tree has been optimized to a depth of 1.