文件隊列
參考了這篇博客的內(nèi)容
為了實現(xiàn)數(shù)據(jù)讀入和數(shù)據(jù)處理的管線化,tensorflow使用文件隊列來獨立處理數(shù)據(jù)讀入。在設(shè)置好文件隊列,以及相應(yīng)的處理函數(shù)以后,tensorflow會自動調(diào)度,不斷的把文件隊列中的數(shù)據(jù)取出來,并包裝成batch,使我們的訓(xùn)練過程始終有數(shù)據(jù)可用,不存在停下來等數(shù)據(jù)的情況。這很像一個虹吸管,在你架設(shè)好管道以后,數(shù)據(jù)就會源源不斷的流向桶里。實現(xiàn)這個過程主要使用了Tensorflow里面的如下幾個方法:
tf.train.string_input_producer #創(chuàng)建文件隊列
tf.train.shuffle_batch #創(chuàng)建batch
首先,我們使用tf.train.string_input_producer
方法來將文件列表轉(zhuǎn)化為一個隊列:
import tensorflow as tf
import glob
def listfiles(rootpath,ext): #獲得文件列表
return glob.glob(rootpath+'/*.{}'.format(ext))
filelist = listfiles('rootfolder','hdr') #獲得根目錄下所有后綴文件
filelist_queue = tf.train.string_input_producer(filelist,shuffle=True) #利用tf的方法,將filelist(list類型)轉(zhuǎn)化為文件隊列。
其中,文件隊列filelist_queue
中,每一個entry對應(yīng)了一個文件的位置,那么下一步應(yīng)該是將文件隊列中的文件取出來。在很多的博客中都使用了下面的方法來從隊列中讀取一個文件位置對應(yīng)的圖像。
image_reader = tf.WholeFileReader() ##定義一個reader
filename, image_file = image_reader.read(filelist_queue) ##將文件隊列作為參數(shù)輸入,
##每一次執(zhí)行該函數(shù)就讀取隊列中一個entry對應(yīng)文件位置的圖像,返回
##圖像名稱和圖像文件本身
image = tf.image.decode_jpeg(image_file) ##利用tf自帶的jpeg解碼器進(jìn)行解碼,得到圖像矩陣
上面的方法適合用在png,jpg,tiff這樣一些常用的圖像格式上,因為tensorflow提供了image_reader,還有相應(yīng)的decoder,但是如果數(shù)據(jù)不在這些格式里面的話,怎么辦呢?首先想到的是自己定義一個函數(shù),在這個函數(shù)里面讀取需要的文件
我嘗試寫一個函數(shù),將文件列表queue作為參數(shù)傳遞給該函數(shù),然后希望該函數(shù)能夠使用queue.dequeue()方法將文件位置返回出來,但是我發(fā)現(xiàn)使用queue.dequeue()的時候,在函數(shù)里面取出來的仍然是tensor類型,而不是一個文件路徑的字符串。只有使用 content = sess.run(content)之后才能正確的返回這個字符串。我覺得這個可能跟tensorflow的處理機(jī)制有關(guān)。
為了避免浪費更多的時間,這里使用一種折衷的辦法,首先將數(shù)據(jù)轉(zhuǎn)換為tfrecord格式,然后利用tf自帶的對tf的讀取方法來解決這個問題。tfrecord跟mxnet中的數(shù)據(jù)制作過程非常類似,就是利用已經(jīng)有的數(shù)據(jù),制作一個二進(jìn)制文件。我參考了這篇博客的內(nèi)容和這篇博客。
TFRecodrs結(jié)構(gòu)
TFRecords將用戶的數(shù)據(jù)以二進(jìn)制串的方式存儲,這也就意味著你要首先定義數(shù)據(jù)格式。Tensorflow為此提供了兩種格式:tf.train.Examples
和tf.train.SequenceExample
,所有需要存儲到TFRecord中的數(shù)據(jù)都要首先轉(zhuǎn)換為這兩種格式之一,然后使用 tf.python_io.TFRecordWriter
將這些數(shù)據(jù)寫入磁盤。在Tensorflow的官方網(wǎng)站可以看到tf.train.Examples
和tf.train.SequenceExample
其實都是protocol buffer
文件,而不是python類。
對于輸入神經(jīng)網(wǎng)絡(luò)的每個數(shù)據(jù)entry,都由多個feature組成。tf.train.BytesList
, tf.train.FloatList
和tf.train.Int64List
同樣為protocol,他們只有一個屬性attibute
: value。他們的目的是將數(shù)據(jù)轉(zhuǎn)化為相應(yīng)的列表。比如肖申克的救贖和搏擊俱樂部兩部電影是兩個數(shù)據(jù)entry,分別他們的得分為9.0和9.7。那么這些數(shù)據(jù)可以通過下面的方式轉(zhuǎn)化為list:
movie_name_list = tf.train.BytesList(value=[b'The Shawshank Redemption', b'Fight Club'])
movie_rating_list = tf.train.FloatList(value=[9.0, 9.7])
注意:Python strings need to be converted to bytes, (e.g. my_string.encode(‘utf-8’)) before they are stored in a tf.train.BytesList. 所以,即使是在用ByteList去處理一般的8bits圖像數(shù)據(jù)的時候,也需要進(jìn)行img.tobytes()
處理,將圖像處理轉(zhuǎn)化為bytes。
tf.Train.Feature
則將某種特殊類型的數(shù)據(jù)列表包裝成Tensorflow理解的格式,它也只有一個屬性attibute
,即為bytes_list
,float_list
和int64_list
之一。緊接上面的例子,我們可以將movie_name_list
和 movie_rating_list
包裝為Feature
movie_names = tf.train.Feature(bytes_list=movie_name_list)
movie_ratings = tf.train.Feature(float_list=movie_rating_list)
很自然的,F(xiàn)eature可以組成了Features
movie_dict = {
'Movie Names': movie_names,
'Movie Ratings': movie_ratings
}
movies = tf.train.Features(feature=movie_dict)
所有Features都組合完成以后,將其組成一個example,并通過tf.python_io.TFRecordWriter寫入文件中,
example = tf.train.Example(features=movies)
with tf.python_io.TFRecordWriter('movie_ratings.tfrecord') as writer:
writer.write(example.SerializeToString())
總結(jié)起來,就是一個類似包包裝的例子,tf.train.FloastFist-->tf.train.Feature-->tf.train.Features-->tf.train.Example。
TFRecord的讀取
參考了Site1997、yyeqiustu以及這篇文章,Youtube上有一個TensorFlow的官方視頻也做了非常好的講解。這個知乎專欄也進(jìn)行了詳細(xì)的講解。
使用Dataset API主要包含如下三個步驟:
- 載入數(shù)據(jù):為數(shù)據(jù)創(chuàng)建一個Dataset實例
- 創(chuàng)建迭代器:使用Dataset實例來構(gòu)建一個Iterator,這樣就能通過其遍歷數(shù)據(jù)
- 使用數(shù)據(jù):跟模型對接
由于我們的數(shù)據(jù)存儲在tfrecord中,可以使用如下的代碼來創(chuàng)建一個Dataset實例:
dataset = tf.contrib.data.TFRecordDataset(tfrecords)
那么如何從dataset中取出元素呢?方法就是從Dataset實例化一個Iterator,然后對Iterator進(jìn)行迭代
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next() ###one_element 這里是取到的一個tensor
with tf.Session() as sess:
for i in range(5):
print(sess.run(one_element))
但是,這里我們要注意一個問題,在我們制作tfrecord的時候,