tensorflow教程2:數(shù)據(jù)讀取

Tensorflow的數(shù)據(jù)讀取有三種方式:

Preloaded data: 預(yù)加載數(shù)據(jù),也就是TensorFlow圖中的常量或變量保留所有數(shù)據(jù)(對(duì)于小數(shù)據(jù)集)。
Feeding: Python產(chǎn)生數(shù)據(jù),再把數(shù)據(jù)喂給后端。
Reading from file: 從文件中直接讀取,輸入流水線從TensorFlow圖開(kāi)頭的文件中讀取數(shù)據(jù)。

Preloaded data: 預(yù)加載數(shù)據(jù)

預(yù)加載數(shù)據(jù)方法僅限于用在可以完全加載到內(nèi)存中的小數(shù)據(jù)集上,主要有兩種方法:

把數(shù)據(jù)存在常量(constant)中。
把數(shù)據(jù)存在變量(variable)中,我們初始化并且永不改變它的值。
用常量更簡(jiǎn)單些,但會(huì)占用更多的內(nèi)存,因?yàn)槌A看鎯?chǔ)在graph數(shù)據(jù)結(jié)構(gòu)內(nèi)部。例如:

import tensorflow as tf
# 構(gòu)造Graph
x1 = tf.constant([2, 3, 4])
x2 = tf.constant([4, 0, 1])
y = tf.add(x1, x2)
# 打開(kāi)一個(gè)session --> 計(jì)算y
with tf.Session() as sess:
    print sess.run(y)

這種方法在設(shè)計(jì)Graph的時(shí)候,x1和x2就被定義成了兩個(gè)有值的列表,在計(jì)算y的時(shí)候直接取x1和x2的值。

如果用變量的話(huà),我們需要在graph構(gòu)建好之后初始化該變量。例如:

training_data = ...
training_labels = ...
with tf.Session() as sess:
  data_initializer = tf.placeholder(dtype=training_data.dtype,
                                    shape=training_data.shape)
  label_initializer = tf.placeholder(dtype=training_labels.dtype,
                                     shape=training_labels.shape)
  input_data = tf.Variable(data_initializer, trainable=False, collections=[])
  input_labels = tf.Variable(label_initializer, trainable=False, collections=[])
  ...
  sess.run(input_data.initializer,
           feed_dict={data_initializer: training_data})
  sess.run(input_labels.initializer,
           feed_dict={label_initializer: training_labels})

Feeding: 供給數(shù)據(jù)

我們一般用tf.placeholder節(jié)點(diǎn)來(lái)feed數(shù)據(jù),該節(jié)點(diǎn)不需要初始化也不包含任何數(shù)據(jù),我們?cè)趫?zhí)行run()或者eval()指令時(shí)通過(guò)feed_dict參數(shù)把數(shù)據(jù)傳入graph中來(lái)計(jì)算。如果在運(yùn)行過(guò)程中沒(méi)有對(duì)tf.placeholder節(jié)點(diǎn)傳入數(shù)據(jù),程序會(huì)報(bào)錯(cuò)。例如:

import tensorflow as tf
# 設(shè)計(jì)Graph
x1 = tf.placeholder(tf.int16)
x2 = tf.placeholder(tf.int16)
y = tf.add(x1, x2)
# 用Python產(chǎn)生數(shù)據(jù)
li1 = [2, 3, 4]
li2 = [4, 0, 1]
# 打開(kāi)一個(gè)session --> 喂數(shù)據(jù) --> 計(jì)算y
with tf.Session() as sess:
    print sess.run(y, feed_dict={x1: li1, x2: li2})

兩種方法的區(qū)別

Preload:

將數(shù)據(jù)直接內(nèi)嵌到Graph中,再把Graph傳入Session中運(yùn)行。當(dāng)數(shù)據(jù)量比較大時(shí),Graph的傳輸會(huì)遇到效率問(wèn)題。

Feeding:

用占位符替代數(shù)據(jù),待運(yùn)行的時(shí)候填充數(shù)據(jù)。

Reading From File 從文件中讀數(shù)據(jù)

前兩種方法很方便,但是遇到大型數(shù)據(jù)的時(shí)候就會(huì)很吃力,即使是Feeding,中間環(huán)節(jié)的增加也是不小的開(kāi)銷(xiāo),比如數(shù)據(jù)類(lèi)型轉(zhuǎn)換等等。最優(yōu)的方案就是在Graph定義好文件讀取的方法,讓TF自己去從文件中讀取數(shù)據(jù),并解碼成可使用的樣本集。從文件中讀取數(shù)據(jù)一般包含以下步驟:

  • 文件名列表
  • 文件名隨機(jī)排序(可選的)
  • 迭代控制(可選的)
  • 文件名隊(duì)列
  • 針對(duì)輸入文件格式的閱讀器
  • 記錄解析器
  • 預(yù)處理器(可選的)
  • 樣本隊(duì)列

在了解具體的操作之前首先了解文件讀取數(shù)據(jù)的優(yōu)點(diǎn):


AnimatedFileQueues.gif

在上圖中,首先由一個(gè)單線程把文件名堆入隊(duì)列,兩個(gè)Reader同時(shí)從隊(duì)列中取文件名并讀取數(shù)據(jù),Decoder將讀出的數(shù)據(jù)解碼后堆入樣本隊(duì)列,最后單個(gè)或批量取出樣本(圖中沒(méi)有展示樣本出列)。我們這里通過(guò)三段代碼逐步實(shí)現(xiàn)上圖的數(shù)據(jù)流,這里我們不使用隨機(jī),讓結(jié)果更清晰。

文件準(zhǔn)備

$ echo -e "Alpha1,A1\nAlpha2,A2\nAlpha3,A3" > A.csv
$ echo -e "Bee1,B1\nBee2,B2\nBee3,B3" > B.csv
$ echo -e "Sea1,C1\nSea2,C2\nSea3,C3" > C.csv
$ cat A.csv
Alpha1,A1
Alpha2,A2
Alpha3,A3

單個(gè)Reader,單個(gè)樣本

import tensorflow as tf
# 生成一個(gè)先入先出隊(duì)列和一個(gè)QueueRunner
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
# 定義Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定義Decoder
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
# 運(yùn)行Graph
with tf.Session() as sess:
    coord = tf.train.Coordinator()  #創(chuàng)建一個(gè)協(xié)調(diào)器,管理線程
    threads = tf.train.start_queue_runners(coord=coord)  #啟動(dòng)QueueRunner, 此時(shí)文件名隊(duì)列已經(jīng)進(jìn)隊(duì)。
    for i in range(10):
        print example.eval()   #取樣本的時(shí)候,一個(gè)Reader先從文件名隊(duì)列中取出文件名,讀出數(shù)據(jù),Decoder解析后進(jìn)入樣本隊(duì)列。
    coord.request_stop()
    coord.join(threads)
# outpt
Alpha1
Alpha2
Alpha3
Bee1
Bee2
Bee3
Sea1
Sea2
Sea3
Alpha1

單個(gè)Reader,多個(gè)樣本

import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
# 使用tf.train.batch()會(huì)多加了一個(gè)樣本隊(duì)列和一個(gè)QueueRunner。Decoder解碼后數(shù)據(jù)會(huì)進(jìn)入這個(gè)隊(duì)列,再批量出隊(duì)。
# 雖然這里只有一個(gè)Reader,但可以設(shè)置多線程,相應(yīng)增加線程數(shù)會(huì)提高讀取速度,但并不是線程越多越好。
example_batch, label_batch = tf.train.batch(
      [example, label], batch_size=5)
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10):
        print example_batch.eval()
    coord.request_stop()
    coord.join(threads)
# output
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']
# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']
# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']
# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']

多Reader,多個(gè)樣本

import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['null'], ['null']]
example_list = [tf.decode_csv(value, record_defaults=record_defaults)
                  for _ in range(2)]  # Reader設(shè)置為2
# 使用tf.train.batch_join(),可以使用多個(gè)reader,并行讀取數(shù)據(jù)。每個(gè)Reader使用一個(gè)線程。
example_batch, label_batch = tf.train.batch_join(
      example_list, batch_size=5)
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10):
        print example_batch.eval()
    coord.request_stop()
    coord.join(threads)
    
# output
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']
# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']
# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']
# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']

tf.train.batch與tf.train.shuffle_batch函數(shù)是單個(gè)Reader讀取,但是可以多線程。tf.train.batch_join與tf.train.shuffle_batch_join可設(shè)置多Reader讀取,每個(gè)Reader使用一個(gè)線程。至于兩種方法的效率,單Reader時(shí),2個(gè)線程就達(dá)到了速度的極限。多Reader時(shí),2個(gè)Reader就達(dá)到了極限。所以并不是線程越多越快,甚至更多的線程反而會(huì)使效率下降。

迭代控制

filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=3)  # num_epoch: 設(shè)置迭代數(shù)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['null'], ['null']]
example_list = [tf.decode_csv(value, record_defaults=record_defaults)
                  for _ in range(2)]
example_batch, label_batch = tf.train.batch_join(
      example_list, batch_size=5)
init_local_op = tf.initialize_local_variables()
with tf.Session() as sess:
    sess.run(init_local_op)   # 初始化本地變量 
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    try:
        while not coord.should_stop():
            print example_batch.eval()
    except tf.errors.OutOfRangeError:
        print('Epochs Complete!')
    finally:
        coord.request_stop()
    coord.join(threads)
    coord.request_stop()
    coord.join(threads)
# output
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']

在迭代控制中,記得添加tf.initialize_local_variables(),官網(wǎng)教程沒(méi)有說(shuō)明,但是如果不初始化,運(yùn)行就會(huì)報(bào)錯(cuò)。

下面開(kāi)始正式的步驟:

文件名列表

文件名列表.jpg

我們首先要有個(gè)文件名列表,為了產(chǎn)生文件名列表,我們可以手動(dòng)用Python輸入字符串,例如:

["file0", "file1"]
[("file%d" % i) for i in range(2)]
[("file%d" % i) for i in range(2)]

我們也可以用tf.train.match_filenames_once函數(shù)來(lái)生成文件名列表。

有了文件名列表后,我們需要把它送入 tf.train.string_input_producer函數(shù)中生成一個(gè)先入先出的文件名隊(duì)列,文件閱讀器需要從該隊(duì)列中讀取文件名。

string_input_producer(
    string_tensor,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None,
    cancel_op=None
)

一個(gè)QueueRunner每次會(huì)把每批次的所有文件名送入隊(duì)列中,可以通過(guò)設(shè)置string_input_producer函數(shù)的shuffle參數(shù)來(lái)對(duì)文件名隨機(jī)排序,或者通過(guò)設(shè)置num_epochs來(lái)決定對(duì)string_tensor里的文件使用多少次,類(lèi)型為整型,如果想要迭代控制則需要設(shè)置了num_epochs參數(shù),同時(shí)需要添加tf.local_variables_initializer()進(jìn)行初始化,如果不初始化會(huì)報(bào)錯(cuò)。
這個(gè)QueueRunner的工作線程獨(dú)立于文件閱讀器的線程, 因此隨機(jī)排序和將文件名送入到文件名隊(duì)列這些過(guò)程不會(huì)阻礙文件閱讀器的運(yùn)行。

文件格式

根據(jù)不同的文件格式, 應(yīng)該選擇對(duì)應(yīng)的文件閱讀器, 然后將文件名隊(duì)列提供給閱讀器的read方法。閱讀器每次從隊(duì)列中讀取一個(gè)文件,它的read方法會(huì)輸出一個(gè)key來(lái)表征讀入的文件和其中的紀(jì)錄(對(duì)于調(diào)試非常有用),同時(shí)得到一個(gè)字符串標(biāo)量, 這個(gè)字符串標(biāo)量可以被一個(gè)或多個(gè)解析器,或者轉(zhuǎn)換操作將其解碼為張量并且構(gòu)造成為樣本。
根據(jù)不同的文件類(lèi)型,有三種不同的文件閱讀器:

  • tf.TextLineReader
  • tf.FixedLengthRecordReader
  • tf.TFRecordReader

它們分別用于單行讀取(如CSV文件)、固定長(zhǎng)度讀取(如CIFAR-10的.bin二進(jìn)制文件)、TensorFlow標(biāo)準(zhǔn)格式讀取。

根據(jù)不同的文件閱讀器,有三種不同的解析器,它們分別對(duì)應(yīng)上面三種閱讀器:

  • tf.decode_csv
  • tf.decode_raw
  • tf.parse_single_exampletf.parse_example

CSV文件

當(dāng)我們讀入CSV格式的文件時(shí),我們可以使用tf.TextLineReader閱讀器和tf.decode_csv解析器。例如:

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import tensorflow as tf
import numpy as np

filename_queue = tf.train.string_input_producer(["./data/data1.csv", "./data/data2.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# key返回的是讀取文件和行數(shù)信息 b'./data/iris.csv:146'
# value是按行讀取到的原始字符串,送到下面的decoder去解析

record_defaults = [[1.0], [1.0], [1.0], [1.0], ["Null"]] # 這里的數(shù)據(jù)類(lèi)型決定了讀取的數(shù)據(jù)類(lèi)型,而且必須是list形式
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults) # 解析出的每一個(gè)屬性都是rank為0的標(biāo)量,每次解碼一行,col對(duì)應(yīng)這一行的一列也就是一個(gè)數(shù)字
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(100):
        example, label = sess.run([features, col5])
        print (example,col5)
    coord.request_stop()
    coord.join(threads)   

每次read的執(zhí)行都會(huì)從文件中讀取一行內(nèi)容,decode_csv操作會(huì)解析這一行內(nèi)容并將其轉(zhuǎn)為張量列表。在調(diào)用run或者eval去執(zhí)行read之前, 必須先調(diào)用tf.train.start_queue_runners來(lái)將文件名填充到隊(duì)列。否則read操作會(huì)被阻塞到文件名隊(duì)列中有值為止。

record_defaults = [[1], [1], [1], [1], [1]]代表了解析的摸版,默認(rèn)用,隔開(kāi),是用于指定矩陣格式以及數(shù)據(jù)類(lèi)型的,CSV文件中的矩陣是NXM的,則此處為1XM,例如上例中M=5[1]表示解析為整型,如果矩陣中有小數(shù),則應(yīng)為float型,[1]應(yīng)該變?yōu)?code>[1.0],[‘null’]解析為string類(lèi)型。

col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults = record_defaults), 矩陣中有幾列,這里就要寫(xiě)幾個(gè)參數(shù),比如5列,就要寫(xiě)到col5,不管你到底用多少。否則報(bào)錯(cuò)。

固定長(zhǎng)度記錄

我們也可以從二進(jìn)制文件‘(.bin)中讀取固定長(zhǎng)度的數(shù)據(jù),使用的是tf.FixedLengthRecordReader閱讀器和tf.decode_raw解析器。decode_raw節(jié)點(diǎn)會(huì)把string轉(zhuǎn)化為uint8類(lèi)型的張量。

例如CIFAR-10數(shù)據(jù)集就采用的固定長(zhǎng)度的數(shù)據(jù),1字節(jié)的標(biāo)簽,后面跟著3072字節(jié)的圖像數(shù)據(jù)。使用uint8類(lèi)型張量的標(biāo)準(zhǔn)操作可以把每個(gè)圖像的片段截取下來(lái)并且按照需要重組。下面有一個(gè)例子:

reader = tf.FixedLengthRecordReader(record_bytes = record_bytes)
key, value = reader.read(filename_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
image_raw = tf.slice(record_bytes, [label_bytes], [image_bytes])
image_raw = tf.reshape(image_raw, [depth, height, width])
image = tf.transpose(image_raw, (1,2,0)) # 圖像形狀為[height, width, channels]     
image = tf.cast(image, tf.float32)

這里介紹上述代碼中出現(xiàn)的函數(shù):tf.slice()

slice(
    input_,
    begin,
    size,
    name=None
)

從一個(gè)張量input中提取出長(zhǎng)度為size的一部分,提取的起點(diǎn)由begin定義。size是一個(gè)向量,它代表著在每個(gè)維度提取出的tensor的大小。begin表示提取的位置,它表示的是input的起點(diǎn)偏離值,也就是從每個(gè)維度第幾個(gè)值開(kāi)始提取。

begin從0開(kāi)始,size從1開(kāi)始,如果size[i]的值為-1,則第i個(gè)維度從begin處到余下的所有值都被提取出來(lái)。

例如:

# 'input' is [[[1, 1, 1], [2, 2, 2]],
#             [[3, 3, 3], [4, 4, 4]],
#             [[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
                                            [4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
                                           [[5, 5, 5]]]

標(biāo)準(zhǔn)TensorFlow格式

我們也可以把任意的數(shù)據(jù)轉(zhuǎn)換為T(mén)ensorFlow所支持的格式, 這種方法使TensorFlow的數(shù)據(jù)集更容易與網(wǎng)絡(luò)應(yīng)用架構(gòu)相匹配。這種方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Exampleprotocol buffer(里面包含了名為Features的字段)。你可以寫(xiě)一段代碼獲取你的數(shù)據(jù), 將數(shù)據(jù)填入到Exampleprotocol buffer,將protocol buffer序列化為一個(gè)字符串, 并且通過(guò)tf.python_io.TFRecordWriter類(lèi)寫(xiě)入到TFRecords文件。

從TFRecords文件中讀取數(shù)據(jù), 可以使用tf.TFRecordReader閱讀器以及tf.parse_single_example解析器。parse_single_example操作可以將Exampleprotocol buffer解析為張量。 具體可以參考如下例子,把MNIST數(shù)據(jù)集轉(zhuǎn)化為T(mén)FRecords格式:

SparseTensors這種稀疏輸入數(shù)據(jù)類(lèi)型使用隊(duì)列來(lái)處理不是太好。如果要使用SparseTensors你就必須在批處理之后使用tf.parse_example去解析字符串記錄 (而不是在批處理之前使用tf.parse_single_example) 。

預(yù)處理

我們可以對(duì)輸入的樣本數(shù)據(jù)進(jìn)行任意的預(yù)處理, 這些預(yù)處理不依賴(lài)于訓(xùn)練參數(shù), 比如數(shù)據(jù)歸一化, 提取隨機(jī)數(shù)據(jù)片,增加噪聲或失真等等。具體可以參考如下對(duì)CIFAR-10處理的例子:

批處理

經(jīng)過(guò)了之前的步驟,在數(shù)據(jù)讀取流程的最后, 我們需要有另一個(gè)隊(duì)列來(lái)批量執(zhí)行輸入樣本的訓(xùn)練,評(píng)估或者推斷。根據(jù)要不要打亂順序,我們常用的有兩個(gè)函數(shù):

  • tf.train.batch()
  • tf.train.shuffle_batch()

下面來(lái)分別介紹:

tf.train.batch()

tf.train.batch(
   tensors,
   batch_size,
   num_threads=1,
   capacity=32,
   enqueue_many=False,
   shapes=None,
   dynamic_pad=False,
   allow_smaller_final_batch=False,
   shared_name=None,
   name=None
)

該函數(shù)將會(huì)使用一個(gè)隊(duì)列,函數(shù)讀取一定數(shù)量的tensors送入隊(duì)列,然后每次從中選取batch_size個(gè)tensors組成一個(gè)新的tensors返回出來(lái)。

capacity參數(shù)決定了隊(duì)列的長(zhǎng)度。

num_threads決定了有多少個(gè)線程進(jìn)行入隊(duì)操作,如果設(shè)置的超過(guò)一個(gè)線程,它們將從不同文件不同位置同時(shí)讀取,可以更加充分的混合訓(xùn)練樣本。

如果enqueue_many參數(shù)為False,則輸入?yún)?shù)tensors為一個(gè)形狀為[x, y, z]的張量,輸出為一個(gè)形狀為[batch_size, x, y, z]的張量。如果enqueue_many參數(shù)為T(mén)rue,則輸入?yún)?shù)tensors為一個(gè)形狀為[*, x, y, z]的張量,其中所有*的數(shù)值相同,輸出為一個(gè)形狀為[batch_size, x, y, z]的張量。

當(dāng)allow_smaller_final_batchTrue時(shí),如果隊(duì)列中的張量數(shù)量不足batch_size,將會(huì)返回小于batch_size長(zhǎng)度的張量,如果為False,剩下的張量會(huì)被丟棄。

tf.train.shuffle_batch()

tf.train.shuffle_batch(
    tensors,
    batch_size,
    capacity,
    min_after_dequeue,
    num_threads=1,
    seed=None,
    enqueue_many=False,
    shapes=None,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None
)

該函數(shù)類(lèi)似于上面的tf.train.batch(),同樣創(chuàng)建一個(gè)隊(duì)列,主要區(qū)別是會(huì)首先把隊(duì)列中的張量進(jìn)行亂序處理,然后再選取其中的batch_size個(gè)張量組成一個(gè)新的張量返回。但是新增加了幾個(gè)參數(shù)。

capacity參數(shù)依然為隊(duì)列的長(zhǎng)度,建議capacity的取值如下:

min_after_dequeue + (num_threads + a small safety margin) * batch_size

min_after_dequeue這個(gè)參數(shù)的意思是隊(duì)列中,做dequeue(取數(shù)據(jù))的操作后,線程要保證隊(duì)列中至少剩下min_after_dequeue個(gè)數(shù)據(jù)。如果min_after_dequeue設(shè)置的過(guò)少,則即使shuffleTrue,也達(dá)不到好的混合效果。

假設(shè)你有一個(gè)隊(duì)列,現(xiàn)在里面有m個(gè)數(shù)據(jù),你想要每次隨機(jī)從隊(duì)列中取n個(gè)數(shù)據(jù),則代表先混合了m個(gè)數(shù)據(jù),再?gòu)闹腥∽遪個(gè)。

當(dāng)?shù)谝淮稳∽遪個(gè)后,隊(duì)列就變?yōu)閙-n個(gè)數(shù)據(jù);當(dāng)你下次再想要取n個(gè)時(shí),假設(shè)隊(duì)列在此期間入隊(duì)進(jìn)來(lái)了k個(gè)數(shù)據(jù),則現(xiàn)在的隊(duì)列中有(m-n+k)個(gè)數(shù)據(jù),則此時(shí)會(huì)從混合的(m-n+k)個(gè)數(shù)據(jù)中隨機(jī)取走n個(gè)。

如果隊(duì)列填充的速度比較慢,k就比較小,那你取出來(lái)的n個(gè)數(shù)據(jù)只是與周?chē)苄〉囊徊糠?m-n+k)個(gè)數(shù)據(jù)進(jìn)行了混合。

因?yàn)槲覀兊哪康目隙ㄊ窍氡M最大可能的混合數(shù)據(jù),因此設(shè)置min_after_dequeue,可以保證每次dequeue后都有足夠量的數(shù)據(jù)填充盡隊(duì)列,保證下次dequeue時(shí)可以很充分的混合數(shù)據(jù)。

但是min_after_dequeue也不能設(shè)置的太大,這樣會(huì)導(dǎo)致隊(duì)列填充的時(shí)間變長(zhǎng),尤其是在最初的裝載階段,會(huì)花費(fèi)比較長(zhǎng)的時(shí)間。

其他參數(shù)和tf.train.batch()相同。

這里我們使用tf.train.shuffle_batch函數(shù)來(lái)對(duì)隊(duì)列中的樣本進(jìn)行亂序處理。如下的模版:

def read_my_file_format(filename_queue):
  reader = tf.SomeReader()
  key, record_string = reader.read(filename_queue)
  example, label = tf.some_decoder(record_string)
  processed_example = some_processing(example)
  return processed_example, label
def input_pipeline(filenames, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer(
      filenames, num_epochs=num_epochs, shuffle=True)
  example, label = read_my_file_format(filename_queue)
  # min_after_dequeue 越大意味著隨機(jī)效果越好但是也會(huì)占用更多的時(shí)間和內(nèi)存
  # capacity 必須比 min_after_dequeue 大
  # 建議capacity的取值如下:
  # min_after_dequeue + (num_threads + a small safety margin) * batch_size
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch```

一個(gè)具體的例子如下,該例采用了CIFAR-10數(shù)據(jù)集,采用了固定長(zhǎng)度讀取的tf.FixedLengthRecordReader閱讀器和tf.decode_raw解析器,同時(shí)進(jìn)行了數(shù)據(jù)預(yù)處理操作中的標(biāo)準(zhǔn)化操作,最后使用tf.train.shuffle_batch函數(shù)批量執(zhí)行數(shù)據(jù)的亂序處理。

class cifar10_data(object):
    def __init__(self, filename_queue):
        self.height = 32
        self.width = 32
        self.depth = 3
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.depth
        self.record_bytes = self.label_bytes + self.image_bytes
        self.label, self.image = self.read_cifar10(filename_queue)
        
    def read_cifar10(self, filename_queue):
        reader = tf.FixedLengthRecordReader(record_bytes = self.record_bytes)
        key, value = reader.read(filename_queue)
        record_bytes = tf.decode_raw(value, tf.uint8)
        label = tf.cast(tf.slice(record_bytes, [0], [self.label_bytes]), tf.int32)
        image_raw = tf.slice(record_bytes, [self.label_bytes], [self.image_bytes])
        image_raw = tf.reshape(image_raw, [self.depth, self.height, self.width])
        image = tf.transpose(image_raw, (1,2,0))        
        image = tf.cast(image, tf.float32)
        return label, image

def inputs(data_dir, batch_size, train = True, name = 'input'):
    with tf.name_scope(name):
        if train:    
            filenames = [os.path.join(data_dir,'data_batch_%d.bin' % ii) 
                        for ii in range(1,6)]
            for f in filenames:
                if not tf.gfile.Exists(f):
                    raise ValueError('Failed to find file: ' + f)
                    
            filename_queue = tf.train.string_input_producer(filenames)
            read_input = cifar10_data(filename_queue)
            images = read_input.image
            images = tf.image.per_image_standardization(images)
            labels = read_input.label
            image, label = tf.train.shuffle_batch(
                                    [images,labels], batch_size = batch_size, 
                                    min_after_dequeue = 20000, capacity = 20192)
        
            return image, tf.reshape(label, [batch_size])
            
        else:
            filenames = [os.path.join(data_dir,'test_batch.bin')]
            for f in filenames:
                if not tf.gfile.Exists(f):
                    raise ValueError('Failed to find file: ' + f)
                    
            filename_queue = tf.train.string_input_producer(filenames)
            read_input = cifar10_data(filename_queue)
            images = read_input.image
            images = tf.image.per_image_standardization(images)
            labels = read_input.label
            image, label = tf.train.shuffle_batch(
                                    [images,labels], batch_size = batch_size, 
                                    min_after_dequeue = 20000, capacity = 20192)
        
            return image, tf.reshape(label, [batch_size])

這里介紹下函數(shù)tf.image.per_image_standardization(image),該函數(shù)對(duì)圖像進(jìn)行線性變換使它具有零均值和單位方差,即規(guī)范化。其中參數(shù)image是一個(gè)3-D的張量,形狀為[height, width, channels]。

參考 ZangBo

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容