TensorFlow 保存和加載模型

可以在訓練期間和訓練后保存模型進度。 這意味著模型可以從中斷的地方恢復,并避免長時間的訓練。 保存也意味著您可以共享您的模型,而其他人可以重新創建您的工作。 在發布研究模型和技術時,大多數機器學習從業者分享:

  1. 用于創建模型的代碼
  2. 模型的訓練權重或參數

共享此數據有助于其他人了解模型的工作原理,并使用新數據自行嘗試。

注意:小心不受信任的代碼 - TensorFlow模型是代碼。 有關詳細信息,請參閱安全使用TensorFlow。

選項

保存TensorFlow模型有多種方法 - 取決于您使用的API。 本指南使用tf.keras,一個高級API,用于在TensorFlow中構建和訓練模型。 有關其他方法,請參閱TensorFlow保存和還原指南或保存在急切中。

安裝

安裝和引用

安裝和導入TensorFlow和依賴項,有下面兩種方式:

  1. 命令行:pip install -q h5py pyyaml
  2. 在Anaconda Navigator中安裝;

下載樣本數據集

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras

tf.__version__

'1.11.0'

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

定義模型

讓我們構建一個簡單的模型,我們將用它來演示保存和加載權重。

# Returns a short sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation=tf.nn.softmax)
  ])
  
  model.compile(optimizer=tf.keras.optimizers.Adam(), 
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=['accuracy'])
  
  return model


# Create a basic model instance
model = create_model()
model.summary()

在訓練期間保存檢查點

主要用例是在訓練期間和訓練結束時自動保存檢查點。 通過這種方式,您可以使用訓練有素的模型,而無需重新訓練,或者在您離開的地方接受訓練 - 以防止訓練過程中斷。

tf.keras.callbacks.ModelCheckpoint是執行此任務的回調。 回調需要幾個參數來配置檢查點。

檢查點回調使用情況

訓練模型并將模型傳遞給ModelCheckpoint:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, 
 save_weights_only=True,
 verbose=1)

model = create_model()

model.fit(train_images, train_labels,  epochs = 10, 
  validation_data = (test_images,test_labels),
  callbacks = [cp_callback])  # pass callback to training

這將創建一個TensorFlow檢查點文件集合,這些文件在每個時期結束時更新:

!ls {checkpoint_dir}

checkpoint cp.ckpt.data-00000-of-00001 cp.ckpt.index

創建一個新的未經訓練的模型。 僅從權重還原模型時,必須具有與原始模型具有相同體系結構的模型。 由于它是相同的模型架構,我們可以共享權重,盡管它是模型的不同實例。

現在重建一個新的未經訓練的模型,并在測試集上進行評估。 未經訓練的模型將在偶然水平上執行(準確度約為10%):

model = create_model()

loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))

然后從檢查點加載權重,并重新評估:

model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

1000/1000 [==============================] - 0s 40us/step
Restored model, accuracy: 87.60%

檢查點回調選項

回調提供了幾個選項,可以為生成的檢查點提供唯一的名稱,并調整檢查點頻率。

訓練一個新模型,每5個時期保存一次唯一命名的檢查點:

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.fit(train_images, train_labels,
  epochs = 50, callbacks = [cp_callback],
  validation_data = (test_images,test_labels),
  verbose=0)

現在,查看生成的檢查點并選擇最新的檢查點:

! ls {checkpoint_dir}

checkpoint cp-0030.ckpt.data-00000-of-00001
cp-0005.ckpt.data-00000-of-00001 cp-0030.ckpt.index
cp-0005.ckpt.index cp-0035.ckpt.data-00000-of-00001
cp-0010.ckpt.data-00000-of-00001 cp-0035.ckpt.index
cp-0010.ckpt.index cp-0040.ckpt.data-00000-of-00001
cp-0015.ckpt.data-00000-of-00001 cp-0040.ckpt.index
cp-0015.ckpt.index cp-0045.ckpt.data-00000-of-00001
cp-0020.ckpt.data-00000-of-00001 cp-0045.ckpt.index
cp-0020.ckpt.index cp-0050.ckpt.data-00000-of-00001
cp-0025.ckpt.data-00000-of-00001 cp-0050.ckpt.index
cp-0025.ckpt.index

latest = tf.train.latest_checkpoint(checkpoint_dir)
latest

'training_2/cp-0050.ckpt'

注意:默認的tensorflow格式僅保存最近的5個檢查點。

要測試,請重置模型并加載最新的檢查點:

model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

1000/1000 [==============================] - 0s 96us/step
Restored model, accuracy: 86.80%

這些文件是什么?

上述代碼將權重存儲到檢查點格式的文件集合中,這些文件僅包含二進制格式的訓練權重。 檢查點包含:*一個或多個包含模型權重的分片。 *索引文件,指示哪些權重存儲在哪個分片中。

如果您只在一臺機器上訓練模型,那么您將有一個帶有后綴的分片:.data-00000-of-00001

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

保存整個模型

整個模型可以保存到包含權重值,模型配置甚至優化器配置的文件中。 這允許您檢查模型并稍后從完全相同的狀態恢復培訓 - 無需訪問原始代碼。

在Keras中保存功能齊全的模型非常有用 - 您可以在TensorFlow.js中加載它們,然后在Web瀏覽器中訓練和運行它們。

Keras使用HDF5標準提供基本保存格式。 出于我們的目的,可以將保存的模型視為單個二進制blob。

model = create_model()

model.fit(train_images, train_labels, epochs=5)

# Save entire model to a HDF5 file
model.save('my_model.h5')

Epoch 1/5
1000/1000 [==============================] - 0s 395us/step - loss: 1.1260 - acc: 0.6870
Epoch 2/5
1000/1000 [==============================] - 0s 135us/step - loss: 0.4136 - acc: 0.8760
Epoch 3/5
1000/1000 [==============================] - 0s 138us/step - loss: 0.2811 - acc: 0.9280
Epoch 4/5
1000/1000 [==============================] - 0s 153us/step - loss: 0.2078 - acc: 0.9480
Epoch 5/5
1000/1000 [==============================] - 0s 154us/step - loss: 0.1452 - acc: 0.9750

現在從該文件重新創建模型:

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()

檢查其準確性:

loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

這項技術可以保存以下:

  1. 權重值
  2. 模型的配置(架構)
  3. 優化器配置

Keras通過檢查架構來保存模型。 目前,它無法保存TensorFlow優化器(來自tf.train)。 使用這些時,您需要在加載后重新編譯模型,并且您將失去優化器的狀態。

下一步是什么

這是使用tf.keras保存和加載的快速指南。

tf.keras指南顯示了有關使用tf.keras保存和加載模型的更多信息。

請參閱在急切執行期間保存以備保存。

“保存和還原”指南包含有關TensorFlow保存的低級詳細信息。

完整代碼:

from __future__ import absolute_import,division,print_function
import os
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)


# Download dataset
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1,28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1,28 * 28) / 255.0

# Define a model
# Returns a short sequential model
def create_model():
    model = tf.keras.models.Sequential([
    keras.layers.Dense(512,activation=tf.nn.relu,input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10,activation=tf.nn.softmax)
])

model.compile(optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.sparse_categorical_crossentropy,
  metrics=['accuracy'])
return model

# Create a basic model instance
model = create_model()
model.summary()

# Checkpoint callback usage
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
 save_weights_only=True,
 verbose=1)
model = create_model()
model.fit(train_images,train_labels,epochs=10,
  validation_data=(test_images,test_labels),
  callbacks=[cp_callback]) # pass callback to training

# Create a new, untrained model. 
model = create_model()
loss,acc = model.evaluate(test_images,test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))

# Load the weights from chekpoint, and re-evaluate.
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

# Train a new model, and save uniquely named checkpoints once every 5epochs
# include the epoch in the file name. (uses 'str.format')
checkpoint_path = 'training_2/cp-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1,save_weights_only=True,
    # Save weights, every 5-epochs
    period=5)

model = create_model()
model.fit(train_images,train_labels,
  epochs=50,callbacks = [cp_callback],
  validation_data = (test_images,test_labels),
  verbose=0)


latest = tf.train.latest_checkpoint(checkpoint_dir)
print(latest)


# To test, reset the model and load the latest checkpoint
model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

# Manually save weights
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss, acc = model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))


# Save the entire model
model = create_model()
model.fit(train_images,train_labels,
  epochs=5)
# Save entire model to a HDF5 file
model.save('my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()

loss, acc = new_model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 228,702評論 6 534
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 98,615評論 3 419
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 176,606評論 0 376
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,044評論 1 314
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 71,826評論 6 410
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,227評論 1 324
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,307評論 3 442
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,447評論 0 289
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 48,992評論 1 335
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 40,807評論 3 355
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 43,001評論 1 370
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,550評論 5 361
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,243評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,667評論 0 26
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,930評論 1 287
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,709評論 3 393
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 47,996評論 2 374

推薦閱讀更多精彩內容

  • 近期做了一些反垃圾的工作,除了使用常用的規則匹配過濾等手段,也采用了一些機器學習方法進行分類預測。我們使用Tens...
    liuyan731閱讀 12,799評論 0 19
  • 在這篇tensorflow教程中,我會解釋: 1) Tensorflow的模型(model)長什么樣子? 2) 如...
    JunsorPeng閱讀 3,448評論 1 6
  • 世界這么大,你應該去看看 今天剛剛高考完的表妹問起大學報考志愿應該怎么填,因為發揮的不太好,家里人都建議她學護理,...
    Miss凌妹妹閱讀 462評論 6 4
  • 近海風云烈, 征帆拓遠洲。 迷霧隱奇偉, 但去必賢優。
    村客閱讀 156評論 0 6
  • 你說你不吃香菜 我說我吃餃子不帶湯 我倆的碗里卻是漂著香菜的餃子湯 在南方第一次吃熱干面,我嫌它太噎人 在北方,后...
    Crazy麻麻閱讀 348評論 7 9