R語言:lasso建模和預測

導讀:

clustlasso函數包lasso函數進行建模和預測,包中clustlasso函數也可以進行相似的建模和預測。

clustlasso安裝:
http://www.lxweimin.com/p/2aed75aeca91

clustlasso lasso使用文檔:
https://gitlab.com/biomerieux-data-science/clustlasso/-/blob/master/vignettes/vignette.pdf

1 加載包和數據

# load package
library(clustlasso)
# specify / set random seed
seed = 42
set.seed(seed)
# load example dataset
input.file = system.file("data", "NG-dataset.Rdata", package = "clustlasso")
load(input.file)

2 隨機選擇20%的ID

# pick 20% for test
test.frac = 0.2
# stratify by origin / population structure
ind.by.struct = split(seq(nrow(meta)), meta$pop_structure)
# split按值分割成列表
ind.sample = sapply(ind.by.struct, function(x){sample(x, round(test.frac * length(x)))})  # 每個表種select 20%, sample對List中的每個df執行一次function。

3 制備test set和train set

ind.test = unlist(ind.sample)
# test dataset
X.test = X[ind.test, ]
y.test = y[ind.test]
meta.test = meta[ind.test, ]
# train datasets
X.train = X[-ind.test, ]
y.train = y[-ind.test]
meta.train = meta[-ind.test, ]

4 建模和交叉驗證

# 1. Cross-validation process
# specify cross-validation parameters
n.folds = 10
n.lambda = 100
n.repeat = 3
# run cross-validation process
cv.res.lasso = lasso_cv(X.train, y.train, subgroup = meta.train$pop_structure, n.lambda = n.lambda, n.folds = n.folds, n.repeat = n.repeat, seed = seed, verbose = FALSE)

pdf("cv.pdf", width=15)
par(mfcol = c(1, 3))  # 一頁多圖,一行三列
show_cv_overall(cv.res.lasso, modsel.criterion = "balanced.accuracy.best", best.eps = 1)
dev.off()

5 最佳模型

# 2. Selecting the best model
pdf("cv_best.pdf", width=15)
layout(matrix(c(1, 2, 3), nrow = 1, byrow = TRUE), width = c(0.3,
0.3, 0.4), height = c(1))
perf.best.lasso = show_cv_best(cv.res.lasso, modsel.criterion = "balanced.accuracy.best", best.eps = 1, method = "lasso")
dev.off()
# print cross-validation performance of best model
print(perf.best.lasso)
best.model.lasso = extract_best_model(cv.res.lasso, modsel.criterion = "balanced.accuracy.best", best.eps = 1)

6 模型預測和表現評估

# 3. Making predictions and measuring performance
# make predictions # preds.lasso$preds預測結果
preds.lasso = predict_clustlasso(X.test, best.model.lasso)
# compute performance
perf.lasso = compute_perf(preds.lasso$preds, preds.lasso$probs,
y.test)
# print
print(t(perf.lasso$perf))
pdf("predict.pdf", width=15)
par(mfcol = c(1, 2))
plot(perf.lasso$roc.curves[[1]], lwd = 2, main = "lasso - test set ROC curve")
grid()
plot(perf.lasso$pr.curves[[1]], lwd = 2, main = "lasso - test set precision / recall curve")
grid()
dev.off()

參考:
【機器學習】Cross-Validation(交叉驗證)詳解
Lasso regression(稀疏學習,R)
lasso_cv

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容