#!/usr/bin/env python # coding: utf-8 # # Attention Mechanism # In[1]: import math from mxnet import nd from mxnet.gluon import nn # Masked softmax. # In[2]: # X: 3-D tensor, valid_length: 1-D or 2-D tensor def masked_softmax(X, valid_length): if valid_length is None: return X.softmax() else: shape = X.shape if valid_length.ndim == 1: valid_length = valid_length.repeat(shape[1], axis=0) else: valid_length = valid_length.reshape((-1,)) # fill masked elements with a large negative, whose exp is 0 X = nd.SequenceMask(X.reshape((-1, shape[-1])), valid_length, True, axis=1, value=-1e6) return X.softmax().reshape(shape) # Example # In[3]: masked_softmax(nd.random.uniform(shape=(2,2,4)), nd.array([2,3])) # ## Dot Product Attention # # $$\alpha(\mathbf Q, \mathbf K) = \langle \mathbf Q, \mathbf K^T \rangle /\sqrt{d}.$$ # In[5]: class DotProductAttention(nn.Block): def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) # query: (batch_size, #queries, d) # key: (batch_size, #kv_pairs, d) # value: (batch_size, #kv_pairs, dim_v) # valid_length: either (batch_size, ) or (batch_size, seq_len) def forward(self, query, key, value, valid_length=None): d = query.shape[-1] # set transpose_b=True to swap the last two dimensions of key scores = nd.batch_dot(query, key, transpose_b=True) / math.sqrt(d) attention_weights = self.dropout(masked_softmax(scores, valid_length)) return nd.batch_dot(attention_weights, value) # Example: # In[6]: atten = DotProductAttention(dropout=0.5) atten.initialize() keys = nd.ones((2,10,2)) values = nd.arange(40).reshape((1,10,4)).repeat(2,axis=0) atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6])) # ## Multilayer Perception Attention # # $\mathbf W_k\in\mathbb R^{h\times d_k}$, $\mathbf W_q\in\mathbb R^{h\times d_q}$, and $\mathbf v\in\mathbb R^{p}$: # # $$\alpha(\mathbf k, \mathbf q) = \mathbf v^T \text{tanh}(\mathbf W_k \mathbf k + \mathbf W_q\mathbf q). $$ # # In[7]: class MLPAttention(nn.Block): # This class is saved in d2l. def __init__(self, units, dropout, **kwargs): super(MLPAttention, self).__init__(**kwargs) # Use flatten=True to keep query's and key's 3-D shapes. self.W_k = nn.Dense(units, activation='tanh', use_bias=False, flatten=False) self.W_q = nn.Dense(units, activation='tanh', use_bias=False, flatten=False) self.v = nn.Dense(1, use_bias=False, flatten=False) self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, valid_length): query, key = self.W_k(query), self.W_q(key) # expand query to (batch_size, #querys, 1, units), and key to # (batch_size, 1, #kv_pairs, units). Then plus them with broadcast. features = query.expand_dims(axis=2) + key.expand_dims(axis=1) scores = self.v(features).squeeze(axis=-1) attention_weights = self.dropout(masked_softmax(scores, valid_length)) return nd.batch_dot(attention_weights, value) # Example # In[8]: atten = MLPAttention(units=8, dropout=0.1) atten.initialize() atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))