Classification with caret
Classification is a fundamental machine learning task where you predict categorical outcomes. The caret package (Classification And REgression Training) provides a unified interface to hundreds of classification algorithms in R.
Installing and Loading caret
First, install and load the package:
# Install caret (once)
install.packages("caret")
# Load the package
library(caret)
Preparing Your Data
Let’s use the classic Titanic dataset to demonstrate classification:
# Load Titanic data (from the titanic package)
library(titanic)
data(titanic_train)
# Quick look at the data
str(titanic_train)
The caret package expects your data in a data frame with the target variable as a factor:
# Convert Survived to factor
titanic_train$Survived <- factor(titanic_train$Survived,
levels = c(0, 1),
labels = c("No", "Yes"))
# Check the distribution
table(titanic_train$Survived)
Splitting Data into Training and Test Sets
Before training, split your data to evaluate performance:
# Set seed for reproducibility
set.seed(42)
# Create train/test split (70/30)
train_index <- createDataPartition(titanic_train$Survived, p = 0.7, list = FALSE)
train_data <- titanic_train[train_index, ]
test_data <- titanic_train[-train_index, ]
cat("Training set:", nrow(train_data), "rows\n")
cat("Test set:", nrow(test_data), "rows\n")
Training a Classification Model
Caret provides a consistent interface for model training:
# Train a logistic regression model
logistic_model <- train(
Survived ~ Pclass + Sex + Age + Fare,
data = train_data,
method = "glm",
family = "binomial"
)
# View model details
print(logistic_model)
Training with Different Algorithms
Caret supports many classifiers. Here are common ones:
Decision Tree
# Decision Tree
tree_model <- train(
Survived ~ Pclass + Sex + Age + Fare,
data = train_data,
method = "rpart"
)
print(tree_model)
Random Forest
# Random Forest (may take a moment)
rf_model <- train(
Survived ~ Pclass + Sex + Age + Fare,
data = train_data,
method = "rf",
ntree = 100
)
print(rf_model)
Making Predictions
Use the trained model to predict on new data:
# Predict on test set
predictions <- predict(rf_model, newdata = test_data)
# View first few predictions
head(predictions)
Evaluating Model Performance
Caret provides various metrics for classification:
# Confusion matrix
confusionMatrix(predictions, test_data$Survived)
This outputs:
- Accuracy - proportion of correct predictions
- Sensitivity - true positive rate
- Specificity - true negative rate
- Kappa - agreement measure accounting for chance
Hyperparameter Tuning
Many models have hyperparameters to tune. Caret uses cross-validation:
# Train with 5-fold cross-validation
rf_tuned <- train(
Survived ~ Pclass + Sex + Age + Fare,
data = train_data,
method = "rf",
trControl = trainControl(method = "cv", number = 5),
tuneGrid = data.frame(mtry = c(2, 3, 4))
)
print(rf_tuned)
# Best parameters
rf_tuned$bestTune
Comparing Multiple Models
Use resampling to compare models fairly:
# Define models to compare
model_list <- train(
Survived ~ Pclass + Sex + Age + Fare,
data = train_data,
trControl = trainControl(method = "cv", number = 5),
metric = "Accuracy"
)
# Compare
resamples(model_list)
Feature Importance
Understand which features matter most:
# Get variable importance
var_importance <- varImp(rf_model)
plot(var_importance)
Summary
The caret package simplifies classification in R by providing a consistent interface to hundreds of algorithms. Instead of learning the unique syntax for each package like randomForest, glm, e1071, or rpart, caret lets you use a single train() function for all of them.
The caret package simplifies classification in R:
- Unified interface - Same syntax for 200+ algorithms
- Built-in preprocessing - Handle missing values, scale features
- Cross-validation - Robust model evaluation
- Hyperparameter tuning - Automatic grid search
- Performance metrics - Confusion matrices, ROC curves, and more
With these fundamentals, you can tackle any classification problem in R using caret. Whether you’re predicting customer churn, diagnosing diseases, or identifying spam emails, the workflow remains the same: prepare your data, train a model, evaluate performance, and iterate until you achieve satisfactory results.