Here, we illustrate for the example of a complex 2×2 matrix $U$, that for $\Omega = \det(U)$ and a perturbation $\Delta\Omega$, the correct ChainRules.rrule
is $\bar\Omega\, \Delta\Omega\, ({U^{-1}})^\dagger$, not $\Omega\, \Delta\Omega\, ({U^{-1}})^\dagger$. That is, it involves the complex conjugate of the determinant, not the value of the determinant.
from sympy import *
def csym(name, i=None, j=None, part=None):
"""Create a symbolic complex number."""
if part is None:
return csym(name, i, j, 're') + I * csym(name, i, j, 'im')
else:
if i is None or j is None:
return symbols('{name}^{part}'.format(name=name, part=part), real=true)
else:
return symbols('{name}_{i}{j}^{part}'.format(name=name, i=i, j=j, part=part), real=true)
def real(z):
return (z + z.conjugate()) / 2
def imag(z):
return (z - z.conjugate()) / (2 * I)
conj = conjugate
def dagger(U):
return conj(transpose(U))
def deriv_scalar_matrix(J, M):
"""Return the matrix that is the derivative of the real-valued scalar
$J$ with respect to the real-valued matrix $M$. The $ij$ element of
the result is defined as $\frac{\partial J}{\partial M_{ji}}$, see
https://en.wikipedia.org/wiki/Matrix_calculus#Scalar-by-matrix
Note the implicit transpose in this equation!
"""
n, m = M.shape
return Matrix(
[
[J.diff(M[j, i]) for i in range(n)]
for j in range(m)
]
)
def u(i, j, part=None):
"""Complex elements of the matrix U"""
return csym("u", i, j, part)
U = Matrix([[u(1, 1), u(1, 2)], [u(2, 1), u(2, 2)]])
U
Ω = det(U)
Ω
ΔΩ = csym(r"\Delta\Omega");
ΔΩ
On the left-hand side, we have the equation for rrule
from https://juliadiff.org/ChainRulesCore.jl/dev/maths/complex.html
lhs = (
real(ΔΩ) * deriv_scalar_matrix(real(Ω), real(U))
+ imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), real(U))
+ I * real(ΔΩ) * deriv_scalar_matrix(real(Ω), imag(U))
+ I * imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), imag(U))
);
lhs
On the right hand side, we have to formula corresponding to the code at https://github.com/JuliaDiff/ChainRules.jl/blob/9023d898a0b957bd9b3baab6bc38b54822d6963a/src/rulesets/LinearAlgebra/dense.jl#L132
rhs_old = simplify(Ω * ΔΩ * dagger(U.inv())).expand()
rhs_old;
respectively the corrected version:
rhs_new = simplify(conj(Ω) * ΔΩ * dagger(U.inv())).expand()
rhs_new
diff_old = (lhs - rhs_old).simplify();
diff_old
Note that for U ∈ ℝ, the old definition works:
diff_old.subs({sym: 0 for sym in imag(U).free_symbols})
(Even without plugging in values, this is pretty obvious, since $\det U \in \mathbb{R}$)
The new RHS works unconditionally:
(lhs - rhs_new).simplify()
The current tests in ChainRules
actually include a check for the rrule
of det
for a complex matrix. However, the test matrix is a "well-conditioned matrix" defined as follows:
def v(i, j, part=None):
"""Complex elements of the matrix V"""
return csym("V", i, j, part)
V = Matrix([[v(1, 1), v(1, 2)], [v(2, 1), v(2, 2)]])
W = simplify(V * dagger(V) + eye(2));
W
We find that $\det(W) \in \mathbb{R}$:
imag(det(W))
This is unlike the determinant of an arbitary matrix:
imag(det(U))
and explains why this bug is not caught by the current tests.