TF官網上給出了三種讀取數據的方式:
- Preloaded data: 預加載數據
- Feeding: Python 產生數據,再把數據喂給后端
- Reading from file:從文件中直接讀取
(Ps: 此處參考博客 詳解TF數據讀取有三種方式(next_batch))
(Pps: 文中的代碼均基于Python3.6版本)
TF的核心是用C++寫的,運行快,但是調用不靈活。結合Python和TF,將計算的核心算子和運行框架用C++寫,然后以API的形式提供給Python調用。Python的主要工作是設計計算圖(模型及數據),將設計好的Graph提供給后端執行。簡而言之,TF是Run,Pyhton的角色是Design。
一. Preloaded Data
- constant,常量
- variable,初始化或者后面更新均可
這種數據讀取方式只適合小數據,通常在程序中定義某固定值,如循環次數等,而很少用來讀取訓練數據。
import tensorflow as tf
# 設計Graph
a = tf.constant([1, 2, 3])
b = tf.Variable([1, 2, 4])
c = tf.add(a, b)
二. Feeding
Feeding的方式在設計Graph的時候留占位符,在真正Run的時候向占位符中傳遞數據,喂給后端訓練。
#!/usr/bin/env python3
# _*_coding:utf-8 _*_
import tensorflow as tf
# 設計Graph
a = tf.placeholder(tf.int16)
b = tf.placeholder(tf.int16)
c = tf.add(a, b)
# 用Python產生數據
li1 = [2, 3, 4] # li1:<type:'list'>: [2, 3, 4]
li2 = [4, 0, 1]
# 打開一個session --> 喂數據 --> 計算y
with tf.Session() as sess:
print(sess.run(c, feed_dict={a: li1, b: li2})) # [6, 3, 5]
這里tf.placeholder代表占位符,先定一下變量a的類型。在實際運行的時候,通過feed_dict來指定a在計算中的實際值。
這種數據讀取方式非常靈活,而且易于理解,但是在讀取大數據時會非常吃力。
三. Read from file
官網上給出的例子是從csv等文件中讀取數據,這里都會涉及到隊列的概念, 我們首先簡單介紹一下Queue讀取數據的原理,便于后面代碼的理解。(參考 Blog)
讀取數據其實是為了后續的計算,以圖片為例,假設我們的硬盤中有一個圖片數據集0001.jpg,0002.jpg,0003.jpg……我們只需要把它們讀取到內存中,然后提供給GPU或是CPU進行計算就可以了。這聽起來很容易,但事實遠沒有那么簡單。事實上,我們必須要把數據先讀入后才能進行計算,假設讀入用時0.1s,計算用時0.9s,那么就意味著每過1s,GPU都會有0.1s無事可做,這就大大降低了運算的效率。
隊列的存在就是為了使計算的速度不完全受限于數據讀取的速度,保證有足夠多的數據喂給計算。如圖所示,將數據的讀入和計算分別放在兩個線程中,讀入的數據保存為內存中的一個隊列,負責計算的線程可以源源不斷地從內存隊列中讀取數據。這樣就解決了GPU因為IO而空閑的問題。文件名隊列,我們用tf.train.string_input_producer()
函數創建文件名隊列。
tf.train.string_input_producer(
string_tensor, # 文件名列表
num_epochs=None, # epoch的個數,None代表無限循環
shuffle=True, # 一個epoch內的樣本(文件)順序是否打亂
seed=None, # 當shuffle=True時才用,應該是指定一個打亂順序的入口
capacity=32, # 設置隊列的容量
shared_name=None,
name=None,
cancel_op=None)
ps: 在Tensorflow中,內存隊列不需要我們自己建立,后續只需要使用reader從文件名隊列中讀取數據就可以。
tf.train.string_input_produecer()會將一個隱含的QueueRunner添加到全局圖中(類似的操作還有tf.train.shuffle_batch()等)。由于沒有顯式地返回QueueRunner()來調用create_threads()啟動線程,這里使用了tf.train.start_queue_runners()方法直接啟動tf.GraphKeys.QUEUE_RUNNERS集合中的所有隊列線程。
在我們使用tf.train.string_input_producer創建文件名隊列后,整個系統其實還是處于“停滯狀態”的,也就是說,我們文件名并沒有真正被加入到隊列中(如下圖所示)。此時如果我們開始計算,因為內存隊列中什么也沒有,計算單元就會一直等待,導致整個系統被阻塞。在讀取文件的整個過程中會涉及到:
- 文件名隊列創建: tf.train.string_input_producer()
- 文件閱讀器: tf.TFRecordReader()
- 文件解析器:tf.parse_single_example() 或者decode_csv()
- Batch_size:tf.train.shuffle_batch()
- 填充進程:tf.train.start_queue_runners()
下面我們用python生成數據,并將數據轉換成tfrecord格式,然后讀取tfrecord文件。在這過程中,我們會介紹幾種不同的從文件讀取數據的方法。
生成數據:
#!/usr/bin/env python3
# _*_coding:utf-8 _*_
import os
import numpy as np
'''
二分類問題,樣本數據是形如1,2,5,8,9(1*5)的隨機數,對應標簽是0或1
arg:
data_filename: 路徑下的文件名 'data/data_train.txt'
size: 設定生成樣本數據的size=(10000, 5),其中10000是樣本個數,5是單個樣本的特征。
'''
gene_data = 'data/data_train.txt'
size = (100000, 5)
def generate_data(gene_data, size):
if not os.path.exists(gene_data):
np.random.seed(9)
x_data = np.random.randint(0, 10, size=size)
# 這里設置標簽值一半樣本是0,一半樣本是1
y1_data = np.ones((size[0]//2, 1), int) # 這里需要注意python3和python2的區別。
y2_data = np.zeros((size[0]//2, 1), int) # python2用/得到整數,python3要用//。否則會報錯“'float' object cannot be interpreted as an integer”
y_data = np.append(y1_data, y2_data)
np.random.shuffle(y_data)
# 將樣本和標簽以1 2 3 6 8/1的形式來保存
xy_data = str('')
for xy_row in range(len(x_data)):
x_str = str('')
for xy_col in range(len(x_data[0])):
if not xy_col == (len(x_data[0])-1):
x_str =x_str+str(x_data[xy_row, xy_col])+' '
else:
x_str = x_str + str(x_data[xy_row, xy_col])
y_str = str(y_data[xy_row])
xy_data = xy_data+(x_str+'/'+y_str + '\n')
#print(xy_data[1])
# write to txt 保存成txt格式
write_txt = open(gene_data, 'w')
write_txt.write(xy_data)
write_txt.close()
return
# generate_data(gene_data=gene_data, size=size) # 取消注釋后可以直接生成數據
從txt文件中讀取數據,并轉換成TFrecord格式
tfrecord數據文件是一種將數據和標簽統一存儲的二進制文件,能更好的利用內存,在tensorflow中快速的復制,移動,讀取,存儲等。
TFRecord 文件中的數據是通過 tf.train.Example() 以 Protocol Buffer
(協議緩沖區) 的格式存儲。Protocol Buffer是Google的一種數據交換的格式,他獨立于語言,獨立于平臺,以二進制的形式存在,能更好的利用內存,方便復制和移動。
tf.train.Example()包含Features字段,通過feature將數據和label進行統一封裝, 然后將example協議內存塊轉化為字符串。tf.train.Features()是字典結構,包括字符串格式的key,可以自己定義key。與key對應的是value值,這里需要注意的是,feature的value值只支持列表,可以是字符串(Byteslist),浮點數列表(Floatlist)和整型數列表(int64list),所以,在給value賦值時一定要注意類型將數據轉換為這三種類型的列表。
- 類型為標量:如0,1標簽,轉為列表。 tf.train.Int64List(value=[label])
- 類型為數組:sample = [1, 2, 3],tf.train.Int64List(value=sample)
- 類型為矩陣:sample = [[1, 2, 3], [1, 2 ,3]],
兩種方式:
轉成list類型:將張量fatten成list(向量)
轉成string類型:將張量用.tostring()轉換成string類型。
同時要記得保存形狀信息,在讀取后恢復shape。
'''
讀取txt中的數據,并將數據保存成tfrecord文件
arg:
txt_filename: 是txt保存的路徑+文件名 'data/data_train.txt'
tfrecord_path:tfrecord文件將要保存的路徑及名稱 'data/test_data.tfrecord'
'''
def txt_to_tfrecord(txt_filename=gene_data, tfrecord_path=tfrecord_path):
# 第一步:生成TFRecord Writer
writer = tf.python_io.TFRecordWriter(tfrecord_path)
# 第二步:讀取TXT數據,并分割出樣本數據和標簽
file = open(txt_filename)
for data_line in file.readlines(): # 每一行
data_line = data_line.strip('\n') # 去掉換行符
sample = []
spls = data_line.split('/', 1)[0]# 樣本
for m in spls.split(' '):
sample.append(int(m))
label = data_line.split('/', 1)[1]# 標簽
label = int(label)
# print('sample:', sample, 'labels:', label)
# 第三步: 建立feature字典,tf.train.Feature()對單一數據編碼成feature
feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
# 第四步:可以理解為將內層多個feature的字典數據再編碼,集成為features
features = tf.train.Features(feature = feature)
# 第五步:將features數據封裝成特定的協議格式
example = tf.train.Example(features=features)
# 第六步:將example數據序列化為字符串
Serialized = example.SerializeToString()
# 第七步:將序列化的字符串數據寫入協議緩沖區
writer.write(Serialized)
# 記得關閉writer和open file的操作
writer.close()
file.close()
return
# txt_to_tfrecord(txt_filename=gene_data, tfrecord_path=tfrecord_path)
所以在上面的程序中我們涉及到了讀取txt文本數據,并將數據寫成tfrecord文件。在網絡訓練過程中數據的讀取通常是對tfrecord文件的操作。
TF讀取tfrecord文件有兩種方式:一種是Queue方式,就是上面介紹的隊列,另外一種是用dataset來讀取。先介紹Queue讀取文件數據的方法
1. Queue方式
Queue讀取數據可以分為兩種:tf.parse_single_example()和tf.parse_example()
(1). tf.parse_single_example()讀取數據
tf.parse_single_example(
serialized, # 張量
features, # 對應寫入的features
name=None,
example_names=None)
'''
用tf.parse_single_example()讀取并解析tfrecord文件
args:
filename_queue: 文件名隊列
shuffle_batch: 判斷在batch的時候是否要打亂順序
if_enq_many: 設定batch中的參數enqueue_many,評估該參數的作用
'''
# 第一步: 建立文件名隊列,可設置Epoch次數
filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=3)
def read_single(filename_queue, shuffle_batch, if_enq_many):
# 第二步: 建立閱讀器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 第三步:根據寫入時的格式建立相對應的讀取features
features = {
'sample': tf.FixedLenFeature([5], tf.int64),# 如果不是標量,一定要在這里說明數組的長度
'label': tf.FixedLenFeature([], tf.int64)
}
# 第四步: 用tf.parse_single_example()解析單個EXAMPLE PROTO
Features = tf.parse_single_example(serialized_example, features)
# 第五步:對數據進行后處理
sample = tf.cast(Features['sample'], tf.float32)
label = tf.cast(Features['label'], tf.float32)
# 第六步:生成Batch數據 generate batch
if shuffle_batch: # 打亂數據順序,隨機取樣
sample_single, label_single = tf.train.shuffle_batch([sample, label],
batch_size=2,
capacity=200000,
min_after_dequeue=10000,
num_threads=1,
enqueue_many=if_enq_many)# 主要是為了評估enqueue_many的作用
else: # # 如果不打亂順序則用tf.train.batch(), 輸出隊列按順序組成Batch輸出
sample_single, label_single = tf.train.batch([sample, label],
batch_size=2,
capacity=200000,
min_after_dequeue=10000,
num_threads=1,
enqueue_many = if_enq_many)
return sample_single, label_single
x1_samples, y1_labels = read_single(filename_queue=filename_queue,
shuffle_batch=False, if_enq_many=False)
x2_samples, y2_labels = read_single(filename_queue=filename_queue,
shuffle_batch=True, if_enq_many=False)
print(x1_samples, y1_labels) # 因為是tensor,這里還處于構造tensorflow計算圖的過程,輸出僅僅是shape等,不會是具體的數值。
# 如果想得到具體的數值,必須建立session,是tensor在計算圖中流動起來,也就是用session.run()的方式得到具體的數值。
# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init) # 初始化
# 如果tf.train.string_input_producer([tfrecord_path], num_epochs=3)中num_epochs不為空的化,必須要初始化local變量
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator() # 管理線程
threads = tf.train.start_queue_runners(coord=coord) # 文件名開始進入文件名隊列和內存
for i in range(1):
# Queue + tf.parse_single_example()讀取tfrecord文件
X1, Y1 = sess.run([x1_samples, y1_labels])
print('X1: ', X1, 'Y1: ', Y1) # 這里就可以得到tensor具體的數值
X2, Y2 = sess.run([x2_samples, y2_labels])
print('X2: ', X2, 'Y2: ', Y2) # 這里就可以得到tensor具體的數值
coord.request_stop()
coord.join(threads)
Ps
: 如果建立文件名tf.train.string_input_producer([tfrecord_path], num_epochs=3)時, 設置num_epochs為具體的值(不是None)。在初始化的時候必須對local_variables進行初始化sess.run(tf.local_variables_initializer())
。否則會報錯:
OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 2, current size 0)
上面第六步batch前取到的是單個樣本數據,在實際訓練中通常用批量數據來更新參數,設置批量讀取數據的時候有按順序讀取數據的tf.train.batch()
和打亂數據出列順序的tf.train.shuffle_batch()
。假設文本中的數據如圖所示:
X11: [[5. 6. 8. 6. 1.] [6. 4. 8. 1. 8.]] Y11: [1. 1.] #用tf.train.batch()
X21: [[0. 4. 3. 7. 8.] [5. 0. 2. 8. 7.]] Y21: [0. 1.] # 用tf.train.shuffle_batch()
這里需要對tf.train.shuffle_batch()和tf.train.batch()的參數進行說明
tf.train.shuffle_batch(
tensors,
batch_size, # 設置batch_size的大小
capacity, # 設置隊列中最大的數據量,容量。一般要求capacity > min_after_dequeue + num_threads*batch_size
min_after_dequeue, # 隊列中最小的數據量作為隨機取樣的緩沖區。越大,數據混合越充分,認為采樣到的數據更具有隨機性。
# 但是這個值設置太大在初始啟動時,需要給隊列喂足夠多的數據,啟動慢,而且占用內存。
num_threads=1, # 設置線程數
seed=None,
enqueue_many=False, # Whether each tensor in tensor_list is a single example. 在下面單獨說明
shapes=None,
allow_smaller_final_batch=False, # (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue.
shared_name=None,
name=None)
tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None) # 注意:這里沒有min_after_dequeue這個參數
讀取數據的目的是為了訓練網絡,而使用Batch訓練網絡的原因可以解釋為:
深度學習的優化說白了就是梯度下降。每次的參數更新有兩種方式。
- 第一種,遍歷全部數據集算一次損失函數,然后算函數對各個參數的梯度,更新梯度。這種方法每更新一次參數都要把數據集里的所有樣本都看一遍,計算量開銷大,計算速度慢,不支持在線學習,這稱為Batch gradient descent,批梯度下降。
- 另一種,每看一個數據就算一下損失函數,然后求梯度更新參數,這個稱為隨機梯度下降,stochastic gradient descent。這個方法速度比較快,但是收斂性能不太好,可能在最優點附近晃來晃去,hit不到最優點。兩次參數的更新也有可能互相抵消掉,造成目標函數震蕩的比較劇烈。
為了克服兩種方法的缺點,現在一般采用的是一種折中手段,mini-batch gradient decent,小批的梯度下降,這種方法把數據分為若干個批,按批來更新參數,這樣,一個批中的一組數據共同決定了本次梯度的方向,下降起來就不容易跑偏,減少了隨機性。另一方面因為批的樣本數與整個數據集相比小了很多,計算量也不是很大。
個人理解:大Batch_size一是會受限于計算機硬件,另一方面將會降低梯度下降的隨機性。 而小Batch_size收斂速度慢
這里用代碼對enqueue_many這個參數進行理解
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
tensor_list = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
with tf.Session() as sess:
x1 = tf.train.batch(tensor_list, batch_size=3, enqueue_many=False)
x2 = tf.train.batch(tensor_list, batch_size=3, enqueue_many=True)
x3 = tf.train.shuffle_batch(tensor_list, batch_size=3, capacity = 1000, min_after_dequeue=100, num_threads=1, enqueue_many=False)
x4 = tf.train.shuffle_batch(tensor_list, batch_size=3, capacity = 1000, min_after_dequeue=100, num_threads=1, enqueue_many=True)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
print("x1 batch:" + "-" * 10)
print(sess.run(x1))
print("x2 batch:" + "-" * 10)
print(sess.run(x2))
print("x2 batch:" + "-" * 10)
print(sess.run(x2))
print("x3 batch:" + "-" * 10)
print(sess.run(x3))
print("x4 batch:" + "-" * 10)
print(sess.run(x4))
coord.request_stop()
coord.join(threads)
輸出如下:由以上輸出可以看出,當enqueue_many=False(默認值)時,輸出為batch_size*tensor.shape,把輸入tensors看作一個樣本,Batch就是對第一個維度的數據進行重復采樣,將tensor擴展一個維度。
當enqueue_many=True時,tensor是一個樣本,batch_size只是調整樣本中的維度。這里tensor的維度保持不變,只是在最后一個維度上根據batch_size調整了大小。而最后一個維度內的順序是亂序的。
對于shuffle_batch,注意到,第1維(矩陣每一行)上的數據是打亂的,所以從[1, 2, 3, 4]中取到了[2, 4, 4]。
如果輸入的樣本是一個3x6的矩陣。設置batch_size=5,enqueue_many = False時,tensor會被擴展為3x6x5的張量, 并且。當enqueue_many = True時,tensor是3x5,第二個維度上截取size。
這里比較疑惑的是shuffle在這里感覺沒有任何作用???
(2). tf.parse_example()讀取數據
'''
用tf.parse_example()批量讀取數據,據說比tf.parse_single_exaple()讀取數據的速度快(沒有驗證)
args:
filename_queue: 文件名隊列
shuffle_batch: 是否批量讀取數據
if_enq_many: batch時enqueue_many參數的設定,這里主要用于評估該參數的作用
'''
# 第一步: 建立文件名隊列
filename_queue = tf.train.string_input_producer([tfrecord_path])
def read_parse(filename_queue, shuffle_batch, if_enq_many):
# 第二步: 建立閱讀器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 第三步: 設置shuffle_batch
if shuffle_batch:
batch = tf.train.shuffle_batch([serialized_example],
batch_size=3,
capacity=10000,
min_after_dequeue=1000,
num_threads=1,
enqueue_many=if_enq_many)# 主要是為了評估enqueue_many的作用
else:
batch = tf.train.batch([serialized_example],
batch_size=3,
capacity=10000,
num_threads=1,
enqueue_many=if_enq_many)
# 第四步:根據寫入時的格式建立相對應的讀取features
features = {
'sample': tf.FixedLenFeature([5], tf.int64), # 如果不是標量,一定要在這里說明數組的長度
'label': tf.FixedLenFeature([], tf.int64)
}
# 第五步: 用tf.parse_example()解析多個EXAMPLE PROTO
Features = tf.parse_example(batch, features)
# 第六步:對數據進行后處理
samples_parse= tf.cast(Features['sample'], tf.float32)
labels_parse = tf.cast(Features['label'], tf.float32)
return samples_parse, labels_parse
x2_samples, y2_labels = read_parse(filename_queue=filename_queue, shuffle_batch=True, if_enq_many=False)
print(x2_samples, y2_labels)
# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init) # 初始化
coord = tf.train.Coordinator() # 管理線程
threads = tf.train.start_queue_runners(coord=coord) # 文件名開始進入文件名隊列和內存
for i in range(1):
X2, Y2 = sess.run([x2_samples, y2_labels])
print('X2: ', X2, 'Y2: ', Y2)
coord.request_stop()
coord.join(threads)
調試的時候這里碰到一個bug,提示:return處local variable 'samples_parse' referenced before assignment。網上給的解決辦法基本是python在自上而下執行的時候無法區分變量是全局變量還是局部變量。實際上是我在寫第四步/第五步的時候多了縮進,導致沒有定義features。(??:python對縮進敏感)
?? 閱讀器 + 樣本
根據以上例子,假設txt中的數據只有2個樣本,如下圖所示:在建立文件名隊列時,加入這兩個txt文檔的文件名
# 第一步: 建立文件名隊列
filename_queue = tf.train.string_input_producer([tfrecord_path, tfrecord_path1])
(1). 單個閱讀器 + 單個樣本
batch_size=1 (注意:這里先將num_threads設置為1)
sample_single, label_single = tf.train.batch([sample, label],
batch_size=1,
capacity=10000,
num_threads=1,
enqueue_many=if_enq_many)
for i in range(5):
X14, Y14 = sess.run([x14_samples, y14_labels])
print('X14: ', X14, 'Y14: ', Y14)
打印輸出結果為:
('X14: ', array([[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.]], dtype=float32), 'Y14: ', array([0.], dtype=float32))
('X14: ', array([[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([1.], dtype=float32))
('X14: ', array([[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([0.], dtype=float32))
(2). 單個閱讀器 + 多個樣本
batch_size = 3
輸出結果為:
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([0., 0., 1.], dtype=float32))
('X14: ', array([[6., 4., 8., 1., 8.],[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([1., 1., 1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([0., 0., 0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([0., 1., 1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([0., 0., 1.], dtype=float32))
(3). 多個閱讀器 + 多個樣本
多閱讀器需要用tf.train.batch_join()或者tf.train.shuffle_batch_join(),對程序作稍微的修改
example_list = [[sample, label] for _ in range(2)] # Reader設置為2
sample_single, label_single = tf.train.batch_join(example_list, batch_size=3)
輸出結果為:
('X14: ', array([[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.],[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([1., 1., 0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.],[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.]], dtype=float32), 'Y14: ', array([0., 0., 0.], dtype=float32))
('X14: ', array([[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.],[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([1., 1., 0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([0., 1., 1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([0., 0., 1.], dtype=float32))
從輸出結果來看,單個閱讀器+多個樣本和多個閱讀器+多個樣本在結果呈現時并沒有什么區別,至于對運行速度的影響還有待驗證。
附上對閱讀器進行測試的完整代碼:
# -*- coding: UTF-8 -*-
# !/usr/bin/python3
# Env: python3.6
import tensorflow as tf
import numpy as np
import os
data_filename1 = 'data/data_train1.txt' # 生成txt數據保存路徑
data_filename2 = 'data/data_train2.txt' # 生成txt數據保存路徑
tfrecord_path1 = 'data/test_data1.tfrecord' # tfrecord1文件保存路徑
tfrecord_path2 = 'data/test_data2.tfrecord' # tfrecord2文件保存路徑
############################## 讀取txt文件,并轉為tfrecord文件 ###########################
# every line of data is just as follow: 1 2 3 4 5/1. train data: 1 2 3 4 5, label: 1
def txt_to_tfrecord(txt_filename, tfrecord_path):
# 第一步:生成TFRecord Writer
writer = tf.python_io.TFRecordWriter(tfrecord_path)
# 第二步:讀取TXT數據,并分割出樣本數據和標簽
file = open(txt_filename)
for data_line in file.readlines(): # 每一行
data_line = data_line.strip('\n') # 去掉換行符
sample = []
spls = data_line.split('/', 1)[0] # 樣本
for m in spls.split(' '):
sample.append(int(m))
label = data_line.split('/', 1)[1] # 標簽
label = int(label)
# 第三步: 建立feature字典,tf.train.Feature()對單一數據編碼成feature
feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
# 第四步:可以理解為將內層多個feature的字典數據再編碼,集成為features
features = tf.train.Features(feature=feature)
# 第五步:將features數據封裝成特定的協議格式
example = tf.train.Example(features=features)
# 第六步:將example數據序列化為字符串
Serialized = example.SerializeToString()
# 第七步:將序列化的字符串數據寫入協議緩沖區
writer.write(Serialized)
# 記得關閉writer和open file的操作
writer.close()
file.close()
return
txt_to_tfrecord(txt_filename=data_filename1, tfrecord_path=tfrecord_path1)
txt_to_tfrecord(txt_filename=data_filename2, tfrecord_path=tfrecord_path2)
# 第一步: 建立文件名隊列
filename_queue = tf.train.string_input_producer([tfrecord_path1, tfrecord_path2])
def read_single(filename_queue, shuffle_batch, if_enq_many):
# 第二步: 建立閱讀器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 第三步:根據寫入時的格式建立相對應的讀取features
features = {
'sample': tf.FixedLenFeature([5], tf.int64), # 如果不是標量,一定要在這里說明數組的長度
'label': tf.FixedLenFeature([], tf.int64)
}
# 第四步: 用tf.parse_single_example()解析單個EXAMPLE PROTO
Features = tf.parse_single_example(serialized_example, features)
# 第五步:對數據進行后處理
sample = tf.cast(Features['sample'], tf.float32)
label = tf.cast(Features['label'], tf.float32)
# 第六步:生成Batch數據 generate batch
if shuffle_batch: # 打亂數據順序,隨機取樣
sample_single, label_single = tf.train.shuffle_batch([sample, label],
batch_size=1,
capacity=10000,
min_after_dequeue=1000,
num_threads=1,
enqueue_many=if_enq_many) # 主要是為了評估enqueue_many的作用
else: # # 如果不打亂順序則用tf.train.batch(), 輸出隊列按順序組成Batch輸出
###################### multi reader, multi samples, please code as below ###############################
'''
example_list = [[sample,label] for _ in range(2)] # Reader設置為2
sample_single, label_single = tf.train.batch_join(example_list, batch_size=3)
'''
####################### single reader, single sample, please set batch_size = 1 #########################
####################### single reader, multi samples, please set batch_size = batch_size ###############
sample_single, label_single = tf.train.batch([sample, label],
batch_size=1,
capacity=10000,
num_threads=1,
enqueue_many=if_enq_many)
return sample_single, label_single
x1_samples, y1_labels = read_single(filename_queue, shuffle_batch=False, if_enq_many=False)
# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init) # 初始化
# 如果tf.train.string_input_producer([tfrecord_path], num_epochs=30)中num_epochs不為空的化,必須要初始化local變量
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator() # 管理線程
threads = tf.train.start_queue_runners(coord=coord) # 文件名開始進入文件名隊列和內存
for i in range(5):
# Queue + tf.parse_single_example()讀取tfrecord文件
X1, Y1 = sess.run([x1_samples, y1_labels])
print('X1: ', X1, 'Y1: ', Y1)
# Queue + tf.parse_example()讀取tfrecord文件
coord.request_stop()
coord.join(threads)
2. Dataset + TFrecrods讀取數據
這是目前官網上比較推薦的一種方式,相對于隊列讀取文件的方法,更為簡單。
Dataset API:將數據直接放在graph中進行處理,整體對數據集進行上述數據操作,使代碼更加簡潔
Dataset直接導入比較簡單,這里只是簡單介紹:
dataset = tf.data.Dataset.from_tensor_slices([1,2,3]) # 輸入必須是list
我們重點看dataset讀取tfrecord文件的過程 (關于pipeline的相關信息可以參見博客)
def _parse_function(example_proto): # 解析函數
# 創建解析字典
dics = {
'sample': tf.FixedLenFeature([5], tf.int64), # 如果不是標量,一定要在這里說明數組的長度
'label': tf.FixedLenFeature([], tf.int64)}
# 把序列化樣本和解析字典送入函數里得到解析的樣本
parsed_example = tf.parse_single_example(example_proto, dics)
# 對樣本數據類型的變換
# 這里得到的樣本數據都是向量,如果寫數據的時候對數據進行過reshape操作,可以在這里根據保存的reshape信息,對數據進行還原。
parsed_example['sample'] = tf.cast(parsed_example['sample'], tf.float32)
parsed_example['label'] = tf.cast(parsed_example['label'], tf.float32)
# 返回所有feature
return parsed_example
'''
read_dataset:
arg: tfrecord_path是需要讀取的tfrecord文件路徑,如tfrecord_path = ['test.tfrecord', 'test2.tfrecord'],同上面Queue方式相同,可以同時讀取多個文件
'''
def read_dataset(tfrecord_path = tfrecord_path):
# 第一步:聲明 tf.data.TFRecordDataset
# The tf.data.TFRecordDataset class enables you to stream over the contents of one or more TFRecord files as part of an input pipeline
dataset = tf.data.TFRecordDataset(tfrecord_path)
# 第二步:解析樣本數據。 tfrecord文件記錄的是序列化的樣本,因此需要對樣本進行解析。
# 個人理解:這個解析的過程,是通過上面_parse_function函數建立feature的字典。
# 而dataset.map()是對dataset的統一操作,map操作可以理解為在每一個元素上應用一個函數,所以其輸入是一個函數。
new_dataset = dataset.map(_parse_function)
# 創建獲取數據集中樣本的迭代器
iterator = new_dataset.make_one_shot_iterator()
# 獲得下一個樣本
next_element = iterator.get_next()
return next_element
next_element = read_dataset()
# 建立session,打印輸出,查看數據是否正確
# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init) # 初始化
coord = tf.train.Coordinator() # 管理線程
threads = tf.train.start_queue_runners(coord=coord) # 文件名開始進入文件名隊列和內存
for i in range(5):
print('dataset:', sess.run([next_element['sample'],
next_element['label']]))
coord.request_stop()
coord.join(threads)
輸出結果如下:
('dataset:', [array([5., 6., 8., 6., 1.], dtype=float32), 1.0])
('dataset:', [array([6., 4., 8., 1., 8.], dtype=float32), 1.0])
('dataset:', [array([5., 1., 0., 8., 8.], dtype=float32), 0.0])
('dataset:', [array([8., 2., 6., 8., 1.], dtype=float32), 0.0])
('dataset:', [array([8., 3., 5., 3., 6.], dtype=float32), 0.0])
PS: 這里需要特別特別注意的是當sample 或者 label不是標量,而且長度事先無法獲得的時候怎么創建解析函數。
此時 tf.FixedLenFeature(shape=(), dtype=tf.float32)的 shape 無法指定。
舉例來說: sample.shape=[2,3], 在寫入tfrecord的時候要對矩陣reshape,同時保存值和shape. 如果已經知道sample的長度,在解析函數中可以用上面的tf.FixedLenFeature([6,1], dtype=tf.float32)
來解析。一定一定不能用tf.FixedLenFeature([6], dtype=tf.float32)
。這樣無法還原sample的值,而且會報出各種奇葩錯誤。如果不知道sample的shape,可以用tf.VarLenFeature(dtype=tf.float32)
。由于變長得到的是稀疏矩陣,解析后需要進行轉為密集矩陣的處理。
parsed_example['sample'] = tf.sparse_tensor_to_dense(parsed_example['sample'])
上面的代碼輸出是每次取一個樣本,按順序一個樣本一個樣本出列。如果需要打亂順序,用.shuffle(buffer_size= ) 來打亂順序。其中buffer_size設置成大于數據集匯總樣本數量的值,以保證樣本順序充分打亂。
打亂樣本出列順序
def read_dataset(tfrecord_path = tfrecord_path):
# 聲明讀tfrecord文件
dataset = tf.data.TFRecordDataset(tfrecord_path)
# 建立解析函數
new_dataset = dataset.map(_parse_function)
# 打亂樣本順序
shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
# 數據提前進入隊列
prefetch_dataset = batch_dataset.prefetch(2000) # 會快很多
# 建立迭代器
iterator = prefetch_dataset.make_one_shot_iterator()
# 獲得下一個樣本
next_element = iterator.get_next()
return next_element
輸出的結果是:
('dataset:', [array([5., 1., 1., 7., 5.], dtype=float32), 0.0])
('dataset:', [array([8., 0., 8., 2., 7.], dtype=float32), 1.0])
('dataset:', [array([6., 5., 9., 1., 2.], dtype=float32), 1.0])
('dataset:', [array([9., 9., 4., 0., 5.], dtype=float32), 0.0])
('dataset:', [array([1., 9., 9., 2., 9.], dtype=float32), 0.0])
再運行一次,取到的數據也完全不一樣。已打亂順序,單樣本輸出。
批量輸出樣本:.batch( batch_size )
def read_dataset(tfrecord_path = tfrecord_path):
# 聲明閱讀器
dataset = tf.data.TFRecordDataset(tfrecord_path)
# 建立解析函數
new_dataset = dataset.map(_parse_function)
# 打亂樣本順序
shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
# batch輸出
batch_dataset = shuffle_dataset.batch(2)
# 數據提前進入隊列
prefetch_dataset = batch_dataset.prefetch(2000)
# 建立迭代器
iterator = prefetch_dataset.make_one_shot_iterator()
# 獲得下一個樣本
next_element = iterator.get_next()
return next_element
輸出結果如下:
('dataset:', [array([[1., 4., 6., 2., 5.], [3., 7., 6., 6., 9.]], dtype=float32), array([0., 0.], dtype=float32)])
('dataset:', [array([[8., 2., 2., 6., 3.], [7., 5., 3., 0., 3.]], dtype=float32), array([0., 1.], dtype=float32)])
('dataset:', [array([[2., 8., 9., 5., 7.], [0., 5., 1., 5., 5.]], dtype=float32), array([1., 0.], dtype=float32)])
('dataset:', [array([[0., 8., 1., 6., 0.], [7., 3., 8., 8., 1.]], dtype=float32), array([0., 0.], dtype=float32)])
('dataset:', [array([[2., 4., 9., 8., 9.], [3., 5., 9., 6., 0.]], dtype=float32), array([1., 0.], dtype=float32)])
Epoch: 使用.repeat(num_epochs) 來指定遍歷幾遍數據集
關于Epoch次數,在Queue讀取文件的方式中,是在創建文件名隊列時設定的
filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=3)
根據博客中的實驗可知,先取出(樣本總數??num_Epoch)的數據,打亂順序,按照batch_size,無放回的取樣,保證每個樣本都被訪問num_Epoch次。
三種讀取方式的完整代碼
# -*- coding: UTF-8 -*-
# !/usr/bin/python3
# Env: python3.6
import tensorflow as tf
import numpy as np
import os
# path
data_filename = 'data/data_train.txt' # 生成txt數據保存路徑
size = (10000, 5)
tfrecord_path = 'data/test_data.tfrecord' # tfrecord文件保存路徑
#################### 生成txt數據 10000個樣本。########################
def generate_data(data_filename=data_filename, size=size):
if not os.path.exists(data_filename):
np.random.seed(9)
x_data = np.random.randint(0, 10, size=size)
y1_data = np.ones((size[0] // 2, 1), int) # 一半標簽是0,一半是1
y2_data = np.zeros((size[0] // 2, 1), int)
y_data = np.append(y1_data, y2_data)
np.random.shuffle(y_data)
xy_data = str('')
for xy_row in range(len(x_data)):
x_str = str('')
for xy_col in range(len(x_data[0])):
if not xy_col == (len(x_data[0]) - 1):
x_str = x_str + str(x_data[xy_row, xy_col]) + ' '
else:
x_str = x_str + str(x_data[xy_row, xy_col])
y_str = str(y_data[xy_row])
xy_data = xy_data + (x_str + '/' + y_str + '\n')
# write to txt
write_txt = open(data_filename, 'w')
write_txt.write(xy_data)
write_txt.close()
return
################ 讀取txt文件,并轉為tfrecord文件 ###########################
# every line of data is just as follow: 1 2 3 4 5/1. train data: 1 2 3 4 5, label: 1
def txt_to_tfrecord(txt_filename=data_filename, tfrecord_path=tfrecord_path):
# 第一步:生成TFRecord Writer
writer = tf.python_io.TFRecordWriter(tfrecord_path)
# 第二步:讀取TXT數據,并分割出樣本數據和標簽
file = open(txt_filename)
for data_line in file.readlines(): # 每一行
data_line = data_line.strip('\n') # 去掉換行符
sample = []
spls = data_line.split('/', 1)[0] # 樣本
for m in spls.split(' '):
sample.append(int(m))
label = data_line.split('/', 1)[1] # 標簽
label = int(label)
print('sample:', sample, 'labels:', label)
# 第三步: 建立feature字典,tf.train.Feature()對單一數據編碼成feature
feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
# 第四步:可以理解為將內層多個feature的字典數據再編碼,集成為features
features = tf.train.Features(feature=feature)
# 第五步:將features數據封裝成特定的協議格式
example = tf.train.Example(features=features)
# 第六步:將example數據序列化為字符串
Serialized = example.SerializeToString()
# 第七步:將序列化的字符串數據寫入協議緩沖區
writer.write(Serialized)
# 記得關閉writer和open file的操作
writer.close()
file.close()
return
############### 用Queue方式中的tf.parse_single_example解析tfrecord #########################
# 第一步: 建立文件名隊列
filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=30)
def read_single(filename_queue, shuffle_batch, if_enq_many):
# 第二步: 建立閱讀器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 第三步:根據寫入時的格式建立相對應的讀取features
features = {
'sample': tf.FixedLenFeature([5], tf.int64), # 如果不是標量,一定要在這里說明數組的長度
'label': tf.FixedLenFeature([], tf.int64)
}
# 第四步: 用tf.parse_single_example()解析單個EXAMPLE PROTO
Features = tf.parse_single_example(serialized_example, features)
# 第五步:對數據進行后處理
sample = tf.cast(Features['sample'], tf.float32)
label = tf.cast(Features['label'], tf.float32)
# 第六步:生成Batch數據 generate batch
if shuffle_batch: # 打亂數據順序,隨機取樣
sample_single, label_single = tf.train.shuffle_batch([sample, label],
batch_size=2,
capacity=10000,
min_after_dequeue=1000,
num_threads=1,
enqueue_many=if_enq_many) # 主要是為了評估enqueue_many的作用
else: # # 如果不打亂順序則用tf.train.batch(), 輸出隊列按順序組成Batch輸出
'''
example_list = [[sample,label] for _ in range(2)] # Reader設置為2
sample_single, label_single = tf.train.batch_join(example_list, batch_size=1)
'''
sample_single, label_single = tf.train.batch([sample, label],
batch_size=1,
capacity=10000,
num_threads=1,
enqueue_many=if_enq_many)
return sample_single, label_single
############# 用Queue方式中的tf.parse_example解析tfrecord ##################################
def read_parse(filename_queue, shuffle_batch, if_enq_many):
# 第二步: 建立閱讀器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 第三步: 設置shuffle_batch
if shuffle_batch:
batch = tf.train.shuffle_batch([serialized_example],
batch_size=3,
capacity=10000,
min_after_dequeue=1000,
num_threads=1,
enqueue_many=if_enq_many) # 主要是為了評估enqueue_many的作用
else:
batch = tf.train.batch([serialized_example],
batch_size=3,
capacity=10000,
num_threads=1,
enqueue_many=if_enq_many)
# 第四步:根據寫入時的格式建立相對應的讀取features
features = {
'sample': tf.FixedLenFeature([5], tf.int64), # 如果不是標量,一定要在這里說明數組的長度
'label': tf.FixedLenFeature([], tf.int64)
}
# 第五步: 用tf.parse_example()解析多個EXAMPLE PROTO
Features = tf.parse_example(batch, features)
# 第六步:對數據進行后處理
samples_parse = tf.cast(Features['sample'], tf.float32)
labels_parse = tf.cast(Features['label'], tf.float32)
return samples_parse, labels_parse
############### 用Dataset讀取tfrecord文件 ###############################################
# 定義解析函數
def _parse_function(example_proto):
dics = { # 這里沒用default_value,隨后的都是None
'sample': tf.FixedLenFeature([5], tf.int64), # 如果不是標量,一定要在這里說明數組的長度
'label': tf.FixedLenFeature([], tf.int64)}
# 把序列化樣本和解析字典送入函數里得到解析的樣本
parsed_example = tf.parse_single_example(example_proto, dics)
parsed_example['sample'] = tf.cast(parsed_example['sample'], tf.float32)
parsed_example['label'] = tf.cast(parsed_example['label'], tf.float32)
# 返回所有feature
return parsed_example
def read_dataset(tfrecord_path=tfrecord_path):
# 聲明閱讀器
dataset = tf.data.TFRecordDataset(tfrecord_path)
# 建立解析函數,其中num_parallel_calls指定并行線程數
new_dataset = dataset.map(_parse_function, num_parallel_calls=4)
# 打亂樣本順序
shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
# 設置epoch次數為10,這里需要注意的是目前看來只支持先shuffle再repeat的方式
repeat_dataset = shuffle_dataset.repeat(10)
# batch輸出
batch_dataset = repeat_dataset.batch(2)
# 數據提前進入隊列
prefetch_dataset = batch_dataset.prefetch(2000)
# 建立迭代器
iterator = prefetch_dataset.make_one_shot_iterator()
# 獲得下一個樣本
next_element = iterator.get_next()
return next_element
################## 建立graph ####################################
# 生成數據
# generate_data()
# 讀取數據轉為tfrecord文件
# txt_to_tfrecord()
# Queue + tf.parse_single_example()讀取tfrecord文件
x1_samples, y1_labels = read_single(filename_queue, shuffle_batch=True, if_enq_many=False)
# Queue + tf.parse_example()讀取tfrecord文件
x2_samples, y2_labels = read_parse(filename_queue, shuffle_batch=True, if_enq_many=False)
# Dataset讀取數據
next_element = read_dataset()
# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init) # 初始化
# 如果tf.train.string_input_producer([tfrecord_path], num_epochs=30)中num_epochs不為空的化,必須要初始化local變量
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator() # 管理線程
threads = tf.train.start_queue_runners(coord=coord) # 文件名開始進入文件名隊列和內存
for i in range(1):
# Queue + tf.parse_single_example()讀取tfrecord文件
X1, Y1 = sess.run([x1_samples, y1_labels])
print('X1: ', X1, 'Y1: ', Y1)
# Queue + tf.parse_example()讀取tfrecord文件
X2, Y2 = sess.run([x2_samples, y2_labels])
print('X2: ', X2, 'Y2: ', Y2)
# Dataset讀取數據
print('dataset:', sess.run([next_element['sample'],
next_element['label']]))
#這里需要注意,每run一次,迭代器會取下一個樣本。
# 如果是 a= sess.run(next_element['sample'])
# b = sess.run(next_element['label']),
# 則a樣本對應的標簽值不是b,b是下一個樣本對應的標簽值。
coord.request_stop()
coord.join(threads)
另外,關于dataset加速的用法,可以參見官網說明
Dataset+TFRecord讀取變長數據
使用dataset中的padded_batch方法來進行
padded_batch(
batch_size,
padded_shapes,
padding_values=None #默認使用各類型數據的默認值,一般使用時可忽略該項
)
參數padded_shapes
指明每條記錄中各成員要pad成的形狀,成員若是scalar,則用[ ],若是list,則用[mx_length],若是array,則用[d1,...,dn],假如各成員的順序是scalar數據、list數據、array數據,則padded_shapes=([], [mx_length], [d1,...,dn]);
例如tfrecord文件中的key是fea
, e.g.fea.shape=[568, 366], 二維,長度變化。fea_shape
=[568,366],一維, label
=[1, 0, 2,0,3,0]一維,長度變化。
再讀取變長數據的時候映射函數應為:
def _parse_function(example_proto):
dics = {
'fea': tf.VarLenFeature(dtype=tf.float32),
'fea_shape': tf.FixedLenFeature(shape=(2,), dtype=tf.int64),
'label': tf.VarLenFeature(dtype=tf.float32)}
parsed_example = tf.parse_single_example(example_proto, dics)
parsed_example['fea'] = tf.sparse_tensor_to_dense(parsed_example['fea'])
parsed_example['label'] = tf.sparse_tensor_to_dense(parsed_example['label'])
parsed_example['label'] = tf.cast(parsed_example['label'], tf.int32)
parsed_example['fea'] = tf.reshape(parsed_example['fea'], parsed_example['fea_shape'])
return parsed_example
利用tf.VarLenFeature()
代替tf.FixedLenFeature(),在后處理中要注意用tf.sparse_tensor_to_dense()
將讀取的變長數據轉為稠密矩陣。
def dataset():
tf_lst = get_tf_list(tf_file_lst)
dataset = tf.data.TFRecordDataset(tf_lst)
new_dataset = dataset.map(_parse_function)
shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
repeat_dataset = shuffle_dataset.repeat(10)
prefetch_dataset = repeat_dataset.prefetch(2000)
batch_dataset = prefetch_dataset.padded_batch(2, padded_shapes={'fea': [None, None], 'fea_shape': [None], 'label': [None]})
iterator = batch_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
return next_element
這里padded_shapes={'fea': [None, None], 'fea_shape': [None], 'label': [None]}
如果報錯 All elements in a batch must have the same rank as the padded shape for component1: expected rank 2 but got element with rank 1
請仔細查看padded_shapes中設置的維度是否正確。如果padded_shapes={'fea': [None, None], 'fea_shape': [None, None]
, 'label': [None]}即fea_shape本來的rank應該是1,但是在pad的時候設置了2,所以報錯。
如果報錯The two structures don't have the same sequence type. Input structure has type <class 'tuple'>, while shallow structure has type <class 'dict'>.
,則可能是padded_shapes定義的格式不對,如定義成了padded_shapes=([None, None],[None],[None])
,請按照字典格式定義pad的方式。