利用Python裝飾器來組織Tensorflow代碼的結構

裝飾器

定義Python裝飾器

裝飾器是一種設計模式, 可以使用OOP中的繼承和組合實現, 而Python還直接從語法層面支持了裝飾器.
裝飾器可以在不改變函數定義的前提下, 在代碼運行期間動態增加函數的功能, 本質上就是將原來的函數與新加的功能包裝成一個新的函數wrapper, 并讓原函數的名字指向wrapper.

Python中實現decorator有兩種方式: 函數方式 和 類方式

函數方式

可以用一個返回函數的高階函數來實現裝飾器

簡單的無參數裝飾器

def log(func):
    def wrapper(*args, **kw):
        print('call %s():' % func.__name__)
        return func(*args, **kw)
    return wrapper
@log
def now():
    print('NOW')

在函數fun的定義前面放入@decorator實現的功能相當于fun=decorator(fun),
從而現在調用now()將打印前面的調用信息.

實現帶參數的裝飾器

只要給裝飾器提供參數后,返回的object具備一個無參數裝飾器的功能即可.
可以用返回無參數裝飾器函數的高階函數來實現.

def log(text):
    def decorator(func):
        def wrapper(*args, **kw):
            print('%s %s():' % (text, func.__name__))
            return func(*args, **kw)
        return wrapper
    return decorator

@log('execute')
def now():
  print("parametric NOW")

該語法糖相當于now=log('execute')(now).

如果要保存原函數的__name__屬性, 使用python的functools模塊中的wraps()裝飾器, 只需要將@functools.wraps(func)放在def wrapper()前面即可.該裝飾器實現的功能就相當于添加了wrapper.__name__ = func.__name__語句.

類方式

Python中的類和函數差別不大, 實現類的__call__ method就可以把類當成一個函數來使用了.

實現以上帶參數裝飾器同樣功能的裝飾器類的代碼如下:

class log():
    def __init__(self, text):
        self.text = text
    def __call__(self,func):
        @functools.wraps(func)
        def wrapper(*args, **kw):
            print("%s %s" % (self.text, func.__name__))
            return func(*args, **kw)
        return wrapper

@log("I love Python")
def now():
    print("class decorator NOW")

使用類的好處是可以繼承

使用場景

裝飾器最巧妙的使用場景在Flask和Django Web框架中,它可以用來檢查某人是否被授權使用Web應用的某個endpoint(假設是f函數), 下面是一個檢查授權的示意性代碼片段.

from functools import wraps

def require_auth(f):
  @wraps(f)
  def decorated(*args, **kw):
    auth = request.authorization
    if not auth or not check_auth(auth.username, auth.password):
      authenticate()
    return f(*args, **kw)
  return decorated

另一個常見的用處是用于日志記錄

from functools import wraps

def logit(func):
    @wraps(func)
    def with_logging(*args, **kwargs):
        print(func.__name__ + " was called")
        return func(*args, **kwargs)
    return with_logging

@logit
def addition_func(x):
   """Do some math."""
   return x + x

result = addition_func(4)

是不是超級靈活呢? 雖然裝飾器有點難定義, 但是一旦掌握, 它就像不可思議的魔法. Σ(*?д??)?

利用裝飾器改善你的Tensorflow代碼結構

重頭戲終于來了! 當你在寫Tensorflow代碼時, 定義模型的代碼和動態運行的代碼經常會混亂不清. 一方面, 我們希望定義compute graph的"靜態"Python代碼只執行一次, 而相反, 我們希望調用session來運行的代碼可以運行多次取得不同狀態的數據信息, 而兩類代碼一旦雜糅在一起, 很容易造成Graph中有冗余的nodes被定義了多次, 感覺十分不爽, 寫過那種丑代碼的你們都懂.

那么,如何以一種可讀又可復用的方式來組織你的TF代碼結構呢?

版本1

我們都希望用一個類來抽象一個模型, 這無疑是明智的. 但是如何定義類的接口呢?
我們的模型需要接受input的feature data和target value, 需要進行 training, evaluation 和 inference 操作.

class Model:

    def __init__(self, data, target):
        data_size = int(data.get_shape()[1])   # 假設data的shape為[N,D] N為Batch Size  D是輸入維度
        target_size = int(target.get_shape()[1]) # 假設target的shape為[N,K] K是one-hot的label深度, 即要分類的類的數量
        weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
        bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
        incoming = tf.matmul(data, weight) + bias
        self._prediction = tf.nn.softmax(incoming)
        cross_entropy = tf.reduce_mean(-tf.reduce_sum(target * tf.log(self._prediction), reduction_indices=[1]))
        self._optimize = tf.train.RMSPropOptimizer(0.03).minimize(cross_entropy)
        mistakes = tf.not_equal(
            tf.argmax(target, 1), tf.argmax(self._prediction, 1))
        self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))

    @property
    def prediction(self):
        return self._prediction

    @property
    def optimize(self):
        return self._optimize

    @property
    def error(self):
        return self._error

這是最基本的形式, 但是它存在很多問題. 最嚴重的問題是整個圖都被定義在init構造函數中, 這既不可讀又不可復用.

版本2

直接將代碼分離開來,放在多個函數中是不行的, 因為每次函數調用時都會向Graph中添加nodes, 所以我們必須確保這些Node Operations只在函數第一次調用的時候才添加到Graph中, 這有點類似于singleton模式, 或者叫做lazy-loading(使用時才創建).

class Model:

    def __init__(self, data, target):
        self.data = data
        self.target = target
        self._prediction = None
        self._optimize = None
        self._error = None

    @property
    def prediction(self):
        if not self._prediction:
            data_size = int(self.data.get_shape()[1])
            target_size = int(self.target.get_shape()[1])
            weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
            bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
            incoming = tf.matmul(self.data, weight) + bias
            self._prediction = tf.nn.softmax(incoming)
        return self._prediction

    @property
    def optimize(self):
        if not self._optimize:
             cross_entropy = tf.reduce_mean(-tf.reduce_sum(self.target * tf.log(self._prediction), reduction_indices=[1]))
            optimizer = tf.train.RMSPropOptimizer(0.03)
            self._optimize = optimizer.minimize(cross_entropy)
        return self._optimize

    @property
    def error(self):
        if not self._error:
            mistakes = tf.not_equal(
                tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
            self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
        return self._error

這好多了, 但是每次都需要if判斷還是有點太臃腫, 利用裝飾器, 我們可以做的更好!

版本3

實現一個自定義裝飾器lazy_property, 它的功能和property類似,但是只運行function一次, 然后將返回結果存在一個屬性中, 該屬性的名字是 "_cache_" + function.__name__, 后續函數調用將直接返回緩存好的屬性.

import functools

def lazy_property(function):
    attribute = '_cache_' + function.__name__

    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

使用該裝飾器, 優化后的代碼如下:

class Model:

    def __init__(self, data, target):
        self.data = data
        self.target = target
        self.prediction
        self.optimize
        self.error

    @lazy_property
    def prediction(self):
        data_size = int(self.data.get_shape()[1])
        target_size = int(self.target.get_shape()[1])
        weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
        bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
        incoming = tf.matmul(self.data, weight) + bias
        return tf.nn.softmax(incoming)

    @lazy_property
    def optimize(self):
        cross_entropy = tf.reduce_mean(-tf.reduce_sum(self.target * tf.log(self.prediction), reduction_indices=[1]))
        optimizer = tf.train.RMSPropOptimizer(0.03)
        return optimizer.minimize(cross_entropy)

    @lazy_property
    def error(self):
        mistakes = tf.not_equal(
            tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
        return tf.reduce_mean(tf.cast(mistakes, tf.float32))

注意, 在init構造函數中調用了屬性prediction,optimize和error, 這會讓其第一次執行, 因此構造函數完成后Compute Graph也就構建完畢了.

有時我們使用TensorBoard來可視化Graph時, 希望將相關的Node分組到一起, 這樣看起來更為清楚直觀, 我們只需要修改之前的lazy_property裝飾器, 在其中加上with tf.name_scope("name") 或者 with tf.variable_scope("name")即可, 修改之前的裝飾器如下:

import functools

def define_scope(function):
    attribute = '_cache_' + function.__name__

    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            with tf.variable_scope(function.__name__):
                setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

我們現在能夠用一種結構化和緊湊的方式來定義TensorFlow的模型了, 這歸功于Python的強大的decorator語法糖.

完整的代碼在這里, 有關該代碼的詳細注釋請參考我的博客.

References:

  1. https://danijar.com/structuring-your-tensorflow-models/
  2. https://www.liaoxuefeng.com/wiki/0014316089557264a6b348958f449949df42a6d3a2e542c000/0014318435599930270c0381a3b44db991cd6d858064ac0000
  3. https://eastlakeside.gitbooks.io/interpy-zh/content/decorators/deco_class.html
  4. https://www.liaoxuefeng.com/wiki/001374738125095c955c1e6d8bb493182103fac9270762a000/001386820062641f3bcc60a4b164f8d91df476445697b9e000
  5. https://www.tensorflow.org/get_started/mnist/beginners?hl=zh-cn#training
  6. https://mozillazg.github.io/2016/12/python-super-is-not-as-simple-as-you-thought.html
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容

  • Python進階框架 希望大家喜歡,點贊哦首先感謝廖雪峰老師對于該課程的講解 一、函數式編程 1.1 函數式編程簡...
    Gaolex閱讀 5,532評論 6 53
  • 我還是那么想你。 當我能夠輕松自如地和別人談笑風生,當我不再輾轉反側失眠到深夜,當我平靜地把你的微信拖到了黑名單,...
    安夏的花花世界閱讀 264評論 0 0
  • 打開窗戶,感覺世界距離我 有著遙遠的路程, 而我也懶得動身前往。 ——《孤獨》城子玄
    城子玄閱讀 191評論 0 0
  • 有一種嘮叨,最容易使我們煩躁,那就是母親的嘮叨。這種嘮叨,都是源于她內心深處對我們的愛。 在生活...
    林琨皓閱讀 565評論 0 2