How to Make a Tree Plot Using Caret Package in R
Last Updated :
10 Jul, 2024
Tree-based methods are powerful tools for both classification and regression tasks in machine learning. The caret package in R provides a consistent interface for training, tuning, and evaluating various machine learning models, including decision trees. In this article, we will walk through the steps to create a tree plot using the caret package.
What is a Tree Plot?
A tree plot, also known as a decision tree plot is a visual representation of a decision tree used in machine learning and statistical analysis. Decision trees are a type of model used for classification and regression tasks, where data is split into branches based on feature values to make predictions.
Setting Up the Environment
First, ensure you have the necessary packages installed and loaded.
# Load the libraries
library(caret)
library(rpart)
library(rpart.plot)
Now we will discuss making a Tree Plot in the Caret Package using R Programming Language.
Step 1: Loading and Preparing the Data
We'll use the Iris dataset for this example. This dataset contains measurements of iris flowers from three different species.
R
# Load the Iris dataset
data(iris)
# Inspect the dataset
head(iris)
Output:
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
5 5.0 3.6 1.4 0.2 setosa
6 5.4 3.9 1.7 0.4 setosa
Step 2: Splitting the Data into Training and Testing Sets
We'll split the data into training and testing sets to evaluate the performance of our decision tree model.
R
# Set seed for reproducibility
set.seed(123)
# Split the data into training and testing sets
trainIndex <- createDataPartition(iris$Species, p = 0.8, list = FALSE)
trainData <- iris[trainIndex, ]
testData <- iris[-trainIndex, ]
Step 3: Training a Decision Tree Model
Using the train function from the caret package, we can train a decision tree model. We'll use the rpart method for this purpose.
R
# Train a decision tree model
model <- train(Species ~ ., data = trainData, method = "rpart")
# Print the model summary
print(model)
Output:
CART
120 samples
4 predictor
3 classes: 'setosa', 'versicolor', 'virginica'
No pre-processing
Resampling: Bootstrapped (25 reps)
Summary of sample sizes: 120, 120, 120, 120, 120, 120, ...
Resampling results across tuning parameters:
cp Accuracy Kappa
0.00 0.9398492 0.9086993
0.45 0.7426390 0.6253355
0.50 0.5557896 0.3665192
Accuracy was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.
Step 4: Plotting the Decision Tree
To visualize the trained decision tree, we'll use the rpart.plot function from the rpart.plot package.
R
# Plot the decision tree
rpart.plot(model$finalModel, main = "Decision Tree for Iris Dataset")
Output:
Make a Tree Plot in Caret PackageStep 5: Customizing the Tree Plot
The rpart.plot function offers several customization options to enhance the appearance of the tree plot. Let's explore some of these options.
R
# Customize the tree plot
rpart.plot(model$finalModel,
main = "Customized Decision Tree for Iris Dataset",
type = 3, # Type of plot
extra = 101, # Display extra information at the nodes
fallen.leaves = TRUE, # Put leaves at the bottom of the plot
shadow.col = "gray", # Add shadows for better visualization
box.palette = "Blues", # Use a color palette for the nodes
cex = 0.8) # Adjust the size of the text
Output:
Make a Tree Plot in Caret PackageConclusion
Creating and visualizing a decision tree model using the caret package in R is straightforward and highly customizable. By following the steps outlined in this article, you can train a decision tree model, visualize it, and evaluate its performance on a test dataset. The rpart.plot function provides various options to customize the tree plot, making it easier to interpret and present your results.
Similar Reads
How to Create Pie Chart Using Plotly in R The pie chart is a circular graphical representation of data that is divided into some slices based on the proportion of it present in the dataset. In R programming this pie chart can be drawn using Plot_ly() function which is present in the Plotly package. In this article, we are going to plot a pi
3 min read
Pre-processing and Modelling using Caret Package in R Pre-processing and modeling are important phases in the field of data science and machine learning that affect how well predictive models work. Classification and Regression Training, or the "caret" package in R, is a strong and adaptable tool intended to make training and assessing machine learning
5 min read
How to use Different Algorithms using Caret Package in R The caret (Classification And Regression Training) package in R provides a unified framework for training, tuning and evaluating a wide range of machine learning algorithms. Installing and Loading the caret PackageWe will install caret and load it along with any other necessary dependencies.Rinstall
3 min read
How to Use RWeka Package on a Dataset? The RWeka package in R provides a convenient interface to the powerful machine-learning algorithms offered by the Weka library. Weka is a widely used suite of machine learning software that contains a collection of tools for data preprocessing, classification, regression, clustering, and visualizati
5 min read
Stacked bar plot Using Plotly package in R In general, the bar plots are used to plot the categorical data. The stacked bar plot is a type of bar plot which is used to visualize the data effectively in the same bar by plotting them in a stacked manner. These are mostly used when one wants to summarize similar kinds of data by plotting a sing
4 min read
How to Make a Scatter Plot Matrix in R A scatterplot matrix is ââa grid of scatterplots that allows us to see how different pairs of variables are related to each other. We can easily generate a scatterplot matrix using the pairs() function in R programming. In this article, we will walk through the process of creating a scatterplot matr
6 min read