import argparse
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
import torch
import torch.nn.functional as F
from node_pred import train, test
from utils import EarlyStopping, seed_everything
from model import GNN
import numpy as np
# import CorrectAndSmooth
from gtrick.dgl import CorrectAndSmooth
def run_node_pred(args, model, dataset):
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
model.to(device)
# dataset = DglNodePropPredDataset(name=args.dataset, root=args.dataset_path)
evaluator = Evaluator(name=args.dataset)
g, y = dataset[0]
# add reverse edges
srcs, dsts = g.all_edges()
g.add_edges(dsts, srcs)
# add self-loop
print(f'Total edges before adding self-loop {g.number_of_edges()}')
g = g.remove_self_loop().add_self_loop()
print(f'Total edges after adding self-loop {g.number_of_edges()}')
g, y = g.to(device), y.to(device)
if args.dataset == 'ogbn-proteins':
x = g.ndata['species']
else:
x = g.ndata['feat']
split_idx = dataset.get_idx_split()
train_idx, valid_idx = split_idx['train'], split_idx['valid']
final_test_acc, final_test_acc_cs = [], []
for run in range(args.runs):
model.reset_parameters()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
early_stopping = EarlyStopping(
patience=args.patience, verbose=True, mode='max')
best_test_acc, best_val_acc = 0, 0
best_out = None
test_acc_cs = 0
for epoch in range(1, 1 + args.epochs):
loss = train(model, g, x, y, train_idx,
optimizer, dataset.task_type)
result = test(model, g, x, y, split_idx,
evaluator, dataset.eval_metric)
train_acc, valid_acc, test_acc, out = result
if epoch % args.log_steps == 0:
print(f'Run: {run + 1:02d}, '
f'Epoch: {epoch:02d}, '
f'Loss: {loss:.4f}, '
f'Train: {100 * train_acc:.2f}%, '
f'Valid: {100 * valid_acc:.2f}% '
f'Test: {100 * test_acc:.2f}%')
if valid_acc > best_val_acc:
best_val_acc = valid_acc
best_test_acc = test_acc
best_out = out
if early_stopping(valid_acc, model):
break
# define c & s
cs = CorrectAndSmooth(num_correction_layers=args.num_correction_layers,
correction_alpha=args.correction_alpha,
num_smoothing_layers=args.num_smoothing_layers,
smoothing_alpha=args.smoothing_alpha,
autoscale=args.autoscale)
# use labels of train and valid set to propagate
mask_idx = torch.cat([train_idx, valid_idx])
y_soft = cs.correct(g, best_out, y[mask_idx], mask_idx)
y_soft = cs.smooth(g, y_soft, y[mask_idx], mask_idx)
y_pred = y_soft.argmax(dim=-1, keepdim=True)
test_acc_cs = evaluator.eval({
'y_true': y[split_idx['test']],
'y_pred': y_pred[split_idx['test']],
})[dataset.eval_metric]
print('Best Test Acc: {:.4f}, Best Test Acc with C & S: {:.4f}'.format(best_test_acc, test_acc_cs))
final_test_acc.append(best_test_acc)
final_test_acc_cs.append(test_acc_cs)
print('Test Acc: {:.4f} ± {:.4f}, Test Acc with C & S: {:.4f} ± {:.4f}'.format(np.mean(final_test_acc), np.std(final_test_acc), np.mean(final_test_acc_cs), np.std(final_test_acc_cs)))
parser = argparse.ArgumentParser(
description='train node property prediction')
parser.add_argument('--dataset', type=str, default='ogbn-arxiv',
choices=['ogbn-arxiv'])
parser.add_argument('--dataset_path', type=str, default='/dev/dataset',
help='path to dataset')
parser.add_argument('--device', type=int, default=1)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--runs', type=int, default=3)
parser.add_argument('--patience', type=int, default=30)
# params for GNN
parser.add_argument('--model', type=str, default='gcn')
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--dropout', type=float, default=0.5)
# params for C & S
parser.add_argument('--num-correction-layers', type=int, default=50)
parser.add_argument('--correction-alpha', type=float, default=0.979)
parser.add_argument('--num-smoothing-layers', type=int, default=50)
parser.add_argument('--smoothing-alpha', type=float, default=0.756)
parser.add_argument('--autoscale', action='store_true', default=True)
args = parser.parse_args(args=[])
print(args)
seed_everything(3042)
dataset = DglNodePropPredDataset(name=args.dataset, root=args.dataset_path)
g, _ = dataset[0]
num_features = g.ndata['feat'].shape[1]
model = GNN(num_features, args.hidden_channels,
dataset.num_classes, args.num_layers,
args.dropout, args.model)
run_node_pred(args, model, dataset)