In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"]="0" 
In [2]:
import ktrain
from ktrain import graph as gr
Using TensorFlow backend.
using Keras version: 2.2.4

Node Classification in Graphs

In this notebook, we will use ktrain to perform node classificaiton on the PubMed Diabetes citation graph. In the PubMed graph, each node represents a paper pertaining to one of three topics: Diabetes Mellitus - Experimental, Diabetes Mellitus - Type 1, and Diabetes Mellitus - Type 2. Links represent citations between papers. The attributes or features assigned to each node are in the form of a vector of words in each paper and their corresponding TF-IDF scores. The dataset is available here.

ktrain expects two files for node classification problems. The first is comma or tab delimited file listing the edges in the graph, where each row contains the node IDs forming the edge. The second is a comma or tab delimted file listing the features or attributes associated with each node in the graph. The first column in this file is the User ID and the last column should be string representing the target or label of the node. All other nodes should be numerical features assumed to be standardized appropriately and non-null.

We must prepare the raw data to conform to the above before we begin.

Preparing the Data

The code below will create two files that can be processed directly by ktrain:

  • /tmp/pubmed-nodes.tab
  • /tmp/pubmed-edges.tab
In [3]:
# set this to the location of the downloaded Pubmed-Diabetes data
DATADIR = 'data/pubmed/Pubmed-Diabetes/data'
In [4]:
import os.path
import pandas as pd
import itertools

# process links
edgelist = pd.read_csv(os.path.join(DATADIR, 'Pubmed-Diabetes.DIRECTED.cites.tab'), 
                      skiprows=2, header=None,delimiter='\t')
edgelist.drop(columns=[0,2], inplace=True)
edgelist.columns = ['source', 'target']
edgelist['source'] = edgelist['source'].map(lambda x: x.lstrip('paper:')) 
edgelist['target'] = edgelist['target'].map(lambda x: x.lstrip('paper:'))
edgelist.head()
edgelist.to_csv('/tmp/pubmed-edges.tab', sep='\t', header=None, index=False )

# process nodes and their attributes
nodes_as_dict = []
with open(os.path.join(os.path.expanduser(DATADIR), "Pubmed-Diabetes.NODE.paper.tab")) as fp:
    for line in itertools.islice(fp, 2, None):
        line_res = line.split("\t")
        pid = line_res[0]
        feat_name = ['pid'] + [l.split("=")[0] for l in line_res[1:]][:-1] # delete summary
        feat_value = [l.split("=")[1] for l in line_res[1:]][:-1] # delete summary
        feat_value = [pid] + [ float(x) for x in feat_value ] # change to numeric from str
        row = dict(zip(feat_name, feat_value))
        nodes_as_dict.append(row)
colnames = set()
for row in nodes_as_dict:
    colnames.update(list(row.keys()))
colnames = list(colnames)
colnames.sort()
colnames.remove('label')
colnames.append('label')
target_dict = {1:'Diabetes_Mellitus-Experimental', 2: 'Diabetes_Mellitus-Type_1', 3:'Diabetes_Mellitus-Type_2', }
with open('/tmp/pubmed-nodes.tab', 'w') as fp:
    #fp.write("\t".join(colnames)+'\n')
    for row in nodes_as_dict:
        feats = []
        for col in colnames:
            feats.append(row.get(col, 0.0))
        feats = [str(feat) for feat in feats]
        feats[-1] = round(float(feats[-1]))
        feats[-1] = target_dict[feats[-1]]
        fp.write("\t".join(feats) + '\n')

STEP 1: Load and Preprocess Data

We will hold out 20% of the nodes as test nodes by setting holdout_pct=0.2. Since we specified holdout_for_inductive=True, these heldout nodes are removed from the graph in order to later simulate making predicitions on new nodes added to the graph later (or inductive inference). If holdout_for_inductive=False, the features (not labels) of these nodes are accessible to the model during training. Of the remaining nodes, 5% will be used for training and the remaining nodes will be used for validation (or transductive inference). More information on transductive and inductive inference and the return values df_holdout and df_complete are provided below.

Note that if there are any unlabeled nodes in the graph, these will be automatically used as heldout nodes for which predictions can be made once the model is trained. See the twitter example notebook for an example of this.

In [5]:
(train_data, val_data, preproc, 
 df_holdout, G_complete)        = gr.graph_nodes_from_csv('/tmp/pubmed-nodes.tab',
                                           '/tmp/pubmed-edges.tab',
                                           sample_size=10, holdout_pct=0.2, holdout_for_inductive=True,
                                           train_pct=0.05, sep='\t')
Largest subgraph statistics: 19717 nodes, 44327 edges
Size of training graph: 15774 nodes
Training nodes: 788
Validation nodes: 14986
Nodes treated as unlabeled for testing/inference: 3943
Size of graph with added holdout nodes: 19717
Holdout node features are not visible during training (inductive_inference)

The preproc object includes a reference to the training graph and a dataframe showing the features and target for each node in the graph (both training and validation nodes).

In [6]:
preproc.df.target.value_counts()
Out[6]:
Diabetes_Mellitus-Type_1          6255
Diabetes_Mellitus-Type_2          6242
Diabetes_Mellitus-Experimental    3277
Name: target, dtype: int64

STEP 2: Build a Model and Wrap in Learner Object

In [7]:
gr.print_node_classifiers()
graphsage: GraphSAGE:  https://arxiv.org/pdf/1706.02216.pdf
In [8]:
learner = ktrain.get_learner(model=gr.graph_node_classifier('graphsage', train_data), 
                             train_data=train_data, 
                             val_data=val_data, 
                             batch_size=64)
Is Multi-Label? False
done

STEP 3: Estimate LR

Given the small number of batches per epoch, a larger number of epochs is required to estimate the learning rate. We will cap it at 100 here.

In [9]:
learner.lr_find(max_epochs=100)
simulating training for different learning rates... this may take a few moments...
Epoch 1/100
12/12 [==============================] - 1s 85ms/step - loss: 1.1021 - acc: 0.3638
Epoch 2/100
12/12 [==============================] - 0s 30ms/step - loss: 1.0971 - acc: 0.3743
Epoch 3/100
12/12 [==============================] - 0s 34ms/step - loss: 1.1027 - acc: 0.3324
Epoch 4/100
12/12 [==============================] - 0s 30ms/step - loss: 1.1024 - acc: 0.3402
Epoch 5/100
12/12 [==============================] - 0s 33ms/step - loss: 1.1034 - acc: 0.3294
Epoch 6/100
12/12 [==============================] - 0s 32ms/step - loss: 1.0960 - acc: 0.3763
Epoch 7/100
12/12 [==============================] - 0s 33ms/step - loss: 1.0965 - acc: 0.3534
Epoch 8/100
12/12 [==============================] - 0s 33ms/step - loss: 1.1091 - acc: 0.3430
Epoch 9/100
12/12 [==============================] - 0s 33ms/step - loss: 1.1116 - acc: 0.3320
Epoch 10/100
12/12 [==============================] - 0s 31ms/step - loss: 1.0876 - acc: 0.3856
Epoch 11/100
12/12 [==============================] - 0s 31ms/step - loss: 1.0977 - acc: 0.3698
Epoch 12/100
12/12 [==============================] - 0s 30ms/step - loss: 1.1057 - acc: 0.3329
Epoch 13/100
12/12 [==============================] - 0s 30ms/step - loss: 1.1049 - acc: 0.3359
Epoch 14/100
12/12 [==============================] - 0s 32ms/step - loss: 1.1036 - acc: 0.3468
Epoch 15/100
12/12 [==============================] - 0s 34ms/step - loss: 1.0954 - acc: 0.3613
Epoch 16/100
12/12 [==============================] - 0s 32ms/step - loss: 1.1098 - acc: 0.3744
Epoch 17/100
12/12 [==============================] - 0s 33ms/step - loss: 1.0955 - acc: 0.3547
Epoch 18/100
12/12 [==============================] - 0s 33ms/step - loss: 1.1050 - acc: 0.3443
Epoch 19/100
12/12 [==============================] - 0s 30ms/step - loss: 1.0893 - acc: 0.3744
Epoch 20/100
12/12 [==============================] - 0s 40ms/step - loss: 1.0921 - acc: 0.3665
Epoch 21/100
12/12 [==============================] - 0s 34ms/step - loss: 1.0996 - acc: 0.3600
Epoch 22/100
12/12 [==============================] - 0s 29ms/step - loss: 1.1018 - acc: 0.3390
Epoch 23/100
12/12 [==============================] - 0s 33ms/step - loss: 1.0971 - acc: 0.3555
Epoch 24/100
12/12 [==============================] - 0s 32ms/step - loss: 1.0936 - acc: 0.3613
Epoch 25/100
12/12 [==============================] - 0s 32ms/step - loss: 1.0910 - acc: 0.3795
Epoch 26/100
12/12 [==============================] - 0s 26ms/step - loss: 1.0853 - acc: 0.3953
Epoch 27/100
12/12 [==============================] - 0s 33ms/step - loss: 1.0881 - acc: 0.3613
Epoch 28/100
12/12 [==============================] - 0s 30ms/step - loss: 1.0881 - acc: 0.3756
Epoch 29/100
12/12 [==============================] - 0s 38ms/step - loss: 1.0866 - acc: 0.3808
Epoch 30/100
12/12 [==============================] - 0s 30ms/step - loss: 1.0757 - acc: 0.4281
Epoch 31/100
12/12 [==============================] - 0s 31ms/step - loss: 1.0709 - acc: 0.4267
Epoch 32/100
12/12 [==============================] - 0s 31ms/step - loss: 1.0838 - acc: 0.3919
Epoch 33/100
12/12 [==============================] - 0s 32ms/step - loss: 1.0801 - acc: 0.3847
Epoch 34/100
12/12 [==============================] - 0s 32ms/step - loss: 1.0669 - acc: 0.4266
Epoch 35/100
12/12 [==============================] - 0s 32ms/step - loss: 1.0600 - acc: 0.4369
Epoch 36/100
12/12 [==============================] - 0s 34ms/step - loss: 1.0477 - acc: 0.4766
Epoch 37/100
12/12 [==============================] - 0s 41ms/step - loss: 1.0432 - acc: 0.4712
Epoch 38/100
12/12 [==============================] - 0s 33ms/step - loss: 1.0141 - acc: 0.5523
Epoch 39/100
12/12 [==============================] - 0s 28ms/step - loss: 1.0296 - acc: 0.4856
Epoch 40/100
12/12 [==============================] - 0s 31ms/step - loss: 1.0022 - acc: 0.5169
Epoch 41/100
12/12 [==============================] - 0s 37ms/step - loss: 0.9882 - acc: 0.5443
Epoch 42/100
12/12 [==============================] - 0s 32ms/step - loss: 0.9765 - acc: 0.5565
Epoch 43/100
12/12 [==============================] - 0s 36ms/step - loss: 0.9715 - acc: 0.5536
Epoch 44/100
12/12 [==============================] - 0s 39ms/step - loss: 0.9235 - acc: 0.6191
Epoch 45/100
12/12 [==============================] - 0s 32ms/step - loss: 0.9103 - acc: 0.6361
Epoch 46/100
12/12 [==============================] - 0s 40ms/step - loss: 0.8570 - acc: 0.6875
Epoch 47/100
12/12 [==============================] - 0s 33ms/step - loss: 0.8375 - acc: 0.7134
Epoch 48/100
12/12 [==============================] - 0s 30ms/step - loss: 0.7789 - acc: 0.7579
Epoch 49/100
12/12 [==============================] - 0s 33ms/step - loss: 0.7346 - acc: 0.7943
Epoch 50/100
12/12 [==============================] - 0s 34ms/step - loss: 0.6784 - acc: 0.8115
Epoch 51/100
12/12 [==============================] - 0s 41ms/step - loss: 0.6411 - acc: 0.8181
Epoch 52/100
12/12 [==============================] - 0s 30ms/step - loss: 0.5853 - acc: 0.8338
Epoch 53/100
12/12 [==============================] - 0s 29ms/step - loss: 0.5497 - acc: 0.8469
Epoch 54/100
12/12 [==============================] - 0s 33ms/step - loss: 0.4958 - acc: 0.8587
Epoch 55/100
12/12 [==============================] - 0s 34ms/step - loss: 0.4363 - acc: 0.8861
Epoch 56/100
12/12 [==============================] - 0s 36ms/step - loss: 0.3972 - acc: 0.8927
Epoch 57/100
12/12 [==============================] - 0s 29ms/step - loss: 0.3773 - acc: 0.8737
Epoch 58/100
12/12 [==============================] - 0s 34ms/step - loss: 0.3735 - acc: 0.8652
Epoch 59/100
12/12 [==============================] - 0s 39ms/step - loss: 0.3351 - acc: 0.8974
Epoch 60/100
12/12 [==============================] - 0s 32ms/step - loss: 0.3001 - acc: 0.9097
Epoch 61/100
12/12 [==============================] - 0s 35ms/step - loss: 0.2728 - acc: 0.9215
Epoch 62/100
12/12 [==============================] - 0s 31ms/step - loss: 0.2761 - acc: 0.9128: 0s - loss: 0.2342 - acc: 0.
Epoch 63/100
12/12 [==============================] - 0s 39ms/step - loss: 0.2826 - acc: 0.9071
Epoch 64/100
12/12 [==============================] - 0s 32ms/step - loss: 0.1876 - acc: 0.9372
Epoch 65/100
12/12 [==============================] - 0s 28ms/step - loss: 0.2418 - acc: 0.9163
Epoch 66/100
12/12 [==============================] - 0s 32ms/step - loss: 0.2193 - acc: 0.9254
Epoch 67/100
12/12 [==============================] - 0s 34ms/step - loss: 0.2385 - acc: 0.9175
Epoch 68/100
12/12 [==============================] - 0s 30ms/step - loss: 0.2542 - acc: 0.9045
Epoch 69/100
12/12 [==============================] - 0s 35ms/step - loss: 0.2287 - acc: 0.9071
Epoch 70/100
12/12 [==============================] - 0s 32ms/step - loss: 0.2167 - acc: 0.9245
Epoch 71/100
12/12 [==============================] - 0s 33ms/step - loss: 0.1879 - acc: 0.9342
Epoch 72/100
12/12 [==============================] - 0s 32ms/step - loss: 0.2314 - acc: 0.9128
Epoch 73/100
12/12 [==============================] - 0s 40ms/step - loss: 0.2224 - acc: 0.9319
Epoch 74/100
12/12 [==============================] - 0s 33ms/step - loss: 0.2177 - acc: 0.9267
Epoch 75/100
12/12 [==============================] - 0s 32ms/step - loss: 0.2174 - acc: 0.9241
Epoch 76/100
12/12 [==============================] - 0s 32ms/step - loss: 0.2491 - acc: 0.9149
Epoch 77/100
12/12 [==============================] - 0s 32ms/step - loss: 0.2436 - acc: 0.9136
Epoch 78/100
12/12 [==============================] - 0s 28ms/step - loss: 0.1898 - acc: 0.9293
Epoch 79/100
12/12 [==============================] - 0s 31ms/step - loss: 0.1675 - acc: 0.9307
Epoch 80/100
12/12 [==============================] - 0s 33ms/step - loss: 0.1949 - acc: 0.9294
Epoch 81/100
12/12 [==============================] - 0s 32ms/step - loss: 0.1895 - acc: 0.9385
Epoch 82/100
12/12 [==============================] - 0s 33ms/step - loss: 0.2656 - acc: 0.9084
Epoch 83/100
12/12 [==============================] - 0s 33ms/step - loss: 0.2249 - acc: 0.9232
Epoch 84/100
12/12 [==============================] - 0s 34ms/step - loss: 0.2740 - acc: 0.8953
Epoch 85/100
12/12 [==============================] - 0s 36ms/step - loss: 0.2511 - acc: 0.9019
Epoch 86/100
12/12 [==============================] - 0s 32ms/step - loss: 0.2529 - acc: 0.9084
Epoch 87/100
12/12 [==============================] - 0s 33ms/step - loss: 0.2575 - acc: 0.8992
Epoch 88/100
12/12 [==============================] - 0s 32ms/step - loss: 0.2617 - acc: 0.9097
Epoch 89/100
12/12 [==============================] - 0s 32ms/step - loss: 0.3072 - acc: 0.9097
Epoch 90/100
12/12 [==============================] - 0s 32ms/step - loss: 0.4338 - acc: 0.8692
Epoch 91/100
12/12 [==============================] - 0s 27ms/step - loss: 0.5869 - acc: 0.8324
Epoch 92/100
12/12 [==============================] - 0s 35ms/step - loss: 0.8302 - acc: 0.8206
Epoch 93/100
12/12 [==============================] - 0s 34ms/step - loss: 1.0682 - acc: 0.7723
Epoch 94/100
12/12 [==============================] - 0s 30ms/step - loss: 0.8207 - acc: 0.8390
Epoch 95/100
12/12 [==============================] - 0s 35ms/step - loss: 0.6389 - acc: 0.8260
Epoch 96/100
12/12 [==============================] - 0s 32ms/step - loss: 0.8993 - acc: 0.8521
Epoch 97/100
 8/12 [===================>..........] - ETA: 0s - loss: 1.2450 - acc: 0.8379

done.
Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.
In [10]:
learner.lr_plot()

STEP 4: Train the Model

We will train the model using autofit, which uses a triangular learning rate policy. The training will automatically stop when the validation loss no longer improves.

In [9]:
learner.autofit(0.01)
early_stopping automatically enabled at patience=5
reduce_on_plateau automatically enabled at patience=2


begin training using triangular learning rate policy with max lr of 0.01...
Epoch 1/1024
13/13 [==============================] - 6s 484ms/step - loss: 1.0057 - acc: 0.4807 - val_loss: 0.8324 - val_acc: 0.6990
Epoch 2/1024
13/13 [==============================] - 6s 425ms/step - loss: 0.8001 - acc: 0.7077 - val_loss: 0.6512 - val_acc: 0.7795
Epoch 3/1024
13/13 [==============================] - 6s 438ms/step - loss: 0.6322 - acc: 0.8045 - val_loss: 0.5574 - val_acc: 0.7875
Epoch 4/1024
13/13 [==============================] - 6s 430ms/step - loss: 0.5251 - acc: 0.8237 - val_loss: 0.5077 - val_acc: 0.8106
Epoch 5/1024
13/13 [==============================] - 6s 476ms/step - loss: 0.4407 - acc: 0.8600 - val_loss: 0.5061 - val_acc: 0.8086
Epoch 6/1024
13/13 [==============================] - 6s 454ms/step - loss: 0.3857 - acc: 0.8697 - val_loss: 0.5033 - val_acc: 0.8046
Epoch 7/1024
13/13 [==============================] - 6s 453ms/step - loss: 0.3682 - acc: 0.8528 - val_loss: 0.4966 - val_acc: 0.8058
Epoch 8/1024
13/13 [==============================] - 6s 462ms/step - loss: 0.3110 - acc: 0.8938 - val_loss: 0.4791 - val_acc: 0.8254
Epoch 9/1024
13/13 [==============================] - 6s 444ms/step - loss: 0.2822 - acc: 0.9035 - val_loss: 0.4873 - val_acc: 0.8160
Epoch 10/1024
13/13 [==============================] - 6s 443ms/step - loss: 0.2734 - acc: 0.9035 - val_loss: 0.4955 - val_acc: 0.8101

Epoch 00010: Reducing Max LR on Plateau: new max lr will be 0.005 (if not early_stopping).
Epoch 11/1024
13/13 [==============================] - 6s 435ms/step - loss: 0.2361 - acc: 0.9264 - val_loss: 0.4898 - val_acc: 0.8214
Epoch 12/1024
13/13 [==============================] - 6s 498ms/step - loss: 0.2292 - acc: 0.9155 - val_loss: 0.5074 - val_acc: 0.8174

Epoch 00012: Reducing Max LR on Plateau: new max lr will be 0.0025 (if not early_stopping).
Epoch 13/1024
13/13 [==============================] - 6s 442ms/step - loss: 0.1969 - acc: 0.9421 - val_loss: 0.5203 - val_acc: 0.8132
Restoring model weights from the end of the best epoch
Epoch 00013: early stopping
Weights from best epoch have been loaded into model.
Out[9]:
<keras.callbacks.History at 0x7f673538af98>

Evaluate

Validate

In [10]:
learner.validate(class_names=preproc.get_classes())
                                precision    recall  f1-score   support

Diabetes_Mellitus-Experimental       0.76      0.82      0.79      3113
      Diabetes_Mellitus-Type_1       0.84      0.81      0.82      5943
      Diabetes_Mellitus-Type_2       0.85      0.84      0.85      5930

                      accuracy                           0.83     14986
                     macro avg       0.82      0.82      0.82     14986
                  weighted avg       0.83      0.83      0.83     14986

Out[10]:
array([[2553,  362,  198],
       [ 440, 4815,  688],
       [ 359,  572, 4999]])

Create a Predictor Object

In [11]:
p = ktrain.get_predictor(learner.model, preproc)

Transductive Inference: Making Predictions for Unlabeled Nodes in Original Training Graph

In transductive inference, we make predictions for unlabeled nodes whose features are visible during training. Making predictions on validation nodes in the training graph is transductive inference.

Let's see how well our prediction is for the first validation example.

In [12]:
p.predict_transductive(val_data.ids[0:1], return_proba=True)
Out[12]:
array([[0.04122107, 0.9422023 , 0.0165766 ]], dtype=float32)
In [13]:
val_data[0][1][0]
Out[13]:
array([0., 1., 0.])

Let's make predictions for all validation nodes and visually compare some of them with ground truth.

In [14]:
y_pred = p.predict_transductive(val_data.ids, return_proba=False)
In [15]:
y_true = preproc.df[preproc.df.index.isin(val_data.ids)]['target'].values
In [16]:
import pandas as pd
pd.DataFrame(zip(y_true, y_pred), columns=['Ground Truth', 'Predicted']).head()
Out[16]:
Ground Truth Predicted
0 Diabetes_Mellitus-Type_1 Diabetes_Mellitus-Type_1
1 Diabetes_Mellitus-Type_2 Diabetes_Mellitus-Type_1
2 Diabetes_Mellitus-Type_1 Diabetes_Mellitus-Type_1
3 Diabetes_Mellitus-Type_1 Diabetes_Mellitus-Type_1
4 Diabetes_Mellitus-Experimental Diabetes_Mellitus-Experimental

Inductive Inference: Making Predictions for New Nodes Not in the Original Training Graph

In inductive inference, we make predictions for entirely new nodes that were not present in the traning graph. The features or attributes of these nodes were not visible during training. We consider a graph where the heldout nodes are added back into the training graph, which yields the original graph of 19,717 nodes. This graph, G_complete was returned as the last return value of graph_nodes_from_csv.

In [17]:
y_pred = p.predict_inductive(df_holdout, G_complete, return_proba=False)
In [18]:
y_true = df_holdout['target'].values
In [19]:
import numpy as np
(y_true == np.array(y_pred)).mean()
Out[19]:
0.8303322343393356

With an 83.03% accuracy, we see that inductive performance is quite good and comparable to transductive performance.

In [ ]: