%reload_ext autoreload
%autoreload 2
%matplotlib inline
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import ktrain
from ktrain import graph as gr
Using TensorFlow backend.
using Keras version: 2.2.4
In this notebook, we will use ktrain to perform node classificaiton on the PubMed Diabetes citation graph. In the PubMed graph, each node represents a paper pertaining to one of three topics: Diabetes Mellitus - Experimental, Diabetes Mellitus - Type 1, and Diabetes Mellitus - Type 2. Links represent citations between papers. The attributes or features assigned to each node are in the form of a vector of words in each paper and their corresponding TF-IDF scores. The dataset is available here.
ktrain expects two files for node classification problems. The first is comma or tab delimited file listing the edges in the graph, where each row contains the node IDs forming the edge. The second is a comma or tab delimted file listing the features or attributes associated with each node in the graph. The first column in this file is the User ID and the last column should be string representing the target or label of the node. All other nodes should be numerical features assumed to be standardized appropriately and non-null.
We must prepare the raw data to conform to the above before we begin.
The code below will create two files that can be processed directly by ktrain:
/tmp/pubmed-nodes.tab
/tmp/pubmed-edges.tab
# set this to the location of the downloaded Pubmed-Diabetes data
DATADIR = 'data/pubmed/Pubmed-Diabetes/data'
import os.path
import pandas as pd
import itertools
# process links
edgelist = pd.read_csv(os.path.join(DATADIR, 'Pubmed-Diabetes.DIRECTED.cites.tab'),
skiprows=2, header=None,delimiter='\t')
edgelist.drop(columns=[0,2], inplace=True)
edgelist.columns = ['source', 'target']
edgelist['source'] = edgelist['source'].map(lambda x: x.lstrip('paper:'))
edgelist['target'] = edgelist['target'].map(lambda x: x.lstrip('paper:'))
edgelist.head()
edgelist.to_csv('/tmp/pubmed-edges.tab', sep='\t', header=None, index=False )
# process nodes and their attributes
nodes_as_dict = []
with open(os.path.join(os.path.expanduser(DATADIR), "Pubmed-Diabetes.NODE.paper.tab")) as fp:
for line in itertools.islice(fp, 2, None):
line_res = line.split("\t")
pid = line_res[0]
feat_name = ['pid'] + [l.split("=")[0] for l in line_res[1:]][:-1] # delete summary
feat_value = [l.split("=")[1] for l in line_res[1:]][:-1] # delete summary
feat_value = [pid] + [ float(x) for x in feat_value ] # change to numeric from str
row = dict(zip(feat_name, feat_value))
nodes_as_dict.append(row)
colnames = set()
for row in nodes_as_dict:
colnames.update(list(row.keys()))
colnames = list(colnames)
colnames.sort()
colnames.remove('label')
colnames.append('label')
target_dict = {1:'Diabetes_Mellitus-Experimental', 2: 'Diabetes_Mellitus-Type_1', 3:'Diabetes_Mellitus-Type_2', }
with open('/tmp/pubmed-nodes.tab', 'w') as fp:
#fp.write("\t".join(colnames)+'\n')
for row in nodes_as_dict:
feats = []
for col in colnames:
feats.append(row.get(col, 0.0))
feats = [str(feat) for feat in feats]
feats[-1] = round(float(feats[-1]))
feats[-1] = target_dict[feats[-1]]
fp.write("\t".join(feats) + '\n')
We will hold out 20% of the nodes as test nodes by setting holdout_pct=0.2
. Since we specified holdout_for_inductive=True
, these heldout nodes are removed from the graph in order to later simulate making predicitions on new nodes added to the graph later (or inductive inference). If holdout_for_inductive=False
, the features (not labels) of these nodes are accessible to the model during training. Of the remaining nodes, 5% will be used for training and the remaining nodes will be used for validation (or transductive inference). More information on transductive and inductive inference and the return values df_holdout
and df_complete
are provided below.
Note that if there are any unlabeled nodes in the graph, these will be automatically used as heldout nodes for which predictions can be made once the model is trained. See the twitter example notebook for an example of this.
(train_data, val_data, preproc,
df_holdout, G_complete) = gr.graph_nodes_from_csv('/tmp/pubmed-nodes.tab',
'/tmp/pubmed-edges.tab',
sample_size=10, holdout_pct=0.2, holdout_for_inductive=True,
train_pct=0.05, sep='\t')
Largest subgraph statistics: 19717 nodes, 44327 edges Size of training graph: 15774 nodes Training nodes: 788 Validation nodes: 14986 Nodes treated as unlabeled for testing/inference: 3943 Size of graph with added holdout nodes: 19717 Holdout node features are not visible during training (inductive_inference)
The preproc
object includes a reference to the training graph and a dataframe showing the features and target for each node in the graph (both training and validation nodes).
preproc.df.target.value_counts()
Diabetes_Mellitus-Type_1 6255 Diabetes_Mellitus-Type_2 6242 Diabetes_Mellitus-Experimental 3277 Name: target, dtype: int64
gr.print_node_classifiers()
graphsage: GraphSAGE: https://arxiv.org/pdf/1706.02216.pdf
learner = ktrain.get_learner(model=gr.graph_node_classifier('graphsage', train_data),
train_data=train_data,
val_data=val_data,
batch_size=64)
Is Multi-Label? False done
Given the small number of batches per epoch, a larger number of epochs is required to estimate the learning rate. We will cap it at 100 here.
learner.lr_find(max_epochs=100)
simulating training for different learning rates... this may take a few moments... Epoch 1/100 12/12 [==============================] - 1s 85ms/step - loss: 1.1021 - acc: 0.3638 Epoch 2/100 12/12 [==============================] - 0s 30ms/step - loss: 1.0971 - acc: 0.3743 Epoch 3/100 12/12 [==============================] - 0s 34ms/step - loss: 1.1027 - acc: 0.3324 Epoch 4/100 12/12 [==============================] - 0s 30ms/step - loss: 1.1024 - acc: 0.3402 Epoch 5/100 12/12 [==============================] - 0s 33ms/step - loss: 1.1034 - acc: 0.3294 Epoch 6/100 12/12 [==============================] - 0s 32ms/step - loss: 1.0960 - acc: 0.3763 Epoch 7/100 12/12 [==============================] - 0s 33ms/step - loss: 1.0965 - acc: 0.3534 Epoch 8/100 12/12 [==============================] - 0s 33ms/step - loss: 1.1091 - acc: 0.3430 Epoch 9/100 12/12 [==============================] - 0s 33ms/step - loss: 1.1116 - acc: 0.3320 Epoch 10/100 12/12 [==============================] - 0s 31ms/step - loss: 1.0876 - acc: 0.3856 Epoch 11/100 12/12 [==============================] - 0s 31ms/step - loss: 1.0977 - acc: 0.3698 Epoch 12/100 12/12 [==============================] - 0s 30ms/step - loss: 1.1057 - acc: 0.3329 Epoch 13/100 12/12 [==============================] - 0s 30ms/step - loss: 1.1049 - acc: 0.3359 Epoch 14/100 12/12 [==============================] - 0s 32ms/step - loss: 1.1036 - acc: 0.3468 Epoch 15/100 12/12 [==============================] - 0s 34ms/step - loss: 1.0954 - acc: 0.3613 Epoch 16/100 12/12 [==============================] - 0s 32ms/step - loss: 1.1098 - acc: 0.3744 Epoch 17/100 12/12 [==============================] - 0s 33ms/step - loss: 1.0955 - acc: 0.3547 Epoch 18/100 12/12 [==============================] - 0s 33ms/step - loss: 1.1050 - acc: 0.3443 Epoch 19/100 12/12 [==============================] - 0s 30ms/step - loss: 1.0893 - acc: 0.3744 Epoch 20/100 12/12 [==============================] - 0s 40ms/step - loss: 1.0921 - acc: 0.3665 Epoch 21/100 12/12 [==============================] - 0s 34ms/step - loss: 1.0996 - acc: 0.3600 Epoch 22/100 12/12 [==============================] - 0s 29ms/step - loss: 1.1018 - acc: 0.3390 Epoch 23/100 12/12 [==============================] - 0s 33ms/step - loss: 1.0971 - acc: 0.3555 Epoch 24/100 12/12 [==============================] - 0s 32ms/step - loss: 1.0936 - acc: 0.3613 Epoch 25/100 12/12 [==============================] - 0s 32ms/step - loss: 1.0910 - acc: 0.3795 Epoch 26/100 12/12 [==============================] - 0s 26ms/step - loss: 1.0853 - acc: 0.3953 Epoch 27/100 12/12 [==============================] - 0s 33ms/step - loss: 1.0881 - acc: 0.3613 Epoch 28/100 12/12 [==============================] - 0s 30ms/step - loss: 1.0881 - acc: 0.3756 Epoch 29/100 12/12 [==============================] - 0s 38ms/step - loss: 1.0866 - acc: 0.3808 Epoch 30/100 12/12 [==============================] - 0s 30ms/step - loss: 1.0757 - acc: 0.4281 Epoch 31/100 12/12 [==============================] - 0s 31ms/step - loss: 1.0709 - acc: 0.4267 Epoch 32/100 12/12 [==============================] - 0s 31ms/step - loss: 1.0838 - acc: 0.3919 Epoch 33/100 12/12 [==============================] - 0s 32ms/step - loss: 1.0801 - acc: 0.3847 Epoch 34/100 12/12 [==============================] - 0s 32ms/step - loss: 1.0669 - acc: 0.4266 Epoch 35/100 12/12 [==============================] - 0s 32ms/step - loss: 1.0600 - acc: 0.4369 Epoch 36/100 12/12 [==============================] - 0s 34ms/step - loss: 1.0477 - acc: 0.4766 Epoch 37/100 12/12 [==============================] - 0s 41ms/step - loss: 1.0432 - acc: 0.4712 Epoch 38/100 12/12 [==============================] - 0s 33ms/step - loss: 1.0141 - acc: 0.5523 Epoch 39/100 12/12 [==============================] - 0s 28ms/step - loss: 1.0296 - acc: 0.4856 Epoch 40/100 12/12 [==============================] - 0s 31ms/step - loss: 1.0022 - acc: 0.5169 Epoch 41/100 12/12 [==============================] - 0s 37ms/step - loss: 0.9882 - acc: 0.5443 Epoch 42/100 12/12 [==============================] - 0s 32ms/step - loss: 0.9765 - acc: 0.5565 Epoch 43/100 12/12 [==============================] - 0s 36ms/step - loss: 0.9715 - acc: 0.5536 Epoch 44/100 12/12 [==============================] - 0s 39ms/step - loss: 0.9235 - acc: 0.6191 Epoch 45/100 12/12 [==============================] - 0s 32ms/step - loss: 0.9103 - acc: 0.6361 Epoch 46/100 12/12 [==============================] - 0s 40ms/step - loss: 0.8570 - acc: 0.6875 Epoch 47/100 12/12 [==============================] - 0s 33ms/step - loss: 0.8375 - acc: 0.7134 Epoch 48/100 12/12 [==============================] - 0s 30ms/step - loss: 0.7789 - acc: 0.7579 Epoch 49/100 12/12 [==============================] - 0s 33ms/step - loss: 0.7346 - acc: 0.7943 Epoch 50/100 12/12 [==============================] - 0s 34ms/step - loss: 0.6784 - acc: 0.8115 Epoch 51/100 12/12 [==============================] - 0s 41ms/step - loss: 0.6411 - acc: 0.8181 Epoch 52/100 12/12 [==============================] - 0s 30ms/step - loss: 0.5853 - acc: 0.8338 Epoch 53/100 12/12 [==============================] - 0s 29ms/step - loss: 0.5497 - acc: 0.8469 Epoch 54/100 12/12 [==============================] - 0s 33ms/step - loss: 0.4958 - acc: 0.8587 Epoch 55/100 12/12 [==============================] - 0s 34ms/step - loss: 0.4363 - acc: 0.8861 Epoch 56/100 12/12 [==============================] - 0s 36ms/step - loss: 0.3972 - acc: 0.8927 Epoch 57/100 12/12 [==============================] - 0s 29ms/step - loss: 0.3773 - acc: 0.8737 Epoch 58/100 12/12 [==============================] - 0s 34ms/step - loss: 0.3735 - acc: 0.8652 Epoch 59/100 12/12 [==============================] - 0s 39ms/step - loss: 0.3351 - acc: 0.8974 Epoch 60/100 12/12 [==============================] - 0s 32ms/step - loss: 0.3001 - acc: 0.9097 Epoch 61/100 12/12 [==============================] - 0s 35ms/step - loss: 0.2728 - acc: 0.9215 Epoch 62/100 12/12 [==============================] - 0s 31ms/step - loss: 0.2761 - acc: 0.9128: 0s - loss: 0.2342 - acc: 0. Epoch 63/100 12/12 [==============================] - 0s 39ms/step - loss: 0.2826 - acc: 0.9071 Epoch 64/100 12/12 [==============================] - 0s 32ms/step - loss: 0.1876 - acc: 0.9372 Epoch 65/100 12/12 [==============================] - 0s 28ms/step - loss: 0.2418 - acc: 0.9163 Epoch 66/100 12/12 [==============================] - 0s 32ms/step - loss: 0.2193 - acc: 0.9254 Epoch 67/100 12/12 [==============================] - 0s 34ms/step - loss: 0.2385 - acc: 0.9175 Epoch 68/100 12/12 [==============================] - 0s 30ms/step - loss: 0.2542 - acc: 0.9045 Epoch 69/100 12/12 [==============================] - 0s 35ms/step - loss: 0.2287 - acc: 0.9071 Epoch 70/100 12/12 [==============================] - 0s 32ms/step - loss: 0.2167 - acc: 0.9245 Epoch 71/100 12/12 [==============================] - 0s 33ms/step - loss: 0.1879 - acc: 0.9342 Epoch 72/100 12/12 [==============================] - 0s 32ms/step - loss: 0.2314 - acc: 0.9128 Epoch 73/100 12/12 [==============================] - 0s 40ms/step - loss: 0.2224 - acc: 0.9319 Epoch 74/100 12/12 [==============================] - 0s 33ms/step - loss: 0.2177 - acc: 0.9267 Epoch 75/100 12/12 [==============================] - 0s 32ms/step - loss: 0.2174 - acc: 0.9241 Epoch 76/100 12/12 [==============================] - 0s 32ms/step - loss: 0.2491 - acc: 0.9149 Epoch 77/100 12/12 [==============================] - 0s 32ms/step - loss: 0.2436 - acc: 0.9136 Epoch 78/100 12/12 [==============================] - 0s 28ms/step - loss: 0.1898 - acc: 0.9293 Epoch 79/100 12/12 [==============================] - 0s 31ms/step - loss: 0.1675 - acc: 0.9307 Epoch 80/100 12/12 [==============================] - 0s 33ms/step - loss: 0.1949 - acc: 0.9294 Epoch 81/100 12/12 [==============================] - 0s 32ms/step - loss: 0.1895 - acc: 0.9385 Epoch 82/100 12/12 [==============================] - 0s 33ms/step - loss: 0.2656 - acc: 0.9084 Epoch 83/100 12/12 [==============================] - 0s 33ms/step - loss: 0.2249 - acc: 0.9232 Epoch 84/100 12/12 [==============================] - 0s 34ms/step - loss: 0.2740 - acc: 0.8953 Epoch 85/100 12/12 [==============================] - 0s 36ms/step - loss: 0.2511 - acc: 0.9019 Epoch 86/100 12/12 [==============================] - 0s 32ms/step - loss: 0.2529 - acc: 0.9084 Epoch 87/100 12/12 [==============================] - 0s 33ms/step - loss: 0.2575 - acc: 0.8992 Epoch 88/100 12/12 [==============================] - 0s 32ms/step - loss: 0.2617 - acc: 0.9097 Epoch 89/100 12/12 [==============================] - 0s 32ms/step - loss: 0.3072 - acc: 0.9097 Epoch 90/100 12/12 [==============================] - 0s 32ms/step - loss: 0.4338 - acc: 0.8692 Epoch 91/100 12/12 [==============================] - 0s 27ms/step - loss: 0.5869 - acc: 0.8324 Epoch 92/100 12/12 [==============================] - 0s 35ms/step - loss: 0.8302 - acc: 0.8206 Epoch 93/100 12/12 [==============================] - 0s 34ms/step - loss: 1.0682 - acc: 0.7723 Epoch 94/100 12/12 [==============================] - 0s 30ms/step - loss: 0.8207 - acc: 0.8390 Epoch 95/100 12/12 [==============================] - 0s 35ms/step - loss: 0.6389 - acc: 0.8260 Epoch 96/100 12/12 [==============================] - 0s 32ms/step - loss: 0.8993 - acc: 0.8521 Epoch 97/100 8/12 [===================>..........] - ETA: 0s - loss: 1.2450 - acc: 0.8379 done. Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.
learner.lr_plot()
We will train the model using autofit
, which uses a triangular learning rate policy. The training will automatically stop when the validation loss no longer improves.
learner.autofit(0.01)
early_stopping automatically enabled at patience=5 reduce_on_plateau automatically enabled at patience=2 begin training using triangular learning rate policy with max lr of 0.01... Epoch 1/1024 13/13 [==============================] - 6s 484ms/step - loss: 1.0057 - acc: 0.4807 - val_loss: 0.8324 - val_acc: 0.6990 Epoch 2/1024 13/13 [==============================] - 6s 425ms/step - loss: 0.8001 - acc: 0.7077 - val_loss: 0.6512 - val_acc: 0.7795 Epoch 3/1024 13/13 [==============================] - 6s 438ms/step - loss: 0.6322 - acc: 0.8045 - val_loss: 0.5574 - val_acc: 0.7875 Epoch 4/1024 13/13 [==============================] - 6s 430ms/step - loss: 0.5251 - acc: 0.8237 - val_loss: 0.5077 - val_acc: 0.8106 Epoch 5/1024 13/13 [==============================] - 6s 476ms/step - loss: 0.4407 - acc: 0.8600 - val_loss: 0.5061 - val_acc: 0.8086 Epoch 6/1024 13/13 [==============================] - 6s 454ms/step - loss: 0.3857 - acc: 0.8697 - val_loss: 0.5033 - val_acc: 0.8046 Epoch 7/1024 13/13 [==============================] - 6s 453ms/step - loss: 0.3682 - acc: 0.8528 - val_loss: 0.4966 - val_acc: 0.8058 Epoch 8/1024 13/13 [==============================] - 6s 462ms/step - loss: 0.3110 - acc: 0.8938 - val_loss: 0.4791 - val_acc: 0.8254 Epoch 9/1024 13/13 [==============================] - 6s 444ms/step - loss: 0.2822 - acc: 0.9035 - val_loss: 0.4873 - val_acc: 0.8160 Epoch 10/1024 13/13 [==============================] - 6s 443ms/step - loss: 0.2734 - acc: 0.9035 - val_loss: 0.4955 - val_acc: 0.8101 Epoch 00010: Reducing Max LR on Plateau: new max lr will be 0.005 (if not early_stopping). Epoch 11/1024 13/13 [==============================] - 6s 435ms/step - loss: 0.2361 - acc: 0.9264 - val_loss: 0.4898 - val_acc: 0.8214 Epoch 12/1024 13/13 [==============================] - 6s 498ms/step - loss: 0.2292 - acc: 0.9155 - val_loss: 0.5074 - val_acc: 0.8174 Epoch 00012: Reducing Max LR on Plateau: new max lr will be 0.0025 (if not early_stopping). Epoch 13/1024 13/13 [==============================] - 6s 442ms/step - loss: 0.1969 - acc: 0.9421 - val_loss: 0.5203 - val_acc: 0.8132 Restoring model weights from the end of the best epoch Epoch 00013: early stopping Weights from best epoch have been loaded into model.
<keras.callbacks.History at 0x7f673538af98>
learner.validate(class_names=preproc.get_classes())
precision recall f1-score support Diabetes_Mellitus-Experimental 0.76 0.82 0.79 3113 Diabetes_Mellitus-Type_1 0.84 0.81 0.82 5943 Diabetes_Mellitus-Type_2 0.85 0.84 0.85 5930 accuracy 0.83 14986 macro avg 0.82 0.82 0.82 14986 weighted avg 0.83 0.83 0.83 14986
array([[2553, 362, 198], [ 440, 4815, 688], [ 359, 572, 4999]])
p = ktrain.get_predictor(learner.model, preproc)
In transductive inference, we make predictions for unlabeled nodes whose features are visible during training. Making predictions on validation nodes in the training graph is transductive inference.
Let's see how well our prediction is for the first validation example.
p.predict_transductive(val_data.ids[0:1], return_proba=True)
array([[0.04122107, 0.9422023 , 0.0165766 ]], dtype=float32)
val_data[0][1][0]
array([0., 1., 0.])
Let's make predictions for all validation nodes and visually compare some of them with ground truth.
y_pred = p.predict_transductive(val_data.ids, return_proba=False)
y_true = preproc.df[preproc.df.index.isin(val_data.ids)]['target'].values
import pandas as pd
pd.DataFrame(zip(y_true, y_pred), columns=['Ground Truth', 'Predicted']).head()
Ground Truth | Predicted | |
---|---|---|
0 | Diabetes_Mellitus-Type_1 | Diabetes_Mellitus-Type_1 |
1 | Diabetes_Mellitus-Type_2 | Diabetes_Mellitus-Type_1 |
2 | Diabetes_Mellitus-Type_1 | Diabetes_Mellitus-Type_1 |
3 | Diabetes_Mellitus-Type_1 | Diabetes_Mellitus-Type_1 |
4 | Diabetes_Mellitus-Experimental | Diabetes_Mellitus-Experimental |
In inductive inference, we make predictions for entirely new nodes that were not present in the traning graph. The features or attributes of these nodes were not visible during training. We consider a graph where the heldout nodes are added back into the training graph, which yields the original graph of 19,717 nodes. This graph, G_complete
was returned as the last return value of graph_nodes_from_csv
.
y_pred = p.predict_inductive(df_holdout, G_complete, return_proba=False)
y_true = df_holdout['target'].values
import numpy as np
(y_true == np.array(y_pred)).mean()
0.8303322343393356
With an 83.03% accuracy, we see that inductive performance is quite good and comparable to transductive performance.