note: this tutorial requires to be familiar with tiled
matmul
. A dedicated tutorial is available intutorial
folder of this repository.
Naive implementation of softmax
computation requires to perform several passes on the whole input vector.
Vector loading from global memory (GM
, aka the GPU DRAM) for each pass is by far the operation bottleneck (compared to the computation part).
softmax
triton
tutorial avoids multiple read/write operations on GM
by assuming that the whole input vector is small enough to be loaded in shared memory (SRAM
).
Below, we describe an approach when this assumption doesn't stand, aka when the vector is too large for the SRAM
. In the case of transformer
model, the softmax
is applied to each row of a matrix of shape (sequence length, sequence length)
, and the SRAM
limit for an fp16
vector is around 128 tokens.
We will start the tutorial with a naive approach and optimize it.
import torch
torch.manual_seed(456)
row_count, col_count = 4, 16
long_input_vec: torch.Tensor = torch.rand((row_count, col_count))
To avoid FP16
or FP32
overflow in softmax
computation, it's usual to subtract to input vector its maximum value.
This operation has no effect on the final output outside. It improves stability by reducing values amplitude.
This is sometimes called safe softmax
computation.
Computation of safe softmax
on PyTorch
requires multiple passes on the whole input vector if done manually:
numerator / denominator
Note that because of the eager execution model, on PyTorch
step 2 requires 2 passes.
# torch softmax as a reference
expected_softmax = torch.softmax(long_input_vec, dim=1)
# 1st read, torch max output both indexes and values, we only want the values
# we transpose it to get a vertical tensor
row_max = torch.max(long_input_vec, dim=1).values[:, None]
print("input row max\n", row_max)
# 2nd read
input_safe = long_input_vec - row_max
print("Below we reduce values amplitude, that's the safe part of safe softmax")
print("original 1st row input:\n", long_input_vec[0, :], "safe softmax input 1st row:\n", input_safe[0, :])
softmax_numerator = torch.exp(input_safe)
# 3rd read
normalizer_term = torch.sum(softmax_numerator, dim=1)[:, None]
# 4th read
naive_softmax = softmax_numerator / normalizer_term
assert torch.allclose(naive_softmax, expected_softmax)
input row max tensor([[0.9820], [0.8412], [0.9198], [0.9778]]) Below we reduce values amplitude, that's the safe part of safe softmax original 1st row input: tensor([0.6815, 0.0039, 0.7451, 0.7946, 0.6127, 0.6803, 0.9820, 0.0019, 0.1609, 0.5916, 0.6531, 0.8855, 0.7397, 0.0681, 0.3341, 0.3200]) safe softmax input 1st row: tensor([-0.3005, -0.9780, -0.2369, -0.1874, -0.3693, -0.3017, 0.0000, -0.9800, -0.8211, -0.3904, -0.3289, -0.0965, -0.2423, -0.9139, -0.6479, -0.6620])
In their paper Online normalizer calculation for softmax, M. Milakov & Al. show an approach which makes parallelization possible by computing softmax
progressively.
Basically, we load the input vector in small blocks (adapted to the size of the SRAM
) and compute 2 statistics in a single pass:
The achievement lies in the fact that you are supposed to know the maximum value of the vector to compute the denominator. At each step, our knowledge of the maximum value may evolve (we may meet a value bigger than our precedent maximum). When it happens, we just adjust the result of our computation of the precedent step.
The adjustment procedure is based on rules of exponentiation: when multiplying a base raised to one exponent by the same base raised to another exponent, the exponents add.
online_softmax = torch.zeros_like(long_input_vec)
for row in range(row_count):
row_max = 0.0
normalizer_term = 0.0
print("--- new row ---")
for col in range(col_count): # scalar level iteration
val = long_input_vec[row, col]
old_row_max = row_max
row_max = max(old_row_max, val)
# np.exp(old_max_row - max_row) is the adjustment factor of our precedent normalizer term,
# after this multiplication it's like we had always substracted row_max up to this point
# instead of old_row_max
normalizer_term = normalizer_term * torch.exp(old_row_max - row_max) + torch.exp(val - row_max)
if old_row_max != row_max:
print("new max discovered")
print(f"current row max: {row_max}, denominator: {normalizer_term}")
# leverage our 2 statistics
online_softmax[row, :] = torch.exp(long_input_vec[row, :] - row_max) / normalizer_term
assert torch.allclose(online_softmax, expected_softmax)
--- new row --- new max discovered current row max: 0.6815125346183777, denominator: 1.0 current row max: 0.6815125346183777, denominator: 1.5078496932983398 new max discovered current row max: 0.7450968623161316, denominator: 2.4149584770202637 new max discovered current row max: 0.79459148645401, denominator: 3.2983407974243164 current row max: 0.79459148645401, denominator: 4.132010459899902 current row max: 0.79459148645401, denominator: 5.0240325927734375 new max discovered current row max: 0.9819886684417725, denominator: 5.165497779846191 current row max: 0.9819886684417725, denominator: 5.540792465209961 current row max: 0.9819886684417725, denominator: 5.9807353019714355 current row max: 0.9819886684417725, denominator: 6.657543182373047 current row max: 0.9819886684417725, denominator: 7.377259731292725 current row max: 0.9819886684417725, denominator: 8.285249710083008 current row max: 0.9819886684417725, denominator: 9.070070266723633 current row max: 0.9819886684417725, denominator: 9.471039772033691 current row max: 0.9819886684417725, denominator: 9.994173049926758 current row max: 0.9819886684417725, denominator: 10.509976387023926 --- new row --- new max discovered current row max: 0.3628944754600525, denominator: 1.0 current row max: 0.3628944754600525, denominator: 1.796286702156067 new max discovered current row max: 0.44334477186203003, denominator: 2.6574349403381348 new max discovered current row max: 0.5694260597229004, denominator: 3.3426434993743896 current row max: 0.5694260597229004, denominator: 4.307759761810303 new max discovered current row max: 0.8411758542060852, denominator: 4.282706260681152 current row max: 0.8411758542060852, denominator: 5.092153549194336 current row max: 0.8411758542060852, denominator: 5.754087924957275 current row max: 0.8411758542060852, denominator: 6.385719299316406 current row max: 0.8411758542060852, denominator: 7.075372695922852 current row max: 0.8411758542060852, denominator: 7.718149185180664 current row max: 0.8411758542060852, denominator: 8.450255393981934 current row max: 0.8411758542060852, denominator: 9.37951946258545 current row max: 0.8411758542060852, denominator: 9.812650680541992 current row max: 0.8411758542060852, denominator: 10.249856948852539 current row max: 0.8411758542060852, denominator: 11.185232162475586 --- new row --- new max discovered current row max: 0.9197819828987122, denominator: 1.0 current row max: 0.9197819828987122, denominator: 1.8743796348571777 current row max: 0.9197819828987122, denominator: 2.4121508598327637 current row max: 0.9197819828987122, denominator: 3.1733081340789795 current row max: 0.9197819828987122, denominator: 3.648242712020874 current row max: 0.9197819828987122, denominator: 4.4900665283203125 current row max: 0.9197819828987122, denominator: 5.0027642250061035 current row max: 0.9197819828987122, denominator: 5.762831687927246 current row max: 0.9197819828987122, denominator: 6.3089094161987305 current row max: 0.9197819828987122, denominator: 6.796399116516113 current row max: 0.9197819828987122, denominator: 7.307489395141602 current row max: 0.9197819828987122, denominator: 8.28607177734375 current row max: 0.9197819828987122, denominator: 8.744580268859863 current row max: 0.9197819828987122, denominator: 9.38587760925293 current row max: 0.9197819828987122, denominator: 9.824188232421875 current row max: 0.9197819828987122, denominator: 10.793715476989746 --- new row --- new max discovered current row max: 0.177534282207489, denominator: 1.0 new max discovered current row max: 0.9202759861946106, denominator: 1.4758076667785645 current row max: 0.9202759861946106, denominator: 2.0623040199279785 current row max: 0.9202759861946106, denominator: 2.7364466190338135 new max discovered current row max: 0.9371026754379272, denominator: 3.690786600112915 current row max: 0.9371026754379272, denominator: 4.633510112762451 current row max: 0.9371026754379272, denominator: 5.228850841522217 current row max: 0.9371026754379272, denominator: 5.776777744293213 current row max: 0.9371026754379272, denominator: 6.281983852386475 current row max: 0.9371026754379272, denominator: 6.7736921310424805 current row max: 0.9371026754379272, denominator: 7.39810848236084 current row max: 0.9371026754379272, denominator: 8.079381942749023 current row max: 0.9371026754379272, denominator: 8.852633476257324 new max discovered current row max: 0.9778261780738831, denominator: 9.49936580657959 current row max: 0.9778261780738831, denominator: 9.926789283752441 current row max: 0.9778261780738831, denominator: 10.47911262512207
Instead of working on scalars, we may prefer to work on blocks of vectors as big as GPU
SRAM
can load for performance reasons. For that purpose, the code above needs very small modifications, something we will see in the Flash attention notebook.