from google.colab import drive
drive.mount('/content/gdrive')
import os
os.chdir('/content/gdrive/My Drive/finch/tensorflow2/knowledge_graph_completion/wn18/main')
%tensorflow_version 2.x
!pip install tensorflow-addons
from tensorflow_addons.optimizers.cyclical_learning_rate import Triangular2CyclicalLearningRate
import tensorflow as tf
import pprint
import logging
import time
print("TensorFlow Version", tf.__version__)
print('GPU Enabled:', tf.test.is_gpu_available())
def get_vocab(f_path):
word2idx = {}
with open(f_path) as f:
for i, line in enumerate(f):
line = line.rstrip()
word2idx[line] = i
return word2idx
"""
we use 1vN fast evaluation as purposed in ConvE paper:
"https://arxiv.org/abs/1707.01476"
sp2o is a dictionary that maps a pair of <subject, predicate>
to multiple possible corresponding <objects> in graph
"""
def make_sp2o(f_paths, e2idx, r2idx):
sp2o = {}
for f_path in f_paths:
with open(f_path) as f:
for line in f:
line = line.rstrip()
s, p, o = line.split()
s, p, o = e2idx[s], r2idx[p], e2idx[o]
if (s,p) not in sp2o:
sp2o[(s,p)] = [o]
else:
if o not in sp2o[(s,p)]:
sp2o[(s,p)].append(o)
return sp2o
def map_fn(x, y):
i, v, s = y[0]
one_hot = tf.SparseTensor(i, v, s)
return x, (one_hot, y[1], y[2])
# stream data from text files
def data_generator(f_path, params, sp2o):
with open(f_path) as f:
print('Reading', f_path)
for line in f:
line = line.rstrip()
s, p, o = line.split()
s, p, o = params['e2idx'][s], params['r2idx'][p], params['e2idx'][o]
sparse_i = [[x] for x in sp2o[(s, p)]]
sparse_v = [1.] * len(sparse_i)
sparse_s = [len(params['e2idx'])]
yield ((s, p), ((sparse_i, sparse_v, sparse_s), o, len(sparse_i)))
def dataset(is_training, params, sp2o):
_shapes = (([], []), (([None, 1], [None], [1]), [], []))
_types = ((tf.int32, tf.int32),
((tf.int64, tf.float32, tf.int64), tf.int32, tf.int32))
if is_training:
ds = tf.data.Dataset.from_generator(
lambda: data_generator(params['train_path'], params, sp2o),
output_shapes = _shapes,
output_types = _types,)
ds = ds.shuffle(params['num_samples'])
ds = ds.map(map_fn)
ds = ds.batch(params['batch_size'])
else:
ds = tf.data.Dataset.from_generator(
lambda: data_generator(params['test_path'], params, sp2o),
output_shapes = _shapes,
output_types = _types,)
ds = ds.map(map_fn)
ds = ds.batch(params['batch_size'])
return ds
def update_metrics(scores, query, metrics):
to_float = lambda x: tf.cast(x, tf.float32)
_, i = tf.math.top_k(scores, sorted=True, k=scores.shape[1])
query = tf.expand_dims(query, 1)
is_query = to_float(tf.equal(i, query))
r = tf.argmax(is_query, -1) + 1
mrr = 1. / to_float(r)
hits_10 = to_float(tf.less_equal(r, 10))
hits_3 = to_float(tf.less_equal(r, 3))
hits_1 = to_float(tf.less_equal(r, 1))
metrics['mrr'].update_state(mrr)
metrics['hits_10'].update_state(hits_10)
metrics['hits_3'].update_state(hits_3)
metrics['hits_1'].update_state(hits_1)
class TuckER(tf.keras.Model):
def __init__(self, params):
super().__init__()
self.embed_ent = tf.keras.layers.Embedding(input_dim=len(params['e2idx']),
output_dim=params['ent_embed_dim'],
name='Entity',
embeddings_initializer=tf.initializers.GlorotNormal())
self.embed_rel = tf.keras.layers.Embedding(input_dim=len(params['r2idx']),
output_dim=params['rel_embed_dim'],
name='Relation',
embeddings_initializer=tf.initializers.GlorotNormal())
self.kernel = self.add_weight(name='Kernel', shape=(
params['rel_embed_dim'],
params['ent_embed_dim'],
params['ent_embed_dim'],), initializer=tf.initializers.RandomUniform(-1., 1.))
def call(self, inputs, training):
s, p = inputs
batch_sz = tf.shape(s)[0]
s = self.embed_ent(s)
p = self.embed_rel(p)
kernel = tf.matmul(p, tf.reshape(self.kernel, (params['rel_embed_dim'], -1)))
kernel = tf.reshape(kernel, (batch_sz, params['ent_embed_dim'], params['ent_embed_dim']))
x = tf.matmul(tf.expand_dims(s, 1), kernel)
x = tf.squeeze(x, 1)
x = tf.matmul(x, self.embed_ent.embeddings, transpose_b=True)
return x
def label_smoothing(inputs, epsilon):
V = inputs.get_shape().as_list()[-1]
return ((1-epsilon) * inputs) + (epsilon / V)
params = {
'train_path': '../data/wn18/train.txt',
'valid_path': '../data/wn18/valid.txt',
'test_path': '../data/wn18/test.txt',
'entity_path': '../vocab/entity.txt',
'relation_path': '../vocab/relation.txt',
'batch_size': 128,
'ent_embed_dim': 200,
'rel_embed_dim': 30,
'num_samples': 141442,
'init_lr': 1e-4,
'max_lr': 5e-3,
'num_patience': 10,
'epsilon': .1,
}
params['e2idx'] = get_vocab(params['entity_path'])
params['r2idx'] = get_vocab(params['relation_path'])
sp2o_tr = make_sp2o([params['train_path']], params['e2idx'], params['r2idx'])
sp2o_all = make_sp2o([params['train_path'],
params['test_path'],
params['valid_path']], params['e2idx'], params['r2idx'])
model = TuckER(params)
model.build(input_shape=[[None], [None]])
pprint.pprint([(v.name, v.shape) for v in model.trainable_variables])
decay_lr = Triangular2CyclicalLearningRate(
initial_learning_rate = params['init_lr'],
maximal_learning_rate = params['max_lr'],
step_size = 8 * params['num_samples'] // params['batch_size'],)
optim = tf.optimizers.Adam(params['init_lr'])
global_step = 0
best_mrr = 0.
count = 0
t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.setLevel(logging.INFO)
while True:
# TRAINING
for ((s, p), (multi_o, o, num_pos)) in dataset(is_training=True, params=params, sp2o=sp2o_tr):
with tf.GradientTape() as tape:
logits = model((s, p), training=True)
multi_o = tf.sparse.to_dense(multi_o, validate_indices=False)
num_neg = len(params['e2idx']) - num_pos
pos_weight = tf.expand_dims(tf.cast(num_neg/num_pos, tf.float32), 1)
labels = label_smoothing(multi_o, params['epsilon'])
loss = tf.nn.weighted_cross_entropy_with_logits(labels=labels, logits=logits, pos_weight=pos_weight)
loss = tf.reduce_mean(loss)
optim.lr.assign(decay_lr(global_step))
grads = tape.gradient(loss, model.trainable_variables)
optim.apply_gradients(zip(grads, model.trainable_variables))
if global_step % 50 == 0:
logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(
global_step, loss.numpy().item(), time.time()-t0, optim.lr.numpy().item()))
t0 = time.time()
global_step += 1
# EVALUATION
metrics = {
'mrr': tf.metrics.Mean(),
'hits_10': tf.metrics.Mean(),
'hits_3': tf.metrics.Mean(),
'hits_1': tf.metrics.Mean(),
}
for ((s, p), (multi_o, o, num_pos)) in dataset(is_training=False, params=params, sp2o=sp2o_all):
logits = model((s, p), training=False)
multi_o = tf.sparse.to_dense(multi_o, validate_indices=False)
# create masks for Filtered MRR
o_one_hot = tf.one_hot(o, len(params['e2idx']))
unwanted = multi_o - o_one_hot
masks = tf.cast(tf.equal(unwanted, 0.), tf.float32)
scores = tf.sigmoid(logits) * masks
update_metrics(scores=scores, query=o, metrics=metrics)
logger.info("MRR: {:.3f}| [email protected]: {:.3f} | [email protected]: {:.3f} | [email protected]: {:.3f}".format(
metrics['mrr'].result().numpy(),
metrics['hits_10'].result().numpy(),
metrics['hits_3'].result().numpy(),
metrics['hits_1'].result().numpy()))
mrr = metrics['mrr'].result().numpy()
if mrr > best_mrr:
best_mrr = mrr
# you can save model here
count = 0
else:
count += 1
logger.info("Best MRR: {:.3f}".format(best_mrr))
if count == params['num_patience']:
print(params['num_patience'], "times not improve the best result, therefore stop training")
break