#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('load_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') import banner topics = ['Introduction to Transformers', 'Very Simple Implementation of Transformers'] banner.reset(topics) # In[2]: banner.next_topic() # # Introduction to Transformers # # **adapted from [Transformers from Scratch](https://peterbloem.nl/blog/transformers).** # * [GPT — Intuitively and Exhaustively Explained](https://towardsdatascience.com/gpt-intuitively-and-exhaustively-explained-c70c38e87491), by Daniel Warfield # * [Simplifying Transformer Blocks](https://arxiv.org/abs/2311.01906), by Bobby He and Thomas Hofmann # # I am using Jax in this example. A recent very brief summary about Jax is at [JAX: Fast as PyTorch, Simple as NumPy](https://medium.com/@hylke.donker/jax-fast-as-pytorch-simple-as-numpy-a0c14893a738). # In[3]: banner.next_topic() # # Very Simple Implementation of Transformers # In[6]: import jax.numpy as jnp import jax import numpy import pandas import matplotlib.pyplot as plt import time import re # regular expressions # jax.config.update('jax_platform_name', 'cpu') # `Fake.csv` and `True.csv` files from [Fake and real news dataset](https://www.kaggle.com/datasets/emineyetm/fake-news-detection-datasets) at Kaggle. # In[7]: df_fake = pandas.read_csv('Fake.csv', usecols=['title']) df_real = pandas.read_csv('True.csv', usecols=['title']) df_fake.shape, df_real.shape # In[8]: pandas.set_option('max_colwidth', None) # In[9]: df_fake.head(5) # In[10]: df_real.head(5) # In[11]: keep = 1000 headlines_fake = df_fake.values[:keep] labels_fake = numpy.zeros((headlines_fake.shape[0])) headlines_real = df_real.values[:keep] labels_real = numpy.ones((headlines_real.shape[0])) headlines_fake.shape, labels_fake.shape, headlines_real.shape, labels_real.shape # In[12]: headlines_orig = headlines_fake + headlines_real # In[13]: from string import punctuation punctuation + ' ' print(punctuation) def clean_up_words(titles): import re titles_words = [[word.lower() for word in re.split('\W+', title[0])] for title in titles] # words = [w.strip(punctuation) for w in words] words = [[w for w in title if len(w) > 1] for title in titles_words] return words # In[14]: re.split('\W+', headlines_fake[0][0]) # In[15]: clean_up_words(headlines_fake[0:2]) # In[16]: headlines_fake = clean_up_words(headlines_fake) headlines_real = clean_up_words(headlines_real) # In[17]: headlines = headlines_fake + headlines_real labels = numpy.hstack((labels_fake, labels_real)) len([len(h) for h in headlines]), len(labels) # In[18]: mx = max([len(h) for h in headlines]) mx # In[19]: headlines = [headline + [' '] * (mx - len(headline)) for headline in headlines] # In[20]: len(headlines[0]), headlines[0] # In[21]: words = [word for headline in headlines for word in headline] vocabulary = numpy.unique(words) len(vocabulary) # In[22]: for i in range(100): print(i, vocabulary[i], end='; ') print() for i in range(4300, 4581): print(i, vocabulary[i], end='; ') # In[23]: numpy.where(numpy.array(headlines[0]).reshape(-1, 1) == vocabulary) # In[24]: numpy.where(numpy.array(headlines[0]).reshape(-1, 1) == vocabulary)[1] # In[25]: tokens = [numpy.where(numpy.array(headline).reshape(-1, 1) == vocabulary)[1] for headline in headlines] tokens[:20] # In[26]: tokens[0] # In[27]: X_tokens = jnp.array(tokens) X_tokens # In[23]: # Create one transformer block with the following steps. # # 1. Make embedding layer # 2. Make position encoding # 3. Make weight matrices for mapping word embedding to key, query, and value. # 4. Make weight matrix for combining output of all heads # 5. Define forward pass for self-attention # 6. Make weights for dense net to apply to output of self-attention. # 7. Define forward pass through dense net. # 8. Make weights to linearly convert output of dense net to log probs for each class. # # Finally, start classification. # # 9. Convert list of tokens for each review into their respective embeddings. # 10. Pass each embedding through self-attention then dense net. # 11. Calc mean over all outputs for a review then linearly reduce to log probs for each class. # 12. Convert log probs to probs # In[31]: ###################################################################### # 1. Make embedding layer n_vocabulary_words = len(vocabulary) embed_dim = 40 embedder_W = jnp.array(numpy.random.normal(size=(n_vocabulary_words, embed_dim))) type(embedder_W) # , embedder_W.device() # In[32]: ###################################################################### # 2. Make positional encoding (from http://jalammar.github.io/illustrated-transformer/) def get_angles(pos, i, d_model): angle_rates = 1 / numpy.power(10, (2 * (i//2)) / numpy.float32(d_model)) return pos * angle_rates def make_position_encoding(position, d_model): angle_rads = get_angles(numpy.arange(position)[:, numpy.newaxis], numpy.arange(d_model)[numpy.newaxis, :], d_model) # apply sin to even indices in the array; 2i angle_rads[:, 0::2] = numpy.sin(angle_rads[:, 0::2]) # apply cos to odd indices in the array; 2i+1 angle_rads[:, 1::2] = numpy.cos(angle_rads[:, 1::2]) return jnp.array(angle_rads) n_tokens_per_review = X_tokens.shape[1] position_encoding = make_position_encoding(n_tokens_per_review, embed_dim) plt.imshow(position_encoding[:, :]); plt.axis('auto') # In[33]: ###################################################################### # 3. Make weight matrices for mapping word embedding to key, query, and value. # 4. Make weight matrix for combining output of all heads def make_weights(n_in, n_out): w_scale = 1 / jnp.sqrt(n_in) return jnp.array(numpy.random.uniform(-w_scale, w_scale, size=(n_in, n_out))) n_heads = 8 n_in_per_head = embed_dim W_keys = [make_weights(n_in_per_head, n_in_per_head) for h in range(n_heads)] W_queries = [make_weights(n_in_per_head, n_in_per_head) for h in range(n_heads)] W_values = [make_weights(n_in_per_head, n_in_per_head) for h in range(n_heads)] W_combine = make_weights(n_heads * n_in_per_head, embed_dim) print(f'W_key shapes {[w.shape for w in W_keys]}') print(f'W_query shapes {[w.shape for w in W_queries]}') print(f'W_value shapes {[w.shape for w in W_values]}') print(f'W_combine shape {W_combine.shape}') # In[34]: ###################################################################### # 5. Define forward pass for self-attention def softmax(Y, dim): maxY = jnp.max(Y, axis=dim, keepdims=True) eY = jnp.exp(Y - maxY) eY_sum = jnp.sum(eY, axis=dim, keepdims=True) return eY / eY_sum def forward_attention(params, X): embedder_W, W_keys, W_queries, W_values, W_combine, _, _, _ = params n_samples, n_tokens, embed_dim = X.shape # Layer Normalization pre version, as in https://arxiv.org/pdf/2002.04745.pdf X = (X - X.mean(-1, keepdims=True)) / X.std(-1, keepdims=True) keys = [X @ W_key for W_key in W_keys] queries = [X @ W_query for W_query in W_queries] values = [X @ W_value for W_value in W_values] scale = jnp.sqrt(embed_dim) QKs = [query @ jnp.swapaxes(key, 1, 2) / scale for query, key in zip(queries, keys)] QKs = [softmax(QK, dim=2) for QK in QKs] attentions = [QK @ value for QK, value in zip(QKs, values)] attention = jnp.stack(attentions, axis=-1).reshape(n_samples, n_tokens, -1) @ W_combine return attention, QKs # In[35]: ###################################################################### # 6. Make weights for dense net to apply to output of self-attention. n_ff_units = 10 ff_W1 = make_weights(embed_dim, n_ff_units * embed_dim) ff_W2 = make_weights(n_ff_units * embed_dim, embed_dim) print(f'ff_W1 shape {ff_W1.shape}') print(f'ff_W2 shape {ff_W2.shape}') # In[36]: ###################################################################### # 7. Define forward pass through dense net. def forward_transform_block(params, attention, X): ff_W1 = params[-2] ff_W2 = params[-1] X = attention + X # layernorm of X X = (X - X.mean(-1, keepdims=True)) / X.std(-1, keepdims=True) Y = jnp.tanh(jnp.tanh(X @ ff_W1) @ ff_W2) return Y # In[37]: ###################################################################### # 8. Make weights to linearly convert output of dense net to log probs for each class. n_classes = 2 W_toprobs = make_weights(embed_dim, n_classes) print(f'W_toprobs shape {W_toprobs.shape}') # In[38]: ###################################################################### # Finally, start classification. # 9. Convert list of tokens for each review into their respective embeddings. embedding = jnp.take(embedder_W, X_tokens, axis=0) # n_reviews x max_n_tokens x embed_dim X = embedding + position_encoding print(f'X.shape {X.shape}') # In[39]: ###################################################################### # 10. Pass each embedding through self-attention then dense net. def forward(params, X_tokens): # X is a mini-batch ,first, last): # Each row of X is an array of token indices, one for each word embedder_W = params[0] X_embedding = jnp.take(embedder_W, X_tokens, axis=0) # n_reviews x max_n_tokens x embed_dim X = X_embedding + position_encoding Y = forward_transform_block(params, forward_attention(params, X)[0], X) W_toprobs = params[-3] Y = Y @ W_toprobs # print(f'Output Y shape {Y.shape}') Y = Y.mean(axis=1) # mean over all outputs return Y params = [embedder_W, W_keys, W_queries, W_values, W_combine, W_toprobs, ff_W1, ff_W2] forward(params, X_tokens) # In[40]: labels.reshape(-1, 1) == numpy.unique(labels) # In[41]: def make_indicator_vars(labels): return jnp.array((labels.reshape(-1, 1) == numpy.unique(labels)).astype(int)) make_indicator_vars(labels) # In[42]: # Now we can implement the loss function, compute its gradient, and implement a training loop. def loss(params, X_tokens, T): Y = forward(params, X_tokens) Y = softmax(Y, -1) return -jnp.mean(T * jnp.log(Y)) T = make_indicator_vars(labels) # In[43]: loss(params, X_tokens, T) # In[44]: def generate_stratified_partitions(X, T, n_folds, validation=True, shuffle=True): '''Generates sets of Xtrain,Ttrain,Xvalidate,Tvalidate,Xtest,Ttest or sets of Xtrain,Ttrain,Xtest,Ttest if validation is False Build dictionary keyed by class label. Each entry contains rowIndices and start and stop indices into rowIndices for each of n_folds folds''' def rows_in_fold(folds, k): all_rows = [] for c, rows in folds.items(): class_rows, starts, stops = rows all_rows += class_rows[starts[k]:stops[k]].tolist() return all_rows def rows_in_folds(folds, ks): all_rows = [] for k in ks: all_rows += rows_in_fold(folds, k) return all_rows row_indices = numpy.arange(X.shape[0]) if shuffle: numpy.random.shuffle(row_indices) folds = {} classes = numpy.unique(T) for c in classes: class_indices = row_indices[numpy.where(T[row_indices, :] == c)[0]] n_in_class = len(class_indices) n_each = int(n_in_class / n_folds) starts = numpy.arange(0, n_each * n_folds, n_each) stops = starts + n_each stops[-1] = n_in_class folds[c] = [class_indices, starts, stops] for test_fold in range(n_folds): if validation: for validate_fold in range(n_folds): if test_fold == validate_fold: continue train_folds = numpy.setdiff1d(range(n_folds), [test_fold, validate_fold]) rows = rows_in_fold(folds, test_fold) Xtest = X[rows, :] Ttest = T[rows, :] rows = rows_in_fold(folds, validate_fold) Xvalidate = X[rows, :] Tvalidate = T[rows, :] rows = rows_in_folds(folds, train_folds) Xtrain = X[rows, :] Ttrain = T[rows, :] yield Xtrain, Ttrain, Xvalidate, Tvalidate, Xtest, Ttest else: # No validation set train_folds = numpy.setdiff1d(range(n_folds), [test_fold]) rows = rows_in_fold(folds, test_fold) Xtest = X[rows, :] Ttest = T[rows, :] rows = rows_in_folds(folds, train_folds) Xtrain = X[rows, :] Ttrain = T[rows, :] yield Xtrain, Ttrain, Xtest, Ttest # In[45]: Xtrain, Ttrain, Xtest, Ttest = next(generate_stratified_partitions(X_tokens, T, 4, validation=False, shuffle=True)) print(f'{Xtrain.shape=} {Ttrain.shape=} {Xtest.shape=} {Ttest.shape=}') def frac_pos(T): return (T[:, 1] == 1).mean().item() print(f'{frac_pos(Ttrain)=:.2f} {frac_pos(Ttest)=:.2f}') # In[46]: loss_grad = jax.value_and_grad(loss) def train(n_steps, batch_size, learning_rate): global params if batch_size < 0: batch_size = Xtrain.shape[0] print('Training started') losses = [] likelihoods = [] n_samples = Xtrain.shape[0] start_time = time.time() for step in range(n_steps): likelihoods_batch = [] first = 0 for batch_i, first in enumerate(range(0, n_samples, batch_size)): Xtrain_batch = Xtrain[first:first + batch_size] Ttrain_batch = Ttrain[first:first + batch_size] loss_value, grads = loss_grad(params, Xtrain_batch, Ttrain_batch) losses.append(loss_value) likelihoods_batch.append(jnp.exp(-loss_value)) params = [param - learning_rate * grad if not isinstance(grad, list) else [par - learning_rate * gra for (par, gra) in zip(param, grad)] for (param, grad) in zip(params, grads)] likelihoods.append(jnp.mean(jnp.array(likelihoods_batch))) # exp(-loss_value)) if (step + 1) % max(1, (n_steps // 20)) == 0: print(f'Step {step+1} Likelihood {likelihoods[-1]:.4f}') losses = jnp.array(losses) elapsed = time.time() - start_time print(f'Training took {elapsed/60:.1f} minutes.') return losses, likelihoods # In[48]: batch_size = -1 learning_rate = 0.1 n_steps = 500 losses, likelihoods = train(n_steps, batch_size, learning_rate) plt.plot(likelihoods); # In[49]: # To use our transformer, do a forward pass in batches, convert outputs to probabilities, then calculate a confusion matrix. Y = [] n_samples = X_tokens.shape[0] if batch_size > 0: first = 0 for first in range(0, n_samples, batch_size): Y.append(forward(params, X_tokens[first:first + batch_size])) Y = jnp.vstack(Y) else: Y = forward(params, X_tokens) probs = softmax(Y, dim=1) probs # In[50]: pred_classes = jnp.argmax(Y, axis=1) actual_classes = labels # jnp.argmax(labels, axis=1) row0 = [jnp.mean(pred_classes[actual_classes == 0] == 0), jnp.mean(pred_classes[actual_classes == 0] == 1)] row1 = [jnp.mean(pred_classes[actual_classes == 1] == 0), jnp.mean(pred_classes[actual_classes == 1] == 1)] cm = jnp.array([row0, row1]) pandas.DataFrame(100*cm, columns=('Pred Neg', 'Pred Pos'), index=('Actual Neg', 'Actual Pos')) # In[51]: pandas.set_option('max_colwidth', None) n = 10 df = pandas.DataFrame(numpy.hstack((actual_classes[:n].reshape(-1, 1), pred_classes[:n].reshape(-1, 1), headlines_orig[:n])), columns=('Actual', 'Predicted', 'Headline')) df # In[52]: pandas.set_option('max_colwidth', None) df = pandas.DataFrame(numpy.hstack((actual_classes[-n:].reshape(-1, 1), pred_classes[-n:].reshape(-1, 1), headlines_orig[-n:])), columns=('Actual', 'Predicted', 'Headline')) df # In[53]: ###### attention weights headline_i = 1900 embedder_W = params[0] X_embedding = jnp.take(embedder_W, X_tokens[headline_i:headline_i+1], axis=0) # n_reviews x max_n_tokens x embed_dim X = X_embedding + position_encoding # [None, :, :] attn, QKs = forward_attention(params, X) words = headlines[headline_i] n_words = words.index(' ') - 1 words = words[:n_words] # n_words = len(headlines[headline_i]) plt.figure(figsize=(12, 12)) nplot = int(numpy.sqrt(n_heads)) + 1 ploti = 0 for h in range(n_heads): ploti += 1 plt.subplot(nplot, nplot, ploti) plt.imshow(QKs[h][0, :n_words, :n_words]) plt.colorbar() plt.suptitle(' '.join(words)) plt.tight_layout() ## Draw attention weights as lines between words plt.figure(figsize=(12, 12)) plt.suptitle(' '.join(words)) # headlines_orig[headline_i][0]) ploti = 0 for h in range(n_heads): ploti += 1 plt.subplot(nplot, nplot, ploti) plt.xlim(0, 4) plt.ylim(0, n_words) plt.axis('off') for i, w in enumerate(words): plt.text(1, n_words - i, w, ha='right') plt.text(3, n_words - i, w, ha='left') QK = QKs[h][0].clone()[:n_words, :n_words] if True: mx = numpy.max(QK) # mn = numpy.min(QK) mask = QK < (0.7 * mx) QK = QK.at[mask].set(0.0) QK = QK / mx for i in range(n_words): for j in range(n_words): plt.plot([1, 3], [n_words - i, n_words - j], 'r', alpha=QK[i, j].item()) plt.tight_layout() # In[ ]: