端到端的OCR:LSTM+CTC的實現(xiàn)

前面提到了用CNN來做OCR。這篇文章介紹另一種做OCR的方法,就是通過LSTM+CTC。這種方法的好處是他可以事先不用知道一共有幾個字符需要識別。之前我試過不用CTC,只用LSTM,效果一直不行,后來下決心加上CTC,效果一下就上去了。

CTC是序列標志的一個重要算法,它主要解決了label對齊的問題。有很多實現(xiàn)。百度IDL在16年初公開了一個GPU的實現(xiàn),號稱速度比之前的theano-ctc, stanford-ctc都要快。Mxnet目前還沒有ctc的實現(xiàn),因此決定吧warpctc集成進mxnet。

根據(jù)issue里作者們的建議,決定和集成torch一樣,寫一個plugin,因此C++代碼放在plugin/warpctc目錄中。整個集成任務其實就是寫一個wrapctc的op。代碼在 plugin/warpctc/warpctc-inl.h.

CTC這一層其實和SoftmaxOutput很像。其實他們的forward的實現(xiàn)就是一模一樣的。唯一的差別就是backward中g(shù)rad的實現(xiàn),在這里需要調(diào)用warpctc的compute_ctc_loss函數(shù)來計算梯度。實際上warpctc的主要接口也就是這個函數(shù)。

下面說說具體怎么用lstm+ctc來做ocr的任務。詳細的代碼在 examples/warpctc/lstm_ocr.py。這里只說說大體思路。

假設我們要解決的是4位數(shù)字的識別,圖片是80*30的圖片。那么我們就將每張圖片按列切分成80個30維的向量。然后作為一個lstm的80個輸入。一個lstm的輸出和輸入數(shù)目應該是相同的。而我們的預測目標卻只有4個數(shù)字。而不是80個數(shù)字。在沒有用ctc時我想了兩個解決方案。第一個是用encode-decode模式。也就是80個輸入做encode,然后decode成4個輸出。實測效果很挫。第二個是把4個label每個copy20遍,從而變成80個label。實測也很挫。沒辦法,最后只能用ctc loss了。

用ctc loss的體會就是,如果input的長度遠遠大于label的長度,比如我這里是80和4的關(guān)系。那么一開始的收斂會比較慢。在其中有一段時間cost幾乎不變。此刻一定要有耐心,最終一定會收斂的。在ocr識別的這個例子上最終可以收斂到95%的精度。

目前代碼還在等待merge。pull request

---------------

歡迎關(guān)注 微信公眾號【ResysChina】

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內(nèi)容