簡介
Graph Neural Networks 簡稱 GNN,稱為圖神經網絡,是深度學習中近年來一個比較受關注的領域。近年來 GNN 在學術界受到的關注越來越多,與之相關的論文數量呈上升趨勢,GNN 通過對信息的傳遞,轉換和聚合實現特征的提取,類似于傳統的 CNN,只是 CNN 只能處理規則的輸入,如圖片等輸入的高、寬和通道數都是固定的,而 GNN 可以處理不規則的輸入,如點云等。 可查看【GNN】萬字長文帶你入門 GCN。
而 PyTorch Geometric Library (簡稱 PyG) 是一個基于 PyTorch 的圖神經網絡庫,地址是:https://github.com/rusty1s/pytorch_geometric。它包含了很多 GNN 相關論文中的方法實現和常用數據集,并且提供了簡單易用的接口來生成圖,因此對于復現論文來說也是相當方便。用法大多數和 PyTorch 很相近,因此熟悉 PyTorch 的同學使用這個庫可以很快上手。
torch_geometric.data.Data
節點和節點之間的邊構成了圖。所以在 PyG 中,如果你要構建圖,那么需要兩個要素:節點和邊。PyG 提供了torch_geometric.data.Data
(下面簡稱Data
) 用于構建圖,包括 5 個屬性,每一個屬性都不是必須的,可以為空。
- x: 用于存儲每個節點的特征,形狀是
[num_nodes, num_node_features]
。 - edge_index: 用于存儲節點之間的邊,形狀是
[2, num_edges]
。 - pos: 存儲節點的坐標,形狀是
[num_nodes, num_dimensions]
。 - y: 存儲樣本標簽。如果是每個節點都有標簽,那么形狀是
[num_nodes, *]
;如果是整張圖只有一個標簽,那么形狀是[1, *]
。 - edge_attr: 存儲邊的特征。形狀是
[num_edges, num_edge_features]
。
實際上,Data
對象不僅僅限制于這些屬性,我們可以通過data.face
來擴展Data
,以張量保存三維網格中三角形的連接性。
需要注意的的是,在Data
里包含了樣本的 label,這意味和 PyTorch 稍有不同。在PyTorch
中,我們重寫Dataset
的__getitem__()
,根據 index 返回對應的樣本和 label。在 PyG 中,我們使用的不是這種寫法,而是在get()
函數中根據 index 返回torch_geometric.data.Data
類型的數據,在Data
里包含了數據和 label。
下面一個例子是未加權無向圖 ( 未加權指邊上沒有權值 ),包括 3 個節點和 4 條邊。
<div align="center"><img src="https://image.zhangxiann.com/20200522215349.png"/></div>
由于是無向圖,因此有 4 條邊:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)。每個節點都有自己的特征。上面這個圖可以使用torch_geometric.data.Data
來表示如下:
import torch
from torch_geometric.data import Data
# 由于是無向圖,因此有 4 條邊:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
# 節點的特征
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
注意edge_index
中邊的存儲方式,有兩個list
,第 1 個list
是邊的起始點,第 2 個list
是邊的目標節點。注意與下面的存儲方式的區別。
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous())
這種情況edge_index
需要先轉置然后使用contiguous()
方法。關于contiguous()
函數的作用,查看 PyTorch中的contiguous。
最后再復習一遍,Data
中最基本的 4 個屬性是x
、edge_index
、pos
、y
,我們一般都需要這 4 個參數。
有了Data
,我們可以創建自己的Dataset
,讀取并返回Data
了。
Dataset 與 DataLoader
PyG 的 Dataset
繼承自torch.utils.data.Dataset
,自帶了很多圖數據集,我們以TUDataset
為例,通過以下代碼就可以加載數據集,root
參數設置數據下載的位置。通過索引可以訪問每一個數據。
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
data = dataset[0]
在一個圖中,由edge_index
和edge_attr
可以決定所有節點的鄰接矩陣。PyG 通過創建稀疏的對角鄰接矩陣,并在節點維度中連接特征矩陣和 label 矩陣,實現了在 mini-batch 的并行化。PyG 允許在一個 mini-batch 中的每個Data
(圖) 使用不同數量的節點和邊。
<div align="center"><img src="https://image.zhangxiann.com/20200522225100.png"/></div>
自定義 Dataset
盡管 PyG 已經包含許多有用的數據集,我們也可以通過繼承torch_geometric.data.Dataset
使用自己的數據集。提供 2 種不同的Dataset
:
- InMemoryDataset:使用這個
Dataset
會一次性把數據全部加載到內存中。 - Dataset: 使用這個
Dataset
每次加載一個數據到內存中,比較常用。
我們需要在自定義的Dataset
的初始化方法中傳入數據存放的路徑,然后 PyG 會在這個路徑下再劃分 2 個文件夾:
-
raw_dir
: 存放原始數據的路徑,一般是 csv、mat 等格式 -
processed_dir
: 存放處理后的數據,一般是 pt 格式 ( 由我們重寫process()
方法實現)。
在 PyTorch 中,是沒有這兩個文件夾的。下面來說明一下這兩個文件夾在 PyG 中的實際意義和處理邏輯。
torch_geometric.data.Dataset
繼承自torch.utils.data.Dataset
,在初始化方法 __init__()
中,會調用_download()
方法和_process()
方法。
def __init__(self, root=None, transform=None, pre_transform=None,
pre_filter=None):
super(Dataset, self).__init__()
if isinstance(root, str):
root = osp.expanduser(osp.normpath(root))
self.root = root
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self.__indices__ = None
# 執行 self._download() 方法
if 'download' in self.__class__.__dict__.keys():
self._download()
# 執行 self._process() 方法
if 'process' in self.__class__.__dict__.keys():
self._process()
_download()
方法如下,首先檢查self.raw_paths
列表中的文件是否存在;如果存在,則返回;如果不存在,則調用self.download()
方法下載文件。
def _download(self):
if files_exist(self.raw_paths): # pragma: no cover
return
makedirs(self.raw_dir)
self.download()
_process()
方法如下,首先在self.processed_dir
中有pre_transform
,那么判斷這個pre_transform
和傳進來的pre_transform
是否一致,如果不一致,那么警告提示用戶先刪除self.processed_dir
文件夾。pre_filter
同理。
然后檢查self.processed_paths
列表中的文件是否存在;如果存在,則返回;如果不存在,則調用self.process()
生成文件。
def _process(self):
f = osp.join(self.processed_dir, 'pre_transform.pt')
if osp.exists(f) and torch.load(f) != __repr__(self.pre_transform):
warnings.warn(
'The `pre_transform` argument differs from the one used in '
'the pre-processed version of this dataset. If you really '
'want to make use of another pre-processing technique, make '
'sure to delete `{}` first.'.format(self.processed_dir))
f = osp.join(self.processed_dir, 'pre_filter.pt')
if osp.exists(f) and torch.load(f) != __repr__(self.pre_filter):
warnings.warn(
'The `pre_filter` argument differs from the one used in the '
'pre-processed version of this dataset. If you really want to '
'make use of another pre-fitering technique, make sure to '
'delete `{}` first.'.format(self.processed_dir))
if files_exist(self.processed_paths): # pragma: no cover
return
print('Processing...')
makedirs(self.processed_dir)
self.process()
path = osp.join(self.processed_dir, 'pre_transform.pt')
torch.save(__repr__(self.pre_transform), path)
path = osp.join(self.processed_dir, 'pre_filter.pt')
torch.save(__repr__(self.pre_filter), path)
print('Done!')
一般來說不用實現downloand()
方法。
如果你直接把處理好的 pt 文件放在了self.processed_dir
中,那么也不用實現process()
方法。
在 Pytorch 的dataset
中,我們需要實現__getitem__()
方法,根據index
返回樣本和標簽。在這里torch_geometric.data.Dataset
中,重寫了__getitem__()
方法,其中調用了get()
方法獲取數據。
def __getitem__(self, idx):
if isinstance(idx, int):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
return data
else:
return self.index_select(idx)
我們需要實現的是get()
方法,根據index
返回torch_geometric.data.Data
類型的數據。
process()
方法存在的意義是原始的格式可能是 csv 或者 mat,在process()
函數里可以轉化為 pt 格式的文件,這樣在get()
方法中就可以直接使用torch.load()
函數讀取 pt 格式的文件,返回的是torch_geometric.data.Data
類型的數據,而不用在get()
方法做數據轉換操作 (把其他格式的數據轉換為 torch_geometric.data.Data
類型的數據)。當然我們也可以提前把數據轉換為 torch_geometric.data.Data
類型,使用 pt 格式保存在self.processed_dir
中。
DataLoader
通過torch_geometric.data.DataLoader
可以方便地使用 mini-batch。
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
# 對每一個 mini-batch 進行操作
...
torch_geometric.data.Batch
繼承自torch_geometric.data.Data
,并且多了一個屬性:batch
。batch
是一個列向量,它將每個元素映射到每個 mini-batch 中的相應圖:
batch
我們可以使用它分別為每個圖的節點維度計算平均的節點特征:
from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for data in loader:
data
#data: Batch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
x = scatter_mean(data.x, data.batch, dim=0)
# x.size(): torch.Size([32, 21])
關于 batching 的流程細節,你可以點擊Pytorch Geometric Documentation查看。關于scatter
方法的說明,你可以查看torch-scatter說明文檔。
Transforms
transforms
在計算機視覺領域是一種很常見的數據增強。PyG 有自己的transforms
,輸出是Data
類型,輸出也是Data
類型。可以使用torch_geometric.transforms.Compose
封裝一系列的transforms
。我們以 ShapeNet 數據集 (包含 17000 個 point clouds,每個 point 分類為 16 個類別的其中一個) 為例,我們可以使用transforms
從 point clouds 生成最近鄰圖:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
還可以通過transform
在一定范圍內隨機平移每個點,增加坐標上的擾動,做數據增強:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6),
transform=T.RandomTranslate(0.01))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
模型訓練
這里只是展示一個簡單的 GCN 模型構造和訓練過程,沒有用到Dataset
和DataLoader
。
我們將使用一個簡單的 GCN 層,并在 Cora 數據集上實驗。有關 GCN 的更多內容,請查看 關于 GCN 的理解。
我們首先加載數據集:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
然后定義 2 層的 GCN:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
然后訓練 200 個 epochs:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
最后在測試集上驗證了模型的準確率:
model.eval()
_, pred = model(data).max(dim=1)
correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))
至此,關于Pytorch Geometric
的簡單使用教程就講完了。
回顧一下,在這篇文章中,在講述使用Pytorch Geometric
的過程中,花了較多篇幅分析了圖數據是如何表示的,分析了Dataset
的工作流程,讓你明白圖數據在Dataset
里都經過了哪些步驟,才得以輸入到模型,最終可以利用Dataset
來構建自己的數據集。
如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章。
我的文章會首發在公眾號上,歡迎掃碼關注我的公眾號張賢同學。