Keras為我們提供了很多已經定義好的網絡,比如Embdding層,LSTM層,GRU等。但是在有些情況下,這些預先定義好的網絡層并不能很好的滿足我們的需求,這個時候我們就需要自定義網絡層。
當然,通過線上,我們可以很方便的查閱到大把資料關于使用keras定義自己的網絡層,很多blog都直入主題,直接告訴我們自定義layer需要涉及到三個方法---------build(),call(),以及compute_output_shape(),處理好這幾個方法,我們便可以實現我們所需要的功能。但是知其然也要知其所以然,因而今天我們通過閱讀keras/engine/base_layer中的Layer類,來更好的理解整個網絡層的運行過程。
上面我將Layer類中的一些關鍵方法貼出來,為了更加直觀理解,方法的具體內容都刪了,后面具體分析的時候,在進行補充。
首先我們可以看到,在Layer類中有兩個特殊的方法,__init__()和__call__()。
__init__()是構造方法,當我們建立類對象時,首先調用該方法初始化類對象。
__call__()是可調用方法,一旦實現該方法,我們的類對象在某些行為上可以表現的和函數一樣??梢灾苯油ㄟ^類對象object()進行調用。下面舉個例子。
上面我們定義了一個類,我們可以發現可以直接通過類對象()來調用__call__方法?!? obj()等價于obj.__call__()? 】
這里我們也可以明白為什么平時可以直接使用例如? LSTM(32)(input)這種形式來添加網絡層。其實這種形式本質是
上面的實例,我們也可以知道,在layer中__call__方法的參數是input,返回值是output。那么__call__方法究竟做了什么?下面貼關鍵源碼(方便理解整個流程,貼完整的不易理解)感興趣的可以對照源碼理解。
OK!? 觀察上述代碼,我們可以知道在__call__()方法中有幾個關鍵操作,調用build(),調用call(),調用compute_output_shape(),最后再利用node將該層和上一層鏈接起來(如何鏈接可以不用關心)。
emmmmmmm,到這里,其實就是整個網絡層的運行流程了。大家看懂了就可以撤了。
(somebody :"。。。。。。。what Fuck! 你這講的都是啥,我還云里霧里呢!")
好吧,為了不讓網友罵,我接著將build()等幾個方法具體分析。
build():我們知道,當我們定義網絡層的時候,需要用到一些張量(tensor)來對我們的輸入進行操作。比如權重信息Weights,偏差Biases。其實一個網絡本身就可以理解為這些張量的集合。keras是如何在我們給定input以及output_dim的情況下定義這些張量的呢?這里主要就是build()方法的功勞了。build函數就是為該網絡定義一層相應的張量的集合。在Layer類中有兩個成員變量,分別是trainable_weights和non_trainable_weights,分別是指可以訓練的參數的集合和不可訓練的參數的集合。這兩個參數都是list。在build中建立的張量通過add_weight()方法加入到上面兩個張量集合中,進而建立網絡層。需要注意的是,一個網絡層的參數是固定的,我們不能重復添加,因此,build()方法最多只能調用一次。如何保證每個layer的build()最多調用一次???這是通過self._built變量來控制的。如果built變量為True,那么build()方法將不再會被調用,否則build()才能被調用。在調用之后built會被賦值為True,防止以后build()被重復調用。這在__call__()方法中有體現。所以我們如何沒有重新寫__call__()方法,那么我們不用擔心build()方法會被多次調用。但是如果重新寫了__call__()方法,一定要注意在build()調用之后,將built置為True。【TIP:build只接受input一個參數,所以如果需要用到output_shape,可以在__init__()中將output_shape賦值給一個成員變量,這樣就可以在build中直接使用output_shape的值了】。舉個簡單例子,以output=tanh(X*W+B),我們首先定義build()函數。這里用到的參數分別是W和B。假設輸出的大小為output_dim,且已經在__init__()中已經初始化了。
call(): 該方法是整個網絡層的邏輯輸出。通過build(),我們已經有了網絡層的權重等信息,接下來便是通過input以及這些權重張量(W,B)等來獲得輸出了。如何得到output就要根據大家需要的功能來說了。//////該函數返回值是output,__call__()方法也是通過調用call()來獲得輸出output。
comput_output_shape():返回輸出的形狀,便于keras搭建下一層網絡時,可以自行推導出輸入的形狀
好了,以上便大功告成了【第一次寫blog,希望能幫助大家更好理解keras網絡層的整個控制流程】