INN實現(xiàn)理解——toy_example

github 地址:https://github.com/hagabbar/cINNamon

A Toy Example

1. 設置超參數(shù):

2. 生成數(shù)據(jù)

^^ 調(diào)用函數(shù)來生成樣本數(shù)據(jù),參數(shù) labels 用于限制生成的樣本類型,labels 有三種取值:all、some、none,分別對應圖 [A Toy Example] 中的三種樣本分布;參數(shù) tot_dataset_size 用于指定生成樣本的個數(shù)。

pos 是一個大小為 (tot_dataset_size, 2) 的二階矩陣,其元素都符合均值為0,方差為0.2的正態(tài)分布,表示樣本點的坐標;labels 是一個 (tot_dataset_size, 3) 的矩陣,表示樣本點的顏色 RGB 值。樣本被均勻分為 8 堆,對每堆的樣本坐標進行一定修改,使這堆樣本點落在相同區(qū)域內(nèi),且有相同的顏色。

^^ 分別取 pos、labels 的前 test_split 個元素作為測試樣本,畫出測試數(shù)據(jù)的分布圖如下:

3. 建立模型

^^ ndim_tot = max(ndim_x, ndim_y+ndim_z) + n_neurons,ndim_tot 的值對網(wǎng)絡結(jié)構有重要影響,輸入結(jié)點會將其作為維度值。如果維度 ndim_tot 相對較小,但卻需要學習一個很復雜的轉(zhuǎn)換,最好對網(wǎng)絡的輸入和輸出都進行相同數(shù)量的 0 填充。這并不會改變輸入和輸出的固有維度,但使得網(wǎng)絡內(nèi)部層可以以一種更靈活的方式將數(shù)據(jù)嵌入到更大的表示空間。

^^ ReversibleGraphNet 構造函數(shù)會做四件事:
① 構造 INN 網(wǎng)絡的正向連接,即:inp → t1 → t2 → t3 → outp。其中 t1、t2、t3 都是一個基礎構建塊,其結(jié)構為:

用公式表示為:

其中,s1、t1、s2、t2 都是一種線性映射關系,因此都被構造為一個有三層隱藏層的全連接神經(jīng)網(wǎng)絡。需要說明的是,隱藏層的神經(jīng)元個數(shù),被簡單設置為輸出層神經(jīng)元個數(shù)的 2 倍。

② 確定 INN 網(wǎng)絡的反向連接,使得可以進行反向訓練。
③ 確定正向訓練過程中涉及的變量及操作順序。
④ 確定反向訓練過程中涉及的變量及操作順序。

4. 訓練前準備工作

^^ 設置訓練參數(shù)。

^^ 各項損失的相對權重。INN 訓練過程中考慮三項損失:
① 模型輸出 yi = s(xi) 與網(wǎng)絡預測 fy(xi) 之間的偏差,損失記為 Ly(yi,fy(xi)),Ly 可以是任意有監(jiān)督的損失;lamdb_predict 為 Ly 的權重;
② 模型輸出 p(y = s(x)) = p(x) / |Js| 和潛在變量 p(z) 的邊際分布的乘積與網(wǎng)絡輸出 q(y = fy(x),z = fz(x)) = p(x) / |Jyz| 間的偏差,記為 Lz(p(y)p(z),q(y,z));lambd_latent 為 Lz 的權重;
③ 輸入端的損失 LxLx(p(x),q(x)) 表示了 p(x) 與后向預測分布 q(x) = p(y = fy(x)) p(z = fz(x)) / |Jx| 間的偏差;lambd_rev 為 Lx 的權重.

^^ 定義權重更新規(guī)則(optimizer),scheduler 對其進行封裝,目的是使其學習率每隔 step_size 輪就進行一次衰減。

^^ 定義損失函數(shù)。需要定義三個函數(shù)來進行三種損失的計算,其中 Lx、Lz 是無監(jiān)督損失,因此選擇了 MMD_multiscale(多刻度的 MMD,MMD 常用于度量兩個不同但相關的分布的距離)作為損失函數(shù);Ly 是有監(jiān)督損失,因此選擇了 fit(即平方誤差)。

^^ 建立測試數(shù)據(jù)裝載器和訓練集數(shù)據(jù)裝載器。DataLoader 返回的是一個迭代器,可以使用迭代器分批獲取數(shù)據(jù),或直接使用 for 循環(huán)對其進行遍歷。

^^ 初始化網(wǎng)絡權重。這里 block 指各 INN 構建塊, coeff 指 INN 構建塊中的全連接神經(jīng)網(wǎng)絡。它們都是 Module 類的子類對象,因此使用了三層 for 循環(huán)對權重初始化。

^^ 從測試樣本集中選取一部分,用于在模型訓練完成后進行模型測試。

5. 訓練模型

這個實現(xiàn)對 INN 訓練了 2000 次。我們只看一次訓練的步驟。

首先是設置此輪訓練的學習率,這個一般由我們之前包裝的 scheduler 根據(jù)訓練輪數(shù)來進行設定。

^^ 訓練網(wǎng)絡。其中調(diào)用了 train() 函數(shù):


核心函數(shù) train()

首先要將訓練涉及的各個模塊的狀態(tài)設置為 training。

^^ 設置 loss_factor,當 i_epoch 大于 300 時,其值為 1。

每輪訓練只能使用數(shù)據(jù)裝載器裝載 n_its_per_epoch(設定為4)批的樣本數(shù)據(jù),對于每一批數(shù)據(jù),進行如下處理。

^^ 對 x 和 yz 進行對其填充,使它們的維度和 ndim_tot 相同。在填充前,先為 y 增加隨機噪聲。這里也可以看出,z 服從標準正態(tài)分布

^^ 在開始訓練前,需要清除已存在的梯度。

^^ 執(zhí)行正向傳播得到輸出 output(正向計算由 PyTorch 實現(xiàn)),output 與輸入有相同的維度。y_short 維度為 (樣本數(shù) × 4),其中,前 2 列表示 z,后兩列表示 y。

^^ 計算損失 Ly,即為樣本 y 和網(wǎng)絡預測結(jié)果 y' (包括了補齊部分,但不包括 z)的均方誤差。

^^ 計算損失 Lz。output_block_grad 維度為 (樣本數(shù) × 4),其中,前 2 列表示 z,后兩列表示 y;其與 y_short 相對,區(qū)別是一個來源于正向網(wǎng)絡預測結(jié)果 output,一個來源于樣本 y。

^^ 這個 backward() 函數(shù)是 PyTorch 實現(xiàn)的,調(diào)用它是為了進行梯度計算。l 是正向過程的總損失,調(diào)用 l.backward() 計算梯度,是為了之后更新權值做準備。

^^ 這些都是反向訓練需要的變量。y_rev 除了補齊部分外,包含了增加了隨機噪聲的、上一輪正向訓練得到的 z;以及增加了隨機噪聲的原始樣本 y。y_rev_rand 與 y_rev 大小相同,不同的地方在于,y_rev_rand 包含的是隨機生成的服從標準正態(tài)分布的 z'。

^^ 對 y_rev 和 y_rev_rand 進行反向訓練,得到輸出結(jié)果。

^^ 計算反向訓練的損失 Lx,可見其由兩部分組成:一是樣本 x 與反向訓練結(jié)果 output_rev_rand 中的 x' 的差異;二是正向訓練的輸入與反向訓練的輸出(正向訓練的輸出為反向訓練的輸入)間的差異。

^^ l_rev 是逆向過程的總損失,調(diào)用 l_rev.backward() 計算梯度,是為了之后更新權值做準備。

^^ 將各參數(shù)的梯度值限制在 [-15.0,15.0] 區(qū)間內(nèi),然后更新網(wǎng)絡權值。

^^ 一輪訓練結(jié)束,返回訓練每批數(shù)據(jù)的總損失。

train() 執(zhí)行結(jié)束


在每輪訓練結(jié)束后,使用測試數(shù)據(jù)對模型進行測試,畫出其分布,即對訓練結(jié)果可視化。

?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內(nèi)容