xgboost 庫使用入門

本文 github 地址:1-1 基本模型調(diào)用. ipynb,里面會記錄自己kaggle大賽中的內(nèi)容,歡迎start關(guān)注。

cmd 地址:xgboost 庫使用入門

# 開啟多行顯示
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
# InteractiveShell.ast_node_interactivity = "last_expr"

# 顯示圖片
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

數(shù)據(jù)探索

XGBoost中數(shù)據(jù)形式可以是libsvm的,libsvm作用是對稀疏特征進(jìn)行優(yōu)化,看個例子:

1 101:1.2 102:0.03 
0 1:2.1 10001:300 10002:400
0 2:1.2 1212:21 7777:2

每行表示一個樣本,每行開頭0,1表示標(biāo)簽,而后面的則是特征索引:數(shù)值,其他未表示都是0.

我們以判斷蘑菇是否有毒為例子來做后續(xù)的訓(xùn)練。數(shù)據(jù)集來自:http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/ ,其中蘑菇有22個屬性,將這些原始的特征加工后得到126維特征,并保存為libsvm格式,標(biāo)簽是表示蘑菇是否有毒。其中其中 6513 個樣本做訓(xùn)練,1611 個樣本做測試。

import xgboost as xgb
from sklearn.metrics import accuracy_score

DMatrix is a internal data structure that used by XGBoost
which is optimized for both memory efficiency and training speed.

DMatrix 的數(shù)據(jù)來源可以是 string/numpy array/scipy.sparse/pd.DataFrame,如果是 string,則代表 libsvm 文件的路徑,或者是 xgboost 可讀取的二進(jìn)制文件路徑。

data_fold = "./data/"
dtrain = xgb.DMatrix(data_fold + "agaricus.txt.train")
dtest = xgb.DMatrix(data_fold + "agaricus.txt.test")

查看數(shù)據(jù)情況

(dtrain.num_col(),dtrain.num_row())
(dtest.num_col(),dtest.num_row())
(127, 6513)
(127, 1611)

模型訓(xùn)練

基本參數(shù)設(shè)定:

  • max_depth: 樹的最大深度。缺省值為6,取值范圍為:[1,∞]
  • eta:為了防止過擬合,更新過程中用到的收縮步長。eta通過縮減特征 的權(quán)重使提升計算過程更加保守。缺省值為0.3,取值范圍為:[0,1]
  • silent: 0表示打印出運(yùn)行時信息,取1時表示以緘默方式運(yùn)行,不打印 運(yùn)行時信息。缺省值為0
  • objective: 定義學(xué)習(xí)任務(wù)及相應(yīng)的學(xué)習(xí)目標(biāo),“binary:logistic” 表示 二分類的邏輯回歸問題,輸出為概率。
param = {'max_depth':2, 'eta':1, 'silent':0, 'objective':'binary:logistic' }
%time
# 設(shè)置boosting迭代計算次數(shù)
num_round = 2

bst = xgb.train(param, dtrain, num_round)
CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 65.6 μs

此處模型輸出是一個概率值,我們將其轉(zhuǎn)換為0-1值,然后再計算準(zhǔn)確率

train_preds = bst.predict(dtrain)
train_predictions = [round(value) for value in train_preds]
y_train = dtrain.get_label()
train_accuracy = accuracy_score(y_train, train_predictions)
print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))
Train Accuary: 97.77%

我們最后再測試集上看下模型的準(zhǔn)確率的

preds = bst.predict(dtest)
predictions = [round(value) for value in preds]
y_test = dtest.get_label()
test_accuracy = accuracy_score(y_test, predictions)
print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))
Test Accuracy: 97.83%
from matplotlib import pyplot
import graphviz
xgb.to_graphviz(bst, num_trees=0 )
pyplot.show()
svg
svg

scikit-learn 接口格式

from xgboost import XGBClassifier
from sklearn.datasets import load_svmlight_file
my_workpath = './data/'
X_train,y_train = load_svmlight_file(my_workpath + 'agaricus.txt.train')
X_test,y_test = load_svmlight_file(my_workpath + 'agaricus.txt.test')

# 設(shè)置boosting迭代計算次數(shù)
num_round = 2

#bst = XGBClassifier(**params)
#bst = XGBClassifier()
bst =XGBClassifier(max_depth=2, learning_rate=1, n_estimators=num_round, 
                   silent=True, objective='binary:logistic')

bst.fit(X_train, y_train)
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=1, max_delta_step=0,
       max_depth=2, min_child_weight=1, missing=None, n_estimators=2,
       n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, subsample=1)
# 訓(xùn)練集上準(zhǔn)確率
train_preds = bst.predict(X_train)
train_predictions = [round(value) for value in train_preds]

train_accuracy = accuracy_score(y_train, train_predictions)
print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))
Train Accuary: 97.77%
# 測試集上準(zhǔn)確率
# make prediction
preds = bst.predict(X_test)
predictions = [round(value) for value in preds]

test_accuracy = accuracy_score(y_test, predictions)
print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))
Test Accuracy: 97.83%

scikit-lean 中 cv 使用

做cross_validation主要用到下面 StratifiedKFold 函數(shù)

# 設(shè)置boosting迭代計算次數(shù)
num_round = 2
bst =XGBClassifier(max_depth=2, learning_rate=0.1,n_estimators=num_round, 
                   silent=True, objective='binary:logistic')
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
kfold = StratifiedKFold(n_splits=10, random_state=7)
results = cross_val_score(bst, X_train, y_train, cv=kfold)
print(results)
print("CV Accuracy: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))
[ 0.69478528  0.85276074  0.95398773  0.97235023  0.96006144  0.98771121
  1.          1.          0.96927803  0.97695853]
CV Accuracy: 93.68% (9.00%)

GridSearchcv 搜索最優(yōu)解

from sklearn.model_selection import GridSearchCV
bst =XGBClassifier(max_depth=2, learning_rate=0.1, silent=True, objective='binary:logistic')
%time
param_grid = {
 'n_estimators': range(1, 51, 1)
}
clf = GridSearchCV(bst, param_grid, "accuracy",cv=5)
clf.fit(X_train, y_train)
CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 24.3 μs
clf.best_params_, clf.best_score_
({'n_estimators': 30}, 0.98418547520343924)
## 在測試集合上測試
#make prediction
preds = clf.predict(X_test)
predictions = [round(value) for value in preds]

test_accuracy = accuracy_score(y_test, predictions)
print("Test Accuracy of gridsearchcv: %.2f%%" % (test_accuracy * 100.0))
Test Accuracy of gridsearchcv: 97.27%

early-stop

我們設(shè)置驗(yàn)證valid集,當(dāng)我們迭代過程中發(fā)現(xiàn)在驗(yàn)證集上錯誤率增加,則提前停止迭代。

from sklearn.model_selection import train_test_split
seed = 7
test_size = 0.33
X_train_part, X_validate, y_train_part, y_validate= train_test_split(X_train, y_train, test_size=test_size,
    random_state=seed)

X_train_part.shape
X_validate.shape
(4363, 126)
(2150, 126)
# 設(shè)置boosting迭代計算次數(shù)
num_round = 100


bst =XGBClassifier(max_depth=2, learning_rate=0.1, n_estimators=num_round, silent=True, objective='binary:logistic')

eval_set =[(X_validate, y_validate)]
bst.fit(X_train_part, y_train_part, early_stopping_rounds=10, eval_metric="error",
    eval_set=eval_set, verbose=True)
[0] validation_0-error:0.048372
Will train until validation_0-error hasn't improved in 10 rounds.
[1] validation_0-error:0.042326
[2] validation_0-error:0.048372
[3] validation_0-error:0.042326
[4] validation_0-error:0.042326
[5] validation_0-error:0.042326
[6] validation_0-error:0.023256
[7] validation_0-error:0.042326
[8] validation_0-error:0.042326
[9] validation_0-error:0.023256
[10]    validation_0-error:0.006512
[11]    validation_0-error:0.017674
[12]    validation_0-error:0.017674
[13]    validation_0-error:0.017674
[14]    validation_0-error:0.017674
[15]    validation_0-error:0.017674
[16]    validation_0-error:0.017674
[17]    validation_0-error:0.017674
[18]    validation_0-error:0.024651
[19]    validation_0-error:0.020465
[20]    validation_0-error:0.020465
Stopping. Best iteration:
[10]    validation_0-error:0.006512

我們可以將上面的錯誤率進(jìn)行可視化,方便我們更直觀的觀察

results = bst.evals_result()
#print(results)

epochs = len(results['validation_0']['error'])
x_axis = range(0, epochs)

# plot log loss
fig, ax = pyplot.subplots()
ax.plot(x_axis, results['validation_0']['error'], label='Test')
ax.legend()
pyplot.ylabel('Error')
pyplot.xlabel('Round')
pyplot.title('XGBoost Early Stop')
pyplot.show()
output_35_5.png-30.1kB
output_35_5.png-30.1kB
# 測試集上準(zhǔn)確率
# make prediction
preds = bst.predict(X_test)
predictions = [round(value) for value in preds]

test_accuracy = accuracy_score(y_test, predictions)
print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))
Test Accuracy: 97.27%

學(xué)習(xí)曲線

# 設(shè)置boosting迭代計算次數(shù)
num_round = 100

# 沒有 eraly_stop
bst =XGBClassifier(max_depth=2, learning_rate=0.1, n_estimators=num_round, silent=True, objective='binary:logistic')

eval_set = [(X_train_part, y_train_part), (X_validate, y_validate)]
bst.fit(X_train_part, y_train_part, eval_metric=["error", "logloss"], eval_set=eval_set, verbose=True)
# retrieve performance metrics
results = bst.evals_result()
#print(results)


epochs = len(results['validation_0']['error'])
x_axis = range(0, epochs)

# plot log loss
fig, ax = pyplot.subplots()
ax.plot(x_axis, results['validation_0']['logloss'], label='Train')
ax.plot(x_axis, results['validation_1']['logloss'], label='Test')
ax.legend()
pyplot.ylabel('Log Loss')
pyplot.title('XGBoost Log Loss')
pyplot.show()

# plot classification error
fig, ax = pyplot.subplots()
ax.plot(x_axis, results['validation_0']['error'], label='Train')
ax.plot(x_axis, results['validation_1']['error'], label='Test')
ax.legend()
pyplot.ylabel('Classification Error')
pyplot.title('XGBoost Classification Error')
pyplot.show()
output_39_5.png-27.6kB
output_39_5.png-27.6kB
output_39_11.png-33kB
output_39_11.png-33kB
# make prediction
preds = bst.predict(X_test)
predictions = [round(value) for value in preds]

test_accuracy = accuracy_score(y_test, predictions)
print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))
Test Accuracy: 99.81%
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 229,963評論 6 542
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 99,348評論 3 429
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 178,083評論 0 383
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經(jīng)常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,706評論 1 317
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 72,442評論 6 412
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 55,802評論 1 328
  • 那天,我揣著相機(jī)與錄音,去河邊找鬼。 笑死,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,795評論 3 446
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 42,983評論 0 290
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 49,542評論 1 335
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 41,287評論 3 358
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 43,486評論 1 374
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 39,030評論 5 363
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 44,710評論 3 348
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 35,116評論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,412評論 1 294
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 52,224評論 3 398
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 48,462評論 2 378

推薦閱讀更多精彩內(nèi)容