More Classification
Support Vector Machine, Kernels, more classification metrics
STOR 390
This lecture
- maximal margin classifier (MM)
- AKA hard margin support vector machine
- support vector machine (SVM)
- kernels
- other classification metrics
packages
library(e1071) # SVM
library(caret) # tuning
## Loading required package: lattice
## Loading required package: ggplot2
library(kernlab) # Kernel SVM
##
## Attaching package: 'kernlab'
## The following object is masked from 'package:ggplot2':
##
## alpha
Binary classes
Linear classifiers
- nerest centroid
- split the two classes by a line (hyperplane)
linearly separable data
Which separating hyperplane?
The Margin
Call the distance from the separating hyperplane to the nearest data point the margin.
The Margin
Maximal Margin classifier (intuition)
- data should be as far away from the separating hyperplane as possible.
Maximal Margin classifier (words)
- The separating hyperplane that maximizes the margin.
- Maximizes the minimum distance from the data points to the separating hyperplane.
- Warning: only defined when the data are linearly separable.
- See ISLR chapter 9 for details.
Non-linearly separable data
Soft margin SVM (intuition)
A good linear classifier should aim to put points
- on the correct side of the separating hyperplane far away from the separating hyperplane
- on the wrong side of the separating hyperplane close to the separating hyperplane
Soft margin SVM
- Allow points to be on the wrong side of the separating hyperplane, but penalize them.
- Keep as many points on the correct side of the separating hyperplane as possible.
Two competing objectives
- maximize the margin
- penalize disobedient points
Balance competing objectives
Tuning parameter!
- SVM comes with a tuning parameter \(C >0\) that controls the balance between the two competing objectives.
- Larger values of \(C\) make SVM care more about “bad” points
- Smaller values \(C\) mean SVM has more chill about misclassified points
C large
C moderate
C small
Fitting SVM
- cannot be done in closed form
- requires a numerical optimization algorithm (quadratic programming)
e1071 package
- open source implementation of SVM
- see documentation
?svm()
Some training data
# this function comes from the synthetic_distributions.R package
train <- two_class_guasssian_meatballs(n_pos=200, n_neg=200,
mu_pos=c(1,0), mu_neg=c(-1,0),
sigma_pos=diag(2), sigma_neg=diag(2),
seed=103)
train
## # A tibble: 400 × 3
## x1 x2 y
## <dbl> <dbl> <fctr>
## 1 2.15074931 -0.4792060 1
## 2 -0.08786829 1.6166332 1
## 3 1.76506913 -0.1514896 1
## 4 0.61865671 1.2622023 1
## 5 1.07803836 -0.5458768 1
## 6 0.85736662 0.3856852 1
## 7 1.75866283 0.3809878 1
## 8 2.21017318 -2.4745634 1
## 9 0.96406375 0.3872157 1
## 10 -0.60946315 1.7316120 1
## # ... with 390 more rows
SVM code
# fit SVM
svmfit <- svm(y ~ ., # R's formula notation
data=train, # data frame to use
cost=10, # set the tuning paramter
scale=FALSE,
type='C-classification',
shrinking=FALSE,
kernel='linear')
main arguments
data=train
says fit SVM using the data stored in the train
data frame.
The svm()
function uses R’s formula notation. Recall from linear regression y ~ .
means fit y
on all the rest of the columns of data. We could have equivalently used y ~ x1 + x2
.
cost = 10
fixes the tuning parameter \(C\) to 10. The tuning parameter \(C\) is also sometimes called a cost parameter.
shrinking=FALSE
I’m not sure what this does, but I don’t want anything extra to happen so I told it to stop.
Other arguments
scale = FALSE
says please do not center and scale our data. svm()
applies some preprocessing to the data by default. While preprocessing (e.g. center and scale) is often a good thing to do, I strongly disagree with making this the default behavior.
type='C-classification'
tells svm()
to do classification. It turns out SVM can be used to do other things than classification](http://kernelsvm.tripod.com/).
kernel='linear'
says do linear SVM. The svm()
function can do kernel SVM (discussed below).
Predictions
# this is equivalent to svmfit$fitted
train_predictions <- predict(svmfit, newdata = train)
train_predictions[1:5]
## 1 2 3 4 5
## 1 -1 1 1 1
## Levels: -1 1
Some kind soul took the time to code
- a good implementation of SVM in C and then release it to the public
- a package in R (and Python and many other languages) so that us data scientists don’t have to learn C to aforementioned C implementation of SVM
Saves time and money!
Non-linear classifiers
- sometimes linear doesn’t cut it (e.g. the Boston Cream doughnut)
Linear classifier
Linear classifier
Why not keep going?
Two issues come up when we add more and more non-linear variables
- overfitting
- computational cost
Kernels
- computational trick that turns a linear classifer into a non-linear classifier
- see section 9.3.2 from ISLR
Key idea
- Many algorithms (such as SVM) rely only on the distance between each pair of data points.
- A kernel is a function \(K(a, b)\) that computes the similarity between two things.
Consequences
- If we can compute the distance between points cheaply then we can fit SVM more quickly.
- If we have the ability to compute a “distance” between pairs of objects then we can use SVM.
Upshot of kernels
kernels = easier to compute non-linear version of SVM.
Non-standard data
Anything where we can compute a “similarity function”
Polynomial kernel
- \(n\) data points \(x_1, \dots, x_n\)
- \(d\) variables (i.e. \(x_i \in \mathbb{R}^d\)).
- degree \(m\) polynomial kernel is defined as
\[K(a, b) = (a^T b + 1)^m\]
Might have more parameters i.e. \(K(a, b) = (\gamma a^T b + c)^m\).
Kernel’s reduce computational complexity
- Explicitly all quadratic terms
- parwise distance = \(O(d^2)\)
- Using a degree 2 polynomial kernel
- parwise distance = \(O(d)\)
Kernel SVM in R
# svm() is from the e1071 package
svm_kern2 <- svm(y ~ .,
data=train,
cost=10,
kernel='polynomial', # use a polynomial kernel
degree = 2, # degree two polynomial
gamma=1, # other kernel parameters
coef0 =1, # other kernel parameters
scale=FALSE,
type='C-classification',
shrinking=FALSE)
kern2_predictions <- predict(svm_kern2, newdata = test_grid)
Kernel SVM in R
Degree 100 polynomial kernel
Common kernels
- Polynomial kernel
- Radial basis (or Gaussian kernel) \[K(a, b) = e^{\frac{1}{\sigma}||a - b||^2}\]
Cross-validation grid search
SVM with a polynomial kernel has 2 parameter: \(C\) (for SVM) and \(d\) (degree of the polynomial).
- Select a sequence of \(C\) values (e.g. \(C = 1e-5, 1e-4, \dots, 1e5\)).
- Select a sequence of degrees (e.g. \(d = 1, 2, 5, 10, 20, 50, 100\)).
- For each pair of \(C, d\) values (think a grid) use cross-validation to estimate the test error (originally cross validation had 2 for loops, now it has 3 for loops).
- Select the pair \(C, d\) values that give the best cross-validation error
Triple for
loop!
Tuning procedure
# specify tuning procedure
trControl <- trainControl(method = "cv", # perform cross validation
number = 5) # use 5 folds
Tuning grid
# the values of tuning parameters to look over in cross validation
# C: cost parameters
# degree: polynomial degree
# scale: another polynomial kernel paramter -- we don't care about today
tune_grid <- expand.grid(C=c(.01, .1, 1, 10, 100),
degree=c(1, 5, 10, 20),
scale=1)
Tune and train model
# fit the SVM model
tuned_svm <- train(x=train_x,
y=train_y,
method = "svmPoly", # use linear SVM from the e1071 package
tuneGrid = tune_grid, # tuning parameters to look at
trControl = trControl, # tuning precedure defined above
metric='Accuracy') # what classification metric to use
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
## Warning: Setting row names on a tibble is deprecated.
End result
## Support Vector Machines with Polynomial Kernel
##
## 401 samples
## 2 predictor
## 2 classes: '-1', '1'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 320, 321, 321, 321, 321
## Resampling results across tuning parameters:
##
## C degree Accuracy Kappa
## 1e-02 1 0.9252160 0.8504163
## 1e-02 5 0.9476235 0.8952559
## 1e-02 10 0.9401235 0.8802318
## 1e-02 20 0.9376235 0.8752318
## 1e-01 1 0.9301543 0.8602973
## 1e-01 5 0.9326543 0.8652823
## 1e-01 10 0.9376852 0.8753659
## 1e-01 20 0.9251852 0.8503659
## 1e+00 1 0.9301852 0.8603659
## 1e+00 5 0.9401235 0.8802318
## 1e+00 10 0.9376852 0.8753478
## 1e+00 20 0.9177160 0.8354374
## 1e+01 1 0.9276852 0.8553659
## 1e+01 5 0.9525926 0.9051784
## 1e+01 10 0.9202160 0.8403952
## 1e+01 20 0.9152160 0.8304374
## 1e+02 1 0.9276852 0.8553659
## 1e+02 5 0.9575926 0.9151784
## 1e+02 10 0.9078086 0.8156098
## 1e+02 20 0.9127469 0.8254878
##
## Tuning parameter 'scale' was held constant at a value of 1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were degree = 5, scale = 1 and C = 100.
Best parameters
## degree scale C
## 18 5 1 100
Main arguments to train
method = "svmPoly"
says use SVM with a polynomial kernel. caret
then uses the ksvm()
function from the kernlab
package.
tuneGrid = tune_grid
tells train what tuning parameters to search over (defined above)
trControl = trControl
sets the tuning procedure (defined above)
metric='Accuracy'
tells caret
to use the cross-validation accuracy to pick the optimal tuning parameters (this equivalent to using error rate).
Predict function
test_grid_pred <- predict(tuned_svm, newdata = test_grid)