More Classification

Support Vector Machine, Kernels, more classification metrics

STOR 390

This lecture

packages

library(e1071) # SVM
library(caret) # tuning
## Loading required package: lattice
## Loading required package: ggplot2
library(kernlab) # Kernel SVM
## 
## Attaching package: 'kernlab'
## The following object is masked from 'package:ggplot2':
## 
##     alpha

Binary classes

Linear classifiers

linearly separable data

Which separating hyperplane?

The Margin

Call the distance from the separating hyperplane to the nearest data point the margin.

The Margin

Maximal Margin classifier (intuition)

Maximal Margin classifier (words)

Non-linearly separable data

Soft margin SVM (intuition)

A good linear classifier should aim to put points

Soft margin SVM

Two competing objectives

Balance competing objectives

[image from http://img04.deviantart.net/a1bc/i/2012/189/4/8/angel_vs_devil_by_sakura_wind-d568jp4.png]

[image from http://img04.deviantart.net/a1bc/i/2012/189/4/8/angel_vs_devil_by_sakura_wind-d568jp4.png]

Tuning parameter!

C large

C moderate

C small

Fitting SVM

e1071 package

Some training data

# this function comes from the synthetic_distributions.R package
train <- two_class_guasssian_meatballs(n_pos=200, n_neg=200,
                                       mu_pos=c(1,0), mu_neg=c(-1,0),
                                       sigma_pos=diag(2), sigma_neg=diag(2),
                                       seed=103)

train
## # A tibble: 400 × 3
##             x1         x2      y
##          <dbl>      <dbl> <fctr>
## 1   2.15074931 -0.4792060      1
## 2  -0.08786829  1.6166332      1
## 3   1.76506913 -0.1514896      1
## 4   0.61865671  1.2622023      1
## 5   1.07803836 -0.5458768      1
## 6   0.85736662  0.3856852      1
## 7   1.75866283  0.3809878      1
## 8   2.21017318 -2.4745634      1
## 9   0.96406375  0.3872157      1
## 10 -0.60946315  1.7316120      1
## # ... with 390 more rows

SVM code

# fit SVM
svmfit <- svm(y ~ ., # R's formula notation
              data=train, # data frame to use
              cost=10, # set the tuning paramter
              scale=FALSE,
              type='C-classification',
              shrinking=FALSE,
              kernel='linear') 

main arguments

Other arguments

Predictions

# this is equivalent to svmfit$fitted
train_predictions <- predict(svmfit, newdata = train)
train_predictions[1:5] 
##  1  2  3  4  5 
##  1 -1  1  1  1 
## Levels: -1 1

Open source software

Some kind soul took the time to code

  1. a good implementation of SVM in C and then release it to the public
  2. a package in R (and Python and many other languages) so that us data scientists don’t have to learn C to aforementioned C implementation of SVM

Saves time and money!

Trade offs

Non-linear classifiers

Explicit variable transformation

Linear classifier

Linear classifier

Manually add some non-linear transformations

# add polynomial terms to 
train_poly <- train %>% 
                mutate(x1_sq = x1^2, x1x2 = x1*x2, x2_sq = x2^2)


test_grid_poly <- test_grid %>% 
                    mutate(x1_sq = x1^2, x1x2 = x1*x2, x2_sq = x2^2)


# fit SVM
svm_poly <- svm(y ~ ., 
                  data=train_poly,
                  scale=FALSE,
                  type='C-classification',
                  shrinking=FALSE,
                  kernel='linear', 
                  cost=10)

grid_poly_predictions <- predict(svm_poly, newdata = test_grid_poly)

Manually add some non-linear transformations

Why not keep going?

Two issues come up when we add more and more non-linear variables

Kernels

Key idea

Consequences

  1. If we can compute the distance between points cheaply then we can fit SVM more quickly.
  2. If we have the ability to compute a “distance” between pairs of objects then we can use SVM.

Upshot of kernels

kernels = easier to compute non-linear version of SVM.

Non-standard data

Anything where we can compute a “similarity function”

Polynomial kernel

\[K(a, b) = (a^T b + 1)^m\]

Might have more parameters i.e. \(K(a, b) = (\gamma a^T b + c)^m\).

Suprising math fact

Kernel’s reduce computational complexity

Kernel SVM in R

# svm() is from the e1071 package
svm_kern2 <- svm(y ~ ., 
                  data=train,
                  cost=10,
                  kernel='polynomial', # use a polynomial kernel
                  degree = 2, # degree two polynomial
                  gamma=1, # other kernel parameters
                  coef0 =1, # other kernel parameters
                  scale=FALSE,
                  type='C-classification',
                  shrinking=FALSE)

kern2_predictions <- predict(svm_kern2, newdata = test_grid)

Kernel SVM in R

Degree 100 polynomial kernel

Common kernels

caret package

Format data

# break the data frame up into separate x and y data
train_x <- train %>% select(-y)
train_y <- train$y

Tuning procedure

# specify tuning procedure
trControl <- trainControl(method = "cv", # perform cross validation
                          number = 5) # use 5 folds

Tuning grid

# the values of tuning parameters to look over in cross validation
    # C: cost parameters
    # degree: polynomial degree
    # scale: another polynomial kernel paramter -- we don't care about today
tune_grid <- expand.grid(C=c(.01, .1, 1, 10, 100),
                         degree=c(1, 5, 10, 20),
                         scale=1)

Tune and train model

# fit the SVM model
tuned_svm <- train(x=train_x,
                   y=train_y,
                   method = "svmPoly", # use linear SVM from the e1071 package
                   tuneGrid = tune_grid, # tuning parameters to look at
                   trControl = trControl, # tuning precedure defined above
                   metric='Accuracy') # what classification metric to use
## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

## Warning: Setting row names on a tibble is deprecated.

End result

tuned_svm
## Support Vector Machines with Polynomial Kernel 
## 
## 401 samples
##   2 predictor
##   2 classes: '-1', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 320, 321, 321, 321, 321 
## Resampling results across tuning parameters:
## 
##   C      degree  Accuracy   Kappa    
##   1e-02   1      0.9252160  0.8504163
##   1e-02   5      0.9476235  0.8952559
##   1e-02  10      0.9401235  0.8802318
##   1e-02  20      0.9376235  0.8752318
##   1e-01   1      0.9301543  0.8602973
##   1e-01   5      0.9326543  0.8652823
##   1e-01  10      0.9376852  0.8753659
##   1e-01  20      0.9251852  0.8503659
##   1e+00   1      0.9301852  0.8603659
##   1e+00   5      0.9401235  0.8802318
##   1e+00  10      0.9376852  0.8753478
##   1e+00  20      0.9177160  0.8354374
##   1e+01   1      0.9276852  0.8553659
##   1e+01   5      0.9525926  0.9051784
##   1e+01  10      0.9202160  0.8403952
##   1e+01  20      0.9152160  0.8304374
##   1e+02   1      0.9276852  0.8553659
##   1e+02   5      0.9575926  0.9151784
##   1e+02  10      0.9078086  0.8156098
##   1e+02  20      0.9127469  0.8254878
## 
## Tuning parameter 'scale' was held constant at a value of 1
## Accuracy was used to select the optimal model using  the largest value.
## The final values used for the model were degree = 5, scale = 1 and C = 100.

Best parameters

tuned_svm$bestTune
##    degree scale   C
## 18      5     1 100

Main arguments to train

Predict function

test_grid_pred <- predict(tuned_svm, newdata = test_grid)