In this tutorial we develop a declarative node in PyTorch for solving the ordinary least squares problem,
$$ \text{minimize} \quad \|Ax - b\|_2^2 $$parameterized by $m$-by-$n$ matrix $A$ and $m$-vector $b$. The output has closed-form solution
$$ \begin{align*} x^\star &= \left(A^T \! A\right)^{-1} \! A^T b \\ &= R^{-1} Q^{T} b \end{align*} $$where $A = QR$ is the QR decomposition of $A$.
import torch
print(torch.__version__)
print(torch.cuda.get_device_name() if torch.cuda.is_available() else "No CUDA")
torch.manual_seed(22)
import sys
sys.path.append("../")
1.13.0+cu117 GeForce RTX 2080 Ti
Since QR decomposition is differentiable we can simply produce a PyTorch function for solving the least-squares problem in the forward pass and allow autograd
to automatically compute gradients in the backward pass.
def solve_by_qr(A, b):
"""Auto-differentiable solver for least squares by QR."""
B, M, N = A.shape
assert b.shape == (B, M, 1)
Q, R = torch.linalg.qr(A, mode='reduced')
return torch.linalg.solve_triangular(R, torch.bmm(b.view(B, 1, M), Q).view(B, N, 1), upper=True)
Indeed the PyTorch torch.linalg.lstsq
function is itself auto differentiable.
We can develop a more efficient backward pass code by hand deriving the gradients.
Consider differentiating $x^\star$ with respect to single element $A_{ij}$, we have
$$ \begin{align*} \frac{\text{d}}{\text{d} A_{ij}} x^\star &= \frac{\text{d}}{\text{d} A_{ij}} \left(A^T \! A\right)^{-1} \! A^T b \\ &= \left(\frac{\text{d}}{\text{d} A_{ij}} \left(A^T \! A\right)^{-1} \right) A^T b + \left(A^T \! A\right)^{-1} \! \left(\frac{\text{d}}{\text{d} A_{ij}} A^T b\right) \end{align*} $$Now using the identity $\frac{\text{d}}{\text{d} z} Z^{-1} = -Z^{-1} \left(\frac{\text{d}}{\text{d} z} Z\right) Z^{-1}$ we get, for the first term,
$$ \begin{align*} \frac{\text{d}}{\text{d} A_{ij}} \left(A^T \! A\right)^{-1} &= -\left(A^T \! A\right)^{-1} \!\left(\frac{\text{d}}{\text{d} A_{ij}} \left(A^T \! A\right) \right) \left(A^T \! A\right)^{-1} \\ &= -\left(A^T \! A\right)^{-1} \!\left(E_{ij}^T A + A^T E_{ij}\right) \left(A^T \! A\right)^{-1} \end{align*} $$where $E_{ij}$ is a matrix with one in the $(i,j)$-th element and zeros elsewhere. Furthermore, for the second term,
$$ \begin{align} \frac{\text{d}}{\text{d} A_{ij}} A^T b &= E_{ij}^T b \end{align} $$Plugging these back into parent equation we have
$$ \begin{align*} \frac{\text{d}}{\text{d} A_{ij}} x^\star &= -\left(A^T \! A\right)^{-1} \!\left(E_{ij}^T A + A^T E_{ij}\right) \left(A^T \! A\right)^{-1} \! A^T b + \left(A^T \! A\right)^{-1} E_{ij}^T b \\ &= -\left(A^T \! A\right)^{-1} \!\left(E_{ij}^T A + A^T E_{ij}\right) x^\star + \left(A^T \! A\right)^{-1} E_{ij}^T b \\ &= -\left(A^T \! A\right)^{-1} \left( E_{ij}^T (Ax^\star - b) + A^T E_{ij} x^\star \right) \\ &= -\left(A^T \! A\right)^{-1} \left( (a_i^T x^\star - b_i) e_j + x_j^\star a_i \right) \end{align*} $$where $e_j = (0, 0, \ldots, 1, 0, \ldots) \in \mathbb{R}^n$ is the $j$-th canonical vector, i.e., vector with a one in the $j$-th component and zeros everywhere else, and $a_i^T \in \mathbb{R}^{1 \times n}$ is the $i$-th row of matrix $A$.
Observe that the term $(A^TA)^{-1}$ appears both in the solution for $x$ and the derivatives with respect to each $A_{ij}$. Thus, we only need to factor $A$ once in the forward pass and cache $R$ for use during the backward pass. This saves significant compute. Moreover, we can reuse terms for different $\frac{\text{d}}{\text{d} A_{ij}}$ and efficiently combine with the incoming gradient of the loss with respect to $x^\star$ as we now show.
Let $r = b - Ax^\star$ and let $v^T$ denote the backward coming gradient $\frac{\text{d}}{\text{d} x^\star} L$. Then
$$ \begin{align*} \frac{\text{d} L}{\text{d} A_{ij}} &= v^T \frac{\text{d} x^\star}{\text{d} A_{ij}} \\ &= v^T \left(A^T \! A\right)^{-1} \left( r_i e_j - x_j^\star a_i \right) \\ &= w^T \! \left( r_i e_j - x_j^\star a_i \right) \\ &= r_i w_j - w^T \! a_i x_j^\star \end{align*} $$where $w = \left(A^T \! A\right)^{-1}v$ does not depend on which $A_{ij}$ we are differentiating with respect to. We can therefore compute the entire matrix of $m \times n$ derivatives efficiently as the sum of outer products
$$ \begin{align*} \left(\frac{\text{d} L}{\text{d} A}\right)^{\! T} = \begin{bmatrix} \displaystyle \frac{\text{d} L}{\text{d} A_{ij}} \end{bmatrix}_{\substack{i=1, \ldots, m \\ j=1, \ldots, n}} &= rw^T - (Aw)(x^\star)^T \end{align*} $$We can also differentiate with respect to $b$, which is much simpler,
$$ \begin{align*} \left(\frac{\text{d} L}{\text{d} b}\right)^{\! T} = Aw \end{align*} $$The following code implement a PyTorch function with hand-coded backward pass. It is also available as BasicLeastSquaresFcn
in the ddn.pytorch.leastsquares
module.
class BasicLeastSquaresFcn(torch.autograd.Function):
"""
PyTorch autograd function for basic least squares problems,
y = argmin_u \|Au - b\|
solved via QR decomposition.
"""
@staticmethod
def forward(ctx, A, b, cache_decomposition=True):
B, M, N = A.shape
assert b.shape == (B, M, 1)
with torch.no_grad():
Q, R = torch.linalg.qr(A, mode='reduced')
x = torch.linalg.solve_triangular(R, torch.bmm(b.view(B, 1, M), Q).view(B, N, 1), upper=True)
# save state for backward pass
ctx.save_for_backward(A, b, x, R if cache_decomposition else None)
# return solution
return x
@staticmethod
def backward(ctx, dx):
# check for None tensors
if dx is None:
return None, None
# unpack cached tensors
A, b, x, R = ctx.saved_tensors
B, M, N = A.shape
if R is None:
Q, R = torch.linalg.qr(A, mode='r')
dA, db = None, None
w = torch.linalg.solve_triangular(R, torch.linalg.solve_triangular(torch.transpose(R, 2, 1), dx, upper=False), upper=True)
Aw = torch.bmm(A, w)
if ctx.needs_input_grad[0]:
r = b - torch.bmm(A, x)
dA = torch.bmm(r.view(B, M, 1), w.view(B, 1, N)) - torch.bmm(Aw.view(B, M, 1), x.view(B, 1, N))
if ctx.needs_input_grad[1]:
db = Aw
# return gradients (None for non-tensor inputs)
return dA, db, None
We profile the running time and memory required by the various implementations. Our test is to solve the bi-level optimisation problem,
$$ \begin{align*} \begin{array}{ll} \text{minimize} & \frac{1}{2} \|x^\star - x^{\text{target}}\|_2^2 \\ \text{subject to} & x^\star = \textrm{argmin}_{x} \; \|Ax - b\|_2^2 \end{array} \end{align*} $$by gradient descent for a fixed number of iterations.
import time
import torch.profiler as profiler
def do_profile(fcn, B, M, N, device=torch.device("cpu"), seed=0, iters=1000, mem_iters=10):
"""
Profile implementation. Uses different number of itertions for memory profile
and speed profiling since memory profiling is slow.
"""
torch.manual_seed(seed)
A_true = torch.randn((B, M, N), device=device, requires_grad=False)
b_true = torch.randn((B, M, 1), device=device, requires_grad=False)
x_true = torch.linalg.lstsq(A_true, b_true, driver='gels').solution
A_init = torch.randn((B, M, N), device=device, requires_grad=True)
# profile time
model = [torch.nn.Parameter(A_init.clone())]
optimizer = torch.optim.AdamW(model, lr=1.0e-3)
loss_trace = [None for i in range(iters)]
time_start = time.monotonic()
for i in range(len(loss_trace)):
x = fcn(model[0], b_true)
optimizer.zero_grad()
loss = torch.nn.functional.mse_loss(x, x_true)
loss_trace[i] = loss.item()
loss.backward()
optimizer.step()
duration = time.monotonic() - time_start
# profile memory
model = [torch.nn.Parameter(A_init.clone())]
optimizer = torch.optim.AdamW(model, lr=1.0e-3)
with profiler.profile(profile_memory=True) as prof:
for i in range(mem_iters):
x = fcn(model[0], b_true)
optimizer.zero_grad()
loss = torch.nn.functional.mse_loss(x, x_true)
loss.backward()
optimizer.step()
if device == torch.device("cpu"):
memory = prof.key_averages().total_average().cpu_memory_usage
else:
memory = prof.key_averages().total_average().cuda_memory_usage
prof.enabled = False
return duration, memory, loss_trace
import matplotlib.pyplot as plt
def plot_profiling(device=torch.device("cpu")):
# implementations to profile
if device == torch.device("cpu"):
f_array = [
lambda A, b: torch.linalg.lstsq(A, b, driver='gels').solution, # auto-diff torch least-squares solver
solve_by_qr, # auto-diff solve by QR decomposition
BasicLeastSquaresFcn.apply, # custom diff cache QR decomposition
lambda A, b: BasicLeastSquaresFcn.apply(A, b, False) # custom diff recompute QR decomposition
]
f_legend = ['torch.lstsq', 'autograd QR', 'custom QR', 'custom QR (no cache)']
styles = ['--', ':', '-', '-.']
colours = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
batch_size = 1
iters = 1000
else:
f_array = [
solve_by_qr, # auto-diff solve by QR decomposition
BasicLeastSquaresFcn.apply, # custom diff cache QR decomposition
lambda A, b: BasicLeastSquaresFcn.apply(A, b, False) # custom diff recompute QR decomposition
]
f_legend = ['autograd QR', 'custom QR', 'custom QR (no cache)']
styles = [':', '-', '-.']
colours = ['#ff7f0e', '#2ca02c', '#d62728']
batch_size = 128
iters = 100
runtime = [[] for i in range(len(f_array))]
memused = [[] for i in range(len(f_array))]
endloss = [[] for i in range(len(f_array))]
N_array = [100, 200, 300, 400, 500]
for N in N_array:
print("...running on size {}-by-{}".format(2*N, N))
for i in range(len(f_array)):
t, m, loss = do_profile(f_array[i], batch_size, 2 * N, N, device=device, iters=iters)
runtime[i].append(t)
memused[i].append(m)
endloss[i].append(loss[-1])
for i in range(len(f_array)):
print("time: {}\n mem: {}\nloss: {}".format(runtime[i], memused[i], endloss[i]))
plt.figure()
for i in range(len(f_array)):
plt.plot(N_array, runtime[i], linestyle=styles[i], color=colours[i])
plt.xlabel('problem size (2n-by-n)', fontsize=12)
plt.ylabel('running time (seconds)', fontsize=12)
plt.title('Running time on {} with batch size {}'.format(device, batch_size))
plt.legend(f_legend, fontsize=12)
#plt.savefig("ls_runtime_{}.png".format(device), dpi=300, bbox_inches='tight')
plt.figure()
for i in range(len(f_array)):
plt.plot(N_array, [m / (1024 * 1024) for m in memused[i]], linestyle=styles[i], color=colours[i])
plt.xlabel('problem size (2n-by-n)', fontsize=12)
plt.ylabel('memory used (MB)', fontsize=12)
plt.title('Memory usage on {} with batch size {}'.format(device, batch_size))
#plt.savefig("ls_memory_{}.png".format(device), dpi=300, bbox_inches='tight')
# profile on cpu
plot_profiling(torch.device("cpu"))
plt.show()
...running on size 200-by-100 ...running on size 400-by-200 ...running on size 600-by-300 ...running on size 800-by-400 ...running on size 1000-by-500 time: [1.5934700816869736, 4.518609205260873, 10.245868610218167, 18.57739918678999, 31.83123323507607] mem: [34153388, 133212204, 296661131, 546878118, 828188972] loss: [1.1331253296775756e-10, 3.399825559968761e-11, 1.6916133424582647e-11, 1.1065891258876803e-11, 9.069364183267314e-12] time: [0.7251121029257774, 1.3898825701326132, 2.6002293042838573, 4.631701299920678, 7.616993023082614] mem: [13947592, 57559900, 130030188, 231495456, 354964332] loss: [1.1323958437614579e-10, 3.401116194234888e-11, 1.6925034290737884e-11, 1.1048066107799404e-11, 9.078602453138629e-12] time: [0.6337910871952772, 1.0285055814310908, 1.8827926190569997, 2.957524019293487, 4.728598020039499] mem: [8488908, 33470964, 72097788, 129878496, 194785900] loss: [1.1323725984668798e-10, 3.395194195232598e-11, 1.6915352799018457e-11, 1.1050370687937239e-11, 9.076939720686905e-12] time: [0.7863148981705308, 1.3768486883491278, 2.5757922334596515, 4.334589009173214, 7.4314648024737835] mem: [12770412, 47484968, 111920008, 197073776, 298955156] loss: [1.1323725984668798e-10, 3.395194195232598e-11, 1.6915352799018457e-11, 1.1050370687937239e-11, 9.076939720686905e-12]
# profile on gpu
if torch.cuda.is_available():
plot_profiling(torch.device("cuda"))
plt.show()
...running on size 200-by-100 ...running on size 400-by-200 ...running on size 600-by-300 ...running on size 800-by-400 ...running on size 1000-by-500 time: [1.8819324253126979, 11.011459725908935, 19.37903497274965, 28.165275018662214, 39.22133521549404] mem: [1802263552, 7158280192, 15085527552, 26864092160, 41835319296] loss: [0.004872623365372419, 0.0023604189045727253, 0.0018237746553495526, 0.0016208530869334936, 0.0014888780424371362] time: [1.7105968864634633, 10.76188808772713, 18.692402061074972, 26.319541100412607, 36.21370824240148] mem: [963711488, 4181243904, 7467657216, 13042922496, 20332475392] loss: [0.004872623365372419, 0.0023604189045727253, 0.0018237746553495526, 0.0016208532033488154, 0.0014888780424371362] time: [2.575088639743626, 19.450912252999842, 34.05344322603196, 47.81562957260758, 66.19001938961446] mem: [1331968512, 5921859072, 11095494144, 19711094784, 31298075648] loss: [0.004872623365372419, 0.0023604189045727253, 0.0018237746553495526, 0.0016208532033488154, 0.0014888780424371362]