import math
import torch
import torch.nn as nn
def generate_relative_positions_matrix(length, max_relative_positions,
cache=False):
"""Generate the clipped relative positions matrix
for a given length and maximum relative positions"""
if cache:
distance_mat = torch.arange(-length+1, 1, 1).unsqueeze(0)
else:
range_vec = torch.arange(length)
range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1)
distance_mat = range_mat - range_mat.transpose(0, 1)
distance_mat_clipped = torch.clamp(distance_mat,
min=-max_relative_positions,
max=max_relative_positions)
# Shift values to be >= 0
final_mat = distance_mat_clipped + max_relative_positions
return final_mat
def relative_matmul(x, z, transpose):
"""Helper function for relative positions attention."""
batch_size = x.shape[0]
heads = x.shape[1]
length = x.shape[2]
x_t = x.permute(2, 0, 1, 3)
x_t_r = x_t.reshape(length, heads * batch_size, -1)
if transpose:
z_t = z.transpose(1, 2)
x_tz_matmul = torch.matmul(x_t_r, z_t)
else:
x_tz_matmul = torch.matmul(x_t_r, z)
x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1)
x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3)
return x_tz_matmul_r_t
class MultiHeadedAttention(nn.Module):
def __init__(self, head_count, model_dim, dropout=0.1,
max_relative_positions=0):
assert model_dim % head_count == 0
self.dim_per_head = model_dim // head_count
self.model_dim = model_dim
super(MultiHeadedAttention, self).__init__()
self.head_count = head_count
self.linear_keys = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_values = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_query = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.final_linear = nn.Linear(model_dim, model_dim)
self.max_relative_positions = max_relative_positions
if max_relative_positions > 0:
vocab_size = max_relative_positions * 2 + 1
self.relative_positions_embeddings = nn.Embedding(
vocab_size, self.dim_per_head)
def forward(self, key, value, query, mask=None, layer_cache=None, attn_type=None):
batch_size = key.size(0)
dim_per_head = self.dim_per_head
head_count = self.head_count
key_len = key.size(1)
query_len = query.size(1)
def shape(x):
"""Projection."""
return x.view(batch_size, -1, head_count, dim_per_head) \
.transpose(1, 2)
def unshape(x):
"""Compute context."""
return x.transpose(1, 2).contiguous() \
.view(batch_size, -1, head_count * dim_per_head)
# 1) Project key, value, and query.
if layer_cache is not None:
if attn_type == "self":
query, key, value = self.linear_query(query),\
self.linear_keys(query),\
self.linear_values(query)
key = shape(key)
value = shape(value)
if layer_cache["self_keys"] is not None:
key = torch.cat(
(layer_cache["self_keys"], key),
dim=2)
if layer_cache["self_values"] is not None:
value = torch.cat(
(layer_cache["self_values"], value),
dim=2)
layer_cache["self_keys"] = key
layer_cache["self_values"] = value
elif attn_type == "context":
query = self.linear_query(query)
if layer_cache["memory_keys"] is None:
key, value = self.linear_keys(key),\
self.linear_values(value)
key = shape(key)
value = shape(value)
else:
key, value = layer_cache["memory_keys"],\
layer_cache["memory_values"]
layer_cache["memory_keys"] = key
layer_cache["memory_values"] = value
else:
key = self.linear_keys(key)
value = self.linear_values(value)
query = self.linear_query(query)
key = shape(key)
value = shape(value)
if self.max_relative_positions > 0 and attn_type == "self":
key_len = key.size(2)
# 1 or key_len x key_len
relative_positions_matrix = generate_relative_positions_matrix(
key_len, self.max_relative_positions,
cache=True if layer_cache is not None else False)
# 1 or key_len x key_len x dim_per_head
relations_keys = self.relative_positions_embeddings(
relative_positions_matrix.to(key.device))
# 1 or key_len x key_len x dim_per_head
relations_values = self.relative_positions_embeddings(
relative_positions_matrix.to(key.device))
query = shape(query)
key_len = key.size(2)
query_len = query.size(2)
# 2) Calculate and scale scores.
query = query / math.sqrt(dim_per_head)
# batch x num_heads x query_len x key_len
query_key = torch.matmul(query, key.transpose(2, 3))
if self.max_relative_positions > 0 and attn_type == "self":
scores = query_key + relative_matmul(query, relations_keys, True)
else:
scores = query_key
scores = scores.float()
if mask is not None:
mask = mask.unsqueeze(1) # [B, 1, 1, T_values]
scores = scores.masked_fill(mask, -1e18)
# 3) Apply attention dropout and compute context vectors.
attn = self.softmax(scores).to(query.dtype)
drop_attn = self.dropout(attn)
context_original = torch.matmul(drop_attn, value)
if self.max_relative_positions > 0 and attn_type == "self":
context = unshape(context_original
+ relative_matmul(drop_attn,
relations_values,
False))
else:
context = unshape(context_original)
output = self.final_linear(context)
# Return multi-head attn
attns = attn.view(batch_size, head_count, query_len, key_len)
return output, attns
def update_dropout(self, dropout):
self.dropout.p = dropout
class MultiHeadedAttention2(nn.Module):
def __init__(self, head_count, model_dim, dropout=0.1):
assert model_dim % head_count == 0
self.dim_per_head = model_dim // head_count
self.model_dim = model_dim
super(MultiHeadedAttention2, self).__init__()
self.head_count = head_count
self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head)
self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head)
self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.final_linear = nn.Linear(model_dim, model_dim)
def forward(self, key, value, query, mask=None):
batch_size = key.size(0)
dim_per_head = self.dim_per_head
head_count = self.head_count
key_len = key.size(1)
query_len = query.size(1)
def shape(x):
"""Projection."""
return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2)
def unshape(x):
"""Compute context."""
return x.transpose(1, 2).contiguous().view(
batch_size, -1, head_count * dim_per_head)
# 1) Project key, value, and query.
key = self.linear_keys(key)
value = self.linear_values(value)
query = self.linear_query(query)
key = shape(key)
value = shape(value)
query = shape(query)
key_len = key.size(2)
query_len = query.size(2)
# 2) Calculate and scale scores.
query = query / math.sqrt(dim_per_head)
# batch x num_heads x query_len x key_len
query_key = torch.matmul(query, key.transpose(2, 3))
scores = query_key
scores = scores.float()
# 3) Apply attention dropout and compute context vectors.
attn = self.softmax(scores).to(query.dtype)
drop_attn = self.dropout(attn)
print(drop_attn.shape, value.shape)
context_original = torch.matmul(drop_attn, value)
print(context_original.shape)
context = unshape(context_original)
output = self.final_linear(context)
print(output.shape)
# Return multi-head attn
print(attn.shape)
attns = attn.view(batch_size, head_count, query_len, key_len)
print(attns.shape)
return output, attns
mh2 = MultiHeadedAttention2(1, 128)
output2, att2 = mh2(input, input, input)
torch.Size([1, 1, 10, 10]) torch.Size([1, 1, 10, 128]) torch.Size([1, 1, 10, 128]) torch.Size([1, 10, 128]) torch.Size([1, 1, 10, 10]) torch.Size([1, 1, 10, 10])
x = nn.Linear(128, 128)(input)
x.shape
torch.Size([1, 10, 128])
x.view(1, -1, 16, 8).transpose(1, 2).shape
torch.Size([1, 16, 10, 8])
_.transpose(1, 2).contiguous().view(1, -1, 16 * 8).shape
torch.Size([1, 10, 128])
nn.Embedding(1000, 64)(input).shape
torch.Size([1, 64])
input = torch.LongTensor([1])
input.shape
torch.Size([1])
# (batch, key_len, dim)
input = torch.randn(1, 10, 128)
mh = MultiHeadedAttention(16, 128)
output, attn = mh(input, input, input)
output.shape
torch.Size([1, 10, 128])
attn.shape
torch.Size([1, 16, 10, 10])
class PositionalEncoding(nn.Module):
"Implement the PE function."
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)],
requires_grad=False)
return self.dropout(x)
torch.arange(2, 100, 2)
tensor([ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98])
max_len = 10
dim = 16
pe = torch.zeros(max_len, dim)
pe.shape
torch.Size([10, 16])
position = torch.arange(0, max_len).unsqueeze(1)
position.shape
torch.Size([10, 1])
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
div_term.shape
torch.Size([8])
torch.sin(position * div_term).shape
torch.Size([10, 8])
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe.shape
torch.Size([10, 16])
pe = pe.unsqueeze(1)
pe.shape
torch.Size([10, 1, 16])
x = torch.randn(10, 1, 16)
x.shape
torch.Size([10, 1, 16])
x[0]
tensor([[ 0.0599, 1.7412, 0.9173, 1.3884, 0.6826, -0.6830, 0.8005, -0.4316, 0.1788, 0.0446, -1.1409, -0.8781, 0.8874, 0.7887, 0.3160, -1.0217]])
pe[0]
tensor([[0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.]])
(x + pe[:x.size(0)])[0]
tensor([[ 0.0599, 2.7412, 0.9173, 2.3884, 0.6826, 0.3170, 0.8005, 0.5684, 0.1788, 1.0446, -1.1409, 0.1219, 0.8874, 1.7887, 0.3160, -0.0217]])
pe = pe.unsqueeze(0)
pe.shape
torch.Size([1, 10, 16])
x = torch.randn(1, 10, 16)
x.shape
torch.Size([1, 10, 16])
(x + (pe[:, :x.size(1), :])).shape
torch.Size([1, 10, 16])
(x + pe).shape
torch.Size([1, 10, 16])
pe[:, 1].shape
torch.Size([1, 16])
def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return (torch.arange(0, max_len, device=lengths.device)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1)))
# (batch_size, )
# 每个元素为序列的长度
ts = torch.Tensor([5, 4, 3, 2])
ts.shape
torch.Size([4])
mask = ~sequence_mask(ts).unsqueeze(1)
mask.shape
torch.Size([4, 1, 5])
mask
tensor([[[False, False, False, False, False]], [[False, False, False, False, True]], [[False, False, False, True, True]], [[False, False, True, True, True]]])
# (batch_size, 1, 1, seq_len)
mask = mask.unsqueeze(1) # [B, 1, 1, T_values]
mask.shape
torch.Size([4, 1, 1, 5])
# (batch_size, head_count, query_len, key_len)
score = torch.randint(1, 10, (4, 8, 5, 5))
score.shape
torch.Size([4, 8, 5, 5])
scores = score.masked_fill(mask, -1e18)
scores.shape
torch.Size([4, 8, 5, 5])
def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent_mask) == 0
sequence_mask(ts)
tensor([[ True, True, True, True, True], [ True, True, True, True, False], [ True, True, True, False, False], [ True, True, False, False, False]])
sequence_mask(ts)
tensor([[ True, True, True, True, True], [ True, True, True, True, False], [ True, True, True, False, False], [ True, True, False, False, False]])