Text Classification in R
Text classification assigns predefined categories to text documents. Unlike sentiment analysis which assigns emotional scores, classification puts documents into discrete buckets—spam versus not spam, urgent versus normal, topic A versus topic B. This tutorial shows you how to build text classification models in R using the tidymodels framework.
Prerequisites
This is an advanced tutorial. You should be comfortable with:
- Text preprocessing with tidytext (tokenization, stop words)
- Basic R programming (functions, pipes, data frames)
- Machine learning concepts (training/test splits, accuracy)
If you need background, work through the earlier tutorials in this series first.
What You Will Build
By the end of this tutorial, you will have:
- A text preprocessing pipeline
- A document-term matrix (DTM) for machine learning
- Trained classifiers using tidymodels
- Model evaluation with confusion matrices
Installing Required Packages
install.packages("tidymodels")
install.packages("tidytext")
install.packages("textrecipes")
install.packages("discrim")
install.packages("glmnet")
Key packages:
- tidymodels: Unified interface for model training
- textrecipes: Text preprocessing steps for recipes
- discrim: Discriminant analysis models
- glmnet: Regularized regression (Lasso/Elastic Net)
Understanding Text Classification
Text classification is supervised learning—you need labeled examples. Common applications:
| Application | Categories |
|---|---|
| Spam detection | spam, not_spam |
| Topic labeling | sports, politics, technology, entertainment |
| Sentiment categorization | positive, negative, neutral |
| Intent detection | question, complaint, compliment |
The workflow: preprocess text → create features → train model → evaluate → predict.
Loading and Exploring Data
For this tutorial, use the spam dataset from the textrecipes package:
library(tidytext)
library(tidyverse)
library(tidymodels)
# Load spam data (comes with textrecipes)
data("spam", package = "textrecipes")
# Explore the data
glimpse(spam)
# Rows: 3,581
# Columns: 2 (text, type)
# Check class distribution
spam %>%
count(type)
# # A tibble: 2 × 2
# type n
# <fct> <1>
# ham 2,772
# spam 809
The dataset has 3,581 emails with imbalanced classes (more ham than spam).
Text Preprocessing Pipeline
Clean and standardize text before feature extraction:
# Basic text preprocessing
spam_clean <- spam %>%
mutate(
# Convert to lowercase
text = str_to_lower(text),
# Remove numbers
text = str_remove_all(text, "[0-9]+"),
# Remove URLs
text = str_remove_all(text, "http[^ ]*"),
# Remove extra whitespace
text = str_squish(text)
)
print(spam_clean)
Creating Train/Test Split
Always split data before building models:
# Set seed for reproducibility
set.seed(1234)
# Create train/test split (stratified by type)
spam_split <- spam_clean %>%
initial_split(strata = type, prop = 0.8)
spam_train <- training(spam_split)
spam_test <- testing(spam_split)
# Check split
dim(spam_train) # ~2,864
dim(spam_test) # ~717
Stratified sampling preserves the ham/spam ratio in both splits.
Feature Engineering: Document-Term Matrix
Convert text to numerical features using bag-of-words:
# Tokenize and count words
spam_tokens <- spam_train %>%
unnest_tokens(word, text) %>%
filter(!word %in% stop_words$word) %>%
count(id, word, sort = TRUE)
print(spam_tokens)
# # A tibble: ~30,000 × 3
# id word n
# <int> <chr> <int>
# 1 4304 call 1,234
# 2 4304 free 1,089
Creating the DTM
# Cast to document-term matrix
spam_dtm <- spam_tokens %>%
cast_dtm(document = id, term = word, value = n)
print(spam_dtm)
# <<DocumentTermMatrix (documents: 2864, terms: 6281)>>
The DTM has 2,864 documents and 6,281 terms—this is high-dimensional.
Building a Classification Model
Use tidymodels for a consistent interface:
Step 1: Define the Model
# Specify a logistic regression model
log_reg <- logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification")
print(log_reg)
Step 2: Create a Recipe
The recipe defines preprocessing steps:
# Create preprocessing recipe
spam_rec <- recipe(type ~ text, data = spam_train) %>%
step_tokenize(text) %>%
step_tokenfilter(text, max_tokens = 1000) %>%
step_tfidf(text)
print(spam_rec)
This tokenizes, keeps the top 1000 tokens, and applies TF-IDF weighting.
Step 3: Create a Workflow
Combine recipe and model:
# Create workflow
spam_wf <- workflow() %>%
add_recipe(spam_rec) %>%
add_model(log_reg)
print(spam_wf)
Step 4: Train the Model
# Fit the model
spam_fit <- spam_wf %>%
fit(data = spam_train)
print(spam_fit)
Evaluating the Model
Assess performance on the test set:
# Generate predictions
spam_pred <- spam_fit %>%
predict(spam_test) %>%
bind_cols(spam_test %>% select(type, text))
print(spam_pred)
Confusion Matrix
# Confusion matrix
spam_pred %>%
conf_mat(truth = type, estimate = .pred_class) %>%
autoplot(type = "heatmap")
Performance Metrics
# Calculate metrics
spam_pred %>%
metrics(truth = type, estimate = .pred_class)
# .metric .estimator .estimate
# 1 accuracy binary 0.923
# 2 kap binary 0.783
93% accuracy with basic logistic regression—not bad!
Class-Imbalanced Data
The dataset is imbalanced (more ham than spam). Adjust with class weights:
# Logistic regression with class weights
log_reg_balanced <- logistic_reg(
penalty = 0.1,
engine = "glm",
class_weight = "balanced"
) %>%
set_mode("classification")
# Refit with balanced weights
spam_wf_balanced <- spam_wf %>%
add_model(log_reg_balanced)
spam_fit_balanced <- spam_wf_balanced %>%
fit(data = spam_train)
# Evaluate
spam_pred_balanced <- spam_fit_balanced %>%
predict(spam_test) %>%
bind_cols(spam_test %>% select(type))
spam_pred_balanced %>%
metrics(truth = type, estimate = .pred_class)
Alternative Models
Try other classifiers:
Naive Bayes
library(discrim)
# Naive Bayes classifier
nb_spec <- naive_Bayes() %>%
set_mode("classification") %>%
set_engine("naivebayes")
nb_wf <- spam_wf %>%
update_model(nb_spec)
nb_fit <- nb_wf %>% fit(data = spam_train)
nb_pred <- nb_fit %>%
predict(spam_test) %>%
bind_cols(spam_test %>% select(type))
nb_pred %>%
metrics(truth = type, estimate = .pred_class)
Regularized Regression (Lasso)
# Lasso regression
lasso_spec <- logistic_reg(
penalty = 0.1,
mixture = 1
) %>%
set_engine("glmnet") %>%
set_mode("classification")
lasso_wf <- spam_wf %>%
update_model(lasso_spec)
lasso_fit <- lasso_wf %>% fit(data = spam_train)
lasso_pred <- lasso_fit %>%
predict(spam_test) %>%
bind_cols(spam_test %>% select(type))
lasso_pred %>%
metrics(truth = type, estimate = .pred_class)
Cross-Validation
For more robust evaluation, use k-fold cross-validation:
# Create 5-fold cross-validation
spam_folds <- vfold_cv(spam_train, v = 5, strata = type)
# Fit with cross-validation
spam_cv_results <- spam_wf %>%
fit_resamples(
spam_folds,
metrics = metric_set(accuracy, precision, recall, f_meas),
control = control_resamples(save_pred = TRUE)
)
# View results
spam_cv_results %>%
collect_metrics()
This gives you average performance across multiple train/val splits.
Hyperparameter Tuning
Optimize model parameters for better performance:
# Define tunable model
log_reg_tune <- logistic_reg(
penalty = tune(),
mixture = tune()
) %>%
set_engine("glmnet") %>%
set_mode("classification")
# Update workflow
spam_wf_tune <- spam_wf %>%
update_model(log_reg_tune)
# Grid search
spam_grid <- grid_regular(
penalty(),
mixture(),
levels = 5
)
# Tune
tune_results <- spam_wf_tune %>%
tune_grid(
spam_folds,
grid = spam_grid,
metrics = metric_set(accuracy)
)
# Best parameters
best_params <- tune_results %>%
select_best("accuracy")
print(best_params)
Feature Importance
Understand what the model learned:
# Extract model coefficients
spam_coefs <- spam_fit %>%
extract_fit_engine() %>%
tidy() %>%
filter(term != "(Intercept)") %>%
arrange(desc(abs(estimate)))
print(spam_coefs)
# # A tibble: 1,000 × 4
# term estimate yintercept penalty
# <chr> <dbl> <dbl> <dbl>
# 1 call 2.45 0 0.1
# 2 free 2.12 0 0.1
# 3 txt 1.89 0 0.1
# 4 ur 1.76 0 0.1
# Visualize top features
spam_coefs %>%
head(20) %>%
mutate(term = fct_reorder(term, estimate)) %>%
ggplot(aes(estimate, term, fill = estimate > 0)) +
geom_col() +
labs(
title = "Most Important Words for Spam Detection",
x = "Coefficient (positive = spam)",
y = NULL
) +
theme_minimal() +
scale_fill_manual(values = c("coral", "steelblue"), guide = "none")
Making Predictions on New Data
Use the trained model to predict new emails:
# New emails to classify
new_emails <- tibble(
text = c(
"Congratulations! You have won a free iPhone. Click here to claim your prize!",
"Hey, are we still meeting for lunch tomorrow?",
"URGENT: Your account has been compromised. Verify your password immediately."
)
)
# Predict
new_predictions <- spam_fit %>%
predict(new_emails) %>%
bind_cols(new_emails)
print(new_predictions)
# # A tibble: 3 × 2
# .pred_class .pred_h...
# 1 spam 0.998
# 2 ham 0.991
# 3 spam 0.723
What You Have Learned
| Step | Description |
|---|---|
| Text preprocessing | Clean and standardize text |
| Feature extraction | Create document-term matrix |
| Model training | Fit classifier with tidymodels |
| Evaluation | Confusion matrix and metrics |
| Tuning | Optimize hyperparameters |
Key Takeaways
- Text classification requires labeled training data
- TF-IDF weighting often improves performance
- Class imbalance requires special handling (weights, sampling)
- Cross-validation gives more reliable performance estimates
- Logistic regression works well as a baseline
See Also
- dplyr::filter — Filter rows by condition
- dplyr::count — Count words per document
Next Steps
Continue building your text mining skills:
- Text Regression — Predict continuous values from text
- Deep Learning with torch — Neural networks for text
- BERT embeddings — Modern transformer-based features