R與機器學習系列|15.可解釋的機器學習算法(Interpretable Machine Learning)(下)

今天我們介紹可解釋機器學習算法的最后一部分,基于XGBoost算法的SHAP值可視化。關于SHAP值其實我們之前的很多個推文中都介紹到,不論是R版本的還是Python版本的,亦不論是普通的分類問題還是生存數據模型的。在此推文中我們將基于XGBoost模型理解SHAP值的計算過程。此外,我們之前的SHAP可視化是基于別人封裝好的函數。在今天的推文中,我們將學習如何使用ggplot2實現更加美觀的SHAP值可視化。

生存數據機器學習算法模型的SHAP值可視化

R與機器學習系列|shapviz——機器學習“黑箱模型”SHAP值可視化

機器學習|SHAP value的另一種R可視化方式以及Python實現SHAP value可視化

機器學習|分享一篇25分臨床預測模型文章,再次體現SHAP 值在機器學習中的重要性!

R學習|R復現機器學習算法XGBoost特征重要性解釋——SHAP value

SHAP值在機器學習算法中的重要性主要體現在以下幾個方面:

解釋模型預測結果:SHAP值能夠解釋單個樣本預測結果的貢獻。它告訴我們每個特征對于某個特定預測結果的影響程度,從而幫助我們理解模型是如何基于輸入特征做出預測的。

特征重要性評估:SHAP值可以用來評估特征的重要性。通過分析多個樣本的SHAP值,我們可以得出哪些特征對于整體模型的性能影響最大,從而在特征選擇、降維等任務中提供指導。

模型調試與驗證:通過檢查每個樣本的SHAP值,可以幫助我們識別模型在某些特定預測上可能出現的問題。如果某個樣本的預測與真實值相差較大,SHAP值可以揭示哪些特征導致了這種預測差異。

透明性和可信度:SHAP值的計算基于合理的博弈論原理,它們為模型的預測結果提供了一種可解釋的解釋。這可以增加模型的可信度,特別是在需要對模型決策做出解釋的場景中。

特征交互分析:SHAP值不僅僅告訴我們單個特征的影響,還可以揭示不同特征之間的交互作用對預測結果的影響。這對于理解特征之間的復雜關系以及模型如何從這些關系中學習非常有幫助。

我們也可以看到SHAP值對模型的解釋在高分的機器學習文獻中出現的還是很頻繁,如下面的兩篇分別發表在EClinicalMedicine和 JAMA surgery上的文章。


Tsai, Shang-Feng et al. “Development and validation of an insulin resistance model for a population without diabetes mellitus and its clinical implication: a prospective cohort study.” EClinicalMedicine vol. 58 101934. 4 Apr. 2023, doi:10.1016/j.eclinm.2023.101934
Bertsimas, Dimitris et al. “Using Artificial Intelligence to Find the Optimal Margin Width in Hepatectomy for Colorectal Cancer Liver Metastases.” JAMA surgery vol. 157,8 (2022): e221819. doi:10.1001/jamasurg.2022.1819

1.1介紹

真實的 Shapley 值在理論上被認為是最優的;然而,真實SHAP值的計算會花費大量的時間。因此,iml包提供了近似 的Shapley 值計算方法。此外,Lundberg 和 Lee也開發了其他SHAP值的近似計算方法,雖然不是純粹的模型無關方法,但也適用于基于樹的模型,并且在大多數 XGBoost 算法實現中(包括 xgboost 包)完全可行。與 iml 的近似方法類似,這種基于樹的 Shapley 值估計方法也是一種近似估計的方法,但其運行的時間遠遠要比iml包的計算時間短。為了演示,我們將使用第 12.5.2 節中使用的特征和最終創建的 XGBoost 模型。

1.2 SHAP計算

為了說明我們上面提到的問題,我們利用之前的數據再xgboost中擬合一個模型。xgboost算法的執行、參數調整及特征重要性解釋在之前的章節中也有介紹。這里不過多介紹。首先,我們加載相關依賴包。

# Helper packages
library(tidyverse)    # for general data wrangling needs
# Modeling packages
library(gbm)      # for original implementation of regular and stochastic GBMs
library(h2o)      # for a java-based implementation of GBM variants
library(xgboost)  # for fitting extreme gradient boosting
library(rsample)# for data split
library(caret)# dummy funtion for categorical variables

然后我們加載需要用到的數據。需要注意的是,我們將一個變量處理為多分類變量,已說明獨熱編碼在xgboost模型數據預處理中的應用。此外,如果這里直接將多分類變量處理為數值型變量,那么最后的SHAP圖里面也不會看到該變量其他啞變量的信息。
此外,因為xgboost的輸入特征文件格式為矩陣,如果這個時候不對多分類變量進行虛擬編碼,那么直接轉換為矩陣后數據維度便會出錯。

data<-read.csv("diabetes.csv",header = T)
data%>%
  mutate(Pregnancies=case_when(
    Pregnancies<3~"A",
    Pregnancies>=3 &Pregnancies<=6~"B",
    Pregnancies>6~"C"
  ))->data
data$Pregnancies<-as.factor(data$Pregnancies)
# Stratified sampling with the rsample package
set.seed(123)
split <- initial_split(data, prop = 0.7, 
                       strata = "Outcome")
data_train  <- training(split)
data_test   <- testing(split)

data_train2=select(data_train, -Outcome)

獨熱編碼

dmytr = dummyVars(" ~ .", data =data_train2, fullRank=T)
data_train3 = predict(dmytr, newdata =data_train2)

X <-data_train3
Y<- data_train[,ncol(data_train)]

此時的X為經過獨熱編碼之后的特征矩陣。下面我們利用之前的超參數直接建立xgboost模型

# optimal parameter list
params <- list(
  eta = 0.01,
  max_depth = 3,
  min_child_weight = 3,
  subsample = 0.5,
  colsample_bytree = 0.5
)

# train final model
xgb.fit.final <- xgboost(
  params = params,
  data = X,
  label = Y,
  nrounds = 602,
  objective = "binary:logistic",
  verbose = 0
)

然后我們將特征重新由低到高進行標準化

feature_values <- X %>%
  as.data.frame() %>%
  mutate_all(scale) %>%
  gather(feature, feature_value) %>% 
  pull(feature_value)

然后我們計算特征的SHAP值以及SHAP重要性等參數

shap_df <- xgb.fit.final %>%
  predict(newdata = X, predcontrib = TRUE) %>%
  as.data.frame() %>%
  select(-BIAS) %>%
  gather(feature, shap_value) %>%
  mutate(feature_value = feature_values) %>%
  group_by(feature) %>%
  mutate(shap_importance = mean(abs(shap_value)))

1.3 SHAP可視化

現在,我們已經計算得到了這些特征的SHAP值,下面我們進行可視化。首先我們使用ggplot2進行可視化,嚴格的來說是基于ggplot2的蜂群圖可視化。看過SHAP圖后可以看到其實就是一個散點圖,橫坐標是SHAP值,縱坐標是每個特征,每個點代表一個觀測值。此外,縱坐標按照SHAP值的重要性進行排序。

library(ggbeeswarm)
p1 <- ggplot(shap_df, aes(x = shap_value, y = reorder(feature, shap_importance))) +
  geom_quasirandom(groupOnX = FALSE, varwidth = TRUE, size =1, alpha = 0.8, aes(color = shap_value)) +
  scale_color_gradient(low = "#ffcd30", high = "#6600cd") +
  labs(x="SHAP value",y="")+
  theme_bw()+
  theme(axis.text = element_text(color = "black"),
        panel.border = element_rect(linewidth = 1))+
  geom_vline(xintercept = 0,linetype="dashed",color="grey",linewidth=1)

p1 
基于ggplot2的SHAP值可視化

從上圖中我們可以看出患者血糖對結局影響最大,其次是年齡、BMI。

下面我們再根據SHAP重要性值做一個SHAP重要性圖

p2 <- shap_df %>% 
  select(feature, shap_importance) %>%
  filter(row_number() == 1) %>%
  ggplot(aes(x = reorder(feature, shap_importance), y = shap_importance,fill=feature)) +
  geom_col(alpha=0.6) +
  coord_flip() +
  xlab(NULL) +
  ylab("mean(|SHAP value|)")+
  scale_fill_brewer(palette = "Set1")+
  theme_bw()+
  theme(legend.position = "",
        axis.text = element_text(color = "black"),
        panel.border = element_rect(linewidth = 1))
p2
SHAP重要性圖

我們也可以把兩個拼圖展示

library(patchwork)
plot<-p1+p2&
  plot_layout(widths = c(2,1))
plot
SHAP值可視化及SHAP重要性排序

下面我們用之前封裝好的SHAP.R函數看看效果

source("shap.R")
shap_result = shap.score.rank(xgb_model =xgb.fit.final, 
                              X_train =data_train3,
                              shap_approx = F)

#計算前10個特征的SHAP值
shap_long_hd = shap.prep(X_train =data_train3 , top_n =9)
#SHAP值可視化
shapR<-plot.shap.summary(data_long =shap_long_hd)
shapR

可以看到結果是一致的。
我們還可以利用這些信息來創建與PDPs(部分依賴圖)相對應的另一種方法。基于Shapley值的依賴圖將一個特征的Shapley值顯示在y軸上,將該特征的值顯示在x軸上。通過為數據集中的所有觀察值繪制這些值,我們可以看到隨著特征的值變化,其歸因重要性如何變化。

shap_df %>% 
  filter(feature %in% c("BMI", "Glucose")) %>%
  ggplot(aes(x = feature_value, y = shap_value)) +
  geom_point(aes(color = shap_value)) +
  scale_colour_viridis_c(name = "Feature value\n(standardized)", option = "C") +
  facet_wrap(~ feature, scales = "free") +
  scale_y_continuous('Shapley value', labels = scales::comma) +
  xlab('Normalized feature value')+
  theme_bw()

我們可以看到BMI和血糖與SHAP值明顯正相關,隨著這兩個特征值增大,SHAP值也逐漸增大,說明對結局的影響也增加。
終于,這個系列(有監督機器學習)更新到今天結束了。希望大家都有收獲,下個系列我們再見!


圖源于網絡

參考來源:Bradley Boehmke & Brandon Greenwell R與機器學習

?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容