ETC3250/5250 Tutorial 6
+Trees and forests
+Load the libraries and avoid conflicts
+# Load libraries used everywhere
+library(tidyverse)
+library(tidymodels)
+library(patchwork)
+library(mulgar)
+library(palmerpenguins)
+library(GGally)
+library(tourr)
+library(MASS)
+library(discrim)
+library(classifly)
+library(detourr)
+library(crosstalk)
+library(plotly)
+library(viridis)
+library(colorspace)
+library(randomForest)
+library(geozoo)
+library(ggbeeswarm)
+library(conflicted)
+conflicts_prefer(dplyr::filter)
+conflicts_prefer(dplyr::select)
+conflicts_prefer(dplyr::slice)
+conflicts_prefer(palmerpenguins::penguins)
+conflicts_prefer(viridis::viridis_pal)
+
+options(digits=2)
+<- penguins |>
+ p_tidy select(species, bill_length_mm:body_mass_g) |>
+ rename(bl=bill_length_mm,
+ bd=bill_depth_mm,
+ fl=flipper_length_mm,
+ bm=body_mass_g) |>
+ filter(!is.na(bl)) |>
+ arrange(species) |>
+ na.omit()
+ <- p_tidy |>
+ p_tidy_std mutate_if(is.numeric, function(x) (x-mean(x))/sd(x))
๐ฏ Objectives
+The goal for this week is learn to fit, diagnose, assess assumptions, and predict from classification tree and random forest models.
+๐ง Preparation
+-
+
- Make sure you have all the necessary libraries installed. There are a few new ones this week! +
Exercises:
+Open your project for this unit called iml.Rproj
. For all the work we will use the penguins data. Start with splitting it into a training and test set, as follows.
set.seed(1156)
+<- p_tidy_std |>
+ p_sub filter(species != "Gentoo") |>
+ mutate(species = factor(species)) |>
+ select(species, bl, bm)
+ <- initial_split(p_sub, 2/3, strata = species)
+ p_split <- training(p_split)
+ p_tr <- testing(p_split) p_ts
1. Becoming a car mechanic - looking under the hood at the tree algoriithm
+-
+
- Write down the equation for the Gini measure of impurity, for two groups, and the parameter \(p\) which is the proportion of observations in class 1. Specify the domain of the function, and determine the value of \(p\) which gives the maximum value, and report what that maximum function value is. +
-
+
- For two groups, how would the impurity of a split be measured? Give the equation. +
-
+
- Below is an R function to compute the Gini impurity for a particular split on a single variable. Work through the code of the function, and document what each step does. Make sure to include a not on what the
minsplit
parameter, does to prevent splitting on the edges fewer than the specified number of observations.
+
# This works for two classes, and one variable
+<- function(p) {
+ mygini <- 0
+ g if (p>0 && p<1) {
+ <- 2*p*(1-p)
+ g
+ }
+return(g)
+
+ }
+<- function(x, spl, cl, minsplit=5) {
+ mysplit # Assumes x is sorted
+ # Count number of observations
+ <- length(x)
+ n
+ # Check number of classes
+ <- unique(cl)
+ cl_unique
+ # Split into two subsets on the given value
+ <- x[x<spl]
+ left <- cl[x<spl]
+ cl_left <- length(left)
+ n_l
+<- x[x>=spl]
+ right <- cl[x>=spl]
+ cl_right <- length(right)
+ n_r
+ # Don't calculate is either set is less than minsplit
+ if ((n_l < minsplit) | (n_r < minsplit))
+ = NA
+ impurity else {
+ # Compute the Gini value for the split
+ <- length(cl_left[cl_left == cl_unique[1]])/n_l
+ p_l <- length(cl_right[cl_right == cl_unique[1]])/n_r
+ p_r if (is.na(p_l)) p_l<-0.5
+ if (is.na(p_r)) p_r<-0.5
+ <- (n_l/n)*mygini(p_l) + (n_r/n)*mygini(p_r)
+ impurity
+ }return(impurity)
+ }
-
+
- Apply the function to compute the value for all possible splits for the body mass (
bm
), settingminsplit
to be 1, so that all possible splits will be evaluated. Make a plot of these values vs the variable.
+
-
+
- Use your function to compute the first two steps of a classification tree model for separating Adelie from Chinstrap penguins, after setting
minsplit
to be 5. Make a scatterplot of the two variables that would be used in the splits, with points coloured by species, and the splits as line segments.
+
Digging deeper into diagnosing an error
+-
+
- Fit the random forest model to the full penguins data. +
-
+
- Report the confusion matrix. +
-
+
- Use linked brushing to learn which was the Gentoo penguin that the model was confused about. When we looked at the data in a tour, there was one Gentoo penguin that was an outlier, appearing to be away from the other Gentoos and closer to the Chinstrap group. We would expect this to be the penguin that the forest model is confused about. Is it? +
Have a look at the other misclassifications, to understand whether they are ones weโd expect to misclassify, or whether the model is not well constructed.
+<- p_tr2 |>
+ p_cl mutate(pspecies = p_fit_rf$fit$predicted) |>
+ ::select(bl:bm, species, pspecies) |>
+ dplyrmutate(sp_jit = jitter(as.numeric(species)),
+ psp_jit = jitter(as.numeric(pspecies)))
+ <- SharedData$new(p_cl)
+ p_cl_shared
+<- detour(p_cl_shared, tour_aes(
+ detour_plot projection = bl:bm,
+ colour = species)) |>
+ tour_path(grand_tour(2),
+ max_bases=50, fps = 60) |>
+ show_scatter(alpha = 0.9, axes = FALSE,
+ width = "100%", height = "450px")
+
+<- plot_ly(p_cl_shared,
+ conf_mat x = ~psp_jit,
+ y = ~sp_jit,
+ color = ~species,
+ colors = viridis_pal(option = "D")(3),
+ height = 450) |>
+ highlight(on = "plotly_selected",
+ off = "plotly_doubleclick") |>
+ add_trace(type = "scatter",
+ mode = "markers")
+
+bscols(
+
+ detour_plot, conf_mat,widths = c(5, 6)
+ )
Deciding on variables in a large data problem
+-
+
- Fit a random forest to the bushfire data. You can read more about the bushfire data at https://dicook.github.io/mulgar_book/A2-data.html. Examine the votes matrix using a tour. What do you learn about the confusion between fire causes? +
This code might help:
+data(bushfires)
+
+<- bushfires[,c(5, 8:45, 48:55, 57:60)] |>
+ bushfires_sub mutate(cause = factor(cause))
+
+set.seed(1239)
+<- initial_split(bushfires_sub, 3/4, strata=cause)
+ bf_split <- training(bf_split)
+ bf_tr <- testing(bf_split)
+ bf_ts
+<- rand_forest(mtry=5, trees=1000) |>
+ rf_spec set_mode("classification") |>
+ set_engine("ranger", probability = TRUE,
+ importance="permutation")
+ <- rf_spec |>
+ bf_fit_rf fit(cause~., data = bf_tr)
+
+# Create votes matrix data
+<- bf_fit_rf$fit$predictions |>
+ bf_rf_votes as_tibble() |>
+ mutate(cause = bf_tr$cause)
+
+# Project 4D into 3D
+<- t(geozoo::f_helmert(4)[-1,])
+ proj <- as.matrix(bf_rf_votes[,1:4]) %*% proj
+ bf_rf_v_p colnames(bf_rf_v_p) <- c("x1", "x2", "x3")
+<- bf_rf_v_p |>
+ bf_rf_v_p as.data.frame() |>
+ mutate(cause = bf_tr$cause)
+
+ # Add simplex
+<- simplex(p=3)
+ simp <- data.frame(simp$points)
+ sp colnames(sp) <- c("x1", "x2", "x3")
+$cause = ""
+ sp<- bind_rows(sp, bf_rf_v_p) |>
+ bf_rf_v_p_s mutate(cause = factor(cause))
+ <- c("accident" , "arson",
+ labels "burning_off", "lightning",
+ rep("", nrow(bf_rf_v_p)))
# Examine votes matrix with bounding simplex
+animate_xy(bf_rf_v_p_s[,1:3], col = bf_rf_v_p_s$cause,
+axes = "off", half_range = 1.3,
+ edges = as.matrix(simp$edges),
+ obs_labels = labels)
-
+
- Check the variable importance. Plot the most important variables. +
This code might help:
+$fit$variable.importance |>
+ bf_fit_rfas_tibble() |>
+ rename(imp=value) |>
+ mutate(var = colnames(bf_tr)[1:50]) |>
+ select(var, imp) |>
+ arrange(desc(imp)) |>
+ print(n=50)
Can boosting better detect bushfire case?
+Fit a boosted tree model using xgboost
to the bushfires data. You can use the code below. Compute the confusion tables and the balanced accuracy for the test data for both the forest model and the boosted tree model, to make the comparison.
set.seed(121)
+<- boost_tree() |>
+ bf_spec2 set_mode("classification") |>
+ set_engine("xgboost")
+ <- bf_spec2 |>
+ bf_fit_bt fit(cause~., data = bf_tr)
๐ Finishing up
+Make sure you say thanks and good-bye to your tutor. This is a time to also report what you enjoyed and what you found difficult.
+ + +