Machine Learning with R: A Complete Guide to Decision Trees
Updated: August 22, 2022.
R Decision Trees
R Decision Trees are among the most fundamental algorithms in supervised machine learning, used to handle both regression and classification tasks. In a nutshell, you can think of it as a glorified collection of if-else statements. What makes these if-else statements different from traditional programming is that the logical conditions are “generated” by the machine learning algorithm, but more on that later.
Interested in more basic machine learning guides? Check our detailed guide on Logistic Regression with R.
Today you’ll learn the basic theory behind the decision trees algorithm and also how to implement the algorithm in R.
Table of contents:
- Introduction to R Decision Trees
- Dataset Loading and Preparation
- Predictive Modeling with R Decision Trees
- Generating Predictions
- Summary of R Decision Trees
Introduction to R Decision Trees
Decision trees are intuitive. All they do is ask questions like is the gender male or is the value of a particular variable higher than some threshold. Based on the answers, either more questions are asked, or the classification is made. Simple!
To predict class labels, the decision tree starts from the root (root node). Calculating which attribute should represent the root node is straightforward and boils down to figuring out which attribute best separates the training records. The calculation is done with the gini impurity formula. It’s simple math but can get tedious to do manually if you have many attributes.
After determining the root node, the tree “branches out” to better classify all of the impurities found in the root node.
That’s why it’s common to hear decision tree = multiple if-else statements analogy. The analogy makes sense to a degree, but the conditional statements are calculated automatically. In simple words, the machine learns the best conditions for your data.
Let’s take a look at the following decision tree representation to drive these points further home:
As you can see, variables Outlook?, Humidity?, and Windy? are used to predict the dependent variable – Play.
You now know the basic theory behind the algorithm, and you’ll learn how to implement it in R next.
Dataset Loading and Preparation
There’s no machine learning without data, and there’s no working with data without libraries. You’ll need these ones to follow along:
library(caTools) library(rpart) library(rpart.plot) library(caret) library(Boruta) library(cvms) library(dplyr) head(iris)
As you can see, we’ll use the Iris dataset to build our decision tree classifier. This is how the first couple of lines look like (output from the
head() function call):
The dataset is pretty much familiar to anyone with a week of experience in data science and machine learning, so it doesn’t require a further introduction. Also, the dataset is as clean as they come, which will save us a lot of time in this section.
The only thing we have to do before continuing to predictive modeling is to split this dataset randomly into training and testing subsets. You can use the following code snippet to do a split in a 75:25 ratio:
set.seed(42) sample_split <- sample.split(Y = iris$Species, SplitRatio = 0.75) train_set <- subset(x = iris, sample_split == TRUE) test_set <- subset(x = iris, sample_split == FALSE)
And that’s it! Let’s start with modeling next.
Predictive Modeling with R Decision Trees
We’re using the
rpart library to build the model. The syntax for building models is identical to linear and logistic regression. You’ll need to put the target variable on the left and features on the right, separated with the ~ sign. If you want to use all features, put a dot (.) instead of feature names.
Also, don’t forget to specify
method = "class" since we’re dealing with a classification dataset here.
Here’s how to train the model:
model <- rpart(Species ~ ., data = train_set, method = "class") model
The output of calling
model is shown in the following image:
From this image alone, you can see the “rules” decision tree model used to make classifications. If you’d like a more visual representation, you can use the
rpart.plot package to visualize the tree:
You can see how many classifications were correct (in the train set) by examining the bottom nodes. The setosa was correctly classified every time, the versicolor was misclassified for virginica 5% of the time, and virginica was misclassified for versicolor 3% of the time. It’s a simple graph, but you can read everything from it.
Decision trees are also useful for examining feature importance, ergo, how much predictive power lies in each feature. You can use the
varImp() function to find out. The following snippet calculates the importance and sorts them descendingly:
importances <- varImp(model) importances %>% arrange(desc(Overall))
The results are shown in the image below:
varImp() doesn’t do it for you and you’re looking for something more advanced, look no further than Boruta.
Feature Importances with Boruta
Boruta is a feature ranking and selection algorithm based on the Random Forests algorithm. It will tell you if features in your dataset are relevant for making predictions. There are ways to adjust this “relevancy”, such as tweaking the P-value and other parameters, but that’s not something we’ll go over today.
A call to
boruta() function is identical to
part(), with the additional
doTrace parameter for limiting the console output. The code snippet below shows you how to find the importance, and how to print them sorted in descending order:
library(Boruta) boruta_output <- Boruta(Species ~ ., data = train_set, doTrace = 0) rough_fix_mod <- TentativeRoughFix(boruta_output) boruta_signif <- getSelectedAttributes(rough_fix_mod) importances <- attStats(rough_fix_mod) importances <- importances[importances$decision != "Rejected", c("meanImp", "decision")] importances[order(-importances$meanImp), ]
In case you want to present these results visually, the package has you covered:
plot(boruta_output, ces.axis = 0.7, las = 2, xlab = "", main = "Feature importance")
Look only for the green color – it means the feature is important. The red color would indicate the feature isn’t important, and blue represents the variable used by Boruta to determine importance, so these can be discarded. The higher the box plot on the Y-axis is, the more important the feature. It’s that easy!
You’ve built and explored the model so far, but there’s no use in it yet. The next section shows you how to make predictions on previously unseen data and evaluate the model.
Predicting new instances is now a trivial task. All you have to do is use the
predict() function and pass in the testing subset. Also, make sure to specify
type = "class" for everything to work correctly. Here’s an example:
preds <- predict(model, newdata = test_set, type = "class") preds
The results are shown in the following image:
But how good are these predictions? Let’s evaluate. The confusion matrix is one of the most commonly used metrics to evaluate classification models. In R, it also outputs values for other metrics, such as sensitivity, specificity, and others.
Here’s how you can print the confusion matrix:
And here are the results:
As you can see, there are some misclassifications in versicolor and virginica classes, similar to what we’ve seen in the training set. Overall, the model is just short of 90% accuracy, which is more than acceptable for a simple decision tree classifier.
But let’s be honest – the amount of details in the previous image is overwhelming. What if you want to display the confusion matrix only, and display it visually as a heatmap? That’s where the
cvms package comes in. It allows you to visually represent a tibble, which is just what we need.
Keep in mind the parameters in
plot_confusion_matrix() function – all are intuitive to understand, and the values are fetched from
cfm. Your’s might be different:
library(cvms) cm <- confusionMatrix(test_set$Species, preds) cfm <- as_tibble(cm$table) plot_confusion_matrix(cfm, target_col = "Reference", prediction_col = "Prediction", counts_col = "n")
Much better, isn’t it? Now you have something to present. Let’s wrap things up in the following section.
Summary of R Decision Trees
Decision trees are an excellent introductory algorithm to the whole family of tree-based algorithms. It’s commonly used as a baseline model, which more sophisticated tree-based algorithms (such as random forests and gradient boosting) need to outperform.
Today you’ve learned basic logic and intuition behind decision trees, and how to implement and evaluate the algorithm in R. You can expect the whole suite of tree-based algorithms covered soon, so stay tuned to the Appsilon blog if you want to learn more.
If you want to implement machine learning in your organization, you can always reach out to Appsilon for help.
- Machine Learning with R: A Complete Guide to Logistic Regression
- Machine Learning with R: A Complete Guide to Linear Regression
- What Can I Do With R? 6 Essential R Packages for Programmers
- AI for Good: ML Wildlife Image Classification to Analyze Camera Trap Datasets
- YOLO Algorithm and YOLO Object Detection: An Introduction