This tutorial notebook applies the Deep Declarative Networks (DDNs) framework to the implicit differentiation of Eigen decomposition. For a square matrix $\mathbf{x} \in \mathbb{R}^{m \times m}$, it can be formulated by solving the problem: $\mathbf{x} \in \mathbb{R}^{m \times m} \rightarrow ( \boldsymbol{\lambda} \in \mathbb{R}^{n}, \mathbf{u} \in \mathbb{R}^{m \times n})$ with $\boldsymbol{\lambda}$ as a vector of $n$ Eigenvalues and $\mathbf{u}$ as the corresponding Eigenvectors, as follows,
\begin{equation} \begin{aligned} \min_{\mathbf{u} \in \mathbb{R}^{m \times n}} f(\mathbf{x}, \mathbf{u}) &= -\text{tr} \left( \mathbf{u}^T \mathbf{x} \mathbf{u} \right)\ ,\\ \text{subject to} \quad h \left( \mathbf{u} \right) &= \mathbf{u}^T \mathbf{u} = \mathbf{I}_n\ , \end{aligned} \end{equation}where tr() is trace function of the diagonal elements, and $\mathbf{I}_n$ is a $n \times n$ identity matrix. The optimal solution is
\begin{equation} \mathbf{y} = \text{argmin}_{\mathbf{u} \in \mathbb{R}^{m \times n}} f(\mathbf{x}, \mathbf{u}) \end{equation}satisfying
\begin{equation} \mathbf{x} \mathbf{y}_i = \lambda_i \mathbf{y}_i, \forall i \in \mathcal{N} = \{1, ..., n\}\ . \end{equation}For the greatest Eigenvalue, one can use the Power Iteration algorithm. To be more general, we consider the Simultaneous Iteration algorithm with QR decomposition for multiple Eigenvalues.
[WARNING] Since the Eigenvector of an Eigenvalue has two solutions with the reverse $L_2$ based direction against each other, that is when $\mathbf{y}=-\mathbf{y}$, we apply the cosine similarity based solutions such that the gradient of the Eigenvector will not hover between these two directions but in one of them.
For the similarity of directions, we use $\mathbf{y} \in \mathbb{R}^{m \times n}$ to denote the optimal solution and $\mathbf{r}$ to denote a one-hot reference vector (dimension $\mathbb{R}^{m}$ with value 1 indicating the direction base) or a reference matrix (dimension $\mathbb{R}^{m \times n}$).
for $n$ Eigenvectors.
where $\text{diagnoal}(\mathbf{x})$ transforms the diagnoal elements of matrix $\mathbf{x}$ as a vector.
Next, one can choose either $\mathbf{s}_i > 0$ or $\mathbf{s}_i < 0$ as the sign condition. If $\mathbf{y}_i$ violates this condition, $\mathbf{y}_i=-\mathbf{y}_i$, for all $i \in \mathcal{N}$. The implementation can be found in function uniform_solution_direction().
Differentiating $\mathbf{y}$ over $\mathbf{x}$ is achieved by using Eq. (24) of Deep Declarative Networks (DDNs). Then, we have
\begin{equation} D_{X} \mathbf{y} = H^{-1} A^{T} \left( A H^{-1} A^T \right)^{-1} \left( A H^{-1} B \right) - H^{-1} B \in \mathbb{R}^{(m \times n) \times (m \times m)}\ , \end{equation}where
\begin{equation} \begin{aligned} A &= D_{Y} h(\mathbf{y}) \in \mathbb{R}^{n \times (m \times n)}\ ,\\ B &= D^2_{XY} f(\mathbf{x}, \mathbf{y}) \in \mathbb{R}^{(m \times n) \times (m \times m)}\ ,\\ H &= D^2_{YY} f(\mathbf{x}, \mathbf{y}) + \boldsymbol{\lambda}^T D^2_{YY} h(\mathbf{y}) \in \mathbb{R}^{(m \times n) \times (m \times n)}\ . \end{aligned} \end{equation}[ATTENTION] In Eq.(16) of DDNs, the Lagrangian form uses "-" between the objective function and the constraint equalities. While this sign has no effects on the problem defined, it affects $H$ because the Eigenvalues $\boldsymbol{\lambda}$ and those calculated by using Eq.(25) have different signs. Hence, we use "+" instead of "-" in the calculation of $H$. Alternatively, one can use "-" in the calculation of $H$ and simply set $\boldsymbol{\lambda}=-\boldsymbol{\lambda}$.
For an efficient implementation, we first initialize these matrices with all zeros and then assign values by exploiting their matrix structures as
\begin{equation} \begin{aligned} &\{ D_Y f(\mathbf{x}, \mathbf{y})(:, i) = -\left( \mathbf{x} + \mathbf{x}^T \right) \mathbf{y}(:, i), \forall i \in \mathcal{N} \} \in \mathbb{R}^{m \times n}\ ,\\ &\{D^2_{YY} f(\mathbf{x}, \mathbf{y})(:, i, :, i) = -\left( \mathbf{x} + \mathbf{x}^T \right), \forall i \in \mathcal{N} \} \in \mathbb{R}^{m \times n \times m \times n}\ ,\\ &\{A_i = D_{Y} h(\mathbf{y})(i, :, i) = 2 \mathbf{y}(:, i), \forall i \in \mathcal{N}\} \in \mathbb{R}^{n \times m \times n}\ ,\\ &\{D^2_{YY} h(\mathbf{y})(i, :, i, :, i) = 2 \mathbf{I}_m, \forall i \in \mathcal{N}\} \in \mathbb{R}^{n \times m \times n \times m \times n}\ . \end{aligned} \end{equation}However, since $D^2_{YY} h(\mathbf{y})$ consumes more memory than the others while its $i$th block is merely $2\mathbf{I}_m$, further exploitation of this structure with the vector-matrix multiplication of $\boldsymbol{\lambda}^T D^2_{YY} h(\mathbf{y})$, which is denoted by $D^2_{YY} h_{\lambda}(\mathbf{y})$ for simplicity, can reduce the memory consumption by $n$ times. One can implement it by following
\begin{equation} % \{D^2_{YY} h_{\lambda} (\mathbf{y}) \left( i \times n:(i+1) \times n, i \times n:(i+1) \times n \right) =2 \text{diag}_n ([ \lambda_1, ..., \lambda_n] ), \forall i \in \mathcal{M}=\{0, ..., m-1\}\} \in \mathbb{R}^{(m \times n) \times (m \times n)}\ ,\\ \{D^2_{YY} h_{\lambda} (\mathbf{y}) \left( :, i, :, i \right) = 2 \lambda_i \mathbf{I}_m, \forall i \in \mathcal{N}\} \in \mathbb{R}^{m \times n \times m \times n}\ . \end{equation}Then, the first core for memory reduction is to avoid explicitly storing $D^2_{YY}f(\mathbf{x}, \mathbf{y})$ and $D^2_{YY}h(\mathbf{y})$ by using
\begin{equation} H_i = H \left(:, i, :, i \right) = -\left( \mathbf{x} + \mathbf{x}^T \right) + 2 \lambda_i \mathbf{I}_m, \forall i \in \mathcal{N}\ . \end{equation}The formulation of $D^2_{XY} f(\mathbf{x}, \mathbf{y}) \in \mathbb{R}^{m \times n \times m \times m}$ requires two steps for all $i \in \mathcal{N}$ and $j \in \mathcal{M}=\{1, ..., m\}$,
\begin{equation} \begin{aligned} D^2_{XY} f(\mathbf{x}, \mathbf{y})(j,i,:,j) &\mathrel{-}= \mathbf{y}(:,i)\ ,\\ D^2_{XY} f(\mathbf{x}, \mathbf{y})(j,i,j,:) &\mathrel{-}= \mathbf{y}^T(:,i)\ . \end{aligned} \end{equation}We observe that $B=D^2_{XY} f(\mathbf{x}, \mathbf{y})$ consists of $\mathbf{y}$ and its tranpose with a specific structure. It is feasible to avoid explicitly storing $B$ which could consume rather large memory. Prior to that, however, we first highlight the second core for memory reduction by looping over each Eigenvalue and then sum over for $D_X \mathcal{L}(\mathbf{y})=D_Y \mathcal{L}(\mathbf{y}) D_X \mathbf{y}$ where $\mathcal{L}$ is the value of loss function of $\mathbf{y}_i \in \mathbb{R}^m$, that is
\begin{equation} D_X \mathcal{L}(\mathbf{y}) =\sum_{i \in \mathcal{N}} D_Y \mathcal{L}(\mathbf{y}_i) D_X \mathbf{y}_i\ . \end{equation}In this case, if explicitly stored, $A$, $B$, and $H$ can be reduced by $n^2$, $n$, and $n^2$ times respectively. In addition to this, however, $B_i \in \mathbb{R}^{m \times (m \times m)}$, for all $i \in \mathcal{N}$, still requires large memory. For instance, for $m=256$ with batch size $64$ in single-precision floating point format, it requires $4$ gigabytes.
Now, as the third core for memory reduction, we avoid storing $B$ but instead using $\mathbf{y}$ to calculate $D_X \mathcal{L}(\mathbf{y})$. We assume the implicit differentiation is on the $i$th Eigenvalue with index $i$ in all related variables and then define that
\begin{equation} D_X \mathcal{L}(\mathbf{y}_i) = D_Y \mathcal{L}(\mathbf{y}_i) D_X \mathbf{y}_i = D_Y \mathcal{L}(\mathbf{y}_i) \left( H_i^{-1} A_i^{T} \left( A_i H_i^{-1} A_i^T \right)^{-1} A_i H_i^{-1} - H_i^{-1} \right) B_i = K_i^T B_i \in \mathbb{R}^{m \times m}\ , \end{equation}where $K_i \in \mathbb{R}^{m}$ is a vector given $\mathcal{L} \in \mathbb{R}$. Due to the aforementioned special structure of $D^2_{XY} f(\mathbf{x}, \mathbf{y})$, we find that
\begin{equation} D_X \mathcal{L}(\mathbf{y}_i) = -K_i \mathbf{y}_i^T - \mathbf{y}_i K_i^T\ . \end{equation}This greatly reduces the memory requirement considering the large data size of $B$. To make it more clearly, we take $m=2$, $K_i=[ k_1, k_2 ]^T$, and $\mathbf{y}_i = [ y_1, y_2 ]^T$ as an example. Then,
\begin{equation} B_i = B_i^1 + B_i^2 = \begin{bmatrix} -y_1 & 0 & -y_2 & 0 \\ 0 & -y_1 & 0 & -y_2 \end{bmatrix} + \begin{bmatrix} -y_1 & -y_2 & 0 & 0 \\ 0 & 0 & -y_1 & -y_2 \end{bmatrix} \end{equation}and
\begin{equation} K_i^T B_i = K_i^T B^1_i + K_i^T B^2_i = \begin{bmatrix} -k_1 y_1 & -k_2 y_1 & -k_1 y_2 & -k_2 y_2 \end{bmatrix} + \begin{bmatrix} -k_1 y_1 & -k_1 y_2 & -k_2 y_1 & -k_2 y_2 \end{bmatrix}\ , \end{equation}where $K_i B^1_i$ is an outer-product of $K_i$ and $\mathbf{y}_i$, $K_i B^1_i$ and $K_i B^2_i$ are tranposed when reshaping to $m \times m$ matrices. The memory required by vectors $\mathbf{y}_i \in \mathbb{R}^{m}$ is much less than the one required by $B_i \in \mathbb{R}^{m \times (m \times m)}$. The structure of $B$ also makes $D_X \mathcal{L}(\mathbf{y})$ symmetric.
In summary, we greatly reduce the memory required by the calculation of $D_X \mathcal{L}(\mathbf{y})$ by 1) looping over each Eigenvalue instead of stacking all Eigenvalues for high-dimensional matrix computing, 2) avoiding storing the memory-inefficient $B$, and 3) simplifying the formulation of $H_i=H(:,i,:,i)$ by using only the vector $\mathbf{x}$ and Eigenvalues.
By using the Power Iteration algorithm for the greatest Eigenvalue (assume it is positive), the implicit function is defined as
\begin{equation} f(\mathbf{x}, \mathbf{u}_k, \mathbf{u}_{k+1}) = \mathbf{u}_{k+1} - \frac{\mathbf{x} \mathbf{u}_k}{\| \mathbf{x} \mathbf{u}_k \|}\ , \end{equation}where the source matrix $\mathbf{x} \in \mathbb{R}^{m \times m}$ and the solution $\mathbf{u} \in \mathbb{R}^{m \times n}$. Upon the convergence of $\mathbf{u}_k$ to a fixed point, we have $\mathbf{y}=\mathbf{u}_{k+1}=\mathbf{u}_k$, and thus,
\begin{equation} f(\mathbf{x}, \mathbf{y}) = \mathbf{y}-\frac{\mathbf{x} \mathbf{y}}{\| \mathbf{x} \mathbf{y} \|}\ . \end{equation}Now, since applying the implicit function theorem to $f(\mathbf{x}, \mathbf{y})$ gives
\begin{equation} \frac{\partial f}{\partial \mathbf{x}} + \frac{\partial f}{\partial \mathbf{y}} \frac{\partial \mathbf{y}}{\partial \mathbf{x}} = 0\ , \end{equation}\begin{equation} \frac{\partial \mathbf{y}}{\partial \mathbf{x}} = -\left( \frac{\partial f}{\partial \mathbf{y}} \right)^{-1} \frac{\partial f}{\partial \mathbf{x}}\ . \end{equation}For the notation simplicity, we denote $\mathcal{A} = \partial f / \partial \mathbf{y} \in \mathbb{R}^{1 \times (mn)}, \mathcal{B} = \partial f / \partial \mathbf{x} \in \mathbb{R}^{1 \times (mm)}$, and $\mathcal{K} = \partial \mathbf{y} / \partial \mathbf{x} \in \mathbb{R}^{(mn) \times (mm)}$. Hence, $\mathcal{K} = -\mathcal{A}^{-1} \mathcal{B}$.
Since the solution $\mathbf{y}$ is used to calculate the learning loss $\mathcal{L}$, we denote the gradient of $\mathcal{L}$ over $\mathbf{y}$ as $\mathcal{Q} = \partial \mathcal{L} / \partial \mathbf{y} \in \mathbb{R}^{1 \times (mn)}$ and the one over $\mathbf{x}$ is
\begin{equation} \frac{\partial \mathcal{L}}{\partial \mathbf{x}} = \frac{\partial \mathcal{L}}{\partial \mathbf{y}} \frac{\partial \mathbf{y}}{\partial \mathbf{x}} = \mathcal{Q} \mathcal{K} = -\mathcal{Q} \mathcal{A}^{-1} \mathcal{B} = -\mathcal{H} \mathcal{B} \in \mathbb{R}^{1 \times (mm)}\ , \end{equation}where $\mathcal{H} = \mathcal{Q} \mathcal{A}^{-1} \in \mathbb{R}$. Since \begin{aligned} \mathcal{Q} \mathcal{A}^{-1} &= \mathcal{H} \\ \mathcal{Q} &= \mathcal{H} \mathcal{A} \\ \mathcal{Q}^T &= \mathcal{A}^T \mathcal{H}^T \\ \mathcal{A} \mathcal{Q}^T &= \mathcal{A} \mathcal{A}^T \mathcal{H}^T\ , \end{aligned}
$\mathcal{H}$ can be calculated by $\text{nn.linalg.solve}(\mathcal{A} \mathcal{A}^T, \mathcal{A} \mathcal{Q}^T)^T$, followed by the calculation of $\partial \mathcal{L} / \partial \mathbf{x}$.
import torch, os
from copy import deepcopy
from datetime import datetime
from utils import generate_random_data, method_mode_explaination
from utils import run_precision_statistics, run_speed_memory_statistics
from utils import visual_speed_memory, visual_precision
# Import forward instances
from ied_forward import EigenAuto, SimultaneousIteration, PowerIteration
# Import backward instances
from ied_backward import AutoBackprop, DDNBackprop, FixedPointBackprop
# Build visualization tools
disp_fnc = lambda x: x.detach().cpu().numpy()
diff_fnc = lambda x, y: disp_fnc((x - y.view(x.shape)).abs().max())
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
/home/zhiwei/anaconda3/envs/DDN112/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
# Build instances
num_iters = 100
uniform_solution_method = 'positive'
# Generate random data
enable_stop_condition = False
enable_symmetric = True
# 'solve' achieves lower memory but dLdx accuracy is bad because of the library ligalg.solve issue;
# 'pinverse' is more accurate in dLdx but needs more memory.
backprop_inverse_mode = 'solve' # solve/pinverse
seed = 0
epsilon = 1.0e-13
batch, m = 1, 64
n = 1 # the number of Eigenvalues
dtype = torch.float32
#
save_dir = 'results'
time_string = datetime.now().strftime('%Y%m%d-%H%M%S')
save_dir = os.path.join(save_dir, time_string, f"{str(num_iters)}iters")
os.makedirs(save_dir, exist_ok=True)
g_save_path = save_dir + f"/{backprop_inverse_mode}"
if enable_symmetric: g_save_path += '_symmetric'
else: g_save_path += '_nonsymmetric'
if enable_stop_condition: g_save_path += '_stopTrue'
else: g_save_path += '_stopFalse'
#
obj_AT = EigenAuto(
uniform_solution_method=uniform_solution_method, solver_back=AutoBackprop,
num_eigen_values=n, backprop_inverse_mode=backprop_inverse_mode, enable_symmetric=enable_symmetric)
obj_PI = PowerIteration(
num_iters=num_iters, uniform_solution_method=uniform_solution_method, solver_back=DDNBackprop,
num_eigen_values=n, enable_stop_condition=enable_stop_condition, backprop_inverse_mode=backprop_inverse_mode)
obj_SI = SimultaneousIteration(
num_iters=num_iters, uniform_solution_method=uniform_solution_method, solver_back=DDNBackprop,
num_eigen_values=n, enable_stop_condition=enable_stop_condition, backprop_inverse_mode=backprop_inverse_mode)
obj_PI_IFT = PowerIteration(
num_iters=num_iters, uniform_solution_method=uniform_solution_method, solver_back=FixedPointBackprop,
num_eigen_values=n, enable_stop_condition=enable_stop_condition, backprop_inverse_mode=backprop_inverse_mode)
obj_SI_IFT = SimultaneousIteration(
num_iters=num_iters, uniform_solution_method=uniform_solution_method, solver_back=FixedPointBackprop,
num_eigen_values=n, enable_stop_condition=enable_stop_condition, backprop_inverse_mode=backprop_inverse_mode)
x_org, dLdy = generate_random_data(
batch, m, n, seed, enable_symmetric=enable_symmetric, dtype=dtype, enable_grad_one=True,
distribution_mode='gaussian')
# Get Eigenvalues and Eigenvectors
# ---- Autogradient
x = deepcopy(x_org)
lambd_at, y_at = obj_AT(x, backward_fnc_name='unroll')
# ---- 1. Autogradient: autoback eigh
x.retain_grad()
loss = y_at.sum()
loss.backward()
dLdx_at = x.grad
# ---- 2. Autogradient: get auto dydx then multiply dLdy
x = deepcopy(x_org)
dLdx_at_iter = obj_AT.solver_back.dLdx_auto_iter_fnc(x, y_at, lambd_at, dLdy)
# ---- 3. Autogradient: get auto A, B, H, then DDN dydx and multiply dLdy
x = deepcopy(x_org)
dLdx_at_ddn = obj_AT.solver_back.dLdx_DDN_fnc(x, y_at, lambd_at, dLdy, enable_B=True)
# ---- 4. [AVOID STORING B] This is supposed to be the same as dLdx_at_ddn
dLdx_at_ddn_indiv = 0.0
for i in range(y_at.shape[-1]):
dLdx_at_ddn_indiv += obj_AT.solver_back.dLdx_DDN_fnc(
x, y_at[:, :, i:i + 1], lambd_at[:, i:i + 1], dLdy[:, :, i:i + 1], enable_B=False)
print('dLdx_at_ddn vs dLdx_at_ddn_indiv :', diff_fnc(dLdx_at_ddn, dLdx_at_ddn_indiv))
# ---- Simultaneous Iteration
# ---- 1. Simultaneous Iteration: DDN
x = deepcopy(x_org)
lambd_si, y_si = obj_SI(x, backward_fnc_name='dLdx_DDN_fnc')
loss = y_si.sum()
x.retain_grad()
loss.backward()
dLdx_si = x.grad
# The same as dLdx_si
dLdx_si_ddn = obj_SI.solver_back.dLdx_DDN_fnc(x, y_si, lambd_si, dLdy, enable_B=True)
# ---- 2. [AVOID STORING B] This is supposed to be the same as dLdx_si_ddn
dLdx_si_ddn_indiv = 0.0
for i in range(y_si.shape[-1]):
dLdx_si_ddn_indiv += obj_SI.solver_back.dLdx_DDN_fnc(
x, y_si[:, :, i:i + 1], lambd_si[:, i:i + 1], dLdy[:, :, i:i + 1], enable_B=False)
print('dLdx_si_ddn vs dLdx_si_ddn_indiv :', diff_fnc(dLdx_si_ddn, dLdx_si_ddn_indiv))
# ---- 3. Simultaneous Iteration: unrolling
x = deepcopy(x_org)
lambd_si_unroll, y_si_unroll = obj_SI(x, backward_fnc_name='unroll')
x.retain_grad()
loss = y_si_unroll.sum()
loss.backward()
dLdx_si_unroll = x.grad
# ---- Power Iteration:
# ---- 1. Unrolling
x = deepcopy(x_org)
lambd_pi_unroll, y_pi_unroll = obj_PI(x, backward_fnc_name='unroll')
x.retain_grad()
loss = y_pi_unroll.sum()
loss.backward()
dLdx_pi_unroll = x.grad
# ---- 2. DDN
x = deepcopy(x_org)
lambd_pi, y_pi = obj_PI(x, backward_fnc_name='dLdx_DDN_fnc')
loss = y_pi.sum()
x.retain_grad()
loss.backward()
dLdx_pi = x.grad
# ---- Fixed Point Theorem:
dLdx_pi_ift_auto = obj_PI_IFT.solver_back.dLdx_fnc(x, y_pi, lambd_pi, dLdy, enable_B=False)
dLdx_pi_ift_stru = obj_PI_IFT.solver_back.dLdx_structured_fnc(x, y_pi, lambd_pi, dLdy, enable_B=False)
dLdx_si_ift_auto = obj_SI_IFT.solver_back.dLdx_fnc(x, y_si, lambd_si, dLdy, enable_B=False)
dLdx_si_ift_stru = obj_SI_IFT.solver_back.dLdx_structured_fnc(x, y_si, lambd_si, dLdy, enable_B=False)
# ----
# Zhiwei: torch.eigh() and DDN give symmetric gradients while
# unrolling does not, and they should be symmetric, so add a patch below.
# The observation is that no matter x is symmetric or not, unrolling never
# get symmetric gradients, unless we manually do the following for symmetric
# x; not for asymmetric x as it makes no sense.
if enable_symmetric:
dLdx_si_unroll = 0.5 * (dLdx_si_unroll + dLdx_si_unroll.permute(0, 2, 1))
dLdx_pi_unroll = 0.5 * (dLdx_pi_unroll + dLdx_pi_unroll.permute(0, 2, 1))
# Compare solution and gradients
method_mode_explaination()
print('[ Check Eigenvalues ]')
print('Auto :', disp_fnc(lambd_at).flatten())
print('PI :', disp_fnc(lambd_pi).flatten())
print('SI :', disp_fnc(lambd_si).flatten())
print('Max diff :', diff_fnc(lambd_at, lambd_si))
print('')
print('[ Check Eigenvectors ]')
print('Auto :', disp_fnc(y_at.permute(0, 2, 1)).flatten())
print('PI :', disp_fnc(y_pi.permute(0, 2, 1)).flatten())
print('SI :', disp_fnc(y_si.permute(0, 2, 1)).flatten())
print('Max diff :', diff_fnc(y_at, y_si))
#
eigen_gap_at = obj_AT.check_eigen_gap(lambd_at, y_at, x)
eigen_gap_pi = obj_PI.check_eigen_gap(lambd_pi, y_pi, x)
eigen_gap_si = obj_SI.check_eigen_gap(lambd_si, y_si, x)
print('')
print('[ Check Eigen Gaps ]')
print('Auto :', disp_fnc(eigen_gap_at).flatten())
print('PI :', disp_fnc(eigen_gap_pi).flatten())
print('SI :', disp_fnc(eigen_gap_si).flatten())
#
fp_gap_at = obj_AT.check_fixed_point_gap(y_at, x)
fp_gap_pi = obj_PI.check_fixed_point_gap(y_pi, x)
fp_gap_si = obj_SI.check_fixed_point_gap(y_si, x)
print('')
print('[ Check Fixed Point Gaps ]')
print('Auto :', disp_fnc(fp_gap_at).flatten())
print('PI :', disp_fnc(fp_gap_pi).flatten())
print('SI :', disp_fnc(fp_gap_si).flatten())
#
print('')
print('[ Check Gradient Gaps (Max Diff.) ]')
print('dLdx_at vs dLdx_si :', diff_fnc(dLdx_at, dLdx_si))
print('dLdx_at vs dLdx_si_ddn :', diff_fnc(dLdx_at, dLdx_si_ddn))
print('dLdx_at vs dLdx_at_iter :', diff_fnc(dLdx_at, dLdx_at_iter))
print('dLdx_at vs dLdx_at_ddn :', diff_fnc(dLdx_at, dLdx_at_ddn))
print('dLdx_pi vs dLdx_pi_unroll :', diff_fnc(dLdx_pi, dLdx_pi_unroll))
print('dLdx_pi vs dLdx_si :', diff_fnc(dLdx_pi, dLdx_si))
print('dLdx_si vs dLdx_si_unroll :', diff_fnc(dLdx_si, dLdx_si_unroll))
print('dLdx_si vs dLdx_si_ddn :', diff_fnc(dLdx_si, dLdx_si_ddn))
print('dLdx_at vs dLdx_pi_ift_auto:', diff_fnc(dLdx_at, dLdx_pi_ift_auto))
print('dLdx_at vs dLdx_si_ift_auto:', diff_fnc(dLdx_at, dLdx_si_ift_auto))
print('dLdx_at vs dLdx_pi_ift_stru:', diff_fnc(dLdx_at, dLdx_pi_ift_stru))
print('dLdx_at vs dLdx_si_ift_stru:', diff_fnc(dLdx_at, dLdx_si_ift_stru))
del dLdx_at, dLdx_at_iter, dLdx_at_ddn, dLdx_si, dLdx_si_ddn, dLdx_si_unroll, x_org, x, loss, dLdy
torch.cuda.empty_cache()
dLdx_at_ddn vs dLdx_at_ddn_indiv : 1.4551915e-11 dLdx_si_ddn vs dLdx_si_ddn_indiv : 0.0 [ Check Eigenvalues ] Auto : [101.05782] PI : [101.05673] SI : [101.056725] Max diff : 0.0010986328 [ Check Eigenvectors ] Auto : [0.12875485 0.1301477 0.13224962 0.11636169 0.13526331 0.11521947 0.12188618 0.13483587 0.12095354 0.11549785 0.13366696 0.13038024 0.12198585 0.1215284 0.14015691 0.11579662 0.11842807 0.12208147 0.12053845 0.13782045 0.12166768 0.1162324 0.13253036 0.11562862 0.12917513 0.12311543 0.12204126 0.12370913 0.11222141 0.11167143 0.1246989 0.12367364 0.12886107 0.13705821 0.118035 0.10938529 0.1294138 0.11872555 0.1348373 0.12000456 0.13498007 0.12891752 0.1170295 0.12522595 0.10646042 0.1321402 0.11618277 0.13953795 0.12810251 0.12542069 0.1277837 0.13592969 0.11688785 0.13464409 0.12372686 0.12716451 0.12047588 0.12865123 0.12061165 0.11782203 0.1311092 0.11846078 0.12677294 0.13404019] PI : [0.12875395 0.13014698 0.13224885 0.11636101 0.13526262 0.11521886 0.1218853 0.13483502 0.12095293 0.11549731 0.13366619 0.13037927 0.12198532 0.12152781 0.14015584 0.11579592 0.11842736 0.12208086 0.1205376 0.13781972 0.12166704 0.11623182 0.13252953 0.11562788 0.1291745 0.12311484 0.1220405 0.12370848 0.11222087 0.11167082 0.12469833 0.12367295 0.12886041 0.13705751 0.11803435 0.10938465 0.12941329 0.11872499 0.13483651 0.12000401 0.13497937 0.1289169 0.11702888 0.12522534 0.10645989 0.13213956 0.11618212 0.13953719 0.12810177 0.12542002 0.12778303 0.13592893 0.11688729 0.13464339 0.12372625 0.12716393 0.12047534 0.12865065 0.12061101 0.11782142 0.13110845 0.11846016 0.12677234 0.13403955] SI : [0.1287539 0.130147 0.13224885 0.11636102 0.13526261 0.11521886 0.12188531 0.13483502 0.12095292 0.11549731 0.13366619 0.13037929 0.12198532 0.12152781 0.14015584 0.11579591 0.11842735 0.12208088 0.1205376 0.1378197 0.12166705 0.11623181 0.13252953 0.11562788 0.1291745 0.12311484 0.1220405 0.12370849 0.11222085 0.11167081 0.12469833 0.12367296 0.12886041 0.13705751 0.11803433 0.10938465 0.12941329 0.11872499 0.13483652 0.12000401 0.13497937 0.1289169 0.11702889 0.12522535 0.10645989 0.13213958 0.1161821 0.1395372 0.12810175 0.12542002 0.12778303 0.13592893 0.11688727 0.1346434 0.12372626 0.12716395 0.12047534 0.12865064 0.12061102 0.11782143 0.13110846 0.11846013 0.12677234 0.13403954] Max diff : 1.0728836e-06 [ Check Eigen Gaps ] Auto : [0.00018311] PI : [2.861023e-06] SI : [5.722046e-06] [ Check Fixed Point Gaps ] Auto : [1.0728836e-06] PI : [2.9802322e-08] SI : [4.4703484e-08] [ Check Gradient Gaps (Max Diff.) ] dLdx_at vs dLdx_si : 0.0012523589 dLdx_at vs dLdx_si_ddn : 0.0012523589 dLdx_at vs dLdx_at_iter : 5.120455e-10 dLdx_at vs dLdx_at_ddn : 5.9236423e-05 dLdx_pi vs dLdx_pi_unroll : 0.0039889836 dLdx_pi vs dLdx_si : 0.0040286533 dLdx_si vs dLdx_si_unroll : 0.0012523588 dLdx_si vs dLdx_si_ddn : 0.0 dLdx_at vs dLdx_pi_ift_auto: 1.6370905e-09 dLdx_at vs dLdx_si_ift_auto: 1.6880222e-09 dLdx_at vs dLdx_pi_ift_stru: 1.553417e-09 dLdx_at vs dLdx_si_ift_stru: 1.8590072e-09
obj_dict = {
'AT': obj_AT,
'PI': obj_PI,
'PI_IFT': obj_PI_IFT,
'SI': obj_SI,
'SI_IFT': obj_SI_IFT
}
batch = 5
data_sizes = [
(batch, 32, 1),
(batch, 64, 1),
(batch, 128, 1),
(batch, 256, 1),
(batch, 512, 1),
(batch, 1024, 1)
]
# ==== Run Precision Statistics
with torch.no_grad():
num_seeds = 100
# * and *_IFT have the same solver solution, so show one only
methods = [
'AT',
'PI',
'SI'
]
mode_dict = {
'AT': ['dLdx_DDN_fnc'],
'PI': ['dLdx_DDN_fnc'],
'PI_IFT': ['dLdx_structured_fnc'],
'SI': ['dLdx_DDN_fnc'],
'SI_IFT': ['dLdx_structured_fnc']
}
enable_legend = True
distribution_mode = 'gaussian' # gaussian/uniform/vonmise/choice + _resnet50
uniform_sample_max = 1.0
choice_max = 10.0
# ResNet50 will create m*m neurons, for m=256, it will cause out-of-memory issue, so just run under 256
if distribution_mode.find('resnet') > -1: data_sizes = [v for v in data_sizes if v[1] <= 128]
for dtype_cur in [torch.float32]:
if num_seeds > 100 and dtype_cur == torch.float64: continue
if distribution_mode.find('resnet') > -1 and dtype_cur == torch.float64: continue
print(f'Start precision statistics {dtype_cur}...')
save_path = f"{g_save_path}_{str(dtype_cur).replace('torch.', '')}_{distribution_mode}_numseeds{num_seeds}"
if distribution_mode == 'uniform': save_path += f'_max{uniform_sample_max}'
precision_info = run_precision_statistics(
num_seeds, methods, data_sizes, mode_dict, obj_dict, enable_symmetric, dtype_cur,
distribution_mode=distribution_mode, uniform_sample_max=uniform_sample_max,
choice_max=choice_max)
enable_legend = visual_precision(precision_info, data_sizes, save_path, enable_legend=enable_legend)
del precision_info
torch.cuda.empty_cache()
print('Done!')
Start precision statistics torch.float32... [ Method: AT+dLdx_DDN_fnc, Size: (5, 32, 1) ] [ Method: AT+dLdx_DDN_fnc, Size: (5, 64, 1) ] [ Method: AT+dLdx_DDN_fnc, Size: (5, 128, 1) ] [ Method: AT+dLdx_DDN_fnc, Size: (5, 256, 1) ] [ Method: AT+dLdx_DDN_fnc, Size: (5, 512, 1) ] [ Method: AT+dLdx_DDN_fnc, Size: (5, 1024, 1) ] [ Method: PI+dLdx_DDN_fnc, Size: (5, 32, 1) ] [ Method: PI+dLdx_DDN_fnc, Size: (5, 64, 1) ] [ Method: PI+dLdx_DDN_fnc, Size: (5, 128, 1) ] [ Method: PI+dLdx_DDN_fnc, Size: (5, 256, 1) ] [ Method: PI+dLdx_DDN_fnc, Size: (5, 512, 1) ] [ Method: PI+dLdx_DDN_fnc, Size: (5, 1024, 1) ] [ Method: SI+dLdx_DDN_fnc, Size: (5, 32, 1) ] [ Method: SI+dLdx_DDN_fnc, Size: (5, 64, 1) ] [ Method: SI+dLdx_DDN_fnc, Size: (5, 128, 1) ] [ Method: SI+dLdx_DDN_fnc, Size: (5, 256, 1) ] [ Method: SI+dLdx_DDN_fnc, Size: (5, 512, 1) ] [ Method: SI+dLdx_DDN_fnc, Size: (5, 1024, 1) ] Done!
# ==== Run Time and Memory Statistics
print('Start time and memory statistics...')
methods = [
'AT',
'PI',
'PI_IFT',
'SI',
'SI_IFT'
]
mode_dict = {
'AT': ['unroll', 'dLdx_DDN_fnc'],
'PI': ['unroll', 'dLdx_DDN_fnc_B', 'dLdx_DDN_fnc'],
'PI_IFT': ['dLdx_fnc', 'dLdx_structured_fnc'],
'SI': ['unroll', 'dLdx_DDN_fnc_B', 'dLdx_DDN_fnc'],
'SI_IFT': ['dLdx_fnc', 'dLdx_structured_fnc']
}
dtype = torch.float32
save_path = f"{g_save_path}_{str(dtype).replace('torch.', '')}"
num_seeds = 10
cost_info = run_speed_memory_statistics(
num_seeds, methods, data_sizes, mode_dict, obj_dict, enable_symmetric, dtype)
visual_speed_memory(cost_info, data_sizes, save_path)
del cost_info
torch.cuda.empty_cache()
print('Done!')
Start time and memory statistics... [ Method: AT+unroll, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: AT+unroll, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: AT+unroll, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: AT+unroll, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: AT+unroll, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: AT+unroll, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: AT+dLdx_DDN_fnc, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: AT+dLdx_DDN_fnc, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: AT+dLdx_DDN_fnc, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: AT+dLdx_DDN_fnc, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: AT+dLdx_DDN_fnc, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: AT+dLdx_DDN_fnc, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: PI+unroll, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: PI+unroll, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: PI+unroll, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: PI+unroll, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: PI+unroll, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: PI+unroll, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc_B, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc_B, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc_B, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc_B, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc_B, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc_B, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: PI+dLdx_DDN_fnc, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: PI_IFT+dLdx_fnc, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: PI_IFT+dLdx_fnc, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: PI_IFT+dLdx_fnc, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: PI_IFT+dLdx_fnc, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: PI_IFT+dLdx_fnc, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: PI_IFT+dLdx_fnc, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU time skipped for IFT. !!!CPU memory skipped. !!! GPU memory skipped for IFT+auto with size >= 1024. [ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: SI+unroll, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: SI+unroll, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: SI+unroll, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: SI+unroll, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: SI+unroll, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: SI+unroll, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc_B, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc_B, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc_B, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc_B, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc_B, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc_B, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: SI+dLdx_DDN_fnc, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: SI_IFT+dLdx_fnc, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: SI_IFT+dLdx_fnc, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: SI_IFT+dLdx_fnc, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: SI_IFT+dLdx_fnc, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: SI_IFT+dLdx_fnc, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: SI_IFT+dLdx_fnc, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU time skipped for IFT. !!!CPU memory skipped. !!! GPU memory skipped for IFT+auto with size >= 1024. [ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 32, 1) ] !!!CPU memory skipped. [ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 64, 1) ] !!!CPU memory skipped. [ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 128, 1) ] !!!CPU memory skipped. [ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 256, 1) ] !!!CPU memory skipped. [ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 512, 1) ] !!!CPU time skipped. !!!CPU memory skipped. [ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 1024, 1) ] !!!CPU time skipped. !!!CPU memory skipped. Done!