TensorFlow-4: tf.contrib.learn 快速入門

學習資料:
https://www.tensorflow.org/get_started/tflearn

相應的中文翻譯:
http://studyai.site/2017/03/05/%E3%80%90Tensorflow%20r1.0%20%E6%96%87%E6%A1%A3%E7%BF%BB%E8%AF%91%E3%80%91%E3%80%90tf.contrib.learn%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8%E3%80%91/


今天學習用 tf.contrib.learn 來建立 DNN 對 Iris 數據集進行分類.

問題:
我們有 Iris 數據集,它包含150個樣本數據,分別來自三個品種,每個品種有50個樣本,每個樣本具有四個特征,以及它屬于哪一類,分別由 0,1,2 代表三個品種。
我們將這150個樣本分為兩份,一份是訓練集具有120個樣本,另一份是測試集具有30個樣本。
我們要做的就是建立一個神經網絡分類模型對每個樣本進行分類,識別它是哪個品種。

一共有 5 步:

  • 導入 CSV 格式的數據集
  • 建立神經網絡分類模型
  • 用訓練數據集訓練模型
  • 評價模型的準確率
  • 對新樣本數據進行分類

代碼:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import urllib

import numpy as np
import tensorflow as tf

# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

def main():
  # If the training and test sets aren't stored locally, download them.
  if not os.path.exists(IRIS_TRAINING):
    raw = urllib.urlopen(IRIS_TRAINING_URL).read()
    with open(IRIS_TRAINING, "w") as f:
      f.write(raw)

  if not os.path.exists(IRIS_TEST):
    raw = urllib.urlopen(IRIS_TEST_URL).read()
    with open(IRIS_TEST, "w") as f:
      f.write(raw)

  # Load datasets.
  training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TRAINING,
      target_dtype=np.int,
      features_dtype=np.float32)
  test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TEST,
      target_dtype=np.int,
      features_dtype=np.float32)

  # Specify that all features have real-value data
  feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

  # Build 3 layer DNN with 10, 20, 10 units respectively.
  classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                              hidden_units=[10, 20, 10],
                                              n_classes=3,
                                              model_dir="/tmp/iris_model")
  # Define the training inputs
  def get_train_inputs():
    x = tf.constant(training_set.data)
    y = tf.constant(training_set.target)

    return x, y

  # Fit model.
  classifier.fit(input_fn=get_train_inputs, steps=2000)

  # Define the test inputs
  def get_test_inputs():
    x = tf.constant(test_set.data)
    y = tf.constant(test_set.target)

    return x, y

  # Evaluate accuracy.
  accuracy_score = classifier.evaluate(input_fn=get_test_inputs,
                                       steps=1)["accuracy"]

  print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

  # Classify two new flower samples.
  def new_samples():
    return np.array(
      [[6.4, 3.2, 4.5, 1.5],
       [5.8, 3.1, 5.0, 1.7]], dtype=np.float32)

  predictions = list(classifier.predict(input_fn=new_samples))

  print(
      "New Samples, Class Predictions:    {}\n"
      .format(predictions))

if __name__ == "__main__":
    main()

從代碼可以看出很簡短的幾行就可以完成之前學過的很長的代碼所做的事情,用起來和用 sklearn 相似。

關于 tf.contrib.learn 可以查看:
https://www.tensorflow.org/api_guides/python/contrib.learn

可以看到里面也有 kmeans,logistic,linear 等模型:


在上面的代碼中:

  • tf.contrib.learn.datasets.base.load_csv_with_header 可以導入 CSV 數據集。
  • 分類器模型只需要一行代碼,就可以設置這個模型具有多少隱藏層,每個隱藏層有多少神經元,以及最后分為幾類。
  • 模型的訓練也是只需要一行代碼,輸入指定的數據,包括特征和標簽,再指定迭代的次數,就可以進行訓練。
  • 獲得準確率也同樣很簡單,只需要輸入測試集,調用 evaluate。
  • 預測新的數據集,只需要把新的樣本數據傳遞給 predict。

關于代碼里幾個新的方法:

1. load_csv_with_header():

用于導入 CSV,需要三個必需的參數:

  • filename,CSV文件的路徑
  • target_dtype,數據集的目標值的numpy數據類型。
  • features_dtype,數據集的特征值的numpy數據類型。

在這里,target 是花的品種,它是一個從 0-2 的整數,所以對應的numpy數據類型是np.int

2. tf.contrib.layers.real_valued_column:

所有的特征數據都是連續的,因此用 tf.contrib.layers.real_valued_column,數據集中有四個特征(萼片寬度,萼片高度,花瓣寬度和花瓣高度),因此 dimension=4 。

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

3. DNNClassifier:

  • feature_columns=feature_columns, 上面定義的一組特征
  • hidden_units=[10, 20, 10],三個隱藏層分別包含10,20,10個神經元。
  • n_classes=3,三個目標類,代表三個 Iris 品種。
  • model_dir=/tmp/iris_model,TensorFlow在模型訓練期間將保存 checkpoint data。

在后面會學到關于 TensorFlow 的 logging and monitoring 的章節,可以 track 一下訓練中的模型: “Logging and Monitoring Basics with tf.contrib.learn”。


推薦閱讀 歷史技術博文鏈接匯總
http://www.lxweimin.com/p/28f02bb59fe5
也許可以找到你想要的

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容