利用Python裝飾器來組織Tensorflow代碼的結構

裝飾器

定義Python裝飾器

裝飾器是一種設計模式, 可以使用OOP中的繼承和組合實現, 而Python還直接從語法層面支持了裝飾器.
裝飾器可以在不改變函數定義的前提下, 在代碼運行期間動態增加函數的功能, 本質上就是將原來的函數與新加的功能包裝成一個新的函數wrapper, 并讓原函數的名字指向wrapper.

Python中實現decorator有兩種方式: 函數方式 和 類方式

函數方式

可以用一個返回函數的高階函數來實現裝飾器

簡單的無參數裝飾器

def log(func):
    def wrapper(*args, **kw):
        print('call %s():' % func.__name__)
        return func(*args, **kw)
    return wrapper
@log
def now():
    print('NOW')

在函數fun的定義前面放入@decorator實現的功能相當于fun=decorator(fun),
從而現在調用now()將打印前面的調用信息.

實現帶參數的裝飾器

只要給裝飾器提供參數后,返回的object具備一個無參數裝飾器的功能即可.
可以用返回無參數裝飾器函數的高階函數來實現.

def log(text):
    def decorator(func):
        def wrapper(*args, **kw):
            print('%s %s():' % (text, func.__name__))
            return func(*args, **kw)
        return wrapper
    return decorator

@log('execute')
def now():
  print("parametric NOW")

該語法糖相當于now=log('execute')(now).

如果要保存原函數的__name__屬性, 使用python的functools模塊中的wraps()裝飾器, 只需要將@functools.wraps(func)放在def wrapper()前面即可.該裝飾器實現的功能就相當于添加了wrapper.__name__ = func.__name__語句.

類方式

Python中的類和函數差別不大, 實現類的__call__ method就可以把類當成一個函數來使用了.

實現以上帶參數裝飾器同樣功能的裝飾器類的代碼如下:

class log():
    def __init__(self, text):
        self.text = text
    def __call__(self,func):
        @functools.wraps(func)
        def wrapper(*args, **kw):
            print("%s %s" % (self.text, func.__name__))
            return func(*args, **kw)
        return wrapper

@log("I love Python")
def now():
    print("class decorator NOW")

使用類的好處是可以繼承

使用場景

裝飾器最巧妙的使用場景在Flask和Django Web框架中,它可以用來檢查某人是否被授權使用Web應用的某個endpoint(假設是f函數), 下面是一個檢查授權的示意性代碼片段.

from functools import wraps

def require_auth(f):
  @wraps(f)
  def decorated(*args, **kw):
    auth = request.authorization
    if not auth or not check_auth(auth.username, auth.password):
      authenticate()
    return f(*args, **kw)
  return decorated

另一個常見的用處是用于日志記錄

from functools import wraps

def logit(func):
    @wraps(func)
    def with_logging(*args, **kwargs):
        print(func.__name__ + " was called")
        return func(*args, **kwargs)
    return with_logging

@logit
def addition_func(x):
   """Do some math."""
   return x + x

result = addition_func(4)

是不是超級靈活呢? 雖然裝飾器有點難定義, 但是一旦掌握, 它就像不可思議的魔法. Σ(*?д??)?

利用裝飾器改善你的Tensorflow代碼結構

重頭戲終于來了! 當你在寫Tensorflow代碼時, 定義模型的代碼和動態運行的代碼經常會混亂不清. 一方面, 我們希望定義compute graph的"靜態"Python代碼只執行一次, 而相反, 我們希望調用session來運行的代碼可以運行多次取得不同狀態的數據信息, 而兩類代碼一旦雜糅在一起, 很容易造成Graph中有冗余的nodes被定義了多次, 感覺十分不爽, 寫過那種丑代碼的你們都懂.

那么,如何以一種可讀又可復用的方式來組織你的TF代碼結構呢?

版本1

我們都希望用一個類來抽象一個模型, 這無疑是明智的. 但是如何定義類的接口呢?
我們的模型需要接受input的feature data和target value, 需要進行 training, evaluation 和 inference 操作.

class Model:

    def __init__(self, data, target):
        data_size = int(data.get_shape()[1])   # 假設data的shape為[N,D] N為Batch Size  D是輸入維度
        target_size = int(target.get_shape()[1]) # 假設target的shape為[N,K] K是one-hot的label深度, 即要分類的類的數量
        weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
        bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
        incoming = tf.matmul(data, weight) + bias
        self._prediction = tf.nn.softmax(incoming)
        cross_entropy = tf.reduce_mean(-tf.reduce_sum(target * tf.log(self._prediction), reduction_indices=[1]))
        self._optimize = tf.train.RMSPropOptimizer(0.03).minimize(cross_entropy)
        mistakes = tf.not_equal(
            tf.argmax(target, 1), tf.argmax(self._prediction, 1))
        self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))

    @property
    def prediction(self):
        return self._prediction

    @property
    def optimize(self):
        return self._optimize

    @property
    def error(self):
        return self._error

這是最基本的形式, 但是它存在很多問題. 最嚴重的問題是整個圖都被定義在init構造函數中, 這既不可讀又不可復用.

版本2

直接將代碼分離開來,放在多個函數中是不行的, 因為每次函數調用時都會向Graph中添加nodes, 所以我們必須確保這些Node Operations只在函數第一次調用的時候才添加到Graph中, 這有點類似于singleton模式, 或者叫做lazy-loading(使用時才創建).

class Model:

    def __init__(self, data, target):
        self.data = data
        self.target = target
        self._prediction = None
        self._optimize = None
        self._error = None

    @property
    def prediction(self):
        if not self._prediction:
            data_size = int(self.data.get_shape()[1])
            target_size = int(self.target.get_shape()[1])
            weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
            bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
            incoming = tf.matmul(self.data, weight) + bias
            self._prediction = tf.nn.softmax(incoming)
        return self._prediction

    @property
    def optimize(self):
        if not self._optimize:
             cross_entropy = tf.reduce_mean(-tf.reduce_sum(self.target * tf.log(self._prediction), reduction_indices=[1]))
            optimizer = tf.train.RMSPropOptimizer(0.03)
            self._optimize = optimizer.minimize(cross_entropy)
        return self._optimize

    @property
    def error(self):
        if not self._error:
            mistakes = tf.not_equal(
                tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
            self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
        return self._error

這好多了, 但是每次都需要if判斷還是有點太臃腫, 利用裝飾器, 我們可以做的更好!

版本3

實現一個自定義裝飾器lazy_property, 它的功能和property類似,但是只運行function一次, 然后將返回結果存在一個屬性中, 該屬性的名字是 "_cache_" + function.__name__, 后續函數調用將直接返回緩存好的屬性.

import functools

def lazy_property(function):
    attribute = '_cache_' + function.__name__

    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

使用該裝飾器, 優化后的代碼如下:

class Model:

    def __init__(self, data, target):
        self.data = data
        self.target = target
        self.prediction
        self.optimize
        self.error

    @lazy_property
    def prediction(self):
        data_size = int(self.data.get_shape()[1])
        target_size = int(self.target.get_shape()[1])
        weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
        bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
        incoming = tf.matmul(self.data, weight) + bias
        return tf.nn.softmax(incoming)

    @lazy_property
    def optimize(self):
        cross_entropy = tf.reduce_mean(-tf.reduce_sum(self.target * tf.log(self.prediction), reduction_indices=[1]))
        optimizer = tf.train.RMSPropOptimizer(0.03)
        return optimizer.minimize(cross_entropy)

    @lazy_property
    def error(self):
        mistakes = tf.not_equal(
            tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
        return tf.reduce_mean(tf.cast(mistakes, tf.float32))

注意, 在init構造函數中調用了屬性prediction,optimize和error, 這會讓其第一次執行, 因此構造函數完成后Compute Graph也就構建完畢了.

有時我們使用TensorBoard來可視化Graph時, 希望將相關的Node分組到一起, 這樣看起來更為清楚直觀, 我們只需要修改之前的lazy_property裝飾器, 在其中加上with tf.name_scope("name") 或者 with tf.variable_scope("name")即可, 修改之前的裝飾器如下:

import functools

def define_scope(function):
    attribute = '_cache_' + function.__name__

    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            with tf.variable_scope(function.__name__):
                setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

我們現在能夠用一種結構化和緊湊的方式來定義TensorFlow的模型了, 這歸功于Python的強大的decorator語法糖.

完整的代碼在這里, 有關該代碼的詳細注釋請參考我的博客.

References:

  1. https://danijar.com/structuring-your-tensorflow-models/
  2. https://www.liaoxuefeng.com/wiki/0014316089557264a6b348958f449949df42a6d3a2e542c000/0014318435599930270c0381a3b44db991cd6d858064ac0000
  3. https://eastlakeside.gitbooks.io/interpy-zh/content/decorators/deco_class.html
  4. https://www.liaoxuefeng.com/wiki/001374738125095c955c1e6d8bb493182103fac9270762a000/001386820062641f3bcc60a4b164f8d91df476445697b9e000
  5. https://www.tensorflow.org/get_started/mnist/beginners?hl=zh-cn#training
  6. https://mozillazg.github.io/2016/12/python-super-is-not-as-simple-as-you-thought.html
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 229,619評論 6 539
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 99,155評論 3 425
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事?!?“怎么了?”我有些...
    開封第一講書人閱讀 177,635評論 0 382
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,539評論 1 316
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 72,255評論 6 410
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,646評論 1 326
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,838評論 0 289
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 49,399評論 1 335
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 41,146評論 3 356
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 43,338評論 1 372
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,893評論 5 363
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,565評論 3 348
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,983評論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,257評論 1 292
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 52,059評論 3 397
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 48,296評論 2 376

推薦閱讀更多精彩內容

  • Python進階框架 希望大家喜歡,點贊哦首先感謝廖雪峰老師對于該課程的講解 一、函數式編程 1.1 函數式編程簡...
    Gaolex閱讀 5,510評論 6 53
  • 我還是那么想你。 當我能夠輕松自如地和別人談笑風生,當我不再輾轉反側失眠到深夜,當我平靜地把你的微信拖到了黑名單,...
    安夏的花花世界閱讀 257評論 0 0
  • 打開窗戶,感覺世界距離我 有著遙遠的路程, 而我也懶得動身前往。 ——《孤獨》城子玄
    城子玄閱讀 185評論 0 0
  • 有一種嘮叨,最容易使我們煩躁,那就是母親的嘮叨。這種嘮叨,都是源于她內心深處對我們的愛。 在生活...
    林琨皓閱讀 553評論 0 2