前面提到了用CNN來做OCR。這篇文章介紹另一種做OCR的方法,就是通過LSTM+CTC。這種方法的好處是他可以事先不用知道一共有幾個字符需要識別。之前我試過不用CTC,只用LSTM,效果一直不行,后來下決心加上CTC,效果一下就上去了。
CTC是序列標志的一個重要算法,它主要解決了label對齊的問題。有很多實現(xiàn)。百度IDL在16年初公開了一個GPU的實現(xiàn),號稱速度比之前的theano-ctc, stanford-ctc都要快。Mxnet目前還沒有ctc的實現(xiàn),因此決定吧warpctc集成進mxnet。
根據(jù)issue里作者們的建議,決定和集成torch一樣,寫一個plugin,因此C++代碼放在plugin/warpctc目錄中。整個集成任務其實就是寫一個wrapctc的op。代碼在 plugin/warpctc/warpctc-inl.h.
CTC這一層其實和SoftmaxOutput很像。其實他們的forward的實現(xiàn)就是一模一樣的。唯一的差別就是backward中g(shù)rad的實現(xiàn),在這里需要調(diào)用warpctc的compute_ctc_loss函數(shù)來計算梯度。實際上warpctc的主要接口也就是這個函數(shù)。
下面說說具體怎么用lstm+ctc來做ocr的任務。詳細的代碼在 examples/warpctc/lstm_ocr.py。這里只說說大體思路。
假設我們要解決的是4位數(shù)字的識別,圖片是80*30的圖片。那么我們就將每張圖片按列切分成80個30維的向量。然后作為一個lstm的80個輸入。一個lstm的輸出和輸入數(shù)目應該是相同的。而我們的預測目標卻只有4個數(shù)字。而不是80個數(shù)字。在沒有用ctc時我想了兩個解決方案。第一個是用encode-decode模式。也就是80個輸入做encode,然后decode成4個輸出。實測效果很挫。第二個是把4個label每個copy20遍,從而變成80個label。實測也很挫。沒辦法,最后只能用ctc loss了。
用ctc loss的體會就是,如果input的長度遠遠大于label的長度,比如我這里是80和4的關(guān)系。那么一開始的收斂會比較慢。在其中有一段時間cost幾乎不變。此刻一定要有耐心,最終一定會收斂的。在ocr識別的這個例子上最終可以收斂到95%的精度。
目前代碼還在等待merge。pull request。
---------------
歡迎關(guān)注 微信公眾號【ResysChina】