Decision trees are among the most fundamental algorithms in supervised machine learning, used to handle both regression and classification tasks. In a nutshell, you can think of it as a glorified collection of if-else statements, but more on that later.
Interested in more basic machine learning guides? Check our detailed guide on Logistic Regression with R.
Today you’ll learn the basic theory behind the decision trees algorithm and also how to implement the algorithm in R.
Navigate to a section:
Decision trees are intuitive. All they do is ask questions, like is the gender male or is the value of a particular variable higher than some threshold. Based on the answers, either more questions are asked, or the classification is made. Simple!
To predict class labels, the decision tree starts from the root (root node). Calculating which attribute should represent the root node is straightforward and boils down to figuring which attribute best separates the training records. The calculation is done with the gini impurity formula. It’s simple math, but can get tedious to do manually if you have many attributes.
After determining the root node, the tree “branches out” to better classify all of the impurities found in the root node.
That’s why it’s common to hear decision tree = multiple if-else statements analogy. The analogy makes sense to a degree, but the conditional statements are calculated automatically. In simple words, the machine learns the best conditions for your data.
Let’s take a look at the following decision tree representation to drive these points further home:
As you can see, variables Outlook?, Humidity?, and Windy? are used to predict the dependent variable – Play.
You now know the basic theory behind the algorithm, and you’ll learn how to implement it in R next.
There’s no machine learning without data, and there’s no working with data without libraries. You’ll need these ones to follow along:
As you can see, we’ll use the Iris dataset to build our decision tree classifier. This is how the first couple of lines look like (output from the
head() function call):
The dataset is pretty much familiar to anyone with a week of experience in data science and machine learning, so it doesn’t require further introduction. Also, the dataset is as clean as they come, which will save us a lot of time in this section.
The only thing we have to do before continuing to predictive modeling is to split this dataset randomly into training and testing subsets. You can use the following code snippet to do a split in 75:25 ratio:
And that’s it! Let’s start with modeling next.
We’re using the
rpart library to build the model. The syntax for building models is identical as with linear and logistic regression. You’ll need to put the target variable on the left and features on the right, separated with the ~ sign. If you want to use all features, put a dot (.) instead of feature names.
Also, don’t forget to specify
method = "class" since we’re dealing with a classification dataset here.
Here’s how to train the model:
The output of calling
model is shown in the following image:
From this image alone, you can see the “rules” decision tree model used to make classifications. If you’d like a more visual representation, you can use the
rpart.plot package to visualize the tree:
You can see how many classifications were correct (in the train set) by examining the bottom nodes. The setosa was correctly classified every time, the versicolor was misclassified for virginica 5% of the time, and virginica was misclassified for versicolor 3% of the time. It’s a simple graph, but you can read everything from it.
Decision trees are also useful for examining feature importance, ergo, how much predictive power lies in each feature. You can use the
varImp() function to find out. The following snippet calculates the importances and sorts them descendingly:
The results are shown in the image below:
You’ve built and explored the model so far, but there’s no use in it yet. The next section shows you how to make predictions on previously unseen data and evaluate the model.
Predicting new instances is now a trivial task. All you have to do is use the
predict() function and pass in the testing subset. Also, make sure to specify
type = "class" for everything to work correctly. Here’s an example:
The results are shown in the following image:
But how good are these predictions? Let’s evaluate. The confusion matrix is one of the most commonly used metrics to evaluate classification models. In R, it also outputs values for other metrics, such as sensitivity, specificity, and the others.
Here’s how you can print the confusion matrix:
And here are the results:
As you can see, there are some misclassifications in versicolor and virginica classes, similar to what we’ve seen in the training set. Overall, the model is just short of 90% accuracy, which is more than acceptable for a simple decision tree classifier.
Decision trees are an excellent introductory algorithm to the whole family of tree-based algorithms. It’s commonly used as a baseline model, which more sophisticated tree-based algorithms (such as random forests and gradient boosting) need to outperform.
Today you’ve learned basic logic and intuition behind decision trees, and how to implement and evaluate the algorithm in R. You can expect the whole suite of tree-based algorithms covered soon, so stay tuned to the Appsilon blog if you want to learn more.
If you want to implement machine learning in your organization, you can always reach out to Appsilon for help.
Appsilon is hiring for remote roles! See our Careers page for all open positions, including R Shiny Developers, Fullstack Engineers, Frontend Engineers, a Senior Infrastructure Engineer, and a Community Manager. Join Appsilon and work on groundbreaking projects with the world’s most influential Fortune 500 companies.