56-caret包學習:模型訓練與調優

1、模型訓練與參數優化

在進行建模時,需對模型的參數進行優化,在caret包中其主要函數是train。
一旦定義了模型和調優參數值,就應該指定重采樣的類型。目前,k折交叉驗證重采樣(一次或重復)、留一交叉驗證重采樣和 bootstrap (簡單估計或632規則)重采樣方法可以被train函數使用。 重采樣后,過程中生成性能度量的概要,以指導用戶選擇哪些調優參數值。 默認情況下,函數會自動選擇與最佳值相關聯的調優參數。

重采樣方法:
交叉驗證重采樣將樣本分割,一部分作為訓練樣本,一部分作為測試樣本,通過計算在測試樣本上的誤差率來估計測試誤差,常見的交叉驗證技術有留一交叉驗證和K折交叉驗證法;拔靴法(bootstrap)是利用有限的樣本資料經由多次重復抽樣,重新建立起足以代表母體樣本分布的新樣本,其主要特點是能夠被廣泛的應用到各種統計學習方法中,特別是在對難以估計或者統計軟件不能直接給出結果的變量的估計中。

2、自定義調優過程

train()函數可以在模型擬合之前以各種方式對數據進行預處理, 為了指定需要進行哪些預處理,train函數有個preProcess的參數可供調整。

交替調優網格:
tuneGrid()函數可以為每個調整參數生成一個數據框,數據框的列名為擬合模型的參數,比如RDA模型,列名將為gamma和lambda。

> library(pacman)
> p_load(caret,mlbench,dplyr)
> # 使用Sonar數據集
> data("Sonar")
> 
> # 拆分為訓練集和測試集
> # 默認情況下,該函數使用分層隨機分割
> set.seed(123)
> ind <- createDataPartition(y = Sonar$Class,
+                            # 訓練集所占比例
+                            p = 0.75,
+                            list = F)
> train <- Sonar[ind,]
> test <- Sonar[-ind,]
> ctrl <- trainControl(method = "repeatedcv",
+                      # number交叉驗證折數或重采樣迭代次數
+                      number = 10,
+                      # repeats確定了反復次數
+                      repeats = 10,
+                      # 是否顯示訓練過程
+                      verboseIter = T,
+                      # 是否將數據保存到trainingData
+                      returnData = F,
+                      # 訓練百分比
+                      p = 0.75,
+                      classProbs = T,
+                      summaryFunction = twoClassSummary,
+                      allowParallel = T)

method確定多次交叉檢驗的抽樣方法;
method可選:"boot", "cv", "LOOCV", "LGOCV", "repeatedcv", "timeslice", "none" 和 "oob";
"oob"袋外估計值,只能用于randomForest、袋外決策樹、袋外earth、袋外柔性判別分析或者條件樹森林模型,不適合GBM模型;
對時間序列method = "timeslice", 有三個參數initialWindow, horizon 和 fixedWindow;
有的模型預測結果為計算概率,例如“Prob”、“After”、“Response”、“Probability”或“RAW”,classProbs =TRUE則讓其返回類別"class";
summaryFunction 指定模型性能統計的函數;
selectionFunction 選擇最優參數和抽樣的函數;
returnResamp 指定要保存多少性能指標,可為all,final,none;
allowParallel 是否使用并行計算。

> gbm.grid <- expand.grid(interaction.depth = c(1,5,9),
+                         n.trees = c(50, 100, 150, 200, 250, 300),
+                         shrinkage = 0.1,
+                         n.minobsinnode = 20)
> head(gbm.grid)

梯度提升機模型(GBM)有三個主要的參數:

  1. n.trees:樹的迭代次數
  2. interaction.depth:樹的復雜度
  3. n.minobsinnode:收斂一個節點中開始分割的最小訓練集樣本數
##   interaction.depth n.trees shrinkage n.minobsinnode
## 1                 1      50       0.1             20
## 2                 5      50       0.1             20
## 3                 9      50       0.1             20
## 4                 1     100       0.1             20
## 5                 5     100       0.1             20
## 6                 9     100       0.1             20
> nrow(gbm.grid)
## [1] 18
> set.seed(123)
> # method="svmRadial支持向量機
> # “rda正則判別分析模型
> # treebag裝袋樹
> fit.gbm <- train(Class ~ .,data = train,
+                  # 梯度提升樹模型
+                  method = "gbm",
+                  trControl = ctrl,
+                  verbose = F,
+                  tuneGrid = gbm.grid,
+                  metric = "ROC")
>
> fit.gbm
## Stochastic Gradient Boosting 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 140, 142, 142, 141, 141, 141, ... 
## Resampling results across tuning parameters:
## 
##   interaction.depth  n.trees  ROC        Sens       Spec     
##   1                   50      0.8768105  0.8397222  0.7094643
##   1                  100      0.8934425  0.8737500  0.7714286
##   1                  150      0.8991369  0.8656944  0.7801786
##   1                  200      0.8993031  0.8619444  0.7773214
##   1                  250      0.8987153  0.8650000  0.7789286
##   1                  300      0.9034747  0.8700000  0.7803571
##   5                   50      0.8890377  0.8586111  0.7505357
##   5                  100      0.8994891  0.8694444  0.7746429
##   5                  150      0.9019147  0.8626389  0.7908929
##   5                  200      0.9028720  0.8637500  0.7876786
##   5                  250      0.9028844  0.8700000  0.7876786
##   5                  300      0.9028299  0.8695833  0.7880357
##   9                   50      0.8984772  0.8515278  0.7678571
##   9                  100      0.9083358  0.8659722  0.7864286
##   9                  150      0.9162029  0.8776389  0.7983929
##   9                  200      0.9153100  0.8905556  0.7971429
##   9                  250      0.9168998  0.8902778  0.7991071
##   9                  300      0.9188070  0.8843056  0.7957143
## 
## Tuning parameter 'shrinkage' was held constant at a value of
##  0.1
## Tuning parameter 'n.minobsinnode' was held constant at a value
##  of 20
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees =
##  300, interaction.depth = 9, shrinkage = 0.1 and n.minobsinnode = 20.

最終選擇的參數為n.trees = 300, interaction.depth = 9, shrinkage = 0.1, n.minobsinnode = 20。
plot函數可用于檢查模型的性能估計值與調整參數之間的關系:

> trellis.par.set(caretTheme())
> plot(fit.gbm)
ROC與參數之間的關系

可以看到參數n.trees = 300, interaction.depth = 9時ROC達到最高點。
plot中使用metric參數可以查看其他性能指標,比如要查看“Kappa”指標,可以指定:metric = "Kappa"。本模型沒有記錄Kappa性能,所以無法查看。

查看擬合結果熱力圖,結果與上圖一樣,n.trees = 300, interaction.depth = 9時顏色最深:

> trellis.par.set(caretTheme())
> plot(fit.gbm, metric = "ROC", plotType = "level", scales = list(x = list(rot = 90)))
ROC熱力圖

也可以使用ggplot2包:

> ggplot(fit.gbm)
ggplot2畫圖

xyplot和stripplot可用于繪制針對(數值型)調整參數的重采樣統計信息。
histogram和densityplot還可用于查看調整參數在調整參數之間的分布。

> trellis.par.set(caretTheme())
> densityplot(fit.gbm, pch = "|", resamples = "all")
ROC密度圖

3、模型選擇

tolerance()函數可用于找到不太復雜的模型,例如,要基于2%的性能損失選擇參數值:

> which.pct <- tolerance(fit.gbm, metric = "ROC", tol = 2, maximize = T)
>
> fit.gbm$results[which.pct, 1:6]
##      shrinkage interaction.depth   n.minobsinnode   n.trees       ROC   Sens
##  8       0.1                5              20         100     0.9026612 0.8675

最后選擇了一個不那么復雜的模型,在ROC曲線下的面積為0.9026612。

如果擬合了多個模型,可以通過resamples()函數對他們的性能差異做出統計報表:

> set.seed(123)
> fit.svm <- train(Class ~ .,data=train,
>                  method = "svmRadial", 
>                  trControl = ctrl, 
>                  preProc = c("center", "scale"),
>                  tuneLength = 8,
>                  metric = "ROC")
> 
> set.seed(123)
> fit.rda <- train(Class ~ ., data=train, 
>                  method = "rda", 
>                  trControl = ctrl, 
>                  tuneLength = 4,
>                  metric = "ROC")
>
> resamps <- resamples(list(GBM = fit.gbm,
>                           SVM = fit.svm,
>                           RDA = fit.rda))
>
> resamps
## Call:
## resamples.default(x = list(GBM = fit.gbm, SVM = fit.svm, RDA
##  = fit.rda))
## 
## Models: GBM, SVM, RDA 
## Number of resamples: 100 
## Performance metrics: ROC, Sens, Spec 
## Time estimates for: everything, final model fit 
> summary(resamps)
## Call:
## summary.resamples(object = resamps)
## 
## Models: GBM, SVM, RDA 
## Number of resamples: 100 
## 
## ROC 
##          Min.   1st Qu.    Median      Mean   3rd Qu. Max. NA's
## GBM 0.7142857 0.8700397 0.9186508 0.9107093 0.9683780    1    0
## SVM 0.7031250 0.8888889 0.9285714 0.9224578 0.9821429    1    0
## RDA 0.5781250 0.8571429 0.9206349 0.8978720 0.9598214    1    0
## 
## Sens 
##          Min.   1st Qu.    Median      Mean   3rd Qu. Max. NA's
## GBM 0.5000000 0.7777778 0.8888889 0.8798611 1.0000000    1    0
## SVM 0.5555556 0.8750000 0.8750000 0.8825000 0.9166667    1    0
## RDA 0.6250000 0.7777778 0.8750000 0.8747222 1.0000000    1    0
## 
## Spec 
##          Min.   1st Qu.    Median      Mean   3rd Qu. Max. NA's
## GBM 0.4285714 0.7142857 0.8571429 0.7819643 0.8750000    1    0
## SVM 0.3750000 0.7142857 0.7500000 0.7835714 0.8750000    1    0
## RDA 0.2857143 0.5714286 0.7142857 0.7164286 0.8571429    1    0

畫個圖看看:

trellis.par.set(caretTheme())
dotplot(resamps, metric = "ROC")
ROC對比圖

通過ROC性能對比,SVM模型高于GBM模型高于RDA模型。
由于設置了同樣的隨機數種子,使用同樣的數據訓練模型,因此對模型的差異進行推斷是有意義的。通過這種方法,降低了存在內部重采樣相關性的可能,然后使用簡單的t檢驗來評估模型之間沒有差異的零假設:

> dif.value <- diff(resamps)
> summary(dif.value)
## Call:
## summary.diff.resamples(object = dif.value)
## 
## p-value adjustment: bonferroni 
## Upper diagonal: estimates of the difference
## Lower diagonal: p-value for H0: difference = 0
## 
## ROC 
##     GBM      SVM      RDA     
## GBM          -0.01175  0.01284
## SVM 0.614698           0.02459
## RDA 0.844504 0.005992         
## 
## Sens 
##     GBM SVM       RDA      
## GBM     -0.002639  0.005139
## SVM 1              0.007778
## RDA 1   1                  
## 
## Spec 
##     GBM      SVM       RDA      
## GBM          -0.001607  0.065536
## SVM 1.000000            0.067143
## RDA 0.029286 0.004865 
trellis.par.set(caretTheme())
dotplot(dif.value)
模型之間差異性檢驗

4、擬合最終模型

當最優模型和最優參數已經找到時,可以使用最優參數直接擬合模型:

> fit.ctrl <- trainControl(method = "none", classProbs = TRUE)
> 
> set.seed(123)
> fit.final <- train(Class ~ .,data=train,
>                    method = "svmRadial", 
>                    trControl = fit.ctrl, 
>                    preProc = c("center", "scale"),
>                    tuneGrid = fit.svm$bestTune,
>                    metric = "ROC")
> 
> fit.final
## Support Vector Machines with Radial Basis Function Kernel 
## 
## 157 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## 
## Pre-processing: centered (60), scaled (60) 
## Resampling: None 
> pred.final <- predict(fit.final,newdata=test)
> confusionMatrix(pred.final,test$Class)
## Stochastic Gradient Boosting 
## 
## 157 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## 
## No pre-processing
## Resampling: None
> pred.final <- predict(fit.gbm.final, newdata = test)
> confusionMatrix(pred.final, test$Class)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  M  R
##          M 23  2
##          R  4 22
##                                           
##                Accuracy : 0.8824          
##                  95% CI : (0.7613, 0.9556)
##     No Information Rate : 0.5294          
##     P-Value [Acc > NIR] : 8.488e-08       
##                                           
##                   Kappa : 0.765           
##                                           
##  Mcnemar's Test P-Value : 0.6831          
##                                           
##             Sensitivity : 0.8519          
##             Specificity : 0.9167          
##          Pos Pred Value : 0.9200          
##          Neg Pred Value : 0.8462          
##              Prevalence : 0.5294          
##          Detection Rate : 0.4510          
##    Detection Prevalence : 0.4902          
##       Balanced Accuracy : 0.8843          
##                                           
##        'Positive' Class : M
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容