# Attention Mechanism¶

In [1]:
import math
from mxnet import nd
from mxnet.gluon import nn


In [2]:
# X: 3-D tensor, valid_length: 1-D or 2-D tensor
if valid_length is None:
return X.softmax()
else:
shape = X.shape
if valid_length.ndim == 1:
valid_length = valid_length.repeat(shape[1], axis=0)
else:
valid_length = valid_length.reshape((-1,))
# fill masked elements with a large negative, whose exp is 0
X = nd.SequenceMask(X.reshape((-1, shape[-1])), valid_length, True,
axis=1, value=-1e6)
return X.softmax().reshape(shape)


Example

In [3]:
masked_softmax(nd.random.uniform(shape=(2,2,4)), nd.array([2,3]))

Out[3]:
[[[0.488994   0.511006   0.         0.        ]
[0.43654838 0.56345165 0.         0.        ]]

[[0.28817102 0.3519408  0.3598882  0.        ]
[0.29034293 0.25239873 0.45725834 0.        ]]]
<NDArray 2x2x4 @cpu(0)>

## Dot Product Attention¶

$$\alpha(\mathbf Q, \mathbf K) = \langle \mathbf Q, \mathbf K^T \rangle /\sqrt{d}.$$
In [5]:
class DotProductAttention(nn.Block):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)

# query: (batch_size, #queries, d)
# key: (batch_size, #kv_pairs, d)
# value: (batch_size, #kv_pairs, dim_v)
# valid_length: either (batch_size, ) or (batch_size, seq_len)
def forward(self, query, key, value, valid_length=None):
d = query.shape[-1]
# set transpose_b=True to swap the last two dimensions of key
scores = nd.batch_dot(query, key, transpose_b=True) / math.sqrt(d)
return nd.batch_dot(attention_weights, value)


Example:

In [6]:
atten = DotProductAttention(dropout=0.5)
atten.initialize()
keys = nd.ones((2,10,2))
values = nd.arange(40).reshape((1,10,4)).repeat(2,axis=0)
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))

Out[6]:
[[[ 2.        3.        4.        5.      ]]

[[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>

## Multilayer Perception Attention¶

$\mathbf W_k\in\mathbb R^{h\times d_k}$, $\mathbf W_q\in\mathbb R^{h\times d_q}$, and $\mathbf v\in\mathbb R^{p}$:

$$\alpha(\mathbf k, \mathbf q) = \mathbf v^T \text{tanh}(\mathbf W_k \mathbf k + \mathbf W_q\mathbf q).$$
In [7]:
class MLPAttention(nn.Block):  # This class is saved in d2l.
def __init__(self, units, dropout, **kwargs):
super(MLPAttention, self).__init__(**kwargs)
# Use flatten=True to keep query's and key's 3-D shapes.
self.W_k = nn.Dense(units, activation='tanh',
use_bias=False, flatten=False)
self.W_q = nn.Dense(units, activation='tanh',
use_bias=False, flatten=False)
self.v = nn.Dense(1, use_bias=False, flatten=False)
self.dropout = nn.Dropout(dropout)

def forward(self, query, key, value, valid_length):
query, key = self.W_k(query), self.W_q(key)
# expand query to (batch_size, #querys, 1, units), and key to
# (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.
features = query.expand_dims(axis=2) + key.expand_dims(axis=1)
scores = self.v(features).squeeze(axis=-1)
return nd.batch_dot(attention_weights, value)


Example

In [8]:
atten = MLPAttention(units=8, dropout=0.1)
atten.initialize()
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))

Out[8]:
[[[ 2.        3.        4.        5.      ]]

[[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>