tf.train.Checkpoint
:變量的保存與恢復
??Tensorflow的Checkpoint機制將可追蹤變量以二進制的方式儲存成一個.ckpt文件,儲存了變量的名稱及對應張量的值。
??Checkpoint 只保存模型的參數,不保存模型的計算過程,因此一般用于在具有模型源代碼的時候恢復之前訓練好的模型參數。如果需要導出模型(無需源代碼也能運行模型)。
??很多時候,我們希望在模型訓練完成后能將訓練好的參數(變量)保存起來。在需要使用模型的其他地方載入模型和參數,就能直接得到訓練好的模型。可能你第一個想到的是用 Python 的序列化模塊 pickle
存儲 model.variables
。但不幸的是,TensorFlow 的變量類型 ResourceVariable
并不能被序列化。
??好在 TensorFlow 提供了 tf.train.Checkpoint
這一強大的變量保存與恢復類,可以使用其 save()
和 restore()
方法將 TensorFlow 中所有包含 Checkpointable State 的對象進行保存和恢復。具體而言,tf.keras.optimizer
、 tf.Variable
、 tf.keras.Layer
或者 tf.keras.Model
實例都可以被保存。其使用方法非常簡單,我們首先聲明一個 Checkpoint:
checkpoint = tf.train.Checkpoint(model=model)
??這里 tf.train.Checkpoint()
接受的初始化參數比較特殊,是一個 **kwargs
。具體而言,是一系列的鍵值對,鍵名可以隨意取,值為需要保存的對象。例如,如果我們希望保存一個繼承 tf.keras.Model
的模型實例 model
和一個繼承 tf.train.Optimizer
的優化器 optimizer
,我們可以這樣寫:
checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)
??這里 myAwesomeModel
是我們為待保存的模型 model
所取的任意鍵名。注意,在恢復變量的時候,我們還將使用這一鍵名。
接下來,當模型訓練完成需要保存的時候,使用:
checkpoint.save(save_path_with_prefix)
就可以。 save_path_with_prefix
是保存文件的目錄 + 前綴。
- 注解
??例如,在源代碼目錄建立一個名為 save 的文件夾并調用一次 checkpoint.save('./save/model.ckpt')
,我們就可以在可以在 save 目錄下發現名為 checkpoint
、 model.ckpt-1.index
、 model.ckpt-1.data-00000-of-00001
的三個文件,這些文件就記錄了變量信息。checkpoint.save()
方法可以運行多次,每運行一次都會得到一個. index 文件和. data 文件,序號依次累加。
??當在其他地方需要為模型重新載入之前保存的參數時,需要再次實例化一個 checkpoint,同時保持鍵名的一致。再調用 checkpoint 的 restore 方法。就像下面這樣:
model_to_be_restored = MyModel() # 待恢復參數的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) # 鍵名保持為“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)
即可恢復模型變量。 save_path_with_prefix_and_index
是之前保存的文件的目錄 + 前綴 + 編號。例如,調用 checkpoint.restore('./save/model.ckpt-1')
就可以載入前綴為 model.ckpt
,序號為 1 的文件來恢復模型。
??當保存了多個文件時,我們往往想載入最近的一個。可以使用 tf.train.latest_checkpoint(save_path)
這個輔助函數f。例如如果 save 目錄下有 model.ckpt-1.index
到 model.ckpt-10.index
的 10 個保存文件, tf.train.latest_checkpoint('./save')
即返回 ./save/model.ckpt-10
。
總體而言,恢復與保存變量的典型代碼框架如下:
# train.py 模型訓練階段
model = MyModel()
# 實例化Checkpoint,指定保存對象為model(如果需要保存Optimizer的參數也可加入)
checkpoint = tf.train.Checkpoint(myModel=model)
# ...(模型訓練代碼)
# 模型訓練完畢后將參數保存到文件(也可以在模型訓練過程中每隔一段時間就保存一次)
checkpoint.save('./save/model.ckpt')
# test.py 模型使用階段
model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model) # 實例化Checkpoint,指定恢復對象為model
checkpoint.restore(tf.train.latest_checkpoint('./save')) # 從文件恢復模型參數
# 模型使用代碼
- 注解
??tf.train.Checkpoint
與以前版本常用的 tf.train.Saver
相比,強大之處在于其支持在即時執行模式下 “延遲” 恢復變量。具體而言,當調用了 checkpoint.restore()
,但模型中的變量還沒有被建立的時候,Checkpoint 可以等到變量被建立的時候再進行數值的恢復。即時執行模式下,模型中各個層的初始化和變量的建立是在模型第一次被調用的時候才進行的(好處在于可以根據輸入的張量形狀而自動確定變量形狀,無需手動指定)。這意味著當模型剛剛被實例化的時候,其實里面還一個變量都沒有,這時候使用以往的方式去恢復變量數值是一定會報錯的。比如,你可以試試在 train.py 調用 tf.keras.Model
的 save_weight()
方法保存 model 的參數,并在 test.py 中實例化 model 后立即調用 load_weight()
方法,就會出錯,只有當調用了一遍 model 之后再運行 load_weight()
方法才能得到正確的結果。可見, tf.train.Checkpoint
在這種情況下可以給我們帶來相當大的便利。另外, tf.train.Checkpoint
同時也支持圖執行模式。
在代碼目錄下建立 save 文件夾并運行代碼進行訓練后,save 文件夾內將會存放每隔 100 個 batch 保存一次的模型變量數據。在命令行參數中加入 --mode=test
并再次運行代碼,將直接使用最后一次保存的變量值恢復模型并在測試集上測試模型性能,可以直接獲得 95% 左右的準確率。
使用 tf.train.CheckpointManager
刪除舊的 Checkpoint 以及自定義文件編號
在模型的訓練過程中,我們往往每隔一定步數保存一個 Checkpoint 并進行編號。不過很多時候我們會有這樣的需求:
在長時間的訓練后,程序會保存大量的 Checkpoint,但我們只想保留最后的幾個 Checkpoint;
Checkpoint 默認從 1 開始編號,每次累加 1,但我們可能希望使用別的編號方式(例如使用當前 Batch 的編號作為文件編號)。
這時,我們可以使用 TensorFlow 的 tf.train.CheckpointManager
來實現以上需求。具體而言,在定義 Checkpoint 后接著定義一個 CheckpointManager:
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)
此處, directory
參數為文件保存的路徑, checkpoint_name
為文件名前綴(不提供則默認為 ckpt
), max_to_keep
為保留的 Checkpoint 數目。
在需要保存模型的時候,我們直接使用 manager.save()
即可。如果我們希望自行指定保存的 Checkpoint 的編號,則可以在保存時加入 checkpoint_number
參數。例如 manager.save(checkpoint_number=100)
。
以下是一個基于CIFAR10數據集的一個示例,讀者可進行參考。
GPU環境測試
import tensorflow as tf
# 使用顯卡進行時,將GPU的顯存使用策略設置為 “僅在需要時申請顯存空間”,不然會申請所有顯存空間,報錯
gpu_devices = tf.config.experimental.list_physical_devices("GPU")
for device in gpu_devices:
tf.config.experimental.set_memory_growth(device, True)
返回運行時可見的物理設備列表,默認情況下,所有發現的CPU和GPU設備都被視為可見的。
tf.config.experimental.list_physical_devices(device_type=None)
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'),
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU')]
查看GPU設備信息
!nvidia-smi
Tue May 19 23:17:57 2020
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64 Driver Version: 440.64 CUDA Version: 10.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 GeForce RTX 2070 Off | 00000000:01:00.0 On | N/A |
| 0% 51C P8 19W / 175W | 4923MiB / 7979MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 903 G /usr/lib/xorg/Xorg 66MiB |
| 0 1566 G /usr/bin/gnome-shell 85MiB |
| 0 5359 C /home/wcjb/anaconda3/bin/python 4759MiB |
+-----------------------------------------------------------------------------+
檢查GPU是否可用
tf.test.is_gpu_available()
True
查看GPU是否可用
tf.config.experimental.list_physical_devices('GPU')
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
# 啟用設備放置日志記錄將導致打印任何張量分配或操作
tf.debugging.set_log_device_placement(True)
tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
VirtualDeviceConfiguration(memory_limit=1024)
數據處理
數據載入
- CIFAR-10數據集
??CIFAR-10數據集是一個用于識別普適物體的小型數據集,它一共包含10個類別的RGB彩色圖片:飛機(airplane)、汽車(automobile)、鳥類(bird)、貓(cat)、鹿(deer)、狗(dog)、蛙類(frog)、馬(horse)、船(ship)和卡車(truck)。圖片的尺寸為32x32,該數據集一共有50000張訓練圖片和10000張測試圖片。
??1個10000x3072大小的uint8s數組。數組的每行存儲1張32*32的圖像,第1個1024包含紅色通道值,下1個包含綠色,最后的1024包含藍色。圖像存儲以行順序為主,所以數組的前32列為圖像第1行的紅色通道值。
import pickle
import os
from PIL import Image
from tqdm import tqdm
import numpy as np
class CIFAR10(object):
def __init__(self,path='/home/wcjb/Code/Dataset/cifar-10-batches-py/'):
self.trainpath = [os.path.join(path,'data_batch_'+str(i+1)) for i in range(5)]
self.testpath = [os.path.join(path,'test_batch')]
def unpickle(self,file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding = 'iso-8859-1') #
return dict
def load_batch(self,file):
with open(file, 'rb')as f:
datadict = self.unpickle(file)
data = datadict['data']
label = datadict['labels']
data = data.reshape(10000, 3, 32, 32)
label = np.array(label)
return data,label
def toimg(self,data):
img = []
for i in range(data.shape[0]):
imgs = data[i - 1]
r = imgs[0]
g = imgs[1]
b = imgs[2]
R = Image.fromarray(r)
G = Image.fromarray(g)
B = Image.fromarray(b)
img.append(Image.merge("RGB",(R,G,B)))
return img
def cif2img(self):
train_img,test_img = [],[]
for tp in tqdm(self.trainpath,desc='Train-img'):
data,label = self.load_batch(tp)
train_img.append(self.toimg(data))
for tp in tqdm(self.testpath,desc='Test-img '):
data,label = self.load_batch(tp)
test_img.append(self.toimg(data))
return train_img,test_img
def cif2data(self):
x_train,y_train,x_test,y_test = [],[],[],[]
for tp in tqdm(self.trainpath,desc='Train'):
data,label = self.load_batch(tp)
x_train.append(data)
y_train.append(label)
for tp in tqdm(self.testpath,desc='Test '):
data,label = self.load_batch(tp)
x_test.append(data)
y_test.append(label)
x_train,y_train = np.array(x_train).reshape(-1,3,32,32),np.array(y_train).reshape(-1,)
x_test,y_test = np.array(x_test).reshape(-1,3,32,32),np.array(y_test).reshape(-1,)
x_train,x_test = np.rollaxis(x_train, 1,4),np.rollaxis(x_test,1, 4)
return x_train,y_train,x_test,y_test
cif = CIFAR10()
# 將CIFAR10數據集加載為圖片數據
train_img,test_img = cif.cif2img()
# 將將CIFAR10數據集加載為多維數據用于訓練
x_train,y_train,x_test,y_test = cif.cif2data()
Train-img: 100%|██████████| 5/5 [00:02<00:00, 2.39it/s]
Test-img : 100%|██████████| 1/1 [00:00<00:00, 2.41it/s]
Train: 100%|██████████| 5/5 [00:00<00:00, 40.06it/s]
Test : 100%|██████████| 1/1 [00:00<00:00, 40.89it/s]
查看數據集樣本的圖片
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(train_img[0][0])
<matplotlib.image.AxesImage at 0x7fdfd8cbe2d0>
plt.imshow(test_img[0][0])
<matplotlib.image.AxesImage at 0x7fdfd8bf4250>
查看數據集的樣本的數組形態
x_train.shape
(50000, 32, 32, 3)
數據增強處理函數
- 直方圖均衡化
??直方圖均衡化通常用來增加許多圖像的全局對比度,尤其是當圖像的有用數據的對比度相當接近的時候。通過這種方法,亮度可以更好地在直方圖上分布。這樣就可以用于增強局部的對比度而不影響整體的對比度,直方圖均衡化通過有效地擴展常用的亮度來實現這種功能。這種方法對于背景和前景都太亮或者太暗的圖像非常有用,這種方法尤其是可以帶來X光圖像中更好的骨骼結構顯示以及曝光過度或者曝光不足照片中更好的細節。這種方法的一個主要優勢是它是一個相當直觀的技術并且是可逆操作,如果已知均衡化函數,那么就可以恢復原始的直方圖,并且計算量也不大。這種方法的一個缺點是它對處理的數據不加選擇,它可能會增加背景噪聲的對比度并且降低有用信號的對比度。
import shutil
from PIL import Image
import sys
import cv2
from tqdm import notebook
class DataAugumentation(object):
def __init__(self,num=10):
self.num = num
def CLAHE(self,img):
grayimg = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 局部直方圖均值化
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
cl1 = clahe.apply(grayimg)
return cl1
def Histograms_Equalization(self,img):
grayimg = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 直方圖均值化
equ = cv2.equalizeHist(grayimg)
return equ
def make_one_hot(self,data):
return (np.arange(self.num)==data[:,None]).astype(np.int64)
def augument(self,imgs,labels):
'''
使用圖像處理方法進行數據增強,直方圖均值化和局部直方圖均值化
再加上灰度圖和原圖片可以將數據集增大三倍
'''
x_data,y_data = [],[]
for img,label in notebook.tqdm(zip(imgs,labels),desc='數據增強進度'):
imggray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
x_data.append(imggray.astype('float32') / 255.0)
y_data.append(label)
he_image = self.Histograms_Equalization(img)
x_data.append(he_image.astype('float32') / 255.0)
y_data.append(label)
clahe_img = self.CLAHE(img)
x_data.append(clahe_img.astype('float32') / 255.0)
y_data.append(label)
return np.array(x_data),np.array(y_data)
處理訓練集
da = DataAugumentation()
x_new_train,y_new_train = da.augument(x_train,y_train)
x_new_test,y_new_test = da.augument(x_test,y_test)
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='數據增強進度', max=1.0, style=ProgressStyle(d…
?
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='數據增強進度', max=1.0, style=ProgressStyle(d…
?
- 擴展數據維度,計算卷積
x_new_train = np.expand_dims(x_new_train, 3)
x_new_test = np.expand_dims(x_new_test,3)
可以看到,數據集增強之后比較大,所以可以把增強后的數據集保存在本地方便再次復用。
pickle.dump(x_new_train, open('./CifaData/x_new_train.p', 'wb'))
pickle.dump(y_new_train, open('./CifaData/y_new_train.p', 'wb'))
pickle.dump(x_new_test, open('./CifaData/x_new_test.p', 'wb'))
pickle.dump(y_new_test, open('./CifaData/y_new_test.p', 'wb'))
!cd CifaData && ls -hl
總用量 705M
-rw-rw-r-- 1 wcjb wcjb 118M 5月 19 22:17 x_new_test.p
-rw-rw-r-- 1 wcjb wcjb 586M 5月 19 22:17 x_new_train.p
-rw-rw-r-- 1 wcjb wcjb 235K 5月 19 22:17 y_new_test.p
-rw-rw-r-- 1 wcjb wcjb 1.2M 5月 19 22:17 y_new_train.p
# with open('./CifaData/y_new_train.p', 'rb') as fo:
# y_n_train = pickle.load(fo, encoding = 'iso-8859-1')
搭建模型
import tensorflow as tf
import datetime
import time
MODEL_DIR = "./models"
class network(tf.keras.Model):
def __init__(self,n_class=10,learning_rate=1e-4):
super(network,self).__init__()
# 定義網絡結構
self.conv2d_01 = tf.keras.layers.Convolution2D (kernel_size = (5, 5),input_shape=(32,32,1), filters = 100, activation='relu')
self.maxpool2d_01 = tf.keras.layers.MaxPool2D()
self.conv2d_02 = tf.keras.layers.Convolution2D (kernel_size = (3, 3), filters = 150, activation='relu')
self.maxpool2d_02 = tf.keras.layers.MaxPool2D()
self.conv2d_03 = tf.keras.layers.Convolution2D (kernel_size = (3, 3), filters = 250, padding='same', activation='relu')
self.maxpool2d_03 = tf.keras.layers.MaxPool2D()
self.flatten = tf.keras.layers.Flatten()
self.dense_01 = tf.keras.layers.Dense(512, activation='relu')
self.dense_02 = tf.keras.layers.Dense(300, activation='relu')
self.dense_03 = tf.keras.layers.Dense(10,activation='softmax')
# 優化器
self.optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
# 確認模型日志目錄是否存在,若不存在則創建
if not tf.io.gfile.exists(MODEL_DIR):
tf.io.gfile.makedirs(MODEL_DIR)
# 申明訓練和測試日志路徑
train_dir = os.path.join(MODEL_DIR, 'summaries', 'train')
test_dir = os.path.join(MODEL_DIR, 'summaries', 'eval')
# 根據給定文件在當前上下文環境中創建日志記錄器,記錄數據摘要,便于可視化及分析并且每個10000刷新
self.train_summary_writer = tf.summary.create_file_writer(train_dir, flush_millis=10000)
self.test_summary_writer = tf.summary.create_file_writer(test_dir, flush_millis=10000, name='test')
# 將可追蹤變量以二進制的方式儲存成一個checkpoint 檔(.ckpt),
# 即儲存變量的名字和對應的張量的數值。
checkpoint_dir = os.path.join(MODEL_DIR, 'checkpoints')
self.checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
self.checkpoint = tf.train.Checkpoint(model=self, optimizer=self.optimizer)
# 只保存最近10個模型文件
tf.train.CheckpointManager(self.checkpoint, directory=checkpoint_dir, checkpoint_name='network.ckpt', max_to_keep=10)
# 返回目錄下最近一次checkpoint的文件名,并恢復模型參數
self.checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
def call(self,inputs):
x = self.conv2d_01(inputs)
x = self.maxpool2d_01(x)
x = self.conv2d_02(x)
x = self.maxpool2d_02(x)
x = self.conv2d_03(x)
x = self.maxpool2d_03(x)
x = self.flatten(x)
x = self.dense_01(x)
x = self.dense_02(x)
x = self.dense_03(x)
return x
@tf.function()
def loss(self, logits, labels):
return tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True))
@tf.function()
def accuracy(self, logits, labels):
return tf.keras.metrics.sparse_categorical_accuracy(labels, logits)
@tf.function(experimental_relax_shapes=True)
def train_step(self, images, labels):
with tf.device('/GPU:0'):
with tf.GradientTape() as tape:
# 前向計算
logits = self.call(images)
# 計算當前批次模型的損失函數
loss = self.loss(logits, labels)
# 計算當前批次的模型準確率
accuracy = self.accuracy(logits, labels)
#=====================反向過程=====================
# 計算梯度
grads = tape.gradient(loss, self.trainable_variables)
# 使用梯度更新可訓練集合的變量
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
return loss, accuracy, logits
def train(self, train_dataset, test_dataset, epochs=1, log_freq=50):
for i in range(epochs):
train_start = time.time()
# 在該上下文環境中記錄可追蹤變量
with self.train_summary_writer.as_default():
start = time.time()
# metrics指標是有狀態的。當調用.result()時,會計算累計值并返回累計的結果。使用.reset_states()可以清除累積值
avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
avg_accuracy = tf.keras.metrics.Mean('accuracy', dtype=tf.float32)
for images, labels in train_dataset:
loss, accuracy, logits = self.train_step(images, labels)
# 持續紀律損失值
avg_loss(loss)
# 持續記錄分類正確率
avg_accuracy(accuracy)
# 在訓練log_freq次后,記錄變量,并計算累計指標值
# optimizer.iterations 記錄了優化器運行的訓練步數
if tf.equal(self.optimizer.iterations % log_freq, 0):
# 在日志中寫入變量的摘要
tf.summary.scalar('loss', avg_loss.result(), step=self.optimizer.iterations)
tf.summary.scalar('accuracy', avg_accuracy.result(), step=self.optimizer.iterations)
# 計算完成一個批次訓練所需要的時間
rate = log_freq / (time.time() - start)
print('Step{} Loss: {:0.4f} accuracy: {:0.2f}% ({:0.2f} steps/sec)'.format(self.optimizer.iterations.numpy(), loss, (avg_accuracy.result() * 100), rate))
# 清除當前訓練批次的指標累計值,進入下一訓練批次
avg_loss.reset_states()
avg_accuracy.reset_states()
start = time.time()
train_end = time.time()
print('\nTrain time for epoch: {} ({} total steps): {}'.format(i + 1, self.optimizer.iterations.numpy(), train_end - train_start))
with self.test_summary_writer.as_default():
self.test(test_dataset, self.optimizer.iterations)
# 保存當前epoch的模型參數
self.checkpoint.save(self.checkpoint_prefix)
#在訓練后保存模型會報錯,暫時沒有解決
# self.export_path = os.path.join(MODEL_DIR, 'export')
# tf.saved_model.save(self, self.export_path)
def test(self, test_dataset, step_num):
"""
評估模型在驗證集上的正確率
"""
avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
avg_accuracy = tf.keras.metrics.Mean('accuracy', dtype=tf.float32)
# 只需要計算前向過程,需要計算相應指標
for (images, labels) in test_dataset:
logits = self.call(images)
avg_loss(self.loss(logits, labels))
avg_accuracy(self.accuracy(logits, labels))
print('Test-Loss: {:0.4f} Test-Accuracy: {:0.2f}%'.format(avg_loss.result(), avg_accuracy.result() * 100))
tf.summary.scalar('loss', avg_loss.result(), step=step_num)
tf.summary.scalar('accuracy', avg_accuracy.result(), step=step_num)
def evaluat(self, test_dataset):
# 模型保存報錯,暫未解決,故無法讀取
# restored_model = tf.saved_model.restore(self.export_path)
# y_predict = restored_model(x_test)
avg_accuracy = tf.keras.metrics.Mean('accuracy', dtype=tf.float32)
for (images, labels) in test_dataset:
logits = self.call(images)
avg_accuracy(self.accuracy(logits, labels))
print('Model accuracy: {:0.2f}%'.format(avg_accuracy.result() * 100))
def forward(self, xs):
"""
完成模型的前向計算,用于實際預測
"""
predictions = self.call(xs)
logits = tf.nn.softmax(predictions)
return logits
使用tf.data.Dataset創建可迭代訪問的數據集,便于按批次進行訓練
# 由于用顯卡進行訓練,不是大顯存請使用較小的Batch Size
val_dataset = tf.data.Dataset.from_tensor_slices((x_new_test.astype(np.float32), y_new_test))
val_dataset = val_dataset.shuffle(10000).batch(1024)
dataset = tf.data.Dataset.from_tensor_slices((x_new_train.astype(np.float32), y_new_train))
dataset = dataset.shuffle(5000).batch(1024)
net = network()
net.train(dataset, val_dataset,1)
Step72550 Loss: 1.5211 accuracy: 94.14% (13.43 steps/sec)
Step72600 Loss: 1.5097 accuracy: 94.07% (15.38 steps/sec)
Step72650 Loss: 1.5188 accuracy: 94.38% (15.44 steps/sec)
Train time for epoch: 1 (72666 total steps): 11.80623173713684
Test-Loss: 1.7723 Test-Accuracy: 68.68%
net.forward(x_new_train[:1])
<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[0.0853368 , 0.0853368 , 0.0853368 , 0.08533724, 0.0853368 ,
0.0853368 , 0.23196831, 0.0853368 , 0.0853368 , 0.0853368 ]],
dtype=float32)>
y_new_train[:1]
array([6])
CIFAR-10數據集
如果有讀者想自己在本地復現,可以參考我的代碼:
CIFAR10-Tensorflow
由于文件較大,可能需要一定時間加載,歡迎大家盡情Star、Fork。