#!/usr/bin/env python # coding: utf-8 # # CS 20 : TensorFlow for Deep Learning Research # ## Lecture 12 : Seq2Seq with Attention # Simple example for Seq2Seq (Machine Translation) with Attention by Encoder Bi-directional RNN and Decoder RNN. # # ### Seq2Seq with Attention by Encoder Bi-directional RNN and Decoder RNN # - Creating the **data pipeline** with `tf.data` # - Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)` # - Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character) # - Training **many to many classification** with `tf.contrib.seq2seq.sequence_loss` # - Masking unvalid token with `tf.sequence_mask` # - Using **attention mechanism** by `tf.contrib.seq2seq.LuongAttention`, `tf.contrib.seq2seq.AttentionWrapper` # - Using `tf.contrib.seq2seq.dynamic_decode` # - Training with `tf.contrib.seq2seq.TrainingHelper` # - Translating with `tf.contrib.seq2seq.GreedyEmbeddingHelper` # - Creating the model as **Class** # - Reference # - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/03%20-%20Seq2Seq.py # - https://github.com/HiJiGOO/tf_nmt_tutorial # - https://github.com/hccho2/RNN-Tutorial # - https://www.tensorflow.org/tutorials/seq2seq # - https://github.com/j-min/tf_tutorial_plus/tree/master/RNN_seq2seq/contrib_seq2seq # - https://gist.github.com/ilblackdragon/c92066d9d38b236a21d5a7b729a10f12 # ### Setup # In[1]: import os, sys import numpy as np import pandas as pd import matplotlib.pyplot as plt import tensorflow as tf import string from pprint import pprint get_ipython().run_line_magic('matplotlib', 'inline') s2s = tf.contrib.seq2seq print(tf.__version__) # ### Prepare example data # In[2]: sources = [['I', 'feel', 'hungry'], ['tensorflow', 'is', 'very', 'difficult'], ['tensorflow', 'is', 'a', 'framework', 'for', 'deep', 'learning'], ['tensorflow', 'is', 'very', 'fast', 'changing']] targets = [['나는', '배가', '고프다'], ['텐서플로우는', '매우', '어렵다'], ['텐서플로우는', '딥러닝을', '위한', '프레임워크이다'], ['텐서플로우는', '매우', '빠르게', '변화한다']] # In[3]: # word dic for sentences source_words = [] for elm in sources: source_words += elm source_words = list(set(source_words)) source_words.sort() source_words = [''] + source_words source_dic = {word : idx for idx, word in enumerate(source_words)} print(source_dic) print(len(source_dic)) # In[4]: source_idx_dic = {elm[1] : elm[0] for elm in source_dic.items()} source_idx_dic # In[5]: # word dic for translations target_words = [] for elm in targets: target_words += elm target_words = list(set(target_words)) target_words.sort() target_words = ['']+ [''] + [''] + \ target_words # 번역문의 시작과 끝을 알리는 'start', 'end' token 추가 target_dic = {word : idx for idx, word in enumerate(target_words)} print(target_dic) print(len(target_dic)) # In[6]: target_idx_dic = {elm[1] : elm[0] for elm in target_dic.items()} target_idx_dic # ### Create pad_seq function for sentences # In[7]: def pad_seq_enc(sequences, max_len, dic): seq_len = [] seq_indices = [] for seq in sequences: seq_len.append(len(seq)) seq_idx = [dic.get(word) for word in seq] seq_idx += (max_len - len(seq_idx)) * [dic.get('')] seq_indices.append(seq_idx) return seq_len, seq_indices # In[8]: def pad_seq_dec(sequences, max_len, dic): seq_input_len = [] seq_input_indices = [] seq_target_indices = [] # for decoder input for seq in sequences: seq_input_idx = [dic.get('')] + [dic.get(token) for token in seq] seq_input_len.append(len(seq_input_idx)) seq_input_idx += (max_len - len(seq_input_idx)) * [dic.get('')] seq_input_indices.append(seq_input_idx) # for decoder output for seq in sequences: seq_target_idx = [dic.get(token) for token in seq] + [dic.get('')] seq_target_idx += (max_len - len(seq_target_idx)) * [dic.get('')] seq_target_indices.append(seq_target_idx) return seq_input_len, seq_input_indices, seq_target_indices # ### Pre-process example data # In[9]: # for encoder source_max_len = 10 X_length, X_indices = pad_seq_enc(sequences = sources, max_len = source_max_len, dic = source_dic) print(X_length, np.shape(X_indices)) # In[10]: # for decoder target_max_len = 12 y_length, y_input_indices, y_target_indices = pad_seq_dec(sequences = targets, max_len = target_max_len, dic = target_dic) pprint(y_length) pprint(y_input_indices) pprint(y_target_indices) # ### Define SimpleNMT # Encoder RNN, Decoder RNN, Attention # In[11]: class SimpleNMT: def __init__(self, s_len, s_indices, t_len, t_input_indices, t_output_indices, t_max_len = target_max_len, s_dic = source_dic, t_dic = target_dic, n_of_classes = len(target_dic), enc_hdim = 8, dec_hdim = 4): with tf.variable_scope('input_layer'): # s : source, t : target self._s_len = s_len self._s_indices = s_indices self._t_len = t_len self._t_input_indices = t_input_indices self._t_output_indices = t_output_indices self._s_dic = s_dic self._t_dic = t_dic self._t_max_len = target_max_len s_embeddings = tf.eye(num_rows = len(self._s_dic), dtype = tf.float32) s_embeddings = tf.get_variable(name = 's_embeddings', initializer = s_embeddings, trainable = False) s_batch = tf.nn.embedding_lookup(params = s_embeddings, ids = self._s_indices) with tf.variable_scope('encoder'): enc_fw_cell = tf.contrib.rnn.BasicRNNCell(num_units = enc_hdim, activation = tf.nn.tanh) enc_bw_cell = tf.contrib.rnn.BasicRNNCell(num_units = enc_hdim, activation = tf.nn.tanh) enc_outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw = enc_fw_cell, cell_bw = enc_bw_cell, inputs = s_batch, sequence_length = self._s_len, dtype = tf.float32) enc_outputs = tf.concat(values = [enc_outputs[0],enc_outputs[1]], axis = 2) with tf.variable_scope('pipe'): t_embeddings = tf.eye(num_rows = len(self._t_dic)) t_embeddings = tf.get_variable(name = 'embeddings', initializer = t_embeddings, trainable = False) t_batch = tf.nn.embedding_lookup(params = t_embeddings, ids = self._t_input_indices) batch_size = tf.reduce_sum(tf.ones_like(tensor = self._s_len)) tr_tokens = tf.tile(input = [self._t_max_len], multiples = [batch_size]) trans_tokens = tf.tile(input = [self._t_dic.get('')], multiples = [batch_size]) with tf.variable_scope('decoder'): dec_cell = tf.contrib.rnn.BasicRNNCell(num_units = dec_hdim, activation = tf.nn.tanh) # Applying attention-mechanism attn = s2s.LuongAttention(num_units = dec_hdim, memory = enc_outputs, memory_sequence_length = self._s_len, dtype = tf.float32) attn_cell = s2s.AttentionWrapper(cell = dec_cell, attention_mechanism = attn) dec_initial_state = attn_cell.zero_state(batch_size = batch_size, dtype = tf.float32) output_layer = tf.layers.Dense(units = n_of_classes, kernel_initializer = \ tf.contrib.layers.xavier_initializer(uniform = False)) with tf.variable_scope('training'): tr_helper = s2s.TrainingHelper(inputs = t_batch, sequence_length = tr_tokens) tr_decoder = s2s.BasicDecoder(cell = attn_cell, helper = tr_helper, initial_state = dec_initial_state, output_layer = output_layer) self._tr_outputs,_,_ = s2s.dynamic_decode(decoder = tr_decoder, impute_finished = True, maximum_iterations = self._t_max_len) with tf.variable_scope('translation'): trans_helper = s2s.GreedyEmbeddingHelper(embedding = t_embeddings, start_tokens = trans_tokens, end_token = self._t_dic.get('')) trans_decoder = s2s.BasicDecoder(cell = attn_cell, helper = trans_helper, initial_state = dec_initial_state, output_layer = output_layer) self._trans_outputs, _, _ = s2s.dynamic_decode(decoder = trans_decoder, impute_finished = True, maximum_iterations = self._t_max_len * 2) with tf.variable_scope('seq2seq_loss'): masking = tf.sequence_mask(lengths = self._t_len, maxlen = self._t_max_len, dtype = tf.float32) self.__seq2seq_loss = s2s.sequence_loss(logits = self._tr_outputs.rnn_output, targets = self._t_output_indices, weights = masking) def translate(self, sess, s_len, s_indices): feed_translation = {self._s_len : s_len, self._s_indices : s_indices} return sess.run(self._trans_outputs.sample_id, feed_dict = feed_translation) @property def loss(self): return self.__seq2seq_loss # ### Create a model of SimpleNMT # In[12]: # hyper-parameter# lr = .003 epochs = 500 batch_size = 2 total_step = int(np.shape(X_indices)[0] / batch_size) print(total_step) # In[13]: ## create data pipeline with tf.data tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y_length, y_input_indices, y_target_indices)) tr_dataset = tr_dataset.shuffle(buffer_size = 20) tr_dataset = tr_dataset.batch(batch_size = batch_size) tr_iterator = tr_dataset.make_initializable_iterator() print(tr_dataset) # In[14]: X_length_mb, X_indices_mb, y_length_mb, y_input_indices_mb, y_target_indices_mb = tr_iterator.get_next() # In[15]: sim_nmt = SimpleNMT(s_len = X_length_mb, s_indices = X_indices_mb, t_len = y_length_mb, t_input_indices = y_input_indices_mb, t_output_indices = y_target_indices_mb) # ### Creat training op and train model # In[16]: ## create training op opt = tf.train.AdamOptimizer(learning_rate = lr) training_op = opt.minimize(loss = sim_nmt.loss) # In[17]: sess = tf.Session() sess.run(tf.global_variables_initializer()) tr_loss_hist = [] for epoch in range(epochs): avg_tr_loss = 0 tr_step = 0 sess.run(tr_iterator.initializer) try: while True: _, tr_loss = sess.run(fetches = [training_op, sim_nmt.loss]) avg_tr_loss += tr_loss tr_step += 1 except tf.errors.OutOfRangeError: pass avg_tr_loss /= tr_step tr_loss_hist.append(avg_tr_loss) if (epoch + 1) % 100 == 0: print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss)) # In[18]: yhat = sim_nmt.translate(sess = sess, s_len = X_length, s_indices = X_indices) yhat # In[19]: # 원래 문장 originals = list(map(lambda elm : [target_idx_dic.get(idx) for idx in elm], y_target_indices)) for original in originals: print(original) # In[20]: # 한글 넣은 번역문장 translations = list(map(lambda elm : [target_idx_dic.get(idx) for idx in elm], yhat)) for translation in translations: print(translation)