圖神經網絡 PyTorch Geometric 入門教程

簡介

Graph Neural Networks 簡稱 GNN,稱為圖神經網絡,是深度學習中近年來一個比較受關注的領域。近年來 GNN 在學術界受到的關注越來越多,與之相關的論文數量呈上升趨勢,GNN 通過對信息的傳遞,轉換和聚合實現特征的提取,類似于傳統的 CNN,只是 CNN 只能處理規則的輸入,如圖片等輸入的高、寬和通道數都是固定的,而 GNN 可以處理不規則的輸入,如點云等。 可查看【GNN】萬字長文帶你入門 GCN

而 PyTorch Geometric Library (簡稱 PyG) 是一個基于 PyTorch 的圖神經網絡庫,地址是:https://github.com/rusty1s/pytorch_geometric。它包含了很多 GNN 相關論文中的方法實現和常用數據集,并且提供了簡單易用的接口來生成圖,因此對于復現論文來說也是相當方便。用法大多數和 PyTorch 很相近,因此熟悉 PyTorch 的同學使用這個庫可以很快上手。

torch_geometric.data.Data

節點和節點之間的邊構成了圖。所以在 PyG 中,如果你要構建圖,那么需要兩個要素:節點和邊。PyG 提供了torch_geometric.data.Data (下面簡稱Data) 用于構建圖,包括 5 個屬性,每一個屬性都不是必須的,可以為空。

  • x: 用于存儲每個節點的特征,形狀是[num_nodes, num_node_features]
  • edge_index: 用于存儲節點之間的邊,形狀是 [2, num_edges]
  • pos: 存儲節點的坐標,形狀是[num_nodes, num_dimensions]
  • y: 存儲樣本標簽。如果是每個節點都有標簽,那么形狀是[num_nodes, *];如果是整張圖只有一個標簽,那么形狀是[1, *]
  • edge_attr: 存儲邊的特征。形狀是[num_edges, num_edge_features]

實際上,Data對象不僅僅限制于這些屬性,我們可以通過data.face來擴展Data,以張量保存三維網格中三角形的連接性。

需要注意的的是,在Data里包含了樣本的 label,這意味和 PyTorch 稍有不同。在PyTorch中,我們重寫Dataset__getitem__(),根據 index 返回對應的樣本和 label。在 PyG 中,我們使用的不是這種寫法,而是在get()函數中根據 index 返回torch_geometric.data.Data類型的數據,在Data里包含了數據和 label。

下面一個例子是未加權無向圖 ( 未加權指邊上沒有權值 ),包括 3 個節點和 4 條邊。

<div align="center"><img src="https://image.zhangxiann.com/20200522215349.png"/></div>

由于是無向圖,因此有 4 條邊:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)。每個節點都有自己的特征。上面這個圖可以使用torch_geometric.data.Data來表示如下:

import torch
from torch_geometric.data import Data
# 由于是無向圖,因此有 4 條邊:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
# 節點的特征                           
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

注意edge_index中邊的存儲方式,有兩個list,第 1 個list是邊的起始點,第 2 個list是邊的目標節點。注意與下面的存儲方式的區別。

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous())

這種情況edge_index需要先轉置然后使用contiguous()方法。關于contiguous()函數的作用,查看 PyTorch中的contiguous

最后再復習一遍,Data中最基本的 4 個屬性是xedge_indexposy,我們一般都需要這 4 個參數。

有了Data,我們可以創建自己的Dataset,讀取并返回Data了。

Dataset 與 DataLoader

PyG 的 Dataset繼承自torch.utils.data.Dataset,自帶了很多圖數據集,我們以TUDataset為例,通過以下代碼就可以加載數據集,root參數設置數據下載的位置。通過索引可以訪問每一個數據。

from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
data = dataset[0]

在一個圖中,由edge_indexedge_attr可以決定所有節點的鄰接矩陣。PyG 通過創建稀疏的對角鄰接矩陣,并在節點維度中連接特征矩陣和 label 矩陣,實現了在 mini-batch 的并行化。PyG 允許在一個 mini-batch 中的每個Data (圖) 使用不同數量的節點和邊。

<div align="center"><img src="https://image.zhangxiann.com/20200522225100.png"/></div>

自定義 Dataset

盡管 PyG 已經包含許多有用的數據集,我們也可以通過繼承torch_geometric.data.Dataset使用自己的數據集。提供 2 種不同的Dataset

  • InMemoryDataset:使用這個Dataset會一次性把數據全部加載到內存中。
  • Dataset: 使用這個Dataset每次加載一個數據到內存中,比較常用。

我們需要在自定義的Dataset的初始化方法中傳入數據存放的路徑,然后 PyG 會在這個路徑下再劃分 2 個文件夾:

  • raw_dir: 存放原始數據的路徑,一般是 csv、mat 等格式
  • processed_dir: 存放處理后的數據,一般是 pt 格式 ( 由我們重寫process()方法實現)。

在 PyTorch 中,是沒有這兩個文件夾的。下面來說明一下這兩個文件夾在 PyG 中的實際意義和處理邏輯。

torch_geometric.data.Dataset繼承自torch.utils.data.Dataset,在初始化方法 __init__()中,會調用_download()方法和_process()方法。

def __init__(self, root=None, transform=None, pre_transform=None,
             pre_filter=None):
    super(Dataset, self).__init__()

    if isinstance(root, str):
        root = osp.expanduser(osp.normpath(root))

    self.root = root
    self.transform = transform
    self.pre_transform = pre_transform
    self.pre_filter = pre_filter
    self.__indices__ = None

    # 執行 self._download() 方法
    if 'download' in self.__class__.__dict__.keys():
        self._download()
    # 執行 self._process() 方法
    if 'process' in self.__class__.__dict__.keys():
        self._process()

_download()方法如下,首先檢查self.raw_paths列表中的文件是否存在;如果存在,則返回;如果不存在,則調用self.download()方法下載文件。

def _download(self):
    if files_exist(self.raw_paths):  # pragma: no cover
        return

    makedirs(self.raw_dir)
    self.download()

_process()方法如下,首先在self.processed_dir中有pre_transform,那么判斷這個pre_transform和傳進來的pre_transform是否一致,如果不一致,那么警告提示用戶先刪除self.processed_dir文件夾。pre_filter同理。

然后檢查self.processed_paths列表中的文件是否存在;如果存在,則返回;如果不存在,則調用self.process()生成文件。

def _process(self):
    f = osp.join(self.processed_dir, 'pre_transform.pt')
    if osp.exists(f) and torch.load(f) != __repr__(self.pre_transform):
        warnings.warn(
            'The `pre_transform` argument differs from the one used in '
            'the pre-processed version of this dataset. If you really '
            'want to make use of another pre-processing technique, make '
            'sure to delete `{}` first.'.format(self.processed_dir))
    f = osp.join(self.processed_dir, 'pre_filter.pt')
    if osp.exists(f) and torch.load(f) != __repr__(self.pre_filter):
        warnings.warn(
            'The `pre_filter` argument differs from the one used in the '
            'pre-processed version of this dataset. If you really want to '
            'make use of another pre-fitering technique, make sure to '
            'delete `{}` first.'.format(self.processed_dir))

    if files_exist(self.processed_paths):  # pragma: no cover
        return

    print('Processing...')

    makedirs(self.processed_dir)
    self.process()

    path = osp.join(self.processed_dir, 'pre_transform.pt')
    torch.save(__repr__(self.pre_transform), path)
    path = osp.join(self.processed_dir, 'pre_filter.pt')
    torch.save(__repr__(self.pre_filter), path)

    print('Done!')

一般來說不用實現downloand()方法

如果你直接把處理好的 pt 文件放在了self.processed_dir中,那么也不用實現process()方法。

在 Pytorch 的dataset中,我們需要實現__getitem__()方法,根據index返回樣本和標簽。在這里torch_geometric.data.Dataset中,重寫了__getitem__()方法,其中調用了get()方法獲取數據。

def __getitem__(self, idx):
    if isinstance(idx, int):
        data = self.get(self.indices()[idx])
        data = data if self.transform is None else self.transform(data)
        return data
    else:
        return self.index_select(idx)

我們需要實現的是get()方法,根據index返回torch_geometric.data.Data類型的數據。

process()方法存在的意義是原始的格式可能是 csv 或者 mat,在process()函數里可以轉化為 pt 格式的文件,這樣在get()方法中就可以直接使用torch.load()函數讀取 pt 格式的文件,返回的是torch_geometric.data.Data類型的數據,而不用在get()方法做數據轉換操作 (把其他格式的數據轉換為 torch_geometric.data.Data類型的數據)。當然我們也可以提前把數據轉換為 torch_geometric.data.Data類型,使用 pt 格式保存在self.processed_dir中。

DataLoader

通過torch_geometric.data.DataLoader可以方便地使用 mini-batch。

from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    # 對每一個 mini-batch 進行操作
    ...

torch_geometric.data.Batch繼承自torch_geometric.data.Data,并且多了一個屬性:batchbatch是一個列向量,它將每個元素映射到每個 mini-batch 中的相應圖:

batch =\left[\begin{array}{cccccccc}0 & \cdots & 0 & 1 & \cdots & n-2 & n-1 & \cdots & n-1\end{array}\right]^{\top}

我們可以使用它分別為每個圖的節點維度計算平均的節點特征:

from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for data in loader:
    data
    #data: Batch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

    x = scatter_mean(data.x, data.batch, dim=0)
    # x.size(): torch.Size([32, 21])

關于 batching 的流程細節,你可以點擊Pytorch Geometric Documentation查看。關于scatter方法的說明,你可以查看torch-scatter說明文檔

Transforms

transforms在計算機視覺領域是一種很常見的數據增強。PyG 有自己的transforms,輸出是Data類型,輸出也是Data類型。可以使用torch_geometric.transforms.Compose封裝一系列的transforms。我們以 ShapeNet 數據集 (包含 17000 個 point clouds,每個 point 分類為 16 個類別的其中一個) 為例,我們可以使用transforms從 point clouds 生成最近鄰圖:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

還可以通過transform在一定范圍內隨機平移每個點,增加坐標上的擾動,做數據增強:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6),
                    transform=T.RandomTranslate(0.01))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

模型訓練

這里只是展示一個簡單的 GCN 模型構造和訓練過程,沒有用到DatasetDataLoader

我們將使用一個簡單的 GCN 層,并在 Cora 數據集上實驗。有關 GCN 的更多內容,請查看 關于 GCN 的理解

我們首先加載數據集:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')

然后定義 2 層的 GCN:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

然后訓練 200 個 epochs:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

最后在測試集上驗證了模型的準確率:

model.eval()
_, pred = model(data).max(dim=1)
correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))



至此,關于Pytorch Geometric的簡單使用教程就講完了。

回顧一下,在這篇文章中,在講述使用Pytorch Geometric的過程中,花了較多篇幅分析了圖數據是如何表示的,分析了Dataset的工作流程,讓你明白圖數據在Dataset里都經過了哪些步驟,才得以輸入到模型,最終可以利用Dataset來構建自己的數據集。

如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章。

我的文章會首發在公眾號上,歡迎掃碼關注我的公眾號張賢同學

公眾號
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 230,321評論 6 543
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 99,559評論 3 429
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 178,442評論 0 383
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,835評論 1 317
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 72,581評論 6 412
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,922評論 1 328
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,931評論 3 447
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 43,096評論 0 290
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 49,639評論 1 336
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 41,374評論 3 358
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 43,591評論 1 374
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 39,104評論 5 364
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,789評論 3 349
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 35,196評論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,524評論 1 295
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 52,322評論 3 400
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 48,554評論 2 379