Classification Models with tidymodels
Classification predicts categorical outcomes. You might predict whether a customer will churn, if an email is spam, or whether a tumor is malignant. tidymodels provides a unified interface for all these tasks. This guide shows you how to build, evaluate, and compare classification models.
The Classification Workflow
The basic steps mirror the regression workflow. You split data, define a model, create a workflow, and evaluate. Classification adds a few important differences: you deal with class labels instead of continuous values, you measure performance with different metrics, and you often face class imbalance.
You will work with the Titanic dataset. It records survival outcomes along with passenger details. The goal is to predict whether a passenger survived.
Data Preparation
The Titanic dataset has missing values and categorical variables that need processing. Load the data and create a recipe:
library(tidymodels)
library(titanic)
data(titanic_train, package = "titanic")
titanic <- titanic_train |>
select(Survived, Pclass, Sex, Age, SibSp, Parch, Fare, Embarked) |>
mutate(Survived = factor(Survived, c("0", "1")))
titanic_rec <- recipe(Survived ~ ., data = titanic) |>
step_impute_mode(Embarked) |>
step_impute_median(Age) |>
step_dummy(all_nominal_predictors()) |>
step_zv(all_predictors())
This recipe imputes missing values with the mode for categorical columns and the median for numeric ones. It creates dummy variables for categorical predictors and removes zero-variance predictors.
Now split the data:
set.seed(123)
titanic_split <- initial_split(titanic, strata = Survived, prop = 0.8)
titanic_train <- training(titanic_split)
titanic_test <- testing(titanic_split)
The strata argument ensures both sets have similar survival proportions. This matters for classification with imbalanced classes.
Logistic Regression
Start with logistic regression. It is interpretable and serves as a baseline:
log_spec <- logistic_reg() |>
set_engine("glm") |>
set_mode("classification")
log_wf <- workflow() |>
add_recipe(titanic_rec) |>
add_model(log_spec)
log_fit <- fit(log_wf, titanic_train)
Fit on training data and predict on test data:
predictions <- predict(log_fit, titanic_test, type = "class")
Confusion Matrix
A confusion matrix shows how your predictions compare to actual outcomes:
confusion <- predictions |>
bind_cols(titanic_test |> select(Survived)) |>
conf_mat(truth = Survived, estimate = .pred_class)
confusion
The output shows true negatives, false positives, false negatives, and true positives. Each row represents the predicted class, each column the actual class.
Extract accuracy directly:
accuracy <- predictions |>
bind_cols(titanic_test |> select(Survived)) |>
accuracy(truth = Survived, estimate = .pred_class)
accuracy
Decision Trees
Decision trees split data based on feature values. They are easy to interpret but often overfit. Create a tree specification:
tree_spec <- decision_tree(cost_complexity = 0.01) |>
set_engine("rpart") |>
set_mode("classification")
The cost_complexity parameter controls tree depth. Higher values produce simpler trees.
Random Forests
Random forests combine many decision trees. Each tree sees a random subset of features and data. The final prediction aggregates all trees:
rf_spec <- rand_forest(mtry = 3, trees = 200) |>
set_engine("ranger") |>
set_mode("classification")
rf_wf <- workflow() |>
add_recipe(titanic_rec) |>
add_model(rf_spec)
rf_fit <- fit(rf_wf, titanic_train)
The trees argument sets how many trees to grow. The mtry argument sets how many features each tree considers at each split.
Class Probability Predictions
For many applications you need probabilities instead of hard class predictions. Generate them:
prob_predictions <- predict(rf_fit, titanic_test, type = "prob")
head(prob_predictions)
Each tibble contains probability columns for each class. The column names match the factor levels.
ROC Curves
An ROC curve plots the true positive rate against the false positive rate at every threshold:
roc_curve <- prob_predictions |>
bind_cols(titanic_test |> select(Survived)) |>
roc_curve(truth = Survived, .pred_1)
autoplot(roc_curve)
The curve shows how sensitivity and specificity change as you move the classification threshold. A perfect classifier hugs the top-left corner.
Calculate the area under the curve:
auc <- prob_predictions |>
bind_cols(titanic_test |> select(Survived)) |>
roc_auc(truth = Survived, .pred_1)
auc
An AUC of 0.5 means random guessing. A value of 1.0 means perfect classification.
Multiple Metrics
Evaluate many metrics at once:
multi_metrics <- metric_set(accuracy, sensitivity, specificity, f_meas)
predictions |>
bind_cols(titanic_test |> select(Survived)) |>
multi_metrics(truth = Survived, estimate = .pred_class)
The sensitivity function calculates true positive rate. The specificity function calculates true negative rate. The f_meas function balances precision and recall.
Cross-Validation
Use cross-validation to estimate how your model will perform on new data:
folds <- vfold_cv(titanic_train, v = 5, strata = Survived)
rf_res <- rf_wf |>
fit_resamples(folds)
collect_metrics(rf_res)
This evaluates the random forest on five different train/validation splits. The metrics show mean performance and standard error across folds.
Handling Class Imbalance
The Titanic data has more survivors than victims. Other datasets might have more extreme imbalances. Use class weights:
rf_balanced <- rand_forest(mtry = 3, trees = 200) |>
set_engine("ranger") |>
set_mode("classification") |>
This tells the model to penalize errors on the minority class more heavily.
Summary
You now know how to build classification models with tidymodels. The key steps are:
- Prepare data with a recipe that handles missing values and categorical variables
- Split data with
initial_split(), optionally using stratification - Define a model with
parsnip— try logistic regression, decision trees, random forests, or gradient boosting - Create a workflow combining recipe and model
- Evaluate with
predict()on test data orfit_resamples()for cross-validation - Measure performance with accuracy, confusion matrices, ROC curves, and AUC
Classification with tidymodels follows the same patterns as regression. The main differences are the metrics you use and the attention you pay to class imbalance. Apply these techniques to your own classification problems.
See Also
- dplyr Data Wrangling — Data preprocessing techniques
- purrr Functional Programming — Functional iteration patterns
- R Memory Management — Optimizing large datasets