Pytorch 上的端到端語音識別
基于 Transformer 的語音識別模型
開源地址
https://github.com/gentaiscool/end2end-asr-pytorch
簡介
自從十年前采用基于深度神經網絡 (DNN)的混合建模以來,自動語音識別 (ASR) 的準確率得到了顯著提高。這種突破主要是使用DNN代替?zhèn)鹘y(tǒng)的高斯混合模型進行聲學似然評估,同時保留聲學模型、語言模型和詞典模型等所有模塊,進而組成了混合ASR系統(tǒng)。最近,語音社區(qū)通過從混合建模過渡到端到端(E2E)建模有了新的突破,新方案使用單個網絡將輸入語音序列直接轉換為輸出標記序列。這樣的突破更具革命性,因為它推翻了傳統(tǒng)ASR系統(tǒng)中已經使用了幾十年的模塊式建模。
端到端模型比傳統(tǒng)的混合模型有幾個主要優(yōu)點:
首先,端到端模型使用與ASR目標一致的單一目標函數(shù)來優(yōu)化整個網絡,而傳統(tǒng)的混合模型單獨優(yōu)化每個模塊,無法保證全局最優(yōu)。并且,端到端模型已被證明不論在學術界還是在工業(yè)界都優(yōu)于傳統(tǒng)的混合模型。
其次,由于端到端模型直接輸出字符甚至單詞,大大簡化了語音識別流程。相比之下,傳統(tǒng)混合模型的設計復雜,需要大量ASR專家經驗知識。
最后,由于端到端模型采用單一網絡,比傳統(tǒng)的混合模型更加緊湊,因此,端到端模型可以部署到高精度、低延遲的設備上。
隨著深度神經網絡的發(fā)展和硬件算力支持,基于RNN,DCNN,attenton 和transformer等神經網絡也逐漸開始在語音識別應用,并得到好的效果。
這里就是基于一種低秩結構 low-rank的transformer ,就是將attention 的keys和values的長度維投影到較低維的表示形式,從而改善了transformer在內存的存儲復雜度和提高了計算效率。此方法減少了沖過50%的神經網絡參數(shù),比baseline的transformer模型提高了1.35x的速度。同時實驗說明了LRT model 在測試集獲得了更好的性能表現(xiàn)。 LRT在現(xiàn)存的一些數(shù)據(jù)集上表現(xiàn)更佳,在不用外部語言模型火聲學數(shù)據(jù)的情況下。
部署
這里先使用docker 部署,后期可以轉到本機或帶聲音的設備
Dockerfile
FROM pytorch/pytorch:1.4-cuda10.1-cudnn7-devel
RUN apt-get update \ apt-get install -y libsndfile1
RUN pip install -i https://pypi.douban.com/simple torchaudio tqdm python-Levenshtein librosa wget
RUN pip install -i https://pypi.douban.com/simple SoundFile numpy==1.19 numba==0.48.0 librosa==0.6.0
運行容器
docker run --gpus=all -itd --name asr --shm-size 12G -v /media/nizhengqi/sdf/wyh/end2end-asr-pytorch:/workspace asr:v2
數(shù)據(jù)處理
從https://www.openslr.org/33/下載中文數(shù)據(jù)集
數(shù)據(jù)集存放
在工作目錄 end2end-asr-pytorch下建立Aishell_dataset文件夾
下面存放
transcript 原始數(shù)據(jù)
transcript_clean transcript_clean_lang 處理后數(shù)據(jù)
建立劃分元數(shù)據(jù)存放訓練集 開發(fā)集和測試集的劃分
位于end2end-asr-pytorch/manifests
aishell_dev_lang_manifest.csv aishell_test_lang_manifest.csv aishell_train_lang_manifest.csv
aishell_dev_manifest.csv aishell_test_manifest.csv aishell_train_manifest.csv
修改 /opt/conda/lib/python3.7/codecs.py
這里只是暫時略過異常,會有不少臟數(shù)據(jù)
ef decode(self, input, final=False):
# decode input (taking the buffer into account)
try:
data = self.buffer + input
(result, consumed) = self._buffer_decode(data, self.errors, final)
# keep undecoded input until the next call
self.buffer = data[consumed:]
except:
result = "012345"
return result
修改處理代碼 data/aishell.py
with open(text_file_path, "r", encoding="utf-8") as text_file:
for line in text_file.readlines():
if line=="012345":
continue
print(line)
with open(text_file_path, "r", encoding="utf-8") as text_file:
for line in text_file.readlines():
if line=="012345":
continue
print(line)
with open(text_file_path, "r", encoding="utf-8") as text_file:
for line in text_file.readlines():
if line=="012345":
continue
print(line)
修改文件名錯誤
with open("manifests/aishell_train_manifest.csv", "w+") as train_manifest:
for i in range(len(tr_file_list)):
wav_filename = tr_file_list[i]
text_filename = tr_file_list[i].replace(".wav", "").replace("transcript", "transcript_clean") # 修改
print(text_filename)
with open("manifests/aishell_dev_manifest.csv", "w+") as valid_manifest:
for i in range(len(dev_file_list)):
wav_filename = dev_file_list[i]
text_filename = dev_file_list[i].replace(".wav", "").replace("transcript", "transcript_clean")
with open("manifests/aishell_test_manifest.csv", "w+") as test_manifest:
for i in range(len(test_file_list)):
wav_filename = test_file_list[i]
text_filename = test_file_list[i].replace(".wav", "").replace("transcript", "transcript_clean")
with open("manifests/aishell_train_lang_manifest.csv", "w+") as train_manifest:
for i in range(len(tr_file_list)):
wav_filename = tr_file_list[i]
text_filename = tr_file_list[i].replace(".wav", "").replace("transcript", "transcript_clean_lang")
with open("manifests/aishell_dev_lang_manifest.csv", "w+") as valid_manifest:
for i in range(len(dev_file_list)):
wav_filename = dev_file_list[i]
text_filename = dev_file_list[i].replace(".wav", "").replace("transcript", "transcript_clean_lang")
with open("manifests/aishell_test_lang_manifest.csv", "w+") as test_manifest:
for i in range(len(test_file_list)):
wav_filename = test_file_list[i]
text_filename = test_file_list[i].replace(".wav", "").replace("transcript", "transcript_clean_lang")
注意 label位置,根據(jù)具體情況修改
with open("data/labels/aishell_labels.json", "w+") as labels_json:
修改 utils/audio.py
def load_audio(path):
sound, _ = torchaudio.load(path,normalize=True) # normalization=True)
運行aishell.py
訓練
python train.py --train-manifest-list manifests/aishell_train_manifest.csv --valid-manifest-list manifests/aishell_dev_manifest.csv --test-manifest-list manifests/aishell_test_manifest.csv --cuda --batch-size 12 --labels-path data/labels/aishell_labels.json --lr 1e-4 --name aishell_drop0.1_cnn_batch12_4_vgg_layer4 --save-folder save/ --save-every 5 --feat_extractor vgg_cnn --dropout 0.1 --num-layers 4 --num-heads 8 --dim-model 512 --dim-key 64 --dim-value 64 --dim-input 161 --dim-inner 2048 --dim-emb 512 --shuffle --min-lr 1e-6 --k-lr 1
測試
python test.py --test-manifest-list libri_test_clean_manifest.csv --cuda --continue_from save/model
后續(xù)還要收集處理自己的數(shù)據(jù),繼續(xù)訓練,調試參數(shù),這也是更為麻煩的工作