BERT(Bidirectional Encoder Representations from Transformers)的MLM(Masked Language Model)損失是這樣設計的:在訓練過程中,BERT隨機地將輸入文本中的一些單詞替換為一個特殊的[MASK]標記,然后模型的任務是預測這些被掩蓋的單詞。具體來說,它會預測整個詞匯表中每個單詞作為掩蓋位置的概率。
MLM損失的計算方式是使用交叉熵損失函數。對于每個被掩蓋的單詞,模型會輸出一個概率分布,表示每個可能的單詞是正確單詞的概率。交叉熵損失函數會計算模型輸出的概率分布與真實單詞的分布(實際上是一個one-hot編碼,其中正確單詞的位置是1,其余位置是0)之間的差異。
具體來說,如果你有一個詞匯表大小為V,對于一個被掩蓋的單詞,模型會輸出一個V維的向量,表示詞匯表中每個單詞的概率。如果y是一個one-hot編碼的真實分布,而p是模型預測的分布,則交叉熵損失可以表示為(用于衡量模型預測概率分布與真實標簽概率分布之間的差異):
其中:
-
表示損失函數的值
-
表示類別的數量
-
是第
個類別的真實標簽,通常為0或1
-
是模型預測第
個類別的概率
-
表示自然對數
-
表示對所有類別求和
在這個公式中,是真實分布中的第i個元素,而
是模型預測的分布中的第i個元素。由于y是one-hot編碼的,所以除了正確單詞對應的位置為1,其余位置都是0,這意味著上面的求和實際上只在正確單詞的位置計算。
在實際操作中,為了提高效率,通常不會對整個詞匯表進行預測,而是使用采樣技術,如負采樣(negative sampling)或者層次softmax(hierarchical softmax),來減少每個訓練步驟中需要計算的輸出數量。