Tensorflow的數(shù)據(jù)讀取有三種方式:
Preloaded data: 預(yù)加載數(shù)據(jù),也就是TensorFlow圖中的常量或變量保留所有數(shù)據(jù)(對(duì)于小數(shù)據(jù)集)。
Feeding: Python產(chǎn)生數(shù)據(jù),再把數(shù)據(jù)喂給后端。
Reading from file: 從文件中直接讀取,輸入流水線從TensorFlow圖開(kāi)頭的文件中讀取數(shù)據(jù)。
Preloaded data: 預(yù)加載數(shù)據(jù)
預(yù)加載數(shù)據(jù)方法僅限于用在可以完全加載到內(nèi)存中的小數(shù)據(jù)集上,主要有兩種方法:
把數(shù)據(jù)存在常量(constant)中。
把數(shù)據(jù)存在變量(variable)中,我們初始化并且永不改變它的值。
用常量更簡(jiǎn)單些,但會(huì)占用更多的內(nèi)存,因?yàn)槌A看鎯?chǔ)在graph數(shù)據(jù)結(jié)構(gòu)內(nèi)部。例如:
import tensorflow as tf
# 構(gòu)造Graph
x1 = tf.constant([2, 3, 4])
x2 = tf.constant([4, 0, 1])
y = tf.add(x1, x2)
# 打開(kāi)一個(gè)session --> 計(jì)算y
with tf.Session() as sess:
print sess.run(y)
這種方法在設(shè)計(jì)Graph的時(shí)候,x1和x2就被定義成了兩個(gè)有值的列表,在計(jì)算y的時(shí)候直接取x1和x2的值。
如果用變量的話(huà),我們需要在graph構(gòu)建好之后初始化該變量。例如:
training_data = ...
training_labels = ...
with tf.Session() as sess:
data_initializer = tf.placeholder(dtype=training_data.dtype,
shape=training_data.shape)
label_initializer = tf.placeholder(dtype=training_labels.dtype,
shape=training_labels.shape)
input_data = tf.Variable(data_initializer, trainable=False, collections=[])
input_labels = tf.Variable(label_initializer, trainable=False, collections=[])
...
sess.run(input_data.initializer,
feed_dict={data_initializer: training_data})
sess.run(input_labels.initializer,
feed_dict={label_initializer: training_labels})
Feeding: 供給數(shù)據(jù)
我們一般用tf.placeholder節(jié)點(diǎn)來(lái)feed數(shù)據(jù),該節(jié)點(diǎn)不需要初始化也不包含任何數(shù)據(jù),我們?cè)趫?zhí)行run()或者eval()指令時(shí)通過(guò)feed_dict參數(shù)把數(shù)據(jù)傳入graph中來(lái)計(jì)算。如果在運(yùn)行過(guò)程中沒(méi)有對(duì)tf.placeholder節(jié)點(diǎn)傳入數(shù)據(jù),程序會(huì)報(bào)錯(cuò)。例如:
import tensorflow as tf
# 設(shè)計(jì)Graph
x1 = tf.placeholder(tf.int16)
x2 = tf.placeholder(tf.int16)
y = tf.add(x1, x2)
# 用Python產(chǎn)生數(shù)據(jù)
li1 = [2, 3, 4]
li2 = [4, 0, 1]
# 打開(kāi)一個(gè)session --> 喂數(shù)據(jù) --> 計(jì)算y
with tf.Session() as sess:
print sess.run(y, feed_dict={x1: li1, x2: li2})
兩種方法的區(qū)別
Preload:
將數(shù)據(jù)直接內(nèi)嵌到Graph中,再把Graph傳入Session中運(yùn)行。當(dāng)數(shù)據(jù)量比較大時(shí),Graph的傳輸會(huì)遇到效率問(wèn)題。
Feeding:
用占位符替代數(shù)據(jù),待運(yùn)行的時(shí)候填充數(shù)據(jù)。
Reading From File 從文件中讀數(shù)據(jù)
前兩種方法很方便,但是遇到大型數(shù)據(jù)的時(shí)候就會(huì)很吃力,即使是Feeding,中間環(huán)節(jié)的增加也是不小的開(kāi)銷(xiāo),比如數(shù)據(jù)類(lèi)型轉(zhuǎn)換等等。最優(yōu)的方案就是在Graph定義好文件讀取的方法,讓TF自己去從文件中讀取數(shù)據(jù),并解碼成可使用的樣本集。從文件中讀取數(shù)據(jù)一般包含以下步驟:
- 文件名列表
- 文件名隨機(jī)排序(可選的)
- 迭代控制(可選的)
- 文件名隊(duì)列
- 針對(duì)輸入文件格式的閱讀器
- 記錄解析器
- 預(yù)處理器(可選的)
- 樣本隊(duì)列
在了解具體的操作之前首先了解文件讀取數(shù)據(jù)的優(yōu)點(diǎn):
在上圖中,首先由一個(gè)單線程把文件名堆入隊(duì)列,兩個(gè)Reader同時(shí)從隊(duì)列中取文件名并讀取數(shù)據(jù),Decoder將讀出的數(shù)據(jù)解碼后堆入樣本隊(duì)列,最后單個(gè)或批量取出樣本(圖中沒(méi)有展示樣本出列)。我們這里通過(guò)三段代碼逐步實(shí)現(xiàn)上圖的數(shù)據(jù)流,這里我們不使用隨機(jī),讓結(jié)果更清晰。
文件準(zhǔn)備
$ echo -e "Alpha1,A1\nAlpha2,A2\nAlpha3,A3" > A.csv
$ echo -e "Bee1,B1\nBee2,B2\nBee3,B3" > B.csv
$ echo -e "Sea1,C1\nSea2,C2\nSea3,C3" > C.csv
$ cat A.csv
Alpha1,A1
Alpha2,A2
Alpha3,A3
單個(gè)Reader,單個(gè)樣本
import tensorflow as tf
# 生成一個(gè)先入先出隊(duì)列和一個(gè)QueueRunner
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
# 定義Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定義Decoder
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
# 運(yùn)行Graph
with tf.Session() as sess:
coord = tf.train.Coordinator() #創(chuàng)建一個(gè)協(xié)調(diào)器,管理線程
threads = tf.train.start_queue_runners(coord=coord) #啟動(dòng)QueueRunner, 此時(shí)文件名隊(duì)列已經(jīng)進(jìn)隊(duì)。
for i in range(10):
print example.eval() #取樣本的時(shí)候,一個(gè)Reader先從文件名隊(duì)列中取出文件名,讀出數(shù)據(jù),Decoder解析后進(jìn)入樣本隊(duì)列。
coord.request_stop()
coord.join(threads)
# outpt
Alpha1
Alpha2
Alpha3
Bee1
Bee2
Bee3
Sea1
Sea2
Sea3
Alpha1
單個(gè)Reader,多個(gè)樣本
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
# 使用tf.train.batch()會(huì)多加了一個(gè)樣本隊(duì)列和一個(gè)QueueRunner。Decoder解碼后數(shù)據(jù)會(huì)進(jìn)入這個(gè)隊(duì)列,再批量出隊(duì)。
# 雖然這里只有一個(gè)Reader,但可以設(shè)置多線程,相應(yīng)增加線程數(shù)會(huì)提高讀取速度,但并不是線程越多越好。
example_batch, label_batch = tf.train.batch(
[example, label], batch_size=5)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
print example_batch.eval()
coord.request_stop()
coord.join(threads)
# output
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']
# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']
# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']
# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
多Reader,多個(gè)樣本
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['null'], ['null']]
example_list = [tf.decode_csv(value, record_defaults=record_defaults)
for _ in range(2)] # Reader設(shè)置為2
# 使用tf.train.batch_join(),可以使用多個(gè)reader,并行讀取數(shù)據(jù)。每個(gè)Reader使用一個(gè)線程。
example_batch, label_batch = tf.train.batch_join(
example_list, batch_size=5)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
print example_batch.eval()
coord.request_stop()
coord.join(threads)
# output
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']
# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']
# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']
# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
tf.train.batch與tf.train.shuffle_batch函數(shù)是單個(gè)Reader讀取,但是可以多線程。tf.train.batch_join與tf.train.shuffle_batch_join可設(shè)置多Reader讀取,每個(gè)Reader使用一個(gè)線程。至于兩種方法的效率,單Reader時(shí),2個(gè)線程就達(dá)到了速度的極限。多Reader時(shí),2個(gè)Reader就達(dá)到了極限。所以并不是線程越多越快,甚至更多的線程反而會(huì)使效率下降。
迭代控制
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=3) # num_epoch: 設(shè)置迭代數(shù)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['null'], ['null']]
example_list = [tf.decode_csv(value, record_defaults=record_defaults)
for _ in range(2)]
example_batch, label_batch = tf.train.batch_join(
example_list, batch_size=5)
init_local_op = tf.initialize_local_variables()
with tf.Session() as sess:
sess.run(init_local_op) # 初始化本地變量
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
print example_batch.eval()
except tf.errors.OutOfRangeError:
print('Epochs Complete!')
finally:
coord.request_stop()
coord.join(threads)
coord.request_stop()
coord.join(threads)
# output
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
在迭代控制中,記得添加tf.initialize_local_variables(),官網(wǎng)教程沒(méi)有說(shuō)明,但是如果不初始化,運(yùn)行就會(huì)報(bào)錯(cuò)。
下面開(kāi)始正式的步驟:
文件名列表
我們首先要有個(gè)文件名列表,為了產(chǎn)生文件名列表,我們可以手動(dòng)用Python輸入字符串,例如:
["file0", "file1"]
[("file%d" % i) for i in range(2)]
[("file%d" % i) for i in range(2)]
我們也可以用tf.train.match_filenames_once
函數(shù)來(lái)生成文件名列表。
有了文件名列表后,我們需要把它送入 tf.train.string_input_producer
函數(shù)中生成一個(gè)先入先出的文件名隊(duì)列,文件閱讀器需要從該隊(duì)列中讀取文件名。
string_input_producer(
string_tensor,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None
)
一個(gè)QueueRunner每次會(huì)把每批次的所有文件名送入隊(duì)列中,可以通過(guò)設(shè)置string_input_producer
函數(shù)的shuffle
參數(shù)來(lái)對(duì)文件名隨機(jī)排序,或者通過(guò)設(shè)置num_epochs
來(lái)決定對(duì)string_tensor
里的文件使用多少次,類(lèi)型為整型,如果想要迭代控制則需要設(shè)置了num_epochs
參數(shù),同時(shí)需要添加tf.local_variables_initializer()
進(jìn)行初始化,如果不初始化會(huì)報(bào)錯(cuò)。
這個(gè)QueueRunner
的工作線程獨(dú)立于文件閱讀器的線程, 因此隨機(jī)排序和將文件名送入到文件名隊(duì)列這些過(guò)程不會(huì)阻礙文件閱讀器的運(yùn)行。
文件格式
根據(jù)不同的文件格式, 應(yīng)該選擇對(duì)應(yīng)的文件閱讀器, 然后將文件名隊(duì)列提供給閱讀器的read方法。閱讀器每次從隊(duì)列中讀取一個(gè)文件,它的read方法會(huì)輸出一個(gè)key來(lái)表征讀入的文件和其中的紀(jì)錄(對(duì)于調(diào)試非常有用),同時(shí)得到一個(gè)字符串標(biāo)量, 這個(gè)字符串標(biāo)量可以被一個(gè)或多個(gè)解析器,或者轉(zhuǎn)換操作將其解碼為張量并且構(gòu)造成為樣本。
根據(jù)不同的文件類(lèi)型,有三種不同的文件閱讀器:
tf.TextLineReader
tf.FixedLengthRecordReader
tf.TFRecordReader
它們分別用于單行讀取(如CSV文件)、固定長(zhǎng)度讀取(如CIFAR-10的.bin二進(jìn)制文件)、TensorFlow標(biāo)準(zhǔn)格式讀取。
根據(jù)不同的文件閱讀器,有三種不同的解析器,它們分別對(duì)應(yīng)上面三種閱讀器:
tf.decode_csv
tf.decode_raw
-
tf.parse_single_example
和tf.parse_example
CSV文件
當(dāng)我們讀入CSV格式的文件時(shí),我們可以使用tf.TextLineReader
閱讀器和tf.decode_csv
解析器。例如:
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import tensorflow as tf
import numpy as np
filename_queue = tf.train.string_input_producer(["./data/data1.csv", "./data/data2.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# key返回的是讀取文件和行數(shù)信息 b'./data/iris.csv:146'
# value是按行讀取到的原始字符串,送到下面的decoder去解析
record_defaults = [[1.0], [1.0], [1.0], [1.0], ["Null"]] # 這里的數(shù)據(jù)類(lèi)型決定了讀取的數(shù)據(jù)類(lèi)型,而且必須是list形式
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults) # 解析出的每一個(gè)屬性都是rank為0的標(biāo)量,每次解碼一行,col對(duì)應(yīng)這一行的一列也就是一個(gè)數(shù)字
features = tf.stack([col1, col2, col3, col4])
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(100):
example, label = sess.run([features, col5])
print (example,col5)
coord.request_stop()
coord.join(threads)
每次read的執(zhí)行都會(huì)從文件中讀取一行內(nèi)容,decode_csv
操作會(huì)解析這一行內(nèi)容并將其轉(zhuǎn)為張量列表。在調(diào)用run
或者eval
去執(zhí)行read
之前, 必須先調(diào)用tf.train.start_queue_runners
來(lái)將文件名填充到隊(duì)列。否則read操作會(huì)被阻塞到文件名隊(duì)列中有值為止。
record_defaults = [[1], [1], [1], [1], [1]]
代表了解析的摸版,默認(rèn)用,
隔開(kāi),是用于指定矩陣格式以及數(shù)據(jù)類(lèi)型的,CSV文件中的矩陣是NXM
的,則此處為1XM
,例如上例中M=5
。[1]
表示解析為整型,如果矩陣中有小數(shù),則應(yīng)為float
型,[1]
應(yīng)該變?yōu)?code>[1.0],[‘null’]
解析為string
類(lèi)型。
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults = record_defaults)
, 矩陣中有幾列,這里就要寫(xiě)幾個(gè)參數(shù),比如5
列,就要寫(xiě)到col5
,不管你到底用多少。否則報(bào)錯(cuò)。
固定長(zhǎng)度記錄
我們也可以從二進(jìn)制文件‘(.bin)
中讀取固定長(zhǎng)度的數(shù)據(jù),使用的是tf.FixedLengthRecordReader
閱讀器和tf.decode_raw
解析器。decode_raw
節(jié)點(diǎn)會(huì)把string
轉(zhuǎn)化為uint8
類(lèi)型的張量。
例如CIFAR-10
數(shù)據(jù)集就采用的固定長(zhǎng)度的數(shù)據(jù),1字節(jié)的標(biāo)簽,后面跟著3072字節(jié)的圖像數(shù)據(jù)。使用uint8類(lèi)型張量的標(biāo)準(zhǔn)操作可以把每個(gè)圖像的片段截取下來(lái)并且按照需要重組。下面有一個(gè)例子:
reader = tf.FixedLengthRecordReader(record_bytes = record_bytes)
key, value = reader.read(filename_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
image_raw = tf.slice(record_bytes, [label_bytes], [image_bytes])
image_raw = tf.reshape(image_raw, [depth, height, width])
image = tf.transpose(image_raw, (1,2,0)) # 圖像形狀為[height, width, channels]
image = tf.cast(image, tf.float32)
這里介紹上述代碼中出現(xiàn)的函數(shù):tf.slice()
slice(
input_,
begin,
size,
name=None
)
從一個(gè)張量input
中提取出長(zhǎng)度為size
的一部分,提取的起點(diǎn)由begin
定義。size是一個(gè)向量,它代表著在每個(gè)維度提取出的tensor
的大小。begin
表示提取的位置,它表示的是input
的起點(diǎn)偏離值,也就是從每個(gè)維度第幾個(gè)值開(kāi)始提取。
begin
從0開(kāi)始,size
從1開(kāi)始,如果size[i]
的值為-1,則第i個(gè)維度從begin
處到余下的所有值都被提取出來(lái)。
例如:
# 'input' is [[[1, 1, 1], [2, 2, 2]],
# [[3, 3, 3], [4, 4, 4]],
# [[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
[4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
[[5, 5, 5]]]
標(biāo)準(zhǔn)TensorFlow格式
我們也可以把任意的數(shù)據(jù)轉(zhuǎn)換為T(mén)ensorFlow所支持的格式, 這種方法使TensorFlow的數(shù)據(jù)集更容易與網(wǎng)絡(luò)應(yīng)用架構(gòu)相匹配。這種方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example
的protocol buffer(里面包含了名為Features
的字段)。你可以寫(xiě)一段代碼獲取你的數(shù)據(jù), 將數(shù)據(jù)填入到Example
的protocol buffer,將protocol buffer序列化為一個(gè)字符串, 并且通過(guò)tf.python_io.TFRecordWriter
類(lèi)寫(xiě)入到TFRecords文件。
從TFRecords文件中讀取數(shù)據(jù), 可以使用tf.TFRecordReader
閱讀器以及tf.parse_single_example
解析器。parse_single_example
操作可以將Example
protocol buffer解析為張量。 具體可以參考如下例子,把MNIST數(shù)據(jù)集轉(zhuǎn)化為T(mén)FRecords格式:
SparseTensors這種稀疏輸入數(shù)據(jù)類(lèi)型使用隊(duì)列來(lái)處理不是太好。如果要使用SparseTensors你就必須在批處理之后使用tf.parse_example
去解析字符串記錄 (而不是在批處理之前使用tf.parse_single_example
) 。
預(yù)處理
我們可以對(duì)輸入的樣本數(shù)據(jù)進(jìn)行任意的預(yù)處理, 這些預(yù)處理不依賴(lài)于訓(xùn)練參數(shù), 比如數(shù)據(jù)歸一化, 提取隨機(jī)數(shù)據(jù)片,增加噪聲或失真等等。具體可以參考如下對(duì)CIFAR-10處理的例子:
批處理
經(jīng)過(guò)了之前的步驟,在數(shù)據(jù)讀取流程的最后, 我們需要有另一個(gè)隊(duì)列來(lái)批量執(zhí)行輸入樣本的訓(xùn)練,評(píng)估或者推斷。根據(jù)要不要打亂順序,我們常用的有兩個(gè)函數(shù):
tf.train.batch()
tf.train.shuffle_batch()
下面來(lái)分別介紹:
tf.train.batch()
tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
該函數(shù)將會(huì)使用一個(gè)隊(duì)列,函數(shù)讀取一定數(shù)量的tensors送入隊(duì)列,然后每次從中選取batch_size個(gè)tensors組成一個(gè)新的tensors返回出來(lái)。
capacity
參數(shù)決定了隊(duì)列的長(zhǎng)度。
num_threads
決定了有多少個(gè)線程進(jìn)行入隊(duì)操作,如果設(shè)置的超過(guò)一個(gè)線程,它們將從不同文件不同位置同時(shí)讀取,可以更加充分的混合訓(xùn)練樣本。
如果enqueue_many
參數(shù)為False
,則輸入?yún)?shù)tensors
為一個(gè)形狀為[x, y, z]
的張量,輸出為一個(gè)形狀為[batch_size, x, y, z]
的張量。如果enqueue_many
參數(shù)為T(mén)rue,則輸入?yún)?shù)tensors
為一個(gè)形狀為[*, x, y, z]
的張量,其中所有*的數(shù)值相同,輸出為一個(gè)形狀為[batch_size, x, y, z]
的張量。
當(dāng)allow_smaller_final_batch
為True
時(shí),如果隊(duì)列中的張量數(shù)量不足batch_size
,將會(huì)返回小于batch_size
長(zhǎng)度的張量,如果為False
,剩下的張量會(huì)被丟棄。
tf.train.shuffle_batch()
tf.train.shuffle_batch(
tensors,
batch_size,
capacity,
min_after_dequeue,
num_threads=1,
seed=None,
enqueue_many=False,
shapes=None,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
該函數(shù)類(lèi)似于上面的tf.train.batch()
,同樣創(chuàng)建一個(gè)隊(duì)列,主要區(qū)別是會(huì)首先把隊(duì)列中的張量進(jìn)行亂序處理,然后再選取其中的batch_size
個(gè)張量組成一個(gè)新的張量返回。但是新增加了幾個(gè)參數(shù)。
capacity
參數(shù)依然為隊(duì)列的長(zhǎng)度,建議capacity
的取值如下:
min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue
這個(gè)參數(shù)的意思是隊(duì)列中,做dequeue(取數(shù)據(jù))的操作后,線程要保證隊(duì)列中至少剩下min_after_dequeue
個(gè)數(shù)據(jù)。如果min_after_dequeue
設(shè)置的過(guò)少,則即使shuffle
為True
,也達(dá)不到好的混合效果。
假設(shè)你有一個(gè)隊(duì)列,現(xiàn)在里面有m個(gè)數(shù)據(jù),你想要每次隨機(jī)從隊(duì)列中取n個(gè)數(shù)據(jù),則代表先混合了m個(gè)數(shù)據(jù),再?gòu)闹腥∽遪個(gè)。
當(dāng)?shù)谝淮稳∽遪個(gè)后,隊(duì)列就變?yōu)閙-n個(gè)數(shù)據(jù);當(dāng)你下次再想要取n個(gè)時(shí),假設(shè)隊(duì)列在此期間入隊(duì)進(jìn)來(lái)了k個(gè)數(shù)據(jù),則現(xiàn)在的隊(duì)列中有(m-n+k)個(gè)數(shù)據(jù),則此時(shí)會(huì)從混合的(m-n+k)個(gè)數(shù)據(jù)中隨機(jī)取走n個(gè)。
如果隊(duì)列填充的速度比較慢,k就比較小,那你取出來(lái)的n個(gè)數(shù)據(jù)只是與周?chē)苄〉囊徊糠?m-n+k)個(gè)數(shù)據(jù)進(jìn)行了混合。
因?yàn)槲覀兊哪康目隙ㄊ窍氡M最大可能的混合數(shù)據(jù),因此設(shè)置
min_after_dequeue
,可以保證每次dequeue后都有足夠量的數(shù)據(jù)填充盡隊(duì)列,保證下次dequeue時(shí)可以很充分的混合數(shù)據(jù)。但是
min_after_dequeue
也不能設(shè)置的太大,這樣會(huì)導(dǎo)致隊(duì)列填充的時(shí)間變長(zhǎng),尤其是在最初的裝載階段,會(huì)花費(fèi)比較長(zhǎng)的時(shí)間。
其他參數(shù)和tf.train.batch()
相同。
這里我們使用tf.train.shuffle_batch
函數(shù)來(lái)對(duì)隊(duì)列中的樣本進(jìn)行亂序處理。如下的模版:
def read_my_file_format(filename_queue):
reader = tf.SomeReader()
key, record_string = reader.read(filename_queue)
example, label = tf.some_decoder(record_string)
processed_example = some_processing(example)
return processed_example, label
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
# min_after_dequeue 越大意味著隨機(jī)效果越好但是也會(huì)占用更多的時(shí)間和內(nèi)存
# capacity 必須比 min_after_dequeue 大
# 建議capacity的取值如下:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch```
一個(gè)具體的例子如下,該例采用了CIFAR-10數(shù)據(jù)集,采用了固定長(zhǎng)度讀取的tf.FixedLengthRecordReader
閱讀器和tf.decode_raw
解析器,同時(shí)進(jìn)行了數(shù)據(jù)預(yù)處理操作中的標(biāo)準(zhǔn)化操作,最后使用tf.train.shuffle_batch
函數(shù)批量執(zhí)行數(shù)據(jù)的亂序處理。
class cifar10_data(object):
def __init__(self, filename_queue):
self.height = 32
self.width = 32
self.depth = 3
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.depth
self.record_bytes = self.label_bytes + self.image_bytes
self.label, self.image = self.read_cifar10(filename_queue)
def read_cifar10(self, filename_queue):
reader = tf.FixedLengthRecordReader(record_bytes = self.record_bytes)
key, value = reader.read(filename_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
label = tf.cast(tf.slice(record_bytes, [0], [self.label_bytes]), tf.int32)
image_raw = tf.slice(record_bytes, [self.label_bytes], [self.image_bytes])
image_raw = tf.reshape(image_raw, [self.depth, self.height, self.width])
image = tf.transpose(image_raw, (1,2,0))
image = tf.cast(image, tf.float32)
return label, image
def inputs(data_dir, batch_size, train = True, name = 'input'):
with tf.name_scope(name):
if train:
filenames = [os.path.join(data_dir,'data_batch_%d.bin' % ii)
for ii in range(1,6)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
filename_queue = tf.train.string_input_producer(filenames)
read_input = cifar10_data(filename_queue)
images = read_input.image
images = tf.image.per_image_standardization(images)
labels = read_input.label
image, label = tf.train.shuffle_batch(
[images,labels], batch_size = batch_size,
min_after_dequeue = 20000, capacity = 20192)
return image, tf.reshape(label, [batch_size])
else:
filenames = [os.path.join(data_dir,'test_batch.bin')]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
filename_queue = tf.train.string_input_producer(filenames)
read_input = cifar10_data(filename_queue)
images = read_input.image
images = tf.image.per_image_standardization(images)
labels = read_input.label
image, label = tf.train.shuffle_batch(
[images,labels], batch_size = batch_size,
min_after_dequeue = 20000, capacity = 20192)
return image, tf.reshape(label, [batch_size])
這里介紹下函數(shù)tf.image.per_image_standardization(image),該函數(shù)對(duì)圖像進(jìn)行線性變換使它具有零均值和單位方差,即規(guī)范化。其中參數(shù)image是一個(gè)3-D的張量,形狀為[height, width, channels]。
參考 ZangBo