CNN遷移學習vgg16實踐

目錄

  • 應用場景
  • prerequisite知識
  • 代碼實例
  • 結(jié)論

應用場景

假如我們有一系列訴求是把圖片識別成一個特定分類、比如

  1. 把圖片分類成為貓、狗、狼等
  2. 把圖片分類成為奔馳、寶馬、奧迪
  3. ...

幾乎很少有人從頭訓練網(wǎng)絡、復用只有訓練的網(wǎng)絡參數(shù)適應新的數(shù)據(jù)集、參考transfer-learning

In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.

prerequisite知識

. CNN卷積過程
. TensorFlow的接口

可視化下貼上caffemodel定義可以查看網(wǎng)絡結(jié)構(gòu)、以下是vgg16前幾層的參考


層數(shù)越往上激活的圖片就約簡單、所以更容易被共享;拿用image Net訓練好1000分類的網(wǎng)絡參數(shù)可以認為前幾層幾乎都是訓練好的、替換最后面fc層、換成目標的分類的個數(shù)
假如我們識別的是貓狗、那么fc就兩個分類、最后一層需要重新訓練

代碼實例

基于TensorFlow vgg16 fine tuning
卷積矩陣大小變化變化可以參考過程、

其中涉及數(shù)據(jù)預處理可以參考neural-networks-2

Mean subtraction is the most common form of preprocessing. It involves subtracting the mean across every individual feature in the data, and has the geometric interpretation of centering the cloud of data around the origin along every dimension. In numpy, this operation would be implemented as: X -= np.mean(X, axis = 0). With images specifically, for convenience it can be common to subtract a single value from all pixels (e.g. X -= np.mean(X)), or to do so separately across the three color channels.
代碼如下:

"""
#訓練好的參數(shù)http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
目錄結(jié)構(gòu)
  train/
    貓/
      COCO_train2014_000000005785.jpg
      COCO_train2014_000000015870.jpg
    ??/
  val/
    貓/
    狗/
"""
import argparse
import os

import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets


parser = argparse.ArgumentParser()
#訓練數(shù)據(jù)目錄
parser.add_argument('--train_dir', default='train')
#測試目錄
parser.add_argument('--val_dir', default='val')
#初始網(wǎng)絡參數(shù)
parser.add_argument('--model_path', default='vgg_16.ckpt', type=str)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--num_epochs1', default=10, type=int)
parser.add_argument('--num_epochs2', default=10, type=int)
parser.add_argument('--learning_rate1', default=1e-3, type=float)
parser.add_argument('--learning_rate2', default=1e-5, type=float)
parser.add_argument('--dropout_keep_prob', default=0.5, type=float)
parser.add_argument('--weight_decay', default=5e-4, type=float)

#平化圖像參數(shù)
VGG_MEAN = [123.68, 116.78, 103.94]


def list_images(directory):
    labels = os.listdir(directory)
    files_and_labels = []
    for label in labels:
        for f in os.listdir(os.path.join(directory, label)):
            files_and_labels.append((os.path.join(directory, label, f), label))

    filenames, labels = zip(*files_and_labels)
    filenames = list(filenames)
    labels = list(labels)
    unique_labels = list(set(labels))

    label_to_int = {}
    for i, label in enumerate(unique_labels):
        label_to_int[label] = i

    labels = [label_to_int[l] for l in labels]

    return filenames, labels


def check_accuracy(sess, correct_prediction, is_training, dataset_init_op):
    # Initialize the correct dataset
    sess.run(dataset_init_op)
    num_correct, num_samples = 0, 0
    while True:
        try:
            correct_pred = sess.run(correct_prediction, {is_training: False})
            num_correct += correct_pred.sum()
            num_samples += correct_pred.shape[0]
        except tf.errors.OutOfRangeError:
            break

    acc = float(num_correct) / num_samples
    return acc


def main(args):
    # 拿訓練&測試文件和label
    train_filenames, train_labels = list_images(args.train_dir)
    val_filenames, val_labels = list_images(args.val_dir)

    num_classes = len(set(train_labels))


    graph = tf.Graph()
    with graph.as_default():
        #讀圖
        def _parse_function(filename, label):
            image_string = tf.read_file(filename)
            image_decoded = tf.image.decode_jpeg(image_string, channels=3)          
            image = tf.cast(image_decoded, tf.float32)

            smallest_side = 256.0
            height, width = tf.shape(image)[0], tf.shape(image)[1]
            height = tf.to_float(height)
            width = tf.to_float(width)
            #縮放
            scale = tf.cond(tf.greater(height, width),
                            lambda: smallest_side / width,
                            lambda: smallest_side / height)
            new_height = tf.to_int32(height * scale)
            new_width = tf.to_int32(width * scale)

            resized_image = tf.image.resize_images(image, [new_height, new_width])  # (2)
            return resized_image, label

        #均值數(shù)據(jù)處理
        def training_preprocess(image, label):
            crop_image = tf.random_crop(image, [224, 224, 3])                       # (3)
            flip_image = tf.image.random_flip_left_right(crop_image)                # (4)

            means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
            centered_image = flip_image - means                                     # (5)

            return centered_image, label

        # 預處理、取224*224中間區(qū)域、減平均值
        def val_preprocess(image, label):
            crop_image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)    # (3)

            means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
            centered_image = crop_image - means                                     # (4)

            return centered_image, label

        train_filenames = tf.constant(train_filenames)
        train_labels = tf.constant(train_labels)
        train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels))
        train_dataset = train_dataset.map(_parse_function,
            num_threads=args.num_workers, output_buffer_size=args.batch_size)
        train_dataset = train_dataset.map(training_preprocess,
            num_threads=args.num_workers, output_buffer_size=args.batch_size)
        train_dataset = train_dataset.shuffle(buffer_size=10000) 
        batched_train_dataset = train_dataset.batch(args.batch_size)


        val_filenames = tf.constant(val_filenames)
        val_labels = tf.constant(val_labels)
        val_dataset = tf.contrib.data.Dataset.from_tensor_slices((val_filenames, val_labels))
        val_dataset = val_dataset.map(_parse_function,
            num_threads=args.num_workers, output_buffer_size=args.batch_size)
        val_dataset = val_dataset.map(val_preprocess,
            num_threads=args.num_workers, output_buffer_size=args.batch_size)
        batched_val_dataset = val_dataset.batch(args.batch_size)


        #迭代器讀圖&label
        iterator = tf.contrib.data.Iterator.from_structure(batched_train_dataset.output_types,
                                                           batched_train_dataset.output_shapes)
        images, labels = iterator.get_next()
        
        #初始化迭代器函數(shù)
        train_init_op = iterator.make_initializer(batched_train_dataset)
        val_init_op = iterator.make_initializer(batched_val_dataset)

        #傳給vgg16網(wǎng)絡、標識正向分類或者是訓練網(wǎng)絡參數(shù)
        is_training = tf.placeholder(tf.bool)
      
        vgg = tf.contrib.slim.nets.vgg
        with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=args.weight_decay)):
            #使用TensorFlow封裝好的網(wǎng)絡、設置輸出分類個數(shù)
            logits, _ = vgg.vgg_16(images, num_classes=num_classes, is_training=is_training,
                                   dropout_keep_prob=args.dropout_keep_prob)

        model_path = args.model_path
        assert(os.path.isfile(model_path))

        # 加載fc8之前網(wǎng)絡參數(shù)
        variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['vgg_16/fc8'])
        init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore)

        # 獲取fc8初始化函數(shù)
        fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8')
        fc8_init = tf.variables_initializer(fc8_variables)

        # loss疊加到tf.GraphKeys.LOSSES 結(jié)合上
        tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
        loss = tf.losses.get_total_loss()

        #先訓練fc8這一層的參數(shù)
        fc8_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate1)
        fc8_train_op = fc8_optimizer.minimize(loss, var_list=fc8_variables)

        # 然后再去整體訓練
        full_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate2)
        full_train_op = full_optimizer.minimize(loss)

        # 評估模型
        prediction = tf.to_int32(tf.argmax(logits, 1))
        correct_prediction = tf.equal(prediction, labels)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        tf.get_default_graph().finalize()

    with tf.Session(graph=graph) as sess:
        #加載conv1-fc7的參數(shù)
        init_fn(sess)
        #初始化fc的參數(shù)
        sess.run(fc8_init)

       #迭代
        for epoch in range(args.num_epochs1):
            sess.run(train_init_op)
            while True:
                try:
                    #文件和label已經(jīng)在迭代器中
                    _ = sess.run(fc8_train_op, {is_training: True})
                except tf.errors.OutOfRangeError:
                    break
            
            train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
            val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
            


        # 整體訓練
        for epoch in range(args.num_epochs2):
            print('Starting epoch %d / %d' % (epoch + 1, args.num_epochs2))
            sess.run(train_init_op)
            while True:
                try:
                    _ = sess.run(full_train_op, {is_training: True})
                except tf.errors.OutOfRangeError:
                    break

            train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
            val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
            print('Train accuracy: %f' % train_acc)
            print('Val accuracy: %f\n' % val_acc)


if __name__ == '__main__':
    args = parser.parse_args()
    main(args)

vgg16在TensorFlow封裝slim庫中,函數(shù)原型

def vgg_16(inputs,
           num_classes=1000,
           is_training=True,
           dropout_keep_prob=0.5,
           spatial_squeeze=True,
           scope='vgg_16'):
  """Oxford Net VGG 16-Layers version D Example.
  Note: All the fully_connected layers have been transformed to conv2d layers.
        To use in classification mode, resize input to 224x224.
  Args:
    inputs: a tensor of size [batch_size, height, width, channels].
    num_classes: number of predicted classes.
    is_training: whether or not the model is being trained.
    dropout_keep_prob: the probability that activations are kept in the dropout
      layers during training.
    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
      outputs. Useful to remove unnecessary dimensions for classification.
    scope: Optional scope for the variables.
  Returns:
    the last op containing the log predictions and end_points dict.

這個例子不需要GPU的支持、在osx就可以跑


結(jié)論

通常工程同學不會設計新的網(wǎng)絡結(jié)構(gòu)、甚至很少大改一個網(wǎng)絡機構(gòu)、但是理解網(wǎng)絡結(jié)構(gòu)、loss漸進方式有利于遷移學習、用到特定的場景

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內(nèi)容