數(shù)據(jù)完整存儲與內存的數(shù)據(jù)集類
一、InMemoryDataset基類簡介
在PyG中,通過繼承InMemoryDataset類來自定義一個數(shù)據(jù)可全部存儲到內存的數(shù)據(jù)集類。
class InMemoryDataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None)
每個數(shù)據(jù)集都要有一個根文件夾(root),指示數(shù)據(jù)集應該被保存在哪里。在根目錄下至少有兩個文件夾:
(1)raw_dir,用于存儲未處理的文件,保存從網絡上下載的數(shù)據(jù)集文件;
(2)processed_dir,保存處理后的數(shù)據(jù)集被保存到這里。
創(chuàng)建一個InMemoryDataset,我們需要實現(xiàn)四個基本方法:
(1)raw_file_names()是一個屬性方法,返回一個文件名列表,文件應該能在raw_dir文件夾中找到,否則調用download()函數(shù)下載文件到raw_dir文件夾。
(2)processed_raw_file_names()是一個屬性方法,返回一個文件名列表,文件應該能在processed_dir文件夾中找到,否則調用process()函數(shù)對樣本做預處理然后保存到processed_dir文件夾。
(3)download()將原始數(shù)據(jù)文件下載到raw_dir文件夾。
(4)process()對樣本做預處理然后保存到processed_dir文件夾。
import torch
from torch_geometric.data import InMemoryDataset, download_url
class MyOwnDataset(InMemoryDataset):
? ? ? ? ?def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
? ? ? ? ? ? ? ? ? ? super().__init__(root=root, transform=transform,? ? ? ? ? ? pre_transform=pre_transform, pre_filter=pre_filter)
? ? ? ? ? ? ? ? ? ? self.data, self.slices = torch.load(self.processed_paths[0])
? ? ? ? ? ?@property
? ? ? ? ? ?def raw_file_names(self):
? ? ? ? ? ? ? ? ? ? ?return ['some_file_1', 'some_file_2', ...]
? ? ? ? ? ?@property
? ? ? ? ? ?def processed_file_names(self):
? ? ? ? ? ? ? ? ? ? ?return ['data.pt']
? ? ? ? ?def download(self):
? ? ? ? ? ? ? ? ? ? ?download_url(url, self.raw_dir)
? ? ? ? def process(self):
? ? ? ? ? ? ? ? ? ? ?data_list = [...]
? ? ? ? ? ? ? ? ? ? ?if self.pre_filter is not None:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?data_list = [data for data in data_list if self.pre_filter(data)]
? ? ? ? ? ? ? ? ? ?if self.pre_transform is not None:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? data_list = [self.pre_transform(data) for data in data_list]
? ? ? ? ? ? ? ? ? data, slices = self.collate(data_list)
? ? ? ? ? ? ? ? ? torch.save((data, slices), self.processed_paths[0])
樣本從原始文件轉換成 Data類對象的過程定義在process函數(shù)中。
需要讀取和創(chuàng)建一個 Data對象的列表,并將其保存到processed_dir中。在構造函數(shù)中把Data對象和切片字典分別加載到屬性self.data和self.slices中。
二、定義一個InMemoryDataset子類
以公開數(shù)據(jù)集PubMed為例。PubMed數(shù)據(jù)集存儲的是文章引用網絡,文章對應圖的結點,如果兩篇文章存在引用關系,則這兩篇文章對應的結點之間存在邊。
基于PyG中的Planetoid類修改得到下面的PlanetoidPubMed數(shù)據(jù)集類。
import os.pathas osp
import torch
from torch_geometric.dataimport (InMemoryDataset, download_url)
from torch_geometric.ioimport read_planetoid_data
class PlanetoidPubMed(InMemoryDataset):
? ? ? ? ? ? ?url ='https://github.com/kimiyoung/planetoid/raw/master/data'
? ? ? ? ? ? ?def __init__(self, root, split="public", num_train_per_class=20,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? num_val=500, num_test=1000, transform=None,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? pre_transform=None):
? ? ? ? ? ? ? ? ? ? ? super(PlanetoidPubMed,self).__init__(root, transform, pre_transform)
? ? ? ? ? ? ? ? ? ? ? self.data,self.slices = torch.load(self.processed_paths[0])
? ? ? ? ? ? ? ? ? ? ? self.split = split
? ? ? ? ? ? ? ? ? ? ? assert self.splitin ['public','full','random']
? ? ? ? ? ? ? ? ? ? ? if split =='full':
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? data =self.get(0)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? data.train_mask.fill_(True)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? data.train_mask[data.val_mask | data.test_mask] =False
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? self.data,self.slices =self.collate([data])
? ? ? ? ? ? ? ? ? ? ? elif split =='random':
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?data =self.get(0)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?data.train_mask.fill_(False)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? for cin range(self.num_classes):
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? idx = (data.y == c).nonzero(as_tuple=False).view(-1)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? data.train_mask[idx] =True
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?remaining = remaining[torch.randperm(remaining.size(0))]
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?data.val_mask.fill_(False)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?data.val_mask[remaining[:num_val]] =True
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?data.test_mask.fill_(False)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?data.test_mask[remaining[num_val:num_val + num_test]] =True
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?self.data,self.slices =self.collate([data])
? ? ? ? ? def raw_dir(self):
? ? ? ? ? ? ? ? ?return osp.join(self.root,'raw')
? ? ? ? ? def processed_dir(self):
? ? ? ? ? ? ? ? ? return osp.join(self.root,'processed')
? ? ? ? ?def raw_file_names(self):
? ? ? ? ? ? ? ? ? ?names = ['x','tx','allx','y','ty','ally','graph','test.index']
? ? ? ? ? ? ? ? ? ?return ['ind.pubmed.{}'.format(name)for namein names]
? ? ? ? ?def processed_file_names(self):
? ? ? ? ? ? ? ? ? ?return 'data.pt'
? ? ? ? ?def download(self):
? ? ? ? ? ? ? ? ? ? for namein self.raw_file_names:
? ? ? ? ? ? ? ? ? ? ? ? ? ? download_url('{}/{}'.format(self.url, name),self.raw_dir)
? ? ? ? ?def process(self):
? ? ? ? ? ? ? ? ? ? data = read_planetoid_data(self.raw_dir,'pubmed')
? ? ? ? ? ? ? ? ? ? data = dataif self.pre_transformis None else self.pre_transform(data)
? ? ? ? ? ? ? ? ? ? torch.save(self.collate([data]),self.processed_paths[0])
? ? ? ? ?def __repr__(self):
? ? ? ? ? ? ? ? ? ? return '{}()'.format(self.name)
在生成一個PlanetoidPubMed類的對象時,程序運行流程如下:
(1)檢查數(shù)據(jù)原始文件是否已下載:
檢查self.raw_dir目錄下是否存在raw_file_names()屬性方法返回的每個文件,如有文件不存在,則調用download()方法執(zhí)行原始文件下載。其中self.raw_dir為osp.join(self.root, 'raw')。
(2)檢查數(shù)據(jù)是否經過處理:
檢查之前對數(shù)據(jù)做變換的方法:檢查self.processed_dir目錄下是否存在pre_transform.pt文件:如果存在,意味著之前進行過數(shù)據(jù)變換,則需加載該文件獲取之前所用的數(shù)據(jù)變換的方法,并檢查它與當前pre_transform參數(shù)指定的方法是否相同;如果不相同則會報出一個警告,“The pre_transform argument differs from the one used in ……”。
檢查之前的樣本過濾的方法:檢查self.processed_dir目錄下是否存在pre_filter.pt文件,如果存在,意味著之前進行過樣本過濾,則需加載該文件獲取之前所用的樣本過濾的方法,并檢查它與當前pre_filter參數(shù)指定的方法是否相同,如果不相同則會報出一個警告,“The pre_filter argument differs from the one used in ……”。其中self.processed_dir為osp.join(self.root, 'processed')。
檢查是否存在處理好的數(shù)據(jù):檢查self.processed_dir目錄下是否存在self.processed_paths方法返回的所有文件,如有文件不存在,意味著不存在已經處理好的樣本的文件,如需執(zhí)行以下的操作:
(a)調用process方法,進行數(shù)據(jù)處理。
(b)如果pre_transform參數(shù)不為None,則調用pre_transform方法進行數(shù)據(jù)處理。
(c)如果pre_filter參數(shù)不為None,則進行樣本過濾。
(d)保存處理好的數(shù)據(jù)到文件,文件存儲在processed_paths()屬性方法返回的路徑。如果將數(shù)據(jù)保存到多個文件中,則返回的路徑有多個。這些路徑都在self.processed_dir目錄下,以processed_file_names()屬性方法的返回值為文件名。
(e)保存新的pre_transform.pt文件和pre_filter.pt文件,其中分別存儲當前使用的數(shù)據(jù)處理方法和樣本過濾方法。
查看數(shù)據(jù)集
將下載好的數(shù)據(jù)包復制到項目的/dataset/raw文件夾中,直接讀取本地路徑下的數(shù)據(jù)包。
代碼如下:
dataset= PlanetoidPubMed('dataset')
print(dataset.num_classes)
print(dataset[0].num_nodes)
print(dataset[0].num_edges)
print(dataset[0].num_features)
運行結果如下:
根據(jù)運行結果可以得出:該數(shù)據(jù)集包含3個分類任務,19,717個結點,88,648條邊,500個結點特征維度。
結點預測與邊預測任務實踐
一、結點預測任務實踐
重定義一個GAT神經網絡,使其能夠通過參數(shù)定義GATConv的層數(shù),以及每一層GATConv的out_channels。
神經網絡定義如下:
class GAT(torch.nn.Module):
? ? ? ? ? ?def __init__(self, num_features, hidden_channels_list, num_classes):
? ? ? ? ? ? ? ? ? ? super(GAT, self).__init__()
? ? ? ? ? ? ? ? ? ? torch.manual_seed(12345)
? ? ? ? ? ? ? ? ? ? hns = [num_features] + hidden_channels_list
? ? ? ? ? ? ? ? ? ? conv_list = []
? ? ? ? ? ? ? ? ? ? for idx in range(len(hidden_channels_list)):
? ? ? ? ? ? ? ? ? ? ? ? ? conv_list.append((GATConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
? ? ? ? ? ? ? ? ? ? ? ? ? conv_list.append(ReLU(inplace=True),)
? ? ? ? ? ? ? ? ? ? self.convseq = Sequential('x, edge_index', conv_list)
? ? ? ? ? ? ? ? ? ? self.linear = Linear(hidden_channels_list[-1], num_classes)
? ? ? ? def forward(self, x, edge_index):
? ? ? ? ? ? ? ? ? ?x = self.convseq(x, edge_index)
? ? ? ? ? ? ? ? ? ?x = F.dropout(x, p=0.5, training=self.training)
? ? ? ? ? ? ? ? ? x = self.linear(x)
? ? ? ? ? ? ? ? ? return x
二、邊預測任務實踐
邊預測任務是預測兩個結點之間是否存在邊。
對于圖數(shù)據(jù)集,存在結點特征矩陣x,和哪些結點之間存在邊的信息edge_index。edge_index存儲的便是正樣本,為了構建邊預測任務,需要生成一些負樣本,即采樣一些不存在邊的節(jié)點對作為負樣本邊,正負樣本應平衡。要將樣本分為訓練集、驗證集和測試集三個集合。
PyG中提供了現(xiàn)成的方法,train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1),其第一個參數(shù)為torch_geometric.data.Data對象,第二參數(shù)為驗證集所占比例,第三個參數(shù)為測試集所占比例。該函數(shù)將自動地采樣得到負樣本,并將正負樣本分成訓練集、驗證集和測試集三個集合。使用Cora數(shù)據(jù)集作為例子進行邊預測任務說明。
獲取數(shù)據(jù)集并進行分析
代碼如下:
import os.path as osp
from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
print(data.edge_index.shape)
for key in data.keys:
? ? ? ? ?print(key, getattr(data, key).shape)
運行結果如下:
根據(jù)運行結果可以得出:三個集合中正樣本邊的數(shù)量之和不等于原始邊的數(shù)量。這是因為原始邊的數(shù)量統(tǒng)計的是雙向邊的數(shù)量,在驗證集正樣本邊和測試集正樣本邊中只需對一個方向的邊做預測精度的衡量,對另一個方向的預測精度衡量屬于重復。
構建神經網絡模型
import torch
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
? ? ? ? ? def __init__(self, in_channels, out_channels):
? ? ? ? ? ? ? ? ? ?super(Net, self).__init__()
? ? ? ? ? ? ? ? ? self.conv1 = GCNConv(in_channels, 128)
? ? ? ? ? ? ? ? ? self.conv2 = GCNConv(128, out_channels)
? ? ? ? def encode(self, x, edge_index):
? ? ? ? ? ? ? ? ? x = self.conv1(x, edge_index)
? ? ? ? ? ? ? ? ? x = x.relu()
? ? ? ? ? ? ? ? ?return self.conv2(x, edge_index)
? ? ? ? def decode(self, z, pos_edge_index, neg_edge_index):
? ? ? ? ? ? ? ? ? edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
? ? ? ? ? ? ? ? ? return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
? ? ? ? def decode_all(self, z):
? ? ? ? ? ? ? ? ? ?prob_adj = z @ z.t()
? ? ? ? ? ? ? ? ? return (prob_adj > 0).nonzero(as_tuple=False).t()
用于做邊預測的神經網絡主要由兩部分組成:
其一是編碼(encode),與生成結點表征是一樣的;其二是解碼(decode),為邊兩端結點的表征生成邊為真的幾率(odds)。decode_all(self, z)用于推斷(inference)階段,對輸入結點的所有結點對預測存在邊的幾率。
定義單個epoch的訓練過程
def get_link_labels(pos_edge_index, neg_edge_index):
? ? ? ? ? ?num_links = pos_edge_index.size(1) + neg_edge_index.size(1)
? ? ? ? ? ?link_labels = torch.zeros(num_links, dtype=torch.float)
? ? ? ? ? link_labels[:pos_edge_index.size(1)] = 1.
? ? ? ? ?return link_labels
def train(data, model, optimizer):
? ? ? ? ? model.train()
? ? ? ? ? neg_edge_index = negative_sampling(
? ? ? ? ? ? ? ? ? ? ? ?edge_index=data.train_pos_edge_index,
? ? ? ? ? ? ? ? ? ? ? ?num_nodes=data.num_nodes,
? ? ? ? ? ? ? ? ? ? ? ?num_neg_samples=data.train_pos_edge_index.size(1))
? ? ? ? ? optimizer.zero_grad()
? ? ? ? ? z = model.encode(data.x, data.train_pos_edge_index)
? ? ? ? ? link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index)
? ? ? ? ? link_labels = get_link_labels(data.train_pos_edge_index,? ? ? ? ? ? ? ? neg_edge_index).to(data.x.device)
? ? ? ? ? loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
? ? ? ? ? loss.backward()
? ? ? ? ? optimizer.step()
? ? ? ? ? return loss
通常在圖上存在邊的結點對的數(shù)量往往少于不存在邊的結點對的數(shù)量。為了類平衡,在每一個epoch的訓練過程中,只需要用到與正樣本一樣數(shù)量的負樣本。
綜合以上兩點原因,在每一個epoch的訓練過程中都采樣與正樣本數(shù)量一樣的負樣本,這樣既做到了類平衡,又增加了訓練負樣本的豐富性。get_link_labels函數(shù)用于生成完整訓練集的標簽。在負樣本采樣時,傳遞了train_pos_edge_index為參數(shù),于是negative_sampling函數(shù)只會在訓練集中不存在邊的結點對中采樣。
在訓練階段,應該只見訓練集,對驗證集與測試集都是不可見的,在此階段應該要完成對所有結點的編碼,假設此處正樣本訓練集涉及到了所有的結點,這樣就能實現(xiàn)對所有結點的編碼。
定義單個epoch驗證與測試過程
@torch.no_grad()
def test(data, model):
? ? ? ? ? model.eval()
? ? ? ? ? ?z = model.encode(data.x, data.train_pos_edge_index)
? ? ? ? ? ?results = []
? ? ? ? ? ?for prefix in ['val', 'test']:
? ? ? ? ? ? ? ? ? ? pos_edge_index = data[f'{prefix}_pos_edge_index']
? ? ? ? ? ? ? ? ? ? neg_edge_index = data[f'{prefix}_neg_edge_index']
? ? ? ? ? ? ? ? ? ? link_logits = model.decode(z, pos_edge_index, neg_edge_index)
? ? ? ? ? ? ? ? ? ? link_probs = link_logits.sigmoid()
? ? ? ? ? ? ? ? ? ? link_labels = get_link_labels(pos_edge_index, neg_edge_index)
? ? ? ? ? ? ? ? ? ? results.append(roc_auc_score(link_labels.cpu(), link_probs.cpu()))
? ? ? ? ? return results
在驗證與測試過程中,只用正樣本邊訓練集做節(jié)點特征編碼。
運行完整的訓練、驗證與測試
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
ground_truth_edge_index = data.edge_index.to(device)
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
data = data.to(device)
model = Net(dataset.num_features, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
best_val_auc = test_auc = 0
for epoch in range(1, 101):
? ? ? ? loss = train(data, model, optimizer)
? ? ? ? val_auc, tmp_test_auc = test(data, model)
? ? ? ? if val_auc > best_val_auc:
? ? ? ? ? ? ? ?best_val_auc = val_auc
? ? ? ? ? ? ? ?test_auc = tmp_test_auc
? ? ? ? print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
? ? ? ? ? ? ? f'Test: {test_auc:.4f}')
z = model.encode(data.x, data.train_pos_edge_index)
final_edge_index = model.decode_all(z)
三、總結
定義一個數(shù)據(jù)可全部存儲于內存的數(shù)據(jù)集類的方法,并且實踐結點預測任務和邊預測任務。
要重點關注InMemoryDataset子類的運行流程與實現(xiàn)四個函數(shù)的規(guī)范。在圖神經網絡的實現(xiàn)中,可以適用torch_geometric.nn.Sequential容器對神經網絡的多個模塊順序相連。
四、作業(yè)
對結點預測任務,嘗試用PyG中的不同的網絡層去代替GCNConv,以及不同的層數(shù)和不同的out_channels。
使用GAT代替GCN
class GAT(torch.nn.Module):
? ? ? ? ? ?def __init__(self, num_features, hidden_channels_list, num_classes):
? ? ? ? ? ? ? ? ? ? super(GAT, self).__init__()
? ? ? ? ? ? ? ? ? ? torch.manual_seed(12345)
? ? ? ? ? ? ? ? ? ? hns = [num_features] + hidden_channels_list
? ? ? ? ? ? ? ? ? ? conv_list = []
? ? ? ? ? ? ? ? ? ? for idx in range(len(hidden_channels_list)):
? ? ? ? ? ? ? ? ? ? ? ? ? conv_list.append((GATConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
? ? ? ? ? ? ? ? ? ? ? ? ? conv_list.append(ReLU(inplace=True),)
? ? ? ? ? ? ? ? ? ? self.convseq = Sequential('x, edge_index', conv_list)
? ? ? ? ? ? ? ? ? ? self.linear = Linear(hidden_channels_list[-1], num_classes)
? ? ? ? ? def forward(self, x, edge_index):
? ? ? ? ? ? ? ? ? ?x = self.convseq(x, edge_index)
? ? ? ? ? ? ? ? ? ?x = F.dropout(x, p=0.5, training=self.training)
? ? ? ? ? ? ? ? ? x = self.linear(x)
? ? ? ? ? ? ? ? ? return x
對邊預測任務,嘗試用用torch_geometric.nn.Sequential容器構造圖神經網絡。
class GCN(torch.nn.Module):
? ? ? def __init__(self, num_features, hidden_channels_list, num_classes):
? ? ? ? ? ? ? ?super(GCN, self).__init__()
? ? ? ? ? ? ? ? ?torch.manual_seed(12345)
? ? ? ? ? ? ? ? ?hns = [num_features] + hidden_channels_list
? ? ? ? ? ? ? ? ?conv_list = []
? ? ? ? ? ? ? ? ?for idx in range(len(hidden_channels_list)):
? ? ? ? ? ? ? ? ? ? ? ? conv_list.append((GCNConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
? ? ? ? ? ? ? ? ? ? ? ?conv_list.append(ReLU(inplace=True))
? ? ? ? ? ? ? ? ?self.conseq = Sequential('x, edge_index', conv_list)
? ? ? ? ? ? ? ? ?self.linear = Linear(hidden_channels_list[-1], num_classes)
? ? ? def encode(self, x, edge_index):
? ? ? ? ? ? ? ? ? x = self.conseq(x, edge_index)
? ? ? ? ? ? ? ? ? x = F.dropout(x, p=0.5, training=self.training)
? ? ? ? ? ? ? ? ? x = self.linear(x)
? ? ? ? ? ? ? ? ? return x
? ? def decode(self, z, pos_edge_index, neg_edge_index):
? ? ? ? ? ? ? ? ? ?edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
? ? ? ? ? ? ? ? ? return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
? ? def decode_all(self, z):
? ? ? ? ? ? ? ? ? prob_adj = z @ z.t()
? ? ? ? ? ? ? ? ? return (prob_adj > 0).nonzero(as_tuple=False).t()
如下方代碼所示,以data.train_pos_edge_index為實際參數(shù),這樣采樣得到的負樣本可能包含驗證集正樣本或測試集正樣本,即可能將真實的正樣本標記為負樣本,由此會產生沖突。為什么這么做?為什么在驗證與測試階段只根據(jù)data.train_pos_edge_index做結點表征的編碼?
數(shù)據(jù)集實際上就是正樣本多,負樣本少。生成負樣本是為了使正負樣本平衡。采樣的負樣本中存在部分正樣本,對于數(shù)據(jù)集來說是極少數(shù)的,不影響正負樣本平衡。
只根據(jù)data.train_pos_edge_index做結點表征的編碼,忽略驗證集和測試集的具體情況,減少影響因素,提高準確性。
DataWhale開源學習資料:
https://github.com/datawhalechina/team-learning-nlp/tree/master/GNN