#!/usr/bin/env python # coding: utf-8 # # Reverse-AD for the determinant of a complex matrix # Here, we illustrate for the example of a complex 2×2 matrix $U$, that for $\Omega = \det(U)$ and a perturbation $\Delta\Omega$, the correct `ChainRules.rrule` is $\bar\Omega\, \Delta\Omega\, ({U^{-1}})^\dagger$, not $\Omega\, \Delta\Omega\, ({U^{-1}})^\dagger$. That is, it involves the complex conjugate of the determinant, not the *value* of the determinant. # In[1]: from sympy import * # ## Helper Routines # In[2]: def csym(name, i=None, j=None, part=None): """Create a symbolic complex number.""" if part is None: return csym(name, i, j, 're') + I * csym(name, i, j, 'im') else: if i is None or j is None: return symbols('{name}^{part}'.format(name=name, part=part), real=true) else: return symbols('{name}_{i}{j}^{part}'.format(name=name, i=i, j=j, part=part), real=true) # In[3]: def real(z): return (z + z.conjugate()) / 2 # In[4]: def imag(z): return (z - z.conjugate()) / (2 * I) # In[5]: conj = conjugate # In[6]: def dagger(U): return conj(transpose(U)) # In[7]: def deriv_scalar_matrix(J, M): """Return the matrix that is the derivative of the real-valued scalar $J$ with respect to the real-valued matrix $M$. The $ij$ element of the result is defined as $\frac{\partial J}{\partial M_{ji}}$, see https://en.wikipedia.org/wiki/Matrix_calculus#Scalar-by-matrix Note the implicit transpose in this equation! """ n, m = M.shape return Matrix( [ [J.diff(M[j, i]) for i in range(n)] for j in range(m) ] ) # In[8]: def u(i, j, part=None): """Complex elements of the matrix U""" return csym("u", i, j, part) # ## Definitions # In[9]: U = Matrix([[u(1, 1), u(1, 2)], [u(2, 1), u(2, 2)]]) U # In[10]: Ω = det(U) Ω # In[11]: ΔΩ = csym(r"\Delta\Omega"); ΔΩ # ## Left Hand Side # On the left-hand side, we have the equation for `rrule` from https://juliadiff.org/ChainRulesCore.jl/dev/maths/complex.html # In[12]: lhs = ( real(ΔΩ) * deriv_scalar_matrix(real(Ω), real(U)) + imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), real(U)) + I * real(ΔΩ) * deriv_scalar_matrix(real(Ω), imag(U)) + I * imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), imag(U)) ); lhs # ## Right Hand Side # On the right hand side, we have to formula corresponding to the code at https://github.com/JuliaDiff/ChainRules.jl/blob/9023d898a0b957bd9b3baab6bc38b54822d6963a/src/rulesets/LinearAlgebra/dense.jl#L132 # In[13]: rhs_old = simplify(Ω * ΔΩ * dagger(U.inv())).expand() rhs_old; # respectively the corrected version: # In[14]: rhs_new = simplify(conj(Ω) * ΔΩ * dagger(U.inv())).expand() rhs_new # ## Check # In[15]: diff_old = (lhs - rhs_old).simplify(); diff_old # Note that for U ∈ ℝ, the old definition works: # In[16]: diff_old.subs({sym: 0 for sym in imag(U).free_symbols}) # (Even without plugging in values, this is pretty obvious, since $\det U \in \mathbb{R}$) # The new RHS works unconditionally: # In[17]: (lhs - rhs_new).simplify() # ## A "Well-conditioned" matrix # The current tests in `ChainRules` actually include a check for the `rrule` of `det` for a complex matrix. However, the test matrix is a "well-conditioned matrix" defined as follows: # In[18]: def v(i, j, part=None): """Complex elements of the matrix V""" return csym("V", i, j, part) # In[19]: V = Matrix([[v(1, 1), v(1, 2)], [v(2, 1), v(2, 2)]]) # In[20]: W = simplify(V * dagger(V) + eye(2)); W # We find that $\det(W) \in \mathbb{R}$: # In[21]: imag(det(W)) # This is unlike the determinant of an arbitary matrix: # In[22]: imag(det(U)) # and explains why this bug is not caught by the current tests.