愛麗絲:瘋帽子,為什么烏鴉會像寫字臺呢?
瘋帽子:我也沒有答案。
瘋帽子:別了,愛麗絲。 小時候你來過這里,你對我說,你喜歡我,我問你為什么,你回答說,因為烏鴉長得像寫字臺。 如今,我反復說著這句,是想喚起你的記憶,可惜的是,你什么都忘了。 為什么烏鴉看起來會像寫字臺呢?沒有原因的。 就像我愛你,沒有什么理由。 ——《愛麗絲夢游仙境》
曾經有一次和妹子去看《愛麗絲夢游仙境2》的電影。看完電影后,妹子問我:烏鴉為什么想寫字臺?我:???然后就沒有然后了,最后也沒有和她在一起。作為一個程序員也要多看點技術以外的書啊,要做一個有趣的人嘛。
扯遠了。我們今天要用Keras構建我們的第一個RNN實例,我們將在《愛麗絲夢游仙境》的文本上訓練一個基于字符的語言模型,這個模型將通過給定的前10個字符預測下一個字符。我們選擇一個基于字符的模型,是因為它的字典較小,并可以訓練的更快,這和基于詞的語言模型的想法是一樣的。
1. 文本預處理
首先我們先獲取《愛麗絲夢游仙境的》輸入文本
下載地址
導入必要的庫,讀入文件并作基本的處理
from keras.layers.recurrent import SimpleRNN
from keras.models import Sequential
from keras.layers import Dense, Activation
import numpy as np
INPUT_FILE = "./alice_in_wonderland.txt"
# extract the input as a stream of characters
print("Extracting text from input...")
fin = open(INPUT_FILE, 'rb')
lines = []
for line in fin:
line = line.strip().lower()
line = line.decode("ascii", "ignore")
if len(line) == 0:
continue
lines.append(line)
fin.close()
text = " ".join(lines)
因為我們在構建一個字符級水平的RNN,我們將字典設置為文本中出現的所有字符。因為我們將要處理的是這些字符的索引而非字符本身,于是我們要創建必要的查詢表:
chars = set([c for c in text])
nb_chars = len(chars)
char2index = dict((c, i) for i, c in enumerate(chars))
index2char = dict((i, c) for i, c in enumerate(chars))
2. 創建輸入和標簽文本
我們通過STEP變量給出字符數目(本例為1)來步進便利文本,并提取出一段大小為SEQLEN變量定義值(本例為10)的文本段。文本段的下一字符是我們的標簽字符。
# 例如:輸入"The sky was falling",輸出如下(前5個):
# The sky wa -> s
# he sky was ->
# e sky was -> f
# sky was f -> a
# sky was fa -> l
print("Creating input and label text...")
SEQLEN = 10
STEP = 1
input_chars = []
label_chars = []
for i in range(0, len(text) - SEQLEN, STEP):
input_chars.append(text[i:i + SEQLEN])
label_chars.append(text[i + SEQLEN])
3. 輸入和標簽文本向量化
RNN輸入中的每行都對應了前面展示的一個輸入文本。輸入中共有SEQLEN
個字符,因為我們的字典大小是nb_chars
給定的,我們把每個輸入字符表示成one-hot編碼的大小為(nb_chars)
的向量。這樣每行輸入就是一個大小為(SEQLEN, nb_chars)
的張量。我們的輸出標簽是一個單個的字符,所以和輸入中的每個字符的表示類似。我們將輸出標簽表示成大小為(nb_chars)
的one-hot編碼的向量。因此,每個標簽的形狀就是nb_chars。
print("Vectorizing input and label text...")
X = np.zeros((len(input_chars), SEQLEN, nb_chars), dtype=np.bool)
y = np.zeros((len(input_chars), nb_chars), dtype=np.bool)
for i, input_char in enumerate(input_chars):
for j, ch in enumerate(input_char):
X[i, j, char2index[ch]] = 1
y[i, char2index[label_chars[i]]] = 1
4. 構建模型
設定超參數BATCH_SIZE = 128
我們想返回一個字符作為輸出,而非字符序列,因而設置return_sequences=False
輸入形狀為(SEQLEN, nb_chars)
為了改善TensorFlow后端性能,設置unroll=True
HIDDEN_SIZE = 128
BATCH_SIZE = 128
NUM_ITERATIONS = 25
NUM_EPOCHS_PER_ITERATION = 1
NUM_PREDS_PER_EPOCH = 100
model = Sequential()
model.add(SimpleRNN(HIDDEN_SIZE, return_sequences=False,
input_shape=(SEQLEN, nb_chars),
unroll=True))
model.add(Dense(nb_chars))
model.add(Activation("softmax"))
model.compile(loss="categorical_crossentropy", optimizer="rmsprop")
5. 模型的訓練和測試
我們分批訓練模型,每一步都測試輸出
測試模型:我們先從輸入里隨機選一個,然后用它去預測接下來的100個字符
for iteration in range(NUM_ITERATIONS):
print("=" * 50)
print("Iteration #: %d" % (iteration))
model.fit(X, y, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS_PER_ITERATION, verbose=0)
test_idx = np.random.randint(len(input_chars))
test_chars = input_chars[test_idx]
print("Generating from seed: %s" % (test_chars))
print(test_chars, end="")
for i in range(NUM_PREDS_PER_EPOCH):
Xtest = np.zeros((1, SEQLEN, nb_chars))
for i, ch in enumerate(test_chars):
Xtest[0, i, char2index[ch]] = 1
pred = model.predict(Xtest, verbose=0)[0]
ypred = index2char[np.argmax(pred)]
print(ypred, end="")
# move forward with test_chars + ypred
test_chars = test_chars[1:] + ypred
print()
6. 輸出測試結果
模型剛開始的預測毫無意義,當隨著訓練輪數的增加,它已經可以進行正確的拼寫,盡管在表達語義上還有困難。畢竟我們這個模型是基于字符的,它對詞沒有任何認識,然而它的表現也足夠令人驚艷了。
Iteration #: 1
Generating from seed: cat," said
cat," said the master and the salle the the the the the the the the the the the the the the the the the the th
==================================================
Iteration #: 2
Generating from seed: e went str
e went streand the sare and and and and and and and and and and and and and and and and and and and and and an
==================================================
......
==================================================
Iteration #: 21
Generating from seed: ere near t
ere near the white rabbit was and the book of the mock turtle should the mock turtle should the mock turtle sh
==================================================
Iteration #: 22
Generating from seed: appens whe
appens where of the same the rabbit to see the end of the same the rabbit to see the end of the same the rabbi
==================================================
Iteration #: 23
Generating from seed: "why you
"why you and the rabbit seemed to her head the mock turtle replied to be to the mock turtle replied to be to
==================================================