#!/usr/bin/env python
# coding: utf-8

# # CS 20 : TensorFlow for Deep Learning Research
# ## Lecture 11 : Recurrent Neural Networks
# Simple example for Many to Many Classification (Simple pos tagger) by Bi-directional Long Short-Term Memory. 
# 
# ### Many to Many Classification by Bi-directional LSTM
# - Creating the **data pipeline** with `tf.data`
# - Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`
# - Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)
# - Training **many to many classification** with `tf.contrib.seq2seq.sequence_loss`
# - Masking unvalid token with `tf.sequence_mask`
# - Creating the model as **Class**
# - Reference
#     - https://github.com/aisolab/sample_code_of_Deep_learning_Basics/blob/master/DLEL/DLEL_12_2_RNN_(toy_example).ipynb

# ### Setup

# In[1]:


import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import string
get_ipython().run_line_magic('matplotlib', 'inline')

slim = tf.contrib.slim
print(tf.__version__)


# ### Prepare example data 

# In[2]:


sentences = [['I', 'feel', 'hungry'],
     ['tensorflow', 'is', 'very', 'difficult'],
     ['tensorflow', 'is', 'a', 'framework', 'for', 'deep', 'learning'],
     ['tensorflow', 'is', 'very', 'fast', 'changing']]
pos = [['pronoun', 'verb', 'adjective'],
     ['noun', 'verb', 'adverb', 'adjective'],
     ['noun', 'verb', 'determiner', 'noun', 'preposition', 'adjective', 'noun'],
     ['noun', 'verb', 'adverb', 'adjective', 'verb']]


# In[3]:


# word dic
word_list = []
for elm in sentences:
    word_list += elm
word_list = list(set(word_list))
word_list.sort()
word_list = ['<pad>'] + word_list

word_dic = {word : idx for idx, word in enumerate(word_list)}
print(word_dic)


# In[4]:


# pos dic
pos_list = []
for elm in pos:
    pos_list += elm
pos_list = list(set(pos_list))
pos_list.sort()
pos_list = ['<pad>'] + pos_list
print(pos_list)

pos_dic = {pos : idx for idx, pos in enumerate(pos_list)}
pos_dic


# In[5]:


pos_idx_to_dic = {elm[1] : elm[0] for elm in pos_dic.items()}
pos_idx_to_dic


# ### Create pad_seq function

# In[6]:


def pad_seq(sequences, max_len, dic):
    seq_len, seq_indices = [], []
    for seq in sequences:
        seq_len.append(len(seq))
        seq_idx = [dic.get(char) for char in seq]
        seq_idx += (max_len - len(seq_idx)) * [dic.get('<pad>')] # 0 is idx of meaningless token "<pad>"
        seq_indices.append(seq_idx)
    return seq_len, seq_indices


# ### Pre-process data

# In[7]:


max_length = 10
X_length, X_indices = pad_seq(sequences = sentences, max_len = max_length, dic = word_dic)
print(X_length, np.shape(X_indices))


# In[8]:


y = [elm + ['<pad>'] * (max_length - len(elm)) for elm in pos]
y = [list(map(lambda el : pos_dic.get(el), elm)) for elm in y]
print(np.shape(y))


# In[9]:


y


# ### Define SimPosBiLSTM

# In[10]:


class SimPosBiLSTM:
    def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, max_len, word_dic):
        
        # Data pipeline
        with tf.variable_scope('input_layer'):
            self._X_length = X_length
            self._X_indices = X_indices
            self._y = y
            
            one_hot = tf.eye(len(word_dic), dtype = tf.float32)
            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,
                                            trainable = False) # embedding vector training 안할 것이기 때문
            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)
    
        # Bi-directional LSTM (many to many)
        with tf.variable_scope('bi-directional_lstm'):
            lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(num_units = hidden_dim,
                                                        activation = tf.nn.tanh)
            lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(num_units = hidden_dim,
                                                        activation = tf.nn.tanh)
            outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw = lstm_fw_cell, cell_bw = lstm_bw_cell,
                                                         inputs = self._X_batch, sequence_length = self._X_length,
                                                         dtype = tf.float32)
            concatenated_outputs = tf.concat([outputs[0], outputs[1]], axis = 2)
            weights = tf.get_variable(name = 'weights', shape = (2 * hidden_dim, n_of_classes),
                                      initializer = tf.contrib.layers.xavier_initializer())
            self._score = tf.map_fn(lambda elm : tf.matmul(elm, weights), concatenated_outputs)
            
        with tf.variable_scope('seq2seq_loss'):
            masks = tf.sequence_mask(lengths = self._X_length, maxlen = max_len, dtype = tf.float32)
            self.seq2seq_loss = tf.contrib.seq2seq.sequence_loss(logits = self._score,
                                                                 targets = self._y, weights = masks)
    
        with tf.variable_scope('prediction'):
            self._prediction = tf.argmax(input = self._score,
                                         axis = 2, output_type = tf.int32)
    
    def predict(self, sess, X_length, X_indices):
        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}
        return sess.run(self._prediction, feed_dict = feed_prediction)


# ### Create a model of SimPosBiLSTM

# In[11]:


# hyper-parameter#
lr = .003
epochs = 100
batch_size = 2
total_step = int(np.shape(X_indices)[0] / batch_size)
print(total_step)


# In[12]:


## create data pipeline with tf.data
tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))
tr_dataset = tr_dataset.shuffle(buffer_size = 20)
tr_dataset = tr_dataset.batch(batch_size = batch_size)
tr_iterator = tr_dataset.make_initializable_iterator()
print(tr_dataset)


# In[13]:


X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()


# In[14]:


sim_pos_bi_lstm = SimPosBiLSTM(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,
                               n_of_classes = 8, hidden_dim = 16, max_len = max_length, word_dic = word_dic)


# ### Creat training op and train model

# In[15]:


## create training op
opt = tf.train.AdamOptimizer(learning_rate = lr)
training_op = opt.minimize(loss = sim_pos_bi_lstm.seq2seq_loss)


# In[16]:


sess = tf.Session()
sess.run(tf.global_variables_initializer())

tr_loss_hist = []

for epoch in range(epochs):
    avg_tr_loss = 0
    tr_step = 0
    
    sess.run(tr_iterator.initializer)
    try:
        while True:
            _, tr_loss = sess.run(fetches = [training_op, sim_pos_bi_lstm.seq2seq_loss])
            avg_tr_loss += tr_loss
            tr_step += 1
            
    except tf.errors.OutOfRangeError:
        pass
    
    avg_tr_loss /= tr_step
    tr_loss_hist.append(avg_tr_loss)
    if (epoch + 1) % 10 == 0:
        print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))


# In[17]:


yhat = sim_pos_bi_lstm.predict(sess = sess, X_length = X_length, X_indices = X_indices)
yhat


# In[18]:


y


# In[19]:


yhat = [list(map(lambda elm : pos_idx_to_dic.get(elm), row)) for row in yhat]
for elm in yhat:
    print(elm)