In [1]:
"""
We use following lines because we are running on Google Colab
If you are running notebook on a local computer, you don't need this cell
"""
from google.colab import drive
drive.mount('/content/gdrive')
import os
os.chdir('/content/gdrive/My Drive/finch/tensorflow1/knowledge_graph_completion/wn18/main')
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
In [2]:
import tensorflow as tf
print("TensorFlow Version", tf.__version__)
print('GPU Enabled:', tf.test.is_gpu_available())
import os

from pathlib import Path
TensorFlow Version 1.13.0-rc1
GPU Enabled: True
In [0]:
def get_vocab(f_path):
  word2idx = {}
  with open(f_path) as f:
    for i, line in enumerate(f):
      line = line.rstrip()
      word2idx[line] = i
  return word2idx
In [0]:
def graph_fn(s, p, params):
  e_embed = tf.get_variable('e_embed',
                            [len(params['e2idx']), params['embed_dim']],
                            initializer=tf.initializers.truncated_normal())
  r_embed = tf.get_variable('r_embed',
                            [len(params['r2idx']), params['embed_dim']],
                            initializer=tf.initializers.truncated_normal())

  s = tf.nn.embedding_lookup(e_embed, s)
  p = tf.nn.embedding_lookup(r_embed, p)
  
  logits = tf.matmul(s*p, e_embed, transpose_b=True)
  
  bias = tf.get_variable('bias', [len(params['e2idx'])])
  logits = tf.nn.bias_add(logits, bias)
  
  return logits
In [0]:
def model_fn(features, labels, mode, params):
  s, p = features['subject'], features['predicate']
  
  logits = graph_fn(s, p, params)
  
  if mode == tf.estimator.ModeKeys.PREDICT:
    scores = tf.sigmoid(logits)
    _, indices = tf.nn.top_k(scores, sorted=True, k=params['num_top'])
    vocab_rev = tf.contrib.lookup.index_to_string_table_from_file(params['entity_path'])
    predictions = vocab_rev.lookup(tf.cast(indices, tf.int64))
    
    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions)
In [0]:
params ={
  'model_dir': '../model/distmult_1-N',
  'export_dir': '../model/distmult_1-N_export',
  'entity_path': '../vocab/entity.txt',
  'relation_path': '../vocab/relation.txt',
  'embed_dim': 300,
  'num_top': 5,
}
In [0]:
e2idx = get_vocab(params['entity_path'])
r2idx = get_vocab(params['relation_path'])
params['e2idx'] = e2idx
params['r2idx'] = r2idx
In [0]:
def serving_input_receiver_fn():
    s = tf.placeholder(tf.int32, [None], 'subject')
    p = tf.placeholder(tf.int32, [None], 'predicate')
    
    features = {'subject': s, 'predicate': p,}
    receiver_tensors = features
    
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
In [9]:
estimator = tf.estimator.Estimator(model_fn, params['model_dir'], params=params)
estimator.export_saved_model(params['export_dir'], serving_input_receiver_fn)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '../model/distmult_1-N', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7ffb221cf240>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
INFO:tensorflow:Calling model_fn.

WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:205: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ../model/distmult_1-N/model.ckpt-54146
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: ../model/distmult_1-N_export/temp-b'1550135921'/assets
INFO:tensorflow:SavedModel written to: ../model/distmult_1-N_export/temp-b'1550135921'/saved_model.pb
Out[9]:
b'../model/distmult_1-N_export/1550135921'
In [10]:
example_s = '02174461'
example_p = '_hypernym'

subdirs = [x for x in Path(params['export_dir']).iterdir()
           if x.is_dir() and 'temp' not in str(x)]
latest = str(sorted(subdirs)[-1])
  
predict_fn = tf.contrib.predictor.from_saved_model(latest)
predictions = predict_fn(
  {'subject': [params['e2idx'][example_s]],
   'predicate': [params['r2idx'][example_p]],})

print()
print('Input Entity:', example_s)
print('Input Relation', example_p)
print()
print('Answer: 02176268')
print('Top %d Prediction:' % params['num_top'], predictions['output'][0])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/predictor/saved_model_predictor.py:153: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Restoring parameters from ../model/distmult_1-N_export/1550135921/variables/variables

Input Entity: 02174461
Input Relation _hypernym

Answer: 02176268
Top 5 Prediction: [b'02176268' b'02180529' b'01831531' b'09334396' b'02186360']