rguides

Caret Classification in R: Train, Tune, and Evaluate Models

Caret classification gives R users a single interface to hundreds of machine learning algorithms. Whether you are building a logistic regression, a random forest, or a support vector machine, caret (Classification And REgression Training) standardizes the workflow: split data, train with train(), tune hyperparameters with trainControl(), and evaluate with confusion matrices and ROC curves. This tutorial walks you through each stage on a real dataset.

What you’ll learn

This tutorial covers the key concepts and practical techniques for working with Classification with caret. By the end, you will know how to apply the core functions in real data analysis workflows.

Installing and loading caret

First, install and load the package:

# Install caret (once)
install.packages("caret")

# Load the package
library(caret)

Preparing your data

Let’s use the classic Titanic dataset to demonstrate classification:

# Load Titanic data (from the titanic package)
library(titanic)
data(titanic_train)

# Quick look at the data
str(titanic_train)

The caret package expects your data in a data frame with the target variable as a factor:

# Convert Survived to factor
titanic_train$Survived <- factor(titanic_train$Survived, 
                                  levels = c(0, 1), 
                                  labels = c("No", "Yes"))

# Check the distribution
table(titanic_train$Survived)

Splitting data into training and test sets

Before training, split your data to evaluate performance:

# Set seed for reproducibility
set.seed(42)

# Create train/test split (70/30)
train_index <- createDataPartition(titanic_train$Survived, p = 0.7, list = FALSE)

train_data <- titanic_train[train_index, ]
test_data <- titanic_train[-train_index, ]

cat("Training set:", nrow(train_data), "rows\n")
cat("Test set:", nrow(test_data), "rows\n")

Training a classification model

Caret provides a consistent interface for model training:

# Train a logistic regression model
logistic_model <- train(
  Survived ~ Pclass + Sex + Age + Fare,
  data = train_data,
  method = "glm",
  family = "binomial"
)

# View model details
print(logistic_model)

Training with different algorithms

Caret supports many classifiers. Here are common ones:

Decision tree

# Decision Tree
tree_model <- train(
  Survived ~ Pclass + Sex + Age + Fare,
  data = train_data,
  method = "rpart"
)

print(tree_model)

Random forest

# Random Forest (may take a moment)
rf_model <- train(
  Survived ~ Pclass + Sex + Age + Fare,
  data = train_data,
  method = "rf",
  ntree = 100
)

print(rf_model)

Making predictions

Use the trained model to predict on new data:

# Predict on test set
predictions <- predict(rf_model, newdata = test_data)

# View first few predictions
head(predictions)

Evaluating model performance

Caret provides various metrics for classification:

# Confusion matrix
confusionMatrix(predictions, test_data$Survived)

This outputs:

  • Accuracy - proportion of correct predictions
  • Sensitivity - true positive rate
  • Specificity - true negative rate
  • Kappa - agreement measure accounting for chance

Comparing multiple models

Use resampling to compare models fairly:

# Define models to compare
model_list <- train(
  Survived ~ Pclass + Sex + Age + Fare,
  data = train_data,
  trControl = trainControl(method = "cv", number = 5),
  metric = "Accuracy"
)

# Compare
resamples(model_list)

Feature importance

Understand which features matter most:

# Get variable importance
var_importance <- varImp(rf_model)
plot(var_importance)

Summary

The caret package simplifies classification in R by providing a consistent interface to hundreds of algorithms. Instead of learning the unique syntax for each package like randomForest, glm, e1071, or rpart, caret lets you use a single train() function for all of them.

The caret package simplifies classification in R:

  1. Unified interface - Same syntax for 200+ algorithms
  2. Built-in preprocessing - Handle missing values, scale features
  3. Cross-validation - Reliable model evaluation
  4. Hyperparameter tuning - Automatic grid search
  5. Performance metrics - Confusion matrices, ROC curves, and more

With these fundamentals, you can tackle any classification problem in R using caret. Whether you’re predicting customer churn, diagnosing diseases, or identifying spam emails, the workflow remains the same: prepare your data, train a model, evaluate performance, and iterate until you achieve satisfactory results.

Training with caret

caret::train() fits a model with cross-validation. The method argument selects the algorithm: "rf" for random forest, "xgbTree" for XGBoost, "glm" for logistic regression, "svmRadial" for SVM. trControl = trainControl(method = "cv", number = 10) sets 10-fold cross-validation. tuneLength = 5 tests 5 values for each hyperparameter automatically.

Preprocessing in caret

preProcess(train_data, method = c("center", "scale")) standardizes numeric features. Include preprocessing in the pipeline: train(outcome ~ ., data = train, method = "rf", preProcess = c("center", "scale")) applies the same transformation to both training and test data automatically. method = "medianImpute" fills missing values with column medians. method = "pca" applies PCA dimensionality reduction.

Model evaluation

confusionMatrix(predictions, actual) computes accuracy, precision, recall, F1, and other metrics for classification. For multi-class problems, mode = "everything" returns per-class metrics. varImp(model) extracts variable importance scores. For calibrated probability estimates, calibration(actual ~ predicted) from the caret package plots calibration curves.

Comparing models

Train multiple models with the same trainControl to ensure they use identical cross-validation folds for fair comparison. resamples(list(rf = rf_model, glm = glm_model)) collects CV results. summary(resamples_obj) shows metric statistics per model. dotplot(resamples_obj) visualizes the comparison. diff(resamples_obj) computes pairwise statistical tests to assess whether differences are significant.

Data preprocessing

caret::preProcess() computes preprocessing parameters on training data. Applied transformations: "center" subtracts the mean; "scale" divides by the standard deviation; "pca" reduces dimensionality. Pass the preprocessing object to predict(preproc, newdata) to apply the same transformation to new data. Including preprocessing in train() via the preProcess argument ensures identical transformation for cross-validation folds.

The caret workflow

caret (Classification And REgression Training) provides a unified interface for training, tuning, and evaluating hundreds of machine learning models in R. The core function is train(), which accepts a formula or x/y matrix, a method name, training control options, and tuning parameters.

Data preparation happens before train(). Split data with createDataPartition(y, p = 0.8, list = FALSE) for a stratified split that preserves class proportions. preProcess() computes transformation parameters (centering, scaling, PCA) from training data, and predict(preProc, data) applies the same transformation to test data, crucial for preventing data leakage.

Model training and cross-Validation

ctrl <- trainControl(
  method = "cv",
  number = 10,
  classProbs = TRUE,
  summaryFunction = twoClassSummary
)

model <- train(
  diagnosis ~ .,
  data = train_df,
  method = "rf",
  trControl = ctrl,
  metric = "ROC"
)

trainControl() defines the resampling strategy. method = "cv" with number = 10 uses 10-fold cross-validation. method = "repeatedcv" with repeats = 3 repeats this process three times for more stable estimates. method = "boot" uses bootstrap resampling.

classProbs = TRUE requests class probability estimates for ROC curve computation. summaryFunction = twoClassSummary computes ROC AUC, sensitivity, and specificity as evaluation metrics, which requires two-class problems and class probability support.

Hyperparameter tuning

caret searches over a grid of hyperparameter values. tuneGrid = expand.grid(mtry = c(2, 4, 6, 8)) for random forests evaluates four values of mtry. The value with the best cross-validated metric is selected. model$bestTune shows the winning parameters.

For broad exploration, tuneLength = 10 tells caret to automatically generate a grid of 10 values for each tunable parameter. This works without knowing the parameter names — caret selects appropriate ranges.

Random search over a larger grid: search = "random" in trainControl() samples tuning combinations randomly rather than exhaustively, useful when the hyperparameter space has many dimensions.

Evaluating models

confusionMatrix(predictions, actual) returns accuracy, sensitivity, specificity, and kappa. For multi-class problems, it also returns per-class statistics. predict(model, newdata) returns class predictions; predict(model, newdata, type = "prob") returns class probabilities.

pROC::roc(actual, prob_positive) computes an ROC curve. pROC::auc() extracts the AUC. plot(roc_obj) draws the curve. For comparing models, pROC::roc.test(roc1, roc2) tests whether two AUC values are significantly different.

resamples(list(model1 = m1, model2 = m2)) collects cross-validated metrics from multiple models for comparison. dotplot(resamples_obj) visualizes the distribution of metrics across folds.

Feature importance and variable selection

varImp(model) extracts variable importance scores. The scale varies by algorithm — random forests use mean decrease in Gini impurity; linear models use the absolute value of standardized coefficients. plot(varImp(model), top = 20) shows the top 20 variables.

Recursive feature elimination (rfe()) wraps the training process to progressively remove low-importance features. Specify a sizes vector: rfe(x, y, sizes = c(5, 10, 20, 50), rfeControl = rfeControl(functions = rfFuncs)). The function reports the optimal number of variables. This is computationally expensive but useful when you have hundreds of features.

Alternatives to caret

tidymodels is the current recommended replacement for new projects. It separates the concerns of preprocessing (recipes), model specification (parsnip), resampling (rsample), and tuning (tune) into composable packages. The interface is tidyverse-consistent with pipe-friendly APIs. caret remains widely used in existing code and books but receives minimal new development.

The unified interface value

Machine learning in R suffers from inconsistent interfaces. Different model packages use different argument names, different prediction interfaces, and different resampling methods. caret’s core value is a single train function that calls over two hundred model packages with a consistent interface. You specify the model type with a method string, and caret handles the package-specific details. Switching from a random forest to a gradient boosted tree means changing one argument, not rewriting the entire modeling pipeline.

This abstraction has costs. caret adds a layer between you and the underlying model package, which can obscure what is actually being called. When a model fails or produces unexpected results, debugging requires understanding both caret’s interface and the underlying package’s behavior. For production use of a single model type, using the underlying package directly often gives more control. caret’s value is in exploration and comparison across model types.

Training and resampling

The train function combines model fitting with resampling-based performance estimation. The trControl argument accepts a trainControl object that specifies the resampling strategy — k-fold cross-validation, repeated cross-validation, bootstrap, or leave-one-out. The resampling runs automatically as part of the training call, and the result includes both the final model fit on all training data and the cross-validated performance estimates.

Cross-validated performance estimates are more reliable than single train-test split estimates, especially for small datasets. With ten-fold cross-validation, the model is trained ten times on 90% of the data and evaluated on the remaining 10%, and the performance metrics are averaged across all ten evaluations. Repeated cross-validation runs this process multiple times with different folds, further reducing variance in the estimates.

Preprocessing integration

caret integrates preprocessing into the modeling pipeline through the preProcess argument to train. Specifying centering and scaling transforms numeric predictors before model fitting and automatically applies the same transformation to new data at prediction time. This integration prevents the common mistake of fitting a scaler on training data and forgetting to apply it to test data.

Other preprocessing options include imputation of missing values, near-zero variance removal, and principal component analysis. These are computed from the training data and applied consistently to any data passed to the predict function. When cross-validating, preprocessing is recomputed separately for each fold to avoid data leakage — the scaler parameters for fold one are computed only from the fold one training data, not from the held-out validation data.

See also

Next steps

Now that you understand classification with caret, explore these related topics to deepen your knowledge and apply these techniques in more complex scenarios.