An example of fitting multiclass classification models using
tidymodels, a stacking ensemble using the
stacks package and using ensModelVis for
visualising the fitted models including a majority vote ensemble.
Packages we will need:
Dataset: iris data, split the training data, scale.
data(iris)
train <- iris |> rename(Response = Species) |> relocate(Response)
set.seed(1979)
tr <- initial_split(train, prop = .5, strata = Response)
train_data <- training(tr)
test_data <- testing(tr)
mn <- apply(train_data |> select(-Response), 2, mean)
sd <- apply(train_data |> select(-Response), 2, sd)
train_data[,-c(1)] <- sweep(train_data[,-c(1)] ,2, mn, "-")
train_data[,-c(1)] <- sweep(train_data[,-c(1)] ,2, sd, "/")
test_data[,-c(1)] <- sweep(test_data[,-c(1)] ,2, mn, "-")
test_data[,-c(1)] <- sweep(test_data[,-c(1)] ,2, sd, "/")Set the recipe for stacks
spec_rec <- recipe(Response ~ ., data = train_data)
spec_wflow <-
workflow() |>
add_recipe(spec_rec)
ctrl_grid <- control_stack_grid()
ctrl_res <- control_stack_resamples()
folds <- train_data |> vfold_cv(v = 10, strata = Response)Try some classifiers:
nnet_mod <-
mlp(hidden_units = tune(),
penalty = tune(),
epochs = tune()
) |>
set_mode("classification") |>
set_engine("nnet")
nnet_wf <-
spec_wflow |>
add_model(nnet_mod)
nnet_res <-
nnet_wf |>
tune_grid(
resamples = folds,
grid = 10,
control = ctrl_grid
)
# ===================================
lasso_reg_grid <- tibble(penalty = 10^seq(-8, -1, length.out = 10))
en_mod <-
multinom_reg(penalty = tune(),
mixture = 0.5) |>
set_engine("glmnet") |>
set_mode("classification")
en_wf <-
spec_wflow |>
add_model(en_mod)
en_res <-
en_wf |>
tune_grid(
resamples = folds,
grid = lasso_reg_grid,
control = ctrl_grid
)
# ===================================
lda_mod <-
discrim_linear(
) |>
set_engine("MASS") |>
set_mode("classification")
lda_wf <- spec_wflow |>
add_model(lda_mod)
lda_res <-
fit_resamples(
lda_wf,
resamples = folds,
control = ctrl_res
)
# ==================================
rf_mod <-
rand_forest(
mtry = floor(sqrt(ncol(train) - 1)),
trees = 500
) |>
set_engine("ranger") |>
set_mode("classification")
rf_wf <- spec_wflow |>
add_model(rf_mod)
rf_res <-
rf_wf |>
fit_resamples(
resamples = folds,
control = ctrl_res
)model_st <-
stacks() |>
add_candidates(lda_res) |>
add_candidates(nnet_res) |>
add_candidates(rf_res) |>
add_candidates(en_res) |>
blend_predictions() |>
fit_members()
#> Warning: Predictions from 15 candidates were identical to those from existing candidates
#> and were removed from the data stack.Predict with new data: class and probability.
select <- dplyr::select
ens_pred <-
test_data |>
select(Response) |>
bind_cols(
predict(
model_st,
test_data,
type = "class",
members = TRUE
)
)
ens_prob <-
test_data |>
select(Response) |>
bind_cols(
predict(
model_st,
test_data,
type = "prob",
members = TRUE
)
)Rename to get nicer graphs.
names(ens_pred) <- str_remove(names(ens_pred), ".pred_class_")
names(ens_pred) <- str_remove(names(ens_pred), "respre0_mod0")
names(ens_pred) <- str_remove(names(ens_pred), "_post0")
ens_pred <- ens_pred |> rename(stack = .pred_class)
names(ens_prob) <- str_remove(names(ens_prob), ".pred_")
names(ens_prob) <- str_remove(names(ens_prob), "respre0_mod0")
names(ens_prob) <- str_remove(names(ens_prob), "_post0")
names(ens_prob)[2:4] <- str_c(names(ens_prob)[2:4], "_stack")Calculate AUC
auc <- ens_prob |>
mutate(id = 1:nrow(ens_prob)) |>
pivot_longer(-c(Response, id)) |>
mutate(type = substr(name, 1, 3),
name = str_remove(name,"setosa_"),
name = str_remove(name,"versicolor_"),
name = str_remove(name,"virginica_")) |>
pivot_wider(names_from = type, values_from = value) |>
group_by(name) |>
roc_auc(truth = Response,
set:vir)
auc <- auc |> select(name, .estimate) |> pivot_wider(names_from = name, values_from = .estimate)Only take the probability of the most likely class:
ens_prob <- ens_prob |>
mutate(id = 1:nrow(ens_prob)) |>
pivot_longer(-c(Response, id)) |>
mutate(type = substr(name, 1, 3),
name = str_remove(name,"setosa_"),
name = str_remove(name,"versicolor_"),
name = str_remove(name,"virginica_")) |>
group_by(id, name) |>
summarise(valuemax = max(value)) |>
ungroup() |>
pivot_wider(id_cols = id,
names_from = name, values_from = valuemax) |>
select(-id)
#> `summarise()` has grouped output by 'id'. You can override using the `.groups`
#> argument.Make a plot:
This is a heatmap of the predictions of the different models (x - axis)
for each observation in the test set (y - axis). The colors indicate the
predicted class for each observation by each model.
The first column is the true class (target) of each observation. The following four columns represent predictions of different model in the ensemble and the stacking ensemble predictions. The columns are ordered by the accuracy of each model (from left to right, most to least accurate).
We can see that all models correctly predict over 90 % of the
observations. Let’s zoom in and only look at miss-classified
observations. We can do this by setting the incorrect
argument to TRUE:
Note that this argument removes all but a single observation per class for the observations that were correctly classified by all models. All models correctly classified all ‘setosa’ observations.
Add transparency by probability of the predicted class:
plot_ensemble(
ens_pred |> pull(Response),
ens_pred |> select(-Response),
tibble_prob = ens_prob,
incorrect = TRUE
)Reorder the models (columns) by AUC rather than accuracy:
auc <- auc[,
names(ens_pred |>
select(-Response))]
plot_ensemble(
ens_pred |>
pull(Response),
ens_pred |> select(-Response),
tibble_prob = ens_prob,
order = auc
)Add another model to our ensemble, in this case majority vote :
maj_vote <- apply(ens_pred |>
select(-Response),
1,
function(x)
names(which.max(table(x))))
ens_pred <- ens_pred |> mutate(maj_vote = as.factor(maj_vote))
plot_ensemble(ens_pred |> pull(Response), ens_pred |> select(-Response))Add probability of majority vote:
prob_maj_vote <-
apply(ens_pred |> select(-Response), 1, function(x)
max(table(x)) / length(x))
ens_prob <- ens_prob |>
mutate(maj_vote = prob_maj_vote)
plot_ensemble(ens_pred |> pull(Response),
ens_pred |> select(-Response),
tibble_prob = ens_prob)
plot_ensemble(
ens_pred |> pull(Response),
ens_pred |> select(-Response),
tibble_prob = ens_prob,
incorrect = TRUE
)