原文地址https://www.tensorflow.org/programmersguide/threadingand_queues
主要內(nèi)容
隊(duì)列對(duì)于使用TensorFlow來(lái)進(jìn)行異步計(jì)算是一個(gè)強(qiáng)大的機(jī)制。
像所有的TensorFlow中的東西一樣,一個(gè)隊(duì)列是一個(gè)TensorFlow圖中的一個(gè)節(jié)點(diǎn)。這是個(gè)狀態(tài)節(jié)點(diǎn),像一個(gè)變量:其他節(jié)點(diǎn)可以修改它的內(nèi)容。特別的來(lái)說(shuō),其他節(jié)點(diǎn)可以將新的元素加入隊(duì)列,或者從隊(duì)列中出隊(duì)現(xiàn)存的元素。
為了找點(diǎn)對(duì)于隊(duì)列概念的感覺(jué),讓我們來(lái)考慮一個(gè)簡(jiǎn)單的例子。我們會(huì)創(chuàng)建一個(gè)“先進(jìn),先出”的隊(duì)列(FIFOQueue),并且用0填充這個(gè)隊(duì)列。然后,我們構(gòu)建了一個(gè)圖,出隊(duì)一個(gè)元素,將這個(gè)元素加1,然后將這個(gè)元素放回到這個(gè)隊(duì)列的尾部。于是,隊(duì)列中的數(shù)字都在逐漸增加。
Enqueue,EnqueueMany和Dequeue都是特殊的節(jié)點(diǎn)。他們保存著指向隊(duì)列的指針而不是一般的數(shù)值,這樣使得它們能夠修改隊(duì)列。我們建議你將這些方法認(rèn)為是和隊(duì)列的方法類似。實(shí)際上,在Python的API中,他們是隊(duì)列對(duì)象的方法(例如q.enqueue(…))。
注意:隊(duì)列的方法(例如q.enqueue())必須和隊(duì)列本身運(yùn)行在相同的硬件上。在創(chuàng)建這些操作的時(shí)候,不一致的硬件部署指令會(huì)被忽略。
現(xiàn)在你應(yīng)該對(duì)于隊(duì)列有些感覺(jué)了,那讓我們深入細(xì)節(jié)。
隊(duì)列使用概述
例如像tf.FIFOQueue和tf.RandomShuffleQueue這樣的隊(duì)列,對(duì)于在圖中異步計(jì)算張量而言,是非常重要的TensorFlow對(duì)象。
舉例來(lái)說(shuō),一個(gè)典型的輸入架構(gòu)就是使用RandomShuffleQueue來(lái)為訓(xùn)練模型準(zhǔn)備輸入:
- 多線程準(zhǔn)備訓(xùn)練數(shù)據(jù),并將其壓入隊(duì)列。
- 一個(gè)訓(xùn)練線程運(yùn)行訓(xùn)練操作,這個(gè)操作從隊(duì)列中出隊(duì)小批量的數(shù)據(jù)。
這個(gè)架構(gòu)有很多好處,正如Reading data中所強(qiáng)調(diào)的那樣,同時(shí)這篇文章也一些簡(jiǎn)化構(gòu)建輸入管道的函數(shù)的概述。
TensorFlow的Session對(duì)象是多線程的,因此多線程可以很容易的使用相同的圖,和并行的運(yùn)行操作。但是實(shí)現(xiàn)一個(gè)如上所述的使用線程的Python程序,并不總是那么容易。所有的線程必須能夠同時(shí)停止,異常需要被捕捉并通知(reported),而隊(duì)列比如在線程停止的時(shí)候,被恰當(dāng)?shù)年P(guān)閉。
TensorFlow提供兩個(gè)類來(lái)幫助上述任務(wù)的實(shí)現(xiàn):tf.train.Coordinator和tf.train.QueueRunner。這兩個(gè)類被設(shè)計(jì)為同時(shí)使用。Coordinator類幫助多線程同時(shí)停止,并向那些等待他們停止的程序報(bào)告異常。QueueRunner類被用來(lái)創(chuàng)建多個(gè)線程,這些線程用來(lái)協(xié)作在相同的隊(duì)列中入隊(duì)張量。
Coordinator
Coordinator類幫助多線程同時(shí)停止。 它的關(guān)鍵方法為:
- tf.train.Coordinator.should_stop:如果線程應(yīng)該停止,那么返回True
- tf.train.Coordinator.request_stop:請(qǐng)求停止線程、
- tf.train.Coordinator.join:等待,直到指定的線程已經(jīng)停止
你首先創(chuàng)建一個(gè)Coordinator對(duì)象,然后創(chuàng)建一些線程來(lái)使用Coordinator。通常,線程運(yùn)行循環(huán),而這個(gè)循環(huán)會(huì)在should_stop()返回為T(mén)rue的時(shí)候終止。
任何線程可以決定這次計(jì)算應(yīng)該終止。它僅僅需要調(diào)用requeststop(),然后其他線程會(huì)隨著shouldstop()返回為T(mén)rue而終止。
# Thread body: loop until the coordinator indicates a stop was requested.
# If some condition becomes true, ask the coordinator to stop.
def MyLoop(coord):
while not coord.should_stop():
...do something...
if ...some condition...:
coord.request_stop()
# Main thread: create a coordinator.
coord = tf.train.Coordinator()
# Create 10 threads that run 'MyLoop()'
threads = [##threading.Thread(target=MyLoop, args=(coord,)) for i in xrange(10)]
# Start the threads and wait for all of them to stop.
for t in threads:
t.start()
coord.join(threads)
顯然,協(xié)調(diào)器能管理線程做不同的事情。他們并不需要像上述例子中一樣,全都是做同樣的事情。協(xié)調(diào)器同樣支持捕捉和報(bào)告異常。參見(jiàn)tf.train.Coordinator文檔獲取更詳細(xì)的信息。
QueueRunner
QueueRunner類創(chuàng)建一些線程,來(lái)反復(fù)的運(yùn)行入隊(duì)操作。這些線程可以使用一個(gè)協(xié)調(diào)器來(lái)同時(shí)終止。除此之外,一個(gè)隊(duì)列運(yùn)行器運(yùn)行一個(gè)關(guān)系更為密切的線程,它在協(xié)調(diào)器報(bào)告異常的情況下,自動(dòng)的關(guān)閉隊(duì)列。
你可以使用隊(duì)列UN星期來(lái)實(shí)現(xiàn)上述的架構(gòu)。
首先構(gòu)建一個(gè)圖,使用TenorFlow的隊(duì)列(比如tf.RandomShuffleQueue)來(lái)輸入樣本。然后添加操作來(lái)處理樣本,并講他們?nèi)腙?duì)。最后,添加以出隊(duì)元素開(kāi)始的訓(xùn)練操作。
example = ...ops to create one example...
# Create a queue, and an op that enqueues examples one at a time in the queue.
queue = tf.RandomShuffleQueue(...)
enqueue_op = queue.enqueue(example)
# Create a training graph that starts by dequeuing a batch of examples.
inputs = queue.dequeue_many(batch_size)
train_op = ...use 'inputs' to build the training part of the graph...
在Python的訓(xùn)練程序中,創(chuàng)建一個(gè)QueueRunner對(duì)象,會(huì)運(yùn)行多個(gè)線程來(lái)處理和入隊(duì)樣本。創(chuàng)建一個(gè)Coordinator的對(duì)象,并用coordinator(協(xié)調(diào)器)來(lái)要求隊(duì)列運(yùn)行器來(lái)啟動(dòng)它的線程。編寫(xiě)訓(xùn)練的循環(huán)也可以使用coordinator(協(xié)調(diào)器)。
# Create a queue runner that will run 4 threads in parallel to enqueue
# examples.
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
# Launch the graph.
sess = tf.Session()
# Create a coordinator, launch the queue runner threads.
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
# Run the training loop, controlling termination with the coordinator.
for step in xrange(1000000):
if coord.should_stop():
break
sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(enqueue_threads)
處理異常
以隊(duì)列運(yùn)行器開(kāi)始的線程不僅僅是運(yùn)行入隊(duì)操作。它們也會(huì)捕捉并處理由隊(duì)列產(chǎn)生的異常,這些異常包括tf.errors.OutOfRangeError異常,這個(gè)異常被用來(lái)報(bào)道隊(duì)列被關(guān)閉。
使用協(xié)調(diào)器的訓(xùn)練程序必須在其主循環(huán)中,同樣地捕捉并報(bào)道異常。
這里是上面訓(xùn)練循環(huán)的增強(qiáng)版。
try:
for step in xrange(1000000):
if coord.should_stop():
break
sess.run(train_op)
except Exception, e:
# Report exceptions to the coordinator.
coord.request_stop(e)
finally:
# Terminate as usual. It is safe to call `coord.request_stop()` twice.
coord.request_stop()
coord.join(threads)