# using Unitful: s using Plots; cgrad, default(markerstrokecolor=:auto, label="") using MIRT: Afft, Asense, embed using MIRT: pogm_restart, poweriter using MIRTjim: jim, prompt using FFTW: fft!, bfft!, fftshift! using LinearMapsAA: LinearMapAA, block_diag, redim, undim using MAT: matread import Downloads # todo: use Fetch or DataDeps? using LinearAlgebra: dot, norm, svd, svdvals, Diagonal, I using Random: seed! using StatsBase: mean using LaTeXStrings jif(args...; kwargs...) = jim(args...; prompt=false, kwargs...) isinteractive() ? jim(:prompt, true) : prompt(:draw); if !@isdefined(data) url = "https://github.com/JeffFessler/MIRTdata/raw/main/mri/lin-19-edp/" dataurl = url * "cardiac_perf_R8.mat" data = matread(Downloads.download(dataurl)) xinfurl = url * "Xinf.mat" Xinf = matread(Downloads.download(xinfurl))["Xinf"]["perf"] # (128,128,40) end; pinf = jim(Xinf, L"\mathrm{Converged\ image\ sequence } X_∞") if !@isdefined(ydata0) ydata0 = data["kdata"] # k-space data full of zeros ydata0 = permutedims(ydata0, [1, 2, 4, 3]) # (nx,ny,nc,nt) ydata0 = ComplexF32.(ydata0) end (nx, ny, nc, nt) = size(ydata0) if !@isdefined(samp) samp = ydata0[:,:,1,:] .!= 0 for ic in 2:nc # verify it is same for all coils @assert samp == (ydata0[:,:,ic,:] .!= 0) end kx = -(nx÷2):(nx÷2-1) ky = -(ny÷2):(ny÷2-1) psamp = jim(kx, ky, samp, "Sampling patterns for $nt frames"; xlabel=L"k_x", ylabel=L"k_y") end samp_sum = sum(samp, dims=3) color = cgrad([:blue, :black, :white], [0, 1/2nt, 1]) pssum = jim(kx, ky, samp_sum; xlabel="kx", ylabel="ky", color, clim=(0,nt), title="Number of sampled frames out of $nt") if !@isdefined(smaps) smaps_raw = data["b1"] # raw coil sensitivity maps jim(smaps_raw, "Raw |coil maps| for $nc coils") sum_last = (f, x) -> selectdim(sum(f, x; dims=ndims(x)), ndims(x), 1) ssos_fun = smap -> sqrt.(sum_last(abs2, smap)) # SSoS ssos_raw = ssos_fun(smaps_raw) smaps = smaps_raw ./ ssos_raw ssos = ssos_fun(smaps) @assert all(≈(1), ssos) pmap = jim(smaps, "Normalized |coil maps| for $nc coils") end TF = Afft((nx,ny,nt), 3; unitary=true) # unitary FFT along 3rd (time) dimension if false # verify adjoint tmp1 = randn(ComplexF32, nx, ny, nt) tmp2 = randn(ComplexF32, nx, ny, nt) @assert dot(tmp2, TF * tmp1) ≈ dot(TF' * tmp2, tmp1) @assert TF' * (TF * tmp1) ≈ tmp1 (size(TF), TF._odim, TF._idim) end tmp = TF * Xinf ptfft = jim(tmp, "|Temporal FFT of Xinf|") Aotazo = (samp, smaps) -> Asense(samp, smaps; unitary=true, fft_forward=false) # Otazo style A = block_diag([Aotazo(s, smaps) for s in eachslice(samp, dims=3)]...) #A = ComplexF32(1/sqrt(nx*ny)) * A # match Otazo's scaling (size(A), A._odim, A._idim) if !@isdefined(ydata) tmp = reshape(ydata0, :, nc, nt) tmp = [tmp[vec(s),:,it] for (it,s) in enumerate(eachslice(samp, dims=3))] ydata = cat(tmp..., dims=3) # (nsamp,nc,nt) = (2048,12,40) no "zeros" end size(ydata) tmp = LinearMapAA(I(nx*ny*nt); odim=(nx,ny,nt), idim=(nx,ny,nt), T=ComplexF32, prop=(;name="I")) tmp = kron([1 1], tmp) AII = redim(tmp; odim=(nx,ny,nt), idim=(nx,ny,nt,2)) # "squeeze" odim E = A * AII; if false (_, σ1E) = poweriter(undim(E)) # 1.413 ≈ √2 else σ1E = √2 end tmp = A * Xinf scale0 = dot(tmp, ydata) / norm(tmp)^2 # 1.009 ≈ 1 L0 = A' * ydata # adjoint (zero-filled) S0 = zeros(ComplexF32, nx, ny, nt) X0 = cat(L0, S0, dims=ndims(L0)+1) # (nx, ny, nt, 2) = (128, 128, 40, 2) M0 = AII * X0 # L0 + S0 pm0 = jim(M0, "|Initial L+S via zero-filled recon|") scaleL = 130 / 1.2775 # Otazo's stopping St(1) / b1 constant squared scaleS = 1 / 1.2775; # 1 / b1 constant squared lambda_L = 0.01 # regularization parameter lambda_S = 0.01 * scaleS Lpart = X -> selectdim(X, ndims(X), 1) # extract "L" from X Spart = X -> selectdim(X, ndims(X), 2) # extract "S" from X nucnorm(L::AbstractMatrix) = sum(svdvals(L)) # nuclear norm nucnorm(L::AbstractArray) = nucnorm(reshape(L, :, nt)); # (nx*ny, nt) for L Fcost = X -> 0.5 * norm(E * X - ydata)^2 + lambda_L * scaleL * nucnorm(Lpart(X)) + # note scaleL ! lambda_S * norm(TF * Spart(X), 1); f_grad = X -> E' * (E * X - ydata); # gradient of data-fit term f_L = 2; # σ1E^2 function SVST(X::AbstractArray, β) dims = size(X) X = reshape(X, :, dims[end]) # assume time frame is the last dimension U,s,V = svd(X) sthresh = @. max(s - β, 0) keep = findall(>(0), sthresh) X = U[:,keep] * Diagonal(sthresh[keep]) * V[:,keep]' X = reshape(X, dims) return X end; soft = (v,c) -> sign(v) * max(abs(v) - c, 0) # soft threshold function S_prox = (S, β) -> TF' * soft.(TF * S, β) # 1-norm proximal mapping for unitary TF g_prox = (X, c) -> cat(dims=ndims(X), SVST(Lpart(X), c * lambda_L * scaleL), S_prox(Spart(X), c * lambda_S), ); if false # check functions @assert Fcost(X0) isa Real tmp = f_grad(X0) @assert size(tmp) == size(X0) tmp = SVST(Lpart(X0), 1) @assert size(tmp) == size(L0) tmp = S_prox(S0, 1) @assert size(tmp) == size(S0) tmp = g_prox(X0, 1) @assert size(tmp) == size(X0) end niter = 10 fun = (iter, xk, yk, is_restart) -> (Fcost(xk), xk); # logger if !@isdefined(Mpgm) f_mu = 2/0.99 - f_L # trick to match 0.99 step size in Lin 1999 f_mu = 0 xpgm, out_pgm = pogm_restart(X0, (x) -> 0, f_grad, f_L ; f_mu, mom = :pgm, niter, g_prox, fun) Mpgm = AII * xpgm end; if !@isdefined(Mfpgm) xfpgm, out_fpgm = pogm_restart(X0, (x) -> 0, f_grad, f_L ; mom = :fpgm, niter, g_prox, fun) Mfpgm = AII * xfpgm end; if !@isdefined(Mpogm) xpogm, out_pogm = pogm_restart(X0, (x) -> 0, f_grad, f_L ; mom = :pogm, niter, g_prox, fun) Mpogm = AII * xpogm end; px = jim( jif(Lpart(xpogm), "L"), jif(Spart(xpogm), "S"), jif(Mpogm, "M=L+S"), jif(Xinf, "Minf"), ) costs = out -> [o[1] for o in out] nrmsd = out -> [norm(AII*o[2]-Xinf)/norm(Xinf) for o in out] cost_pgm = costs(out_pgm) cost_fpgm = costs(out_fpgm) cost_pogm = costs(out_pogm) pc = plot(xlabel = "iteration", ylabel = "cost") plot!(0:niter, cost_pgm, marker=:circle, label="PGM (ISTA)") plot!(0:niter, cost_fpgm, marker=:square, label="FPGM (FISTA)") plot!(0:niter, cost_pogm, marker=:star, label="POGM") nrmsd_pgm = nrmsd(out_pgm) nrmsd_fpgm = nrmsd(out_fpgm) nrmsd_pogm = nrmsd(out_pogm) pd = plot(xlabel = "iteration", ylabel = "NRMSD vs Matlab Xinf") plot!(0:niter, nrmsd_pgm, marker=:circle, label="PGM (ISTA)") plot!(0:niter, nrmsd_fpgm, marker=:square, label="FPGM (FISTA)") plot!(0:niter, nrmsd_pogm, marker=:star, label="POGM")