<- new_students |>
new_students mutate(prediction = if_else(ACT < 20 & GPA < 3.25, "Fail", "Pass"))
Homework 8: Classification Trees
Introduction
In this homework we will make our classifications using classification trees which are a type of decision tree. A decision tree works by using splitting rules to divide up the predictor space. For numerical predictors, these rules take the form of linear separators. Linear separators are just linear inequalities of the predictors. For example, consider the problem of trying to predict whether a student will pass their first college mathematics course. The scatterplot below shows the mathematics ACT score and GPA for 57 students. Each student dot is colored by a black dot (0) if they failed their first college mathematics course and a red dot (1) if they passed.
The node at the top of the tree represents the linear separator \(ACT < 18.65\). If a student in the training data has \(ACT\) score less than 18.65, then they are sent to the left side of the tree. If a student has ACT score greater than 18.65, they are sent to the right side of the tree. Students on the right side of the tree are then further divided by the linear separator \(GPA < 2.86111\). Thus, every student is placed at one of the three leaves of the tree. The leaf labels represent the majority class at each leaf. For example, more than half of the students with \(ACT < 18.65\) failed the class, so the leaf at the far left of the tree is labeled by 0. To classify a new student, we apply the linear separators starting at the top of the tree to determine which leaf they belong to. The new student is then classified according to the label of that leaf.
Learning goals
In this assignment, you will…
- Fit and interpret classification trees
- Use grid-based techniques to choose tuning parameters
Getting Started
You are free to use whatever version control strategy you like.
Teams & Rules
You can find your team for this assignment on Canvas in the People section. The group set is called HW8. Your group will consist of 2-3 people and has been randomly generated. The GitHub assignment can be found here. Rules:
- You are all responsible for understanding the work that your team turns in.
- All team members must make roughly equal contributions to the homework.
- Any work completed by a team member must be committed and pushed to GitHub by that person.
Data: iris
We will be working with the famous iris
data set which consists of four measurements (in centimeters) for 150 plants belonging to three species of iris. This data set was first published in a classic 1936 paper by English statistician, and notable racist/eugenicist, Ronald Fisher. In that paper, multivariate linear models were applied to classify these plants. Of course, back then, model fitting was an extremely laborious process that was done without the aid of calculators or statistical software.
Exercise 1
Import the datasets
which contains the iris
data set and split the data using a 60-40 split. Be sure that each species is represented proportionally in the test and train sets. Use a seed of 427.
Exercise 2
Create a plot showing the response variable Species
. Comment on the relative frequency of each category and what impact the balance will have on modeling.
Exercise 3
Create a scatterplot of Petal.Length
versus Sepal.Width
colored by Species
using the training data`. Notice that these are different features than we used in the KNN hw.
Exercise 4
Based on your scatter plot, determine some simple linear separators that can be used to classify the observations in test
. I.e. build a small decision tree by hand. (Remember, you can use the logical operators <
, >
, &
(and), |
(or) to subset vectors in R.) For example, if we were given data for new students to classify based on the scatterplot in the introduction, we might predict whether they will pass or fail using simple linear separators as in the code below (though this won’t run, because we don’t have access to the data)
Exercise 5
To get a sense of how good your simple model predictions are, create a confusion matrix.
Hopefully, your simple linear separators were fairly successful at making predictions. Again, this is because the iris
data set just isn’t that challenging. For more challenging data sets, we will want R to search for the best linear separators and to build us a classification tree.
:::
Exercise 6
Using the training data, train and display a classification tree that predicts Species
as a function of the Sepal.Width
and Petal.Length
.
Exercise 7
Use this classification tree model to classify the species in the test set. Print the confusion matrix and the accuracy.
The classification tree in the above example is kind of boring. To get a better sense of what the decision boundaries from classification trees look like, we’ll try a more interesting example.
Consider the simulated data set below. We randomly generate points in the grid \([0,1] \times [0,1]\) and split these points into two classes (“1” and “2”) based on whether they are above or below the line \(y = x\). These points are stored in the data frame DFsim
.
library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr 1.1.4 ✔ readr 2.1.5
✔ forcats 1.0.0 ✔ stringr 1.5.1
✔ ggplot2 3.5.1 ✔ tibble 3.2.1
✔ lubridate 1.9.4 ✔ tidyr 1.3.1
✔ purrr 1.0.4
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
set.seed(1)
<- 300
ss <- runif(ss, min = 0, max = 1)
x1 <- runif(ss, min = 0, max = 1)
x2 <- rep(1 , ss)
class > x2] <- 2
class[x1 <- tibble(x1 = x1, x2 = x2, CL = as.factor(class) )
DFsim |>
DFsim ggplot(aes(x = x1, y = x2, col = CL)) +
geom_point()
:::
Exercise 8
Build a classification tree on DFsim
that predicts class
as a function of x1
and x2
. Now, apply your model to predict the class of each point in grid
, defined below. Plot the points in grid
colored by the predicted class.
# grid points from [0,1] x [0,1]
<- rep((1:100)*(1/100), 100)
g1 <- rep((1:100)*(1/100), each = 100)
g2 <- data.frame(x1 = g1
grid x2 = g2) ,