用 TensorFlow.js 在瀏覽器中訓練神經網絡

本文結構:

  1. 什么是 TensorFlow.js
  2. 為什么要在瀏覽器中運行機器學習算法
  3. 應用舉例:regression
  4. 和 tflearn 的代碼比較

1. 什么是 TensorFlow.js

TensorFlow.js 是一個開源庫,不僅可以在瀏覽器中運行機器學習模型,還可以訓練模型。
具有 GPU 加速功能,并自動支持 WebGL
可以導入已經訓練好的模型,也可以在瀏覽器中重新訓練現有的所有機器學習模型
運行 Tensorflow.js 只需要你的瀏覽器,而且在本地開發的代碼與發送給用戶的代碼是相同的。

TensorFlow.js 對未來 web 開發有著重要的影響,JS 開發者可以更容易地實現機器學習,工程師和數據科學家們可以有一種新的方法來訓練算法,例如官網上 Emoji Scavenger Hunt 這樣的游戲界面,讓用戶一邊玩游戲一邊將模型訓練地更好。

用 Tensorflow.js 可以做很多事情,
例如 object detection in images, speech recognition, music composition,
而且 不需要安裝任何庫,也不用一次又一次地編譯這些代碼。


2. 為什么要在瀏覽器中運行機器學習算法

TensorFlow.js 可以為用戶解鎖巨大價值:

  1. 隱私:用戶端的機器學習,用來訓練模型的數據還有模型的使用都在用戶的設備上完成,這意味著不需要把數據傳送或存儲在服務器上。
  2. 更廣泛的使用:幾乎每個電腦手機平板上都有瀏覽器,并且幾乎每個瀏覽器都可以運行JS,無需下載或安裝任何應用程序,在瀏覽器中就可以運行機器學習框架來實現更高的用戶轉換率,提高滿意度,例如虛擬試衣間等服務。
  3. 分布式計算:每次用戶使用系統時,他都是在自己的設備上運行機器學習算法,之后新的數據點將被推送到服務器來幫助改進模型,那么未來的用戶就可以使用訓練的更好的算法了,這樣可以減少訓練成本,并且持續訓練模型。

3. 應用舉例:regression

為了很快地看看效果,有下面三種方式:

  1. 可以直接從瀏覽器里寫代碼,例如 chrome 的 View > Developer > Javascript Console,
  2. 還可以在線寫
    有三個流行的在線 JS 平臺:CodePen, JSFiddle, JSBin.
    https://codepen.io/thekevinscott/pen/aGapZL
    https://jsfiddle.net/
    https://jsbin.com/?html,output
  3. 當然還可以在本地把代碼保存為.html文件并用瀏覽器打開

那么先來看一下下面這段代碼,可以在 codepen 中運行:
https://codepen.io/pen?&editors=1011

這段代碼的目的是做個回歸預測,

數據集為:
構造符合 Y=2X-1 的幾個點,
那么當 X 取 [-1, 0, 1, 2, 3, 4] 時,
y 為 [-3, -1, 1, 3, 5, 7],

<html>

 <head>
    <!-- Load TensorFlow.js -->
    <!-- Get latest version at https://github.com/tensorflow/tfjs -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2">   
    </script>
 </head>
 
 <body>
   <div id="output_field"></div>
 </body>
 
 <script>
    async function learnLinear(){
    
        const model = tf.sequential();
        model.add(tf.layers.dense({
            units: 1, 
            inputShape: [1]
        }));
        
        model.compile({
            loss: 'meanSquaredError',
            optimizer: 'sgd'
        });
  
        const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
        const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
  
        await model.fit(xs, ys, {epochs: 500});
  
        document.getElementById('output_field').innerText =
            model.predict( tf.tensor2d([10], [1, 1]) );
    }
    
    learnLinear();
 </script>
 
<html>
  • 首先是熟悉的 js 的基礎結構:
<html>
<head></head>
<body></body>
</html>
        const model = tf.sequential();
        model.add(tf.layers.dense({
            units: 1, 
            inputShape: [1]
        }));
  • 接著定義 loss 為 MSE 和 optimizer 為 SGD:
        model.compile({
            loss: 'meanSquaredError',
            optimizer: 'sgd'
        });
  • 同時需要定義 input 的 tensor,X 和 y,以及它們的維度都是 [6, 1]:
        const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
        const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
  • 然后用 fit 來訓練模型,因為要等模型訓練完才能預測,所以要用 await:
        await model.fit(xs, ys, {epochs: 500});
  • 訓練結束后,用 predict 進行預測,輸入的是 [1, 1] 維的 值為 10 的tensor ,
        document.getElementById('output_field').innerText =
            model.predict( tf.tensor2d([10], [1, 1]) );
  • 最后得到的輸出為
Tensor 
[[18.9862976],]

4. 和 tflearn 的代碼比較

再來通過一個簡單的例子來比較一下 Tensorflow.js 和 tflearn,
可以看出如果熟悉 tflearn 的話,那么 Tensorflow.js 會非常容易上手,


學習資料:
https://medium.com/tensorflow/getting-started-with-tensorflow-js-50f6783489b2
https://thekevinscott.com/reasons-for-machine-learning-in-the-browser/
https://www.analyticsvidhya.com/blog/2018/04/tensorflow-js-build-machine-learning-models-javascript/
https://hackernoon.com/introducing-tensorflow-js-3f31d70f5904
https://thekevinscott.com/tensorflowjs-hello-world/


推薦閱讀 歷史技術博文鏈接匯總
http://www.lxweimin.com/p/28f02bb59fe5
也許可以找到你想要的:
[入門問題][TensorFlow][深度學習][強化學習][神經網絡][機器學習][自然語言處理][聊天機器人]

?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容