import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import AvgPooling
from ogb.graphproppred.mol_encoder import AtomEncoder
from model import EGINConv, EGCNConv
# import virtual node
from gtrick.dgl import VirtualNode
class EGNN(nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers,
dropout, conv_type, mol=False):
super(EGNN, self).__init__()
self.mol = mol
if mol:
self.node_encoder = AtomEncoder(hidden_channels)
else:
self.node_encoder = nn.Embedding(1, hidden_channels)
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
self.vns = nn.ModuleList()
for i in range(num_layers):
if conv_type == 'gin':
self.convs.append(
EGINConv(hidden_channels, self.mol))
elif conv_type == 'gcn':
self.convs.append(
EGCNConv(hidden_channels, self.mol))
self.bns.append(nn.BatchNorm1d(hidden_channels))
# add a virtual node layer
self.vns.append(VirtualNode(hidden_channels, hidden_channels, dropout=dropout))
if not self.mol:
self.bns.append(nn.BatchNorm1d(hidden_channels))
# add a virtual node layer
self.vns.append(VirtualNode(hidden_channels, hidden_channels, dropout=dropout))
self.dropout = dropout
self.pool = AvgPooling()
self.out = nn.Linear(hidden_channels, out_channels)
def reset_parameters(self):
if self.mol:
for emb in self.node_encoder.atom_embedding_list:
nn.init.xavier_uniform_(emb.weight.data)
else:
nn.init.xavier_uniform_(self.node_encoder.weight.data)
num_layers = len(self.convs)
for i in range(num_layers):
self.convs[i].reset_parameters()
self.bns[i].reset_parameters()
self.vns[i].reset_parameters()
self.out.reset_parameters()
def forward(self, g, x, ex):
h = self.node_encoder(x)
vx = None
for i, conv in enumerate(self.convs[:-1]):
# use virtual node to update node embedding
h, vx = self.vns[i].update_node_emb(g, h, vx)
h = conv(g, h, ex)
h = self.bns[i](h)
h = F.relu(h)
h = F.dropout(h, p=self.dropout, training=self.training)
# use updated node embedding to update virtual node embeddding
vx = self.vns[i].update_vn_emb(g, h, vx)
if self.mol:
h = self.convs[-1](g, h, ex)
h = F.dropout(h, self.dropout, training = self.training)
else:
h, vx = self.vns[-1].update_node_emb(g, h, vx)
h = self.convs[-1](g, h, ex)
h = self.bns[-1](h)
h = F.dropout(h, p=self.dropout, training=self.training)
h = self.pool(g, h)
h = self.out(h)
return h
import argparse
from ogb.graphproppred import DglGraphPropPredDataset
from graph_pred import run_graph_pred
!nvidia-smi
parser = argparse.ArgumentParser(
description='train graph property prediction')
parser.add_argument('--dataset', type=str, default='ogbg-ppa',
choices=['ogbg-molhiv', 'ogbg-ppa'])
parser.add_argument('--dataset_path', type=str, default='/dev/dataset',
help='path to dataset')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--hidden_channels', type=int, default=300)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=32,
help='batch size')
parser.add_argument('--num_workers', type=int, default=0,
help='number of workers (default: 0)')
parser.add_argument('--model', type=str, default='gin')
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--runs', type=int, default=3)
parser.add_argument('--patience', type=int, default=30)
args = parser.parse_args(args=[])
print(args)
dataset = DglGraphPropPredDataset(
name=args.dataset, root=args.dataset_path)
if args.dataset == 'ogbg-molhiv':
model = EGNN(args.hidden_channels,
dataset.num_tasks, args.num_layers,
args.dropout, args.model, mol=True)
elif args.dataset == 'ogbg-ppa':
model = EGNN(args.hidden_channels,
int(dataset.num_classes), args.num_layers,
args.dropout, args.model)
run_graph_pred(args, model, dataset)