這幾天在看Spatial Transformer Networks 空間變換網絡,該網絡的結構圖如下所示,STN由三部分組成:分別為localization network,grid generator 和 bilinear sampler組成。
其中U和V分別為輸入/輸出圖片(也可以是feature map)
Localization network: 局部網絡,一般由全連接或者卷積神經網絡加上回歸層構成,輸入為U, 輸出為空間變換矩陣θ,如果是仿射變換,就是6個神經元,如果是射影變換就是8個神經元,可以參考這兩篇博客:Creating a Gallery of Transformed Images以及圖像的等距變換,相似變換,仿射變換,射影變換及其matlab實現。
grid generator:通過下面的變換,利用空間變換矩陣θ 在輸入特征圖上產生輸出特征圖像素應該被采樣的坐標點。
bilinear sampler:在輸入特征圖上,結合上一步產生的坐標,在對應位置進行雙線性插值,得到輸出特征圖對應點的像素值!
下面是STN的TensorFlow版本實現,這里簡單記錄一下編譯的過程:
1、下載 mnist_cluttered_60x60_6distortions.npz 數據集,放在data目錄下;
2、在utils文件夾下創建 __init__.py文件(空文件)即可,這樣就將utils文件夾變成了Python模塊,否則會提醒找不到 data_utils一系列模塊。
3、修改main.py中root_dir、logs_dir、save_dir 和 vis_path的路徑;
4、在終端下進入工程:命令行輸入python main.py 回車,程序開始運行。這時打開utils文件夾,會看到生成了4個對應的 .pyc文件。
5、如果是在GPU上運行代碼,可能還會出現顯示的問題如下:同樣在main.py文件import matplotlib.pyplot as plt 后邊添加 plt.switch_backend(‘agg’)即可。
6、用tensorboard進行可視化,將日志的地址指向程序日志輸出的地址(/spatial_transformer_network/logs),進入工程目錄如下輸入命令:
可以看到服務器端口6006已經在使用中,所以使用其它端口,復制 http://0.0.0.0:7001 到瀏覽器,就可以看到可視化圖表了:
點擊進入GRAPHS欄,可以看到上面程序TensorFlow計算圖的可視化結果。
7、程序運行結束后,可以看到生成了對應的文件:
8、最后對第一下CPU和GPU的運行速度:可以看到兩個速度差了6倍左右!
PS: 不知道是不是由于編輯器的問題,我的代碼在終端可以訓練,但是在本地不能單步調試,進入不了主函數。因為在調試過程中發現__name__ = {str}'main',和平常的類型不一致(__name__ = {str}'__main__'),最后將 if__name__ =='__main__': 修改為:if __name__ == 'main': 就可以了。
簡單記錄一下,防止日后忘了,現在我要去分析結果啦!