最近看到一個巨牛的人工智能教程,分享一下給大家。教程不僅是零基礎(chǔ),通俗易懂,而且非常風趣幽默,像看小說一樣!覺得太牛了,所以分享給大家。平時碎片時間可以當小說看,【點這里可以去膜拜一下大神的“小說”】。
1 CKPT模型轉(zhuǎn)換pb文件
使用上一篇博客《MobileNet V1官方預(yù)訓練模型的使用》中下載的MobileNet V1官方預(yù)訓練的模型《MobileNet_v1_1.0_192》。雖然打包下載的文件中包含已經(jīng)轉(zhuǎn)換過的pb
文件,但是官方提供的pb
模型輸出是1001
類別對應(yīng)的概率,我們需要的是概率最大的3類。可在原始網(wǎng)絡(luò)中使用函數(shù)tf.nn.top_k
獲取概率最大的3類,將函數(shù)tf.nn.top_k
作為網(wǎng)絡(luò)中的一個計算節(jié)點。模型轉(zhuǎn)換代碼如下所示。
import tensorflow as tf
from mobilenet_v1 import mobilenet_v1,mobilenet_v1_arg_scope
import numpy as np
slim = tf.contrib.slim
CKPT = 'mobilenet_v1_1.0_192.ckpt'
def build_model(inputs):
with slim.arg_scope(mobilenet_v1_arg_scope(is_training=False)):
logits, end_points = mobilenet_v1(inputs, is_training=False, depth_multiplier=1.0, num_classes=1001)
scores = end_points['Predictions']
print(scores)
#取概率最大的5個類別及其對應(yīng)概率
output = tf.nn.top_k(scores, k=3, sorted=True)
#indices為類別索引,values為概率值
return output.indices,output.values
def load_model(sess):
loader = tf.train.Saver()
loader.restore(sess,CKPT)
inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3),name='input')
classes_tf,scores_tf = build_model(inputs)
classes = tf.identity(classes_tf, name='classes')
scores = tf.identity(scores_tf, name='scores')
with tf.Session() as sess:
load_model(sess)
graph = tf.get_default_graph()
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [classes.op.name,scores.op.name])
tf.train.write_graph(output_graph_def, 'model', 'mobilenet_v1_1.0_192.pb', as_text=False)
上面代碼中,單一的所有類別概率經(jīng)過計算節(jié)點tf.nn.top_k
后分為兩個輸出:概率最大的3個類別classes
,概率最大的3個類別的概率scores
。執(zhí)行上面代碼后,在目錄“model”
中得到文件mobilenet_v1_1.0_192.pb
。
2 移植到Android中
2.1 AndroidStudio中使用Tensorflow Mobile
首先,AndroidStudio
版本必須是3.0
及以上。創(chuàng)建Android Project
后,在Module:app
的build.gradle
文件中的dependencies中加入如下:
compile 'org.tensorflow:tensorflow-android:+'
2.2 Tensorflow Mobile接口
使用Tensorflow Mobile庫中模型調(diào)用封裝類org.tensorflow.contrib.android.TensorFlowInferenceInterface
完成模型的調(diào)用,主要使用的如下函數(shù)。
public TensorFlowInferenceInterface(AssetManager assetManager, String model){...}
public void feed(String inputName, float[] src, long... dims) {...}
public void run(String[] outputNames) {...}
public void fetch(String outputName, int[] dst) {...}
其中,構(gòu)造函數(shù)中的參數(shù)model
表示目錄“assets”
中模型名稱。feed
函數(shù)中參數(shù)inputName
表示輸入節(jié)點的名稱,即對應(yīng)模型轉(zhuǎn)換時指定輸入節(jié)點的名稱“input”
,參數(shù)src
表示輸入數(shù)據(jù)數(shù)組,變長參數(shù)dims
表示輸入的維度,如傳入1,192,192,3
則表示輸入數(shù)據(jù)的Shape=[1,192,192,3]
。函數(shù)run
的參數(shù)outputNames
表示執(zhí)行從輸入節(jié)點到outputNames
中節(jié)點的所有路徑。函數(shù)fetch
中參數(shù)outputName
表示輸出節(jié)點的名稱,將指定的輸出節(jié)點的數(shù)據(jù)拷貝到dst
中。
2.3 Bitmap對象轉(zhuǎn)float[]
注意到,在2.1小節(jié)中函數(shù)feed
傳入到輸入節(jié)點的數(shù)據(jù)對象是float[]
。因此有必要將Bitmap
轉(zhuǎn)為float[]
對象,示例代碼如下所示。
//讀取Bitmap像素值,并放入到浮點數(shù)數(shù)組中。歸一化到[-1,1]
private float[] getFloatImage(Bitmap bitmap){
Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
for (int i = 0; i < inputIntData.length; ++i) {
final int val = inputIntData[i];
inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2;
inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2;
inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ;
}
return inputFloatData;
}
由于MobileNet V1
預(yù)訓練的模型輸入數(shù)據(jù)歸一化到[-1,1]
,因此在函數(shù)getFloatImage
中轉(zhuǎn)換數(shù)據(jù)的同時將數(shù)據(jù)歸一化到[-1,1]
。
2.4 封裝模型調(diào)用
為了便于調(diào)用,將與模型相關(guān)的調(diào)用函數(shù)封裝到類TFModelUtils
中,通過TFModelUtils
的run
函數(shù)完成模型的調(diào)用,示例代碼如下所示。
package com.huachao.mn_v1_192;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;
public class TFModelUtils {
private TensorFlowInferenceInterface inferenceInterface;
private int[] inputIntData ;
private float[] inputFloatData ;
private int inputWH;
private String inputName;
private String[] outputNames;
private Map<Integer,String> dict;
public TFModelUtils(AssetManager assetMngr,int inputWH,String inputName,String[]outputNames,String modelName){
this.inputWH=inputWH;
this.inputName=inputName;
this.outputNames=outputNames;
this.inputIntData=new int[inputWH*inputWH];
this.inputFloatData = new float[inputWH*inputWH*3];
//從assets目錄加載模型
inferenceInterface= new TensorFlowInferenceInterface(assetMngr, modelName);
this.loadLabel(assetMngr);
}
public Map<String,Object> run(Bitmap bitmap){
float[] inputData = getFloatImage(bitmap);
//將輸入數(shù)據(jù)復(fù)制到TensorFlow中,指定輸入Shape=[1,INPUT_WH,INPUT_WH,3]
inferenceInterface.feed(inputName, inputData, 1, inputWH, inputWH, 3);
// 執(zhí)行模型
inferenceInterface.run( outputNames );
//將輸出Tensor對象復(fù)制到指定數(shù)組中
int[] classes=new int[3];
float[] scores=new float[3];
inferenceInterface.fetch(outputNames[0], classes);
inferenceInterface.fetch(outputNames[1], scores);
Map<String,Object> results=new HashMap<>();
results.put("scores",scores);
String[] classesLabel = new String[3];
for(int i =0;i<3;i++){
int idx=classes[i];
classesLabel[i]=dict.get(idx);
// System.out.printf("classes:"+dict.get(idx)+",scores:"+scores[i]+"\n");
}
results.put("classes",classesLabel);
return results;
}
//讀取Bitmap像素值,并放入到浮點數(shù)數(shù)組中。歸一化到[-1,1]
private float[] getFloatImage(Bitmap bitmap){
Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
for (int i = 0; i < inputIntData.length; ++i) {
final int val = inputIntData[i];
inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2;
inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2;
inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ;
}
return inputFloatData;
}
//對圖像做Resize
public Bitmap getResizedBitmap(Bitmap bm, int newWidth, int newHeight) {
int width = bm.getWidth();
int height = bm.getHeight();
float scaleWidth = ((float) newWidth) / width;
float scaleHeight = ((float) newHeight) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
Bitmap resizedBitmap = Bitmap.createBitmap( bm, 0, 0, width, height, matrix, false);
bm.recycle();
return resizedBitmap;
}
private void loadLabel( AssetManager assetManager ) {
dict=new HashMap<>();
try {
InputStream stream = assetManager.open("label.txt");
InputStreamReader isr=new InputStreamReader(stream);
BufferedReader br=new BufferedReader(isr);
String line;
while((line=br.readLine())!=null){
line=line.trim();
String[] arr = line.split(",");
if(arr.length!=2)
continue;
int key=Integer.parseInt(arr[0]);
String value = arr[1];
dict.put(key,value);
}
}catch (Exception e){
e.printStackTrace();
Log.e("ERROR",e.getMessage());
}
}
}