In this tutorial, we’ll explain the decision tree algorithm/model in machine learning.
Decision trees are powerful yet easy to implement and visualize. It’s a machine learning algorithm widely used for both supervised classification and regression problems.
Within this tutorial, you’ll learn:
- What are Decision Tree models/algorithms in Machine Learning.
- How the popular CART algorithm works, step-by-step.
Including splitting (impurity, information gain), stop condition, and pruning.
- How to create a predictive decision tree model in Python scikit-learn with an example.
- The advantages and disadvantages of decision trees.
- And other tips.
If you want to apply machine learning and present easily interpretable results, the decision tree model could be the option.
Let’s jump in!
Before beginning our decision tree algorithm tutorial, if you are not familiar with machine learning algorithms, please take a look at Machine Learning for Beginners: Overview of Algorithm Types.
If you are new to Python, please take our FREE Python crash course for data science to get a good foundation.
- What are Decision Tree models/algorithms in Machine Learning?
- Classification And Regression Tree (CART) Algorithm
- Recursive Binary Partitions
- Greedy Algorithm
- Cost/Impurity Function
- Information Gain
- Stop Condition
- Tree Pruning
- Classification Tree Example: Step-by-Step
- Step #1: Set up the training dataset based on the tasks.
- Step #2: Go through each feature and the possible splits.
- Step #3: Based on the impurity measures, choose the single best split.
- Step #4: Partition using the best splits recursively until the stopping condition is met.
- Step #5: Prune the decision tree.
- Decision Tree model Advantages and Disadvantages
- Python Example: sklearn DecisionTreeClassifier
What are Decision Tree models/algorithms in Machine Learning?
Decision trees are a non-parametric supervised learning algorithm for both classification and regression tasks. The algorithm aims at creating decision tree models to predict the target variable based on a set of features/input variables. Whether the target variable is of a discrete set of values or continuous values determines whether it’s a classification or regression tree. It’s one of the most popular machine learning algorithms.
The decision tree models built by the decision tree algorithms consist of nodes in a tree-like structure. The tree starts from the entire training dataset: the root node, and moves down to the branches of the internal nodes by a splitting process. Within each internal node, there is a decision function to determine the next path to take. For each observation, the prediction of the output or decision is made at the terminal nodes/leaves.
Given that we all have made countless decisions in our lives, the idea of decision trees should be intuitive.
For example, imagine we are thinking about today’s dinner, we would think logically and sequentially to check the options based on our situations. The question of what we’ll be having for dinner depends on different “variables” such as our feelings, the weather, and the food materials in the fridge.
For example, we could have a decision tree like below.
If someone collects all the past data about us in terms of dinner decisions, they can make a decision tree model to predict our choices!
Now with the basic understanding of decision trees, let’s dive into more details.
Classification And Regression Tree (CART) Algorithm
There are a couple of methods for implementing a decision tree in machine learning. We’re going to focus on the Classification And Regression Tree (CART), which was introduced by Breiman et al. in 1984. It’s also the method used by the Python scikit-learn (sklearn) library.
Before looking at the algorithm in steps with an example, let’s cover some background knowledge.
The critical procedure for growing a tree is splitting, which is partitioning the dataset into subsets.
How to split the dataset based on features?
Recursive Binary Partitions
Assume we have a target variable Y and two features X1 and X2. To make the resulting tree easy to interpret, we use a method called recursive binary partitions.
Let’s see an example.
As shown in the charts below, we can first binary split at X1 = t1 as either X1 <= t1 or X1 > t1, then we can binary split these two partitions further with X2 = t2 and X1 = t3. Finally, the partition of X1 > t3 is further split by X2 = t4. Each of these resulting five regions R1, …, R5 is easy to represent in the tree structure.
Now that we know the splitting method, what’re the criteria to find the best pair of splitting feature and its split threshold (X and t)?
How does the tree algorithm determine the best split?
The decision tree algorithm minimizes a cost function/node impurity measure to find the best split using the so-called greedy algorithm. We are “greedy” to explore each possible split at the node until the best “local” split is found. Due to the greedy nature, the algorithm might not return the globally optimized tree.
So what’s the cost function or node impurity measure?
For each node of the decision tree, we are trying to have the most homogeneous nodes/leaves, or branches with samples with similar target values. The cost function measures the various/impurity of the split regarding the target. It’s different for regression and classification trees.
Let’s see the most used impurity measures.
For regression trees, we could use a simple sum of squared errors, which is commonly used for predictive models:
sum_over_observation ( y – y_hat )²
where y is the actual output value, and y_hat is the predicted output based on the split partitions.
For classification trees, we could use either the Gini index or Cross-entropy/deviance to grow the trees. Both target to measure the node impurity.
Gini = sum_over_k (pk*(1-pk))
Cross-Entropy = – sum_over_k (pk*log(pk))
where pk = (# of observations in class k in the node)/(# of observations in the same node)
= proportion of class k observations in a particular node
It might seem confusing. But we’ll be showing a calculation below. Keep reading to find out!
How do they measure the impurity?
Let’s use the Gini impurity index as an example. A partition with the perfect class “purity” has samples/observations of the same class, in which case pk is either 1 or 0 and Gini = 0. In contrast, the partition with the highest “impurity” has a 50-50 split of classes for a binary target, so pk = 0.5 and Gini will have a maximum value.
Equal to minimizing these impurity measures, we can also think about maximizing the information gain, which is the impurity of the parent node minus the weighted average impurity of its child nodes. We can think of the best split being the feature and threshold that yield the largest information gain at the node.
Now that we learned how to find the best split, the process will be similar for the next split. For the resulting two partitions, we’ll further split them, and so on.
We’re familiar with the process of growing the decision tree model!
How do we know when to stop?
There are many different stop criteria to use. The popular ones are setting:
- the minimum observations within each node or
- the maximum levels (depths) of the tree or
- the minimum decrease of the cost function due to the split.
After building a tree model, is there anything else we need to consider?
Decision trees tend to overfit, which results in the model being too fit to a particular sample but not giving good predictions for new datasets.
To avoid overfitting and increase prediction accuracy, the fully grown tree needs to be pruned by removing some of the nodes. One way of doing this is called minimal cost-complexity pruning. Within this algorithm, we try to find the subtree of the original tree that minimizes the following equation:
R_alpha(T) = R(T) + alpha*|T|
- alpha: the complexity parameter.
- T: a pruned subtree of the original tree.
- |T|: the number of terminal nodes in T.
- R(T) has traditionally been misclassification error of the tree T. But sklearn defines it as the total sample weighted impurity of the terminal nodes.
We won’t go into details, but the higher the value of alpha, the more alpha*|T| penalizes the larger tree with more nodes, then the smaller the pruned tree. The value of alpha can be tuned by cross-validation.
That’s a lot of theories. Let’s see an example with calculations.
Classification Tree Example: Step-by-Step
The above CART algorithm we learned can be summarized as five main steps. Let’s go through them with a simple example, with tips about Python sklearn.
Don’t worry about doing the detailed calculations. Python will handle those for us when we are building decision trees.
Step #1: Set up the training dataset based on the tasks.
To begin the analysis, we must identify the features (input variables) X and the target (output variable) y.
For example, say we have a dataset below. And we want to predict fraudulent transactions, i.e., classify the transactions as fraud or legit.
|Transaction Amount||Hour of Day||Merchant||Is Fraud|
What is the target?
Is Fraud is obviously our target. It’s a binary variable, which makes our decision tree a classification problem. The other three variables Transaction Amount, Hour of Day, and Merchant are the features.
Note: this is a simplified fraud detection dataset for demonstration. In reality, there would be a lot more features.
Step #2: Go through each feature and the possible splits.
Recall that there are two main data types: numerical and categorical variables. And categorical data can be further divided into nominal and ordinal, while the numerical data can also be further divided into discrete and continuous. So let’s consider each feature as one of the four types: continuous numerical, discrete numerical, nominal categorical, and ordinal categorical. Due to their nature, each type has a different number of possible splits.
After each split, the node needs to have at least one observation. We’ll ignore the criteria of splitting for now but instead consider all the possible splits for each feature.
Let’s look at their possible number of splits one-by-one.
- Continuous Numerical feature has n – 1 possible splits, where n is the number of observations.
For example, Transaction Amount is a continuous numerical feature. There are 7 observations, which means 6 possible splits:
– split 1: Transaction Amount < 23.54, Transaction Amount >= 23.54
– split 2: Transaction Amount < 251.46, Transaction Amount >= 251.46
– split 3: Transaction Amount < 425.87, Transaction Amount >= 425.87
– split 4: Transaction Amount < 526.01, Transaction Amount >= 526.01
– split 5: Transaction Amount < 827.45, Transaction Amount >= 827.45
– split 6: Transaction Amount < 1030.98 , Transaction Amount >= 1030.98
- Discrete Numerical or Ordinal Categorical feature has k – 1 possible splits, where k is the number of distinct values.
For example, Hour of Day has distinct values 10, 11, 12, which means k = 3 and there are 2 possible splits:
– split 1: Hour of Day < 11, Hour of Day >= 11
– split 2: Hour of Day < 12, Hour of Day >= 12
- Categorical Nominal feature has 2^(k-1) – 1 possible splits, where k is the number of distinct values or categories.
For example, Merchant is a nominal categorical feature with k = 4, so there are 2^(4-1) – 1 = 7 possible splits:
– split 1: Merchant = [Restaurant], Merchant = [Gas Station, Online Store, Physical Store]
– split 2: Merchant = [Gas Station], Merchant = [Restaurant, Online Store, Physical Store]
– split 3: Merchant = [Online Store], Merchant = [Restaurant, Gas Station, Physical Store]
– split 4: Merchant = [Physical Store], Merchant = [Restaurant, Gas Station, Online Store]
– split 5: Merchant = [Restaurant, Gas Station], Merchant = [Online Store, Physical Store]
– split 6: Merchant = [Restaurant, Online Store], Merchant = [Gas Station, Physical Store]
– split 7: Merchant = [Restaurant, Physical Store], Merchant = [Gas Station, Online Store]
As you can see, for nominal categorical features, the number of possible splits (2^(k-1) – 1) grows exponentially with the number of distinct values k. It’s more challenging to deal with this type of data.
Let’s say we have a feature in our dataset that is categorical and has 1000 distinct values. Each distinct value has 100 observations.
So the total number of observations in the dataset is 1000*100 = 100,000. This is not a large dataset and should fit into the memory of a modern laptop. However, the number of possible splits for the categorical feature is 2^(1000-1) – 1, which is large that will take a long computing time.
Fortunately, we don’t need to look through all the possible splits for classification problems with a binary target (such as fraud or legit) and regression problems. There is an algorithm that efficiently finds the best split (See details on page 310 of The Elements of Statistical Learning).
Also, note that Scikit-learn currently doesn’t use this algorithm, so it doesn’t handle categorical features very well. Thus, if we have a lot of categorical features with high cardinality, it would be better to use a different package such as LightGBM and CatBoost.
Step #3: Based on the impurity measures, choose the single best split.
We’ve covered all the possible splits for the features. But a decision tree can only have one split for each step.
As mentioned in the last section, we can use either the Gini or cross-entropy impurity measures as the criteria. But Gini is the default of Python scikit-learn library, so we’ll use the Gini Index for our example below.
Assume Dm is the dataset to be partitioned at node m. We need to compare the Ginis before and after splits.
The Gini Impurity function for Dm (before any splits) is below:
H(Dm) = Gini(Dm) = sum_over_k(pk*(1-pk))
where pk = (# of observations that are in class k in Dm)/(# of observations in Dm)
When we do a split, the dataset at the parent node would be partitioned into two (e.g., left and right partitions). We define the Gini resulting from any split as the weighted average of the Gini’s of the resulting partitions:
H(Dm_split) = (# of observations of left partition) / (# of observations in Dm) * H(Dm_left) + (# of observations of right partition) / (# of observations in Dm) * H(Dm_right)
We would calculate the above Ginis for all the possible splits and find the split (feature and threshold) that results in the lowest H(Dm_split). The lowest H(Dm_split) should be less than H(Dm), otherwise we wouldn’t make the split.
It might still look a little confusing. But don’t worry, let’s see an example of calculation from our fraud dataset.
The root node of the tree is node 0 with the whole training dataset (D0) to be partitioned. Recall that our target Is Fraud has two classes 0 and 1. D0 has three observations with the target being 0 and four observations being 1. So p0 = 3/7 and p1 = 4/7.
The Gini Impurity before any splits can be calculated as:
H(D0) = p0*(1-p0) + p1*(1-p1) = (3/7*4/7) + (4/7*3/7) = 0.490
As shown in the previous step, there are many possible splits of D0, But let’s consider a split of D0 using the feature Transaction Amount (< 425.87 or >= 425.87) as below.
The resulting partitions are:
- D0_left: # of observations = 3, p0 = 2/3, p1 = 1/3.
- D0_right: # of observations = 4, p0 = 1/4, p1 = 3/4.
We can also calculate the Gini impurities for these two datasets:
- H(D0_left) = (2/3*1/3) + (1/3*2/3) = 0.44
- H(D0_right) = (1/4*3/4) + (3/4*1/4) = 0.38
So plugging these values into the split partitions formula, we have this particular split’s Gini impurity:
H(D0_split) = 3/7*0.44 + 4/7*0.38 = 0.41
Remember that this is only the calculation for one of the possible splits. We need to calculate the same metric for all the possible splits and choose the one with the lowest impurity.
Step #4: Partition using the best splits recursively until the stopping condition is met.
As mentioned, there are different stop conditions. And the algorithm stops when one of the conditions is met. We can set them by using the input parameter of Python scikit-learn DecisionTreeClassifier. Some popular ones are:
- min_samples_split: the minimum number of observations (samples) in the node for us to do the split.
- min_samples_leaf: the minimum observations required in a terminal node. This means that a partition resulting from a split cannot have less than this amount of observations.
- max_depth: the maximum depth of the tree.
- min_impurity_decrease: the minimum amount of decrease of the impurity due to the split.
Step #5: Prune the decision tree.
As mentioned in the last section, Python scikit-learn uses the cost-complexity pruning technique. For details, visit scikit-learn Minimal Cost-Complexity Pruning.
Decision Tree model Advantages and Disadvantages
In this section, we’ll cover the advantages and disadvantages of the decision tree model.
Why are decision trees so popular?
Let’s look at its main pros:
- Easier to understand and interpret.
Compared to other models, it’s simpler to visualize, explain, and apply the results of a decision tree, even to people without a data science background.
- Requires little data preparation.
Unlike other techniques, decision trees are more robust to outliers and don’t need scaling of the features.
- Fewer assumptions needed.
Linear and logistic regression assume linear relationships. Decision trees don’t make these assumptions and can learn non-linear functions by itself.
- Built-in feature selection.
If a feature is not useful, it won’t show up much in the decision tree model. The hierarchy of a decision tree model reflects the importance of features. The features on the top are more informative. For the less informative features, we can potentially remove them on subsequent runs.
However, decision trees also have some disadvantages that we need to be aware of. The main cons are:
- Relatively unstable.
The structure of the tree is less robust than the linear/logistic regression. If we take another sample of data from the same population, it could result in a tree with big differences and also different prediction results.
This is due to the tree’s hierarchical nature. Small differences in the training dataset can lead to different splits at the top, and these differences affect all the child nodes. These differences add up for the deeper nodes.
This can be alleviated by ensembling the trees (e.g., Random Forests or Gradient Boosted decision trees), with the sacrifice of being harder to interpret.
- More prone to overfitting.
Decision tree models are more likely to have overfitting problems. Most of the time, it’s necessary to set stop conditions, prune the trees, and use cross-validation techniques to avoid this problem.
Further Reading: Practical Guide to Cross-Validation in Machine Learning
Python Example: sklearn DecisionTreeClassifier
Since we’ve been talking about Python scikit-learn library, let’s see a quick example of building a decision tree model.
First, we import the packages/functions needed for building and plotting decision trees.
Next, we read in the sample dataset – breast cancer data.
Then we can build the decision tree model and visualize the result. We are using breast_cancer with binary values (benign or malignant) as the target.
As you can see, a decision tree model of maximum depth being three is below.
This is only a demonstration. We didn’t explore or clean the data. But in reality, we need to apply these standard procedures plus pruning the tree. For more details on tips of implementation, visit scikit-learn.
You’ve learned a lot about the popular decision tree algorithms in machine learning. Hope you are ready to build your decision tree models.
Leave a comment for any questions you may have or anything else.
This is a complete tutorial to machine learning algorithm types for ML beginners. Start learning ML with this overview, including a list of popular algorithms.
A FREE Python online course, beginner-friendly tutorial. Start your successful data science career journey: learn Python for data science, machine learning.
In this post, we apply machine learning algorithms on YouTube data in Python. We will include the end-to-end process of: – Scraping the data – Using NLP on the video titles – Feature engineering – Building predictive decision trees – And more.