Attention Mechanism

In [1]:
import math
from mxnet import nd
from mxnet.gluon import nn

Masked softmax.

In [2]:
# X: 3-D tensor, valid_length: 1-D or 2-D tensor
def masked_softmax(X, valid_length):
    if valid_length is None:
        return X.softmax()
    else:
        shape = X.shape
        if valid_length.ndim == 1:
            valid_length = valid_length.repeat(shape[1], axis=0)
        else:
            valid_length = valid_length.reshape((-1,))
        # fill masked elements with a large negative, whose exp is 0
        X = nd.SequenceMask(X.reshape((-1, shape[-1])), valid_length, True, 
                            axis=1, value=-1e6)
        return X.softmax().reshape(shape)

Example

In [3]:
masked_softmax(nd.random.uniform(shape=(2,2,4)), nd.array([2,3]))
Out[3]:
[[[0.488994   0.511006   0.         0.        ]
  [0.43654838 0.56345165 0.         0.        ]]

 [[0.28817102 0.3519408  0.3598882  0.        ]
  [0.29034293 0.25239873 0.45725834 0.        ]]]
<NDArray 2x2x4 @cpu(0)>

Dot Product Attention

$$\alpha(\mathbf Q, \mathbf K) = \langle \mathbf Q, \mathbf K^T \rangle /\sqrt{d}.$$
In [5]:
class DotProductAttention(nn.Block):  
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # query: (batch_size, #queries, d)
    # key: (batch_size, #kv_pairs, d)
    # value: (batch_size, #kv_pairs, dim_v)
    # valid_length: either (batch_size, ) or (batch_size, seq_len) 
    def forward(self, query, key, value, valid_length=None):
        d = query.shape[-1]
        # set transpose_b=True to swap the last two dimensions of key
        scores = nd.batch_dot(query, key, transpose_b=True) / math.sqrt(d)
        attention_weights = self.dropout(masked_softmax(scores, valid_length))
        return nd.batch_dot(attention_weights, value)

Example:

In [6]:
atten = DotProductAttention(dropout=0.5)
atten.initialize()
keys = nd.ones((2,10,2))
values = nd.arange(40).reshape((1,10,4)).repeat(2,axis=0)
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))
Out[6]:
[[[ 2.        3.        4.        5.      ]]

 [[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>

Multilayer Perception Attention

$\mathbf W_k\in\mathbb R^{h\times d_k}$, $\mathbf W_q\in\mathbb R^{h\times d_q}$, and $\mathbf v\in\mathbb R^{p}$:

$$\alpha(\mathbf k, \mathbf q) = \mathbf v^T \text{tanh}(\mathbf W_k \mathbf k + \mathbf W_q\mathbf q). $$
In [7]:
class MLPAttention(nn.Block):  # This class is saved in d2l. 
    def __init__(self, units, dropout, **kwargs):
        super(MLPAttention, self).__init__(**kwargs)
        # Use flatten=True to keep query's and key's 3-D shapes.   
        self.W_k = nn.Dense(units, activation='tanh', 
                            use_bias=False, flatten=False)
        self.W_q = nn.Dense(units, activation='tanh', 
                            use_bias=False, flatten=False)
        self.v = nn.Dense(1, use_bias=False, flatten=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, valid_length):
        query, key = self.W_k(query), self.W_q(key)
        # expand query to (batch_size, #querys, 1, units), and key to 
        # (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.  
        features = query.expand_dims(axis=2) + key.expand_dims(axis=1)
        scores = self.v(features).squeeze(axis=-1)
        attention_weights = self.dropout(masked_softmax(scores, valid_length))
        return nd.batch_dot(attention_weights, value)

Example

In [8]:
atten = MLPAttention(units=8, dropout=0.1)
atten.initialize()
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))
Out[8]:
[[[ 2.        3.        4.        5.      ]]

 [[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>