Tensorflow核心代碼解析之計算圖篇其一:計算圖結構初探

介紹

當今計算機科學給人對未來最大想象的莫過于人工智能的大規模應用前景。它對于人類文明進步所帶來的潛在貢獻可以被視為與四大發明、蒸汽機、電力、計算機等人類史標桿性工具具有同等的地位。當今人工智能的蓬勃發展主要是機器學習尤其是深度學習的大規模成功應用。憑著日益增多的海量數據,快速發展的計算機并行計算能力,快速迭代、高效更新的各種模型、算法、策略以及各個國家、政府、大企業對它的日益重視與超高的資金、政策投入,當今AI的發展速度真可謂一日千里!

深度學習最終要與具體行業場景有效率地結合起來才能發揮出其效益來。當下的整個深度學習已經行成了良好的產業鏈體系。最下端的是用于深度學習加速的各種硬件芯片如CPU/GPU/FPGA/ASIC專用芯片等。目前此領域里面GPU憑其優良的超多弱核心并行計算能力獨領風騷;但CPU在推理加速、成本等方面也挺有競爭力;FPGA憑其靈活性也非常適宜于進行AI方案的原型設計,但因其開發難度較大,生態相對較缺乏,當下大公司里面大規模部署應用FPGA的唯有微軟;至于ASIC專用芯片,可謂是給了諸多有心在AI半導體上面實現彎道超車的公司一個很好的機會,尤其是那些有著大規模機器學習用戶,可以基于上層封裝提供AI云計算實例服務的公司如Google就聲勢浩大地推出了自己快速迭代著的TPU,號稱常規模型(如Resnet-50)加速快于最新的GPU,同時功耗更為節省,其它像Amazon,Facebook,Ali等云計算公司也在搞自己的AI專用ASIC芯片。此外一些創業型公司也對此雄心勃勃,國內已經涌現出了一堆此類的獨角獸像賣自家礦機發了大財的Bit大陸,技術勢力雄厚的地平線等等。還有些手機大廠則將精力用在了終端一側的AI芯片研發像華為、蘋果等已經有了自己的AI芯片并部署在了自己的手機新品當中,其它像小米、三星等也在紛紛跟進。需要提下的是ARM這個在移動時代的主要得意者也于最近發布了自己家設計的用于AI加速的各種硬件IP。

AI芯片

計算芯片向上走則是將一些基本運算如矩陣乘積、各類型卷積運算等結合硬件平臺優化過了的數學計算庫如用于Intel CPU端的MKL/MKLDNN,用于nVidia GPU的CUDA/cuDNN,用于FPGA專用DNN網絡加速過的openCL,還有針對各大ASIC芯片產商針對自己ASIC加速過了的種種計算庫套件。這些數學計算庫基本由AI芯片產商自己來完成,目的即在于借用軟件的力量給自家的硬件以強大的驅動力。一般他們會選擇將優化過了的核心程序開源、公布出來為自己的客戶所借鑒、使用,最終通過AI芯片的出售來獲得價值。值得一提的是nVIDIA在基于CUDA/cuDNN上面的多年耕耘,相關社區、生態的耐心培養直接帶來了它們今天的巨大成功。現在半導體公司已經再不同往日只需要賣出芯片即可了。芯片相關的軟件庫(并行計算庫等)的性能,對用戶提供API的友好性,用戶社區的培養,用戶支持的力度等等軟實力真正是愈益重要。

底層計算庫再向上走則是使用這些優化過了的數學計算庫來完成基本類型計算,然后將之抽象封裝后向上提供出友好用戶API的深度學習框架。這些深度學習框架是我們軟件人員開發應用的基本工具。當前最流行的框架有Google的Tensorflow,Facebook的Pytorch,Amazon的Mxnet,微軟的CNTK,當然還有傳統社區在維護的bvlc/Caffe等等。它們大都提供類似的功能,相似的API用于用戶程序構建計算圖,并能將圖方便地導入、導出為序列化文件,還提供了基于Framework level對圖的一些優化如合并(fusion),去重(典型的如CSE),并行計算(通過使用OMP等并行庫)等,此外還有一些功能如用于進行內存分配、管理及線程執行、調度、檢測的session/workspace等,當然還有用于具體執行某計算的op / kernel等,這也是常規計算優化的核心所在。

在這些計算框架中,無疑Google brain團隊開發的Tensorflow是最為流行的。它的框架設計最為復雜,可以天生地支持模型并行訓練、推理等,它的背后有一個google強大的開發團隊在快速迭代、開發,它的底下也有集成當前最好的像cuDNN/mklDNN/TensorRT等加速技術,它的用戶社區也已經非常完善、活躍(作為程序員這個還是蠻重要的。畢竟在APP開發中出了問題,肯定都希望通過在網上翻一下看有沒有人踩過類似的坑以來快速解決問題。)。

當下對于如何使用Tensorflow來開發一個AI程序,構建深度學習模型并進行訓練或推理的文章已經很多了。本系列單元中筆者想試著跟大家一起理一下它框架核心的一些代碼實現。無益它對于我們基于TF做一些開發,加深對TF的理解是很有幫助的。此外TF框架的設計、代碼實現非常良好,對它們的理解、梳理清楚對于我們日常的軟件設計、開發也會有較強的工程借鑒意義。

計算圖

計算圖(Graph)描述了一組需要依次序完成的計算單元以及表示這些計算單元之間相互依賴的關系。一般的深度學習模型都會被分化組裝成一個單向無環圖(DAG)來執行。圖當中的結點(node)用來表示某一具體的計算單元(如Multmul結點表示兩個張量之間的乘積,Conv結點則表示兩個張量之間的卷積計算)。圖上的片(edge)則被用來表示兩個結點之間的依賴關系。比如A結點的第i個輸出來自于B結點的第j輸入,那么就會構成(B,j) -> (A,i)這么一條邊來,如此結點A的執行就對結點B構成依賴。

在Tensorflow的計算圖中,一般會包含兩個特殊的結點分別為Source節點(也稱Start節點)與Sink節點(也稱為Finish節點)。其中Source節點表示此節點不依賴于任何其它節點作為其輸入,而Sink節點則表示該節點并無任何輸出來作為其它節點的輸入。

Tensorflow計算圖示例

class graph代碼實例

  • Tensorflow中Graph構造的描述可見于class Graph當中(可在core/graph/graph.h中找到其定義)
class Graph {
 public:
  // Constructs a graph with a single SOURCE (always id kSourceId) and a
  // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
  //
  // The graph can hold ops found in registry. `registry`s lifetime must be at
  // least that of the constructed graph's.
  explicit Graph(const OpRegistryInterface* registry);

  // Constructs a graph with a single SOURCE (always id kSourceId) and a
  // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
  //
  // The graph can hold ops found in `flib_def`. Unlike the constructor taking
  // an OpRegistryInterface, this constructor copies the function definitions in
  // `flib_def` so its lifetime may be shorter than that of the graph's. The
  // OpRegistryInterface backing `flib_def` must still have the lifetime of the
  // graph though.
  explicit Graph(const FunctionLibraryDefinition& flib_def);

以上構造函數當中,我們看到Graph只需引入一個參數即OpRegistryInterface或FunctionLibraryDefinition。這兩個參數提供了具體每個節點的實際執行定義。在我們構建計算圖的時候,我們找到一個node的nodeDef(通常是基于google protocol buffer協議的node參數定義)后,會在OpRegistryInterface或FunctionLibraryDefinition當中去獲取其具體的類型實現。也就是說我們如果實現了一個在某硬件平臺上優化過了的Op或一種嶄新的Op操作,為了將此操作能夠作為計算圖的一個節點為我們的模型所用,那么需要將此新創建的Op實現函數注冊于OpRegistryInterface或FunctionLibraryDefinition結構當中。

  • 計算圖中有對Node與Edge的Add/Remove/Update等操作

其函數接口如下。具體的定義可見于core/graph/graph.cc當中。本身實現起來因為是使用了指針鏈表結構的DAG所以還是比較簡單、容易理解的,在此就不多說了。

 // Adds a new node to this graph, and returns it. Infers the Op and
  // input/output types for the node. *this owns the returned instance.
  // Returns nullptr and sets *status on error.
  Node* AddNode(const NodeDef& node_def, Status* status);

  // Copies *node, which may belong to another graph, to a new node,
  // which is returned.  Does not copy any edges.  *this owns the
  // returned instance.
  Node* CopyNode(const Node* node);

  // Removes a node from this graph, including all edges from or to it.
  // *node should not be accessed after calling this function.
  // REQUIRES: node->IsOp()
  void RemoveNode(Node* node);

  // Adds an edge that connects the xth output of `source` to the yth input of
  // `dest` and returns it. Does not update dest's NodeDef.
  const Edge* AddEdge(Node* source, int x, Node* dest, int y);
  
  // Removes edge from the graph. Does not update the destination node's
  // NodeDef.
  // REQUIRES: The edge must exist.
  void RemoveEdge(const Edge* edge);
  // Updates the input to a node.  The existing edge to `dst` is removed and an
  // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
  // is also updated.
  Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
  • 圖之上的Op函數庫

class graph中有一個類成員為 FunctionLibraryDefinition ops_,其中包含了所有已知的具體類型的Op函數定義。而我們可利用以下函數來增加、拓展其Op函數庫。

// Adds the function and gradient definitions in `fdef_lib` to this graph's op
  // registry. Ignores duplicate functions, and returns a bad status if an
  // imported function differs from an existing function or op with the same
  // name.
  Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
  • Node節點對應的宿主設備

Tensorflow當中計算圖的執行是并發的。圖上的每個Node都可被分布在不同的計算設備上計算。TF有提供API可以讓我們指定某個Op操作的宿主設備。當然也有函數用來提供相應的查詢操作。如下所示,見名可知其義。

const string& get_assigned_device_name(const Node& node) const {
    return device_names_[node.assigned_device_name_index()];
  }

  void set_assigned_device_name_index(Node* node, int device_name_index) {
    CheckDeviceNameIndex(device_name_index);
    node->assigned_device_name_index_ = device_name_index;
  }

  void set_assigned_device_name(Node* node, const string& device_name) {
    node->assigned_device_name_index_ = InternDeviceName(device_name);
  }

參考文獻

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容