3. 用Keras實現SimpleRNN——生成文本

愛麗絲:瘋帽子,為什么烏鴉會像寫字臺呢?
瘋帽子:我也沒有答案。
瘋帽子:別了,愛麗絲。 小時候你來過這里,你對我說,你喜歡我,我問你為什么,你回答說,因為烏鴉長得像寫字臺。 如今,我反復說著這句,是想喚起你的記憶,可惜的是,你什么都忘了。 為什么烏鴉看起來會像寫字臺呢?沒有原因的。 就像我愛你,沒有什么理由。 ——《愛麗絲夢游仙境》

曾經有一次和妹子去看《愛麗絲夢游仙境2》的電影。看完電影后,妹子問我:烏鴉為什么想寫字臺?我:???然后就沒有然后了,最后也沒有和她在一起。作為一個程序員也要多看點技術以外的書啊,要做一個有趣的人嘛。

扯遠了。我們今天要用Keras構建我們的第一個RNN實例,我們將在《愛麗絲夢游仙境》的文本上訓練一個基于字符的語言模型,這個模型將通過給定的前10個字符預測下一個字符。我們選擇一個基于字符的模型,是因為它的字典較小,并可以訓練的更快,這和基于詞的語言模型的想法是一樣的。

1. 文本預處理

首先我們先獲取《愛麗絲夢游仙境的》輸入文本
下載地址

導入必要的庫,讀入文件并作基本的處理

from keras.layers.recurrent import SimpleRNN
from keras.models import Sequential
from keras.layers import Dense, Activation
import numpy as np

INPUT_FILE = "./alice_in_wonderland.txt"

# extract the input as a stream of characters
print("Extracting text from input...")
fin = open(INPUT_FILE, 'rb')
lines = []
for line in fin:
    line = line.strip().lower()
    line = line.decode("ascii", "ignore")
    if len(line) == 0:
        continue
    lines.append(line)
fin.close()
text = " ".join(lines)

因為我們在構建一個字符級水平的RNN,我們將字典設置為文本中出現的所有字符。因為我們將要處理的是這些字符的索引而非字符本身,于是我們要創建必要的查詢表:

chars = set([c for c in text])
nb_chars = len(chars)
char2index = dict((c, i) for i, c in enumerate(chars))
index2char = dict((i, c) for i, c in enumerate(chars))

2. 創建輸入和標簽文本

我們通過STEP變量給出字符數目(本例為1)來步進便利文本,并提取出一段大小為SEQLEN變量定義值(本例為10)的文本段。文本段的下一字符是我們的標簽字符。

#   例如:輸入"The sky was falling",輸出如下(前5個):
#   The sky wa -> s
#   he sky was ->  
#   e sky was  -> f
#    sky was f -> a
#   sky was fa -> l
print("Creating input and label text...")
SEQLEN = 10
STEP = 1

input_chars = []
label_chars = []
for i in range(0, len(text) - SEQLEN, STEP):
    input_chars.append(text[i:i + SEQLEN])
    label_chars.append(text[i + SEQLEN])

3. 輸入和標簽文本向量化

RNN輸入中的每行都對應了前面展示的一個輸入文本。輸入中共有SEQLEN個字符,因為我們的字典大小是nb_chars給定的,我們把每個輸入字符表示成one-hot編碼的大小為(nb_chars)的向量。這樣每行輸入就是一個大小為(SEQLEN, nb_chars)的張量。我們的輸出標簽是一個單個的字符,所以和輸入中的每個字符的表示類似。我們將輸出標簽表示成大小為(nb_chars)的one-hot編碼的向量。因此,每個標簽的形狀就是nb_chars。

print("Vectorizing input and label text...")
X = np.zeros((len(input_chars), SEQLEN, nb_chars), dtype=np.bool)
y = np.zeros((len(input_chars), nb_chars), dtype=np.bool)
for i, input_char in enumerate(input_chars):
    for j, ch in enumerate(input_char):
        X[i, j, char2index[ch]] = 1
    y[i, char2index[label_chars[i]]] = 1

4. 構建模型

設定超參數BATCH_SIZE = 128
我們想返回一個字符作為輸出,而非字符序列,因而設置return_sequences=False
輸入形狀為(SEQLEN, nb_chars)
為了改善TensorFlow后端性能,設置unroll=True

HIDDEN_SIZE = 128
BATCH_SIZE = 128
NUM_ITERATIONS = 25
NUM_EPOCHS_PER_ITERATION = 1
NUM_PREDS_PER_EPOCH = 100

model = Sequential()
model.add(SimpleRNN(HIDDEN_SIZE, return_sequences=False,
                    input_shape=(SEQLEN, nb_chars),
                    unroll=True))
model.add(Dense(nb_chars))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer="rmsprop")

5. 模型的訓練和測試

我們分批訓練模型,每一步都測試輸出
測試模型:我們先從輸入里隨機選一個,然后用它去預測接下來的100個字符

for iteration in range(NUM_ITERATIONS):
    print("=" * 50)
    print("Iteration #: %d" % (iteration))
    model.fit(X, y, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS_PER_ITERATION, verbose=0)

    test_idx = np.random.randint(len(input_chars))
    test_chars = input_chars[test_idx]
    print("Generating from seed: %s" % (test_chars))
    print(test_chars, end="")
    for i in range(NUM_PREDS_PER_EPOCH):
        Xtest = np.zeros((1, SEQLEN, nb_chars))
        for i, ch in enumerate(test_chars):
            Xtest[0, i, char2index[ch]] = 1
        pred = model.predict(Xtest, verbose=0)[0]
        ypred = index2char[np.argmax(pred)]
        print(ypred, end="")
        # move forward with test_chars + ypred
        test_chars = test_chars[1:] + ypred
    print()

6. 輸出測試結果

模型剛開始的預測毫無意義,當隨著訓練輪數的增加,它已經可以進行正確的拼寫,盡管在表達語義上還有困難。畢竟我們這個模型是基于字符的,它對詞沒有任何認識,然而它的表現也足夠令人驚艷了。

Iteration #: 1
Generating from seed: cat," said
cat," said the master and the salle the the the the the the the the the the the the the the the the the the th
==================================================
Iteration #: 2
Generating from seed: e went str
e went streand the sare and and and and and and and and and and and and and and and and and and and and and an
==================================================

......

==================================================
Iteration #: 21
Generating from seed: ere near t
ere near the white rabbit was and the book of the mock turtle should the mock turtle should the mock turtle sh
==================================================
Iteration #: 22
Generating from seed: appens whe
appens where of the same the rabbit to see the end of the same the rabbit to see the end of the same the rabbi
==================================================
Iteration #: 23
Generating from seed:  "why you 
 "why you and the rabbit seemed to her head the mock turtle replied to be to the mock turtle replied to be to 
==================================================
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容

  • HTML標簽解釋大全 一、HTML標記 標簽:!DOCTYPE 說明:指定了 HTML 文檔遵循的文檔類型定義(D...
    米塔塔閱讀 3,310評論 1 41
  • 司令全名叫蒙古國海軍司令。 有點常識的人都知道,蒙古國是內陸國家,海軍就算有,也不過是蝦兵蟹將。作為最高長官,免不...
    四小姐的家閱讀 699評論 1 51
  • 有多久沒有靜下來思考了,太久了,久到已經想不起上一次是什么時候了。最近,發現自己的拖延癥已經是越來越嚴重了,進入到...
    行走在時空的迷茫人閱讀 227評論 0 0