using LinearAlgebra: norm, mul! using SPECTrecon: SPECTplan, project!, backproject!, psf_gauss, mlem! using MIRTjim: jim, prompt using Plots: default; default(markerstrokecolor=:auto) using ZygoteRules: @adjoint using Flux: Chain, Conv, SamePad, relu, params, unsqueeze import Flux # apparently needed for BSON @load import NNlib using LinearMapsAA: LinearMapAA using Distributions: Poisson using BSON: @load, @save import BSON # load using InteractiveUtils: versioninfo import Downloads # download isinteractive() ? jim(:prompt, true) : prompt(:draw); nx,ny,nz = 64,64,50 T = Float32 xtrue = zeros(T, nx,ny,nz) xtrue[(1nx÷4):(2nx÷3), 1ny÷5:(3ny÷5), 2nz÷6:(3nz÷6)] .= 1 xtrue[(2nx÷5):(3nx÷5), 1ny÷5:(2ny÷5), 4nz÷6:(5nz÷6)] .= 2 average(x) = sum(x) / length(x) function mid3(x::AbstractArray{T,3}) where {T} (nx,ny,nz) = size(x) xy = x[:,:,ceil(Int, nz÷2)] xz = x[:,ceil(Int,end/2),:] zy = x[ceil(Int, nx÷2),:,:]' return [xy xz; zy fill(average(xy), nz, nz)] end jim(mid3(xtrue), "Middle slices of xtrue") px = 11 psf1 = psf_gauss( ; ny, px) jim(psf1, "PSF for each of $ny planes") nview = 60 psfs = repeat(psf1, 1, 1, 1, nview) size(psfs) dy = 8 # transaxial pixel size in mm mumap = zeros(T, size(xtrue)) # zero μ-map just for illustration here plan = SPECTplan(mumap, psfs, dy; T) forw! = (y,x) -> project!(y, x, plan) back! = (x,y) -> backproject!(x, y, plan) idim = (nx,ny,nz) odim = (nx,nz,nview) A = LinearMapAA(forw!, back!, (prod(odim),prod(idim)); T, odim, idim) if !@isdefined(ynoisy) # generate (scaled) Poisson data ytrue = A * xtrue target_mean = 20 # aim for mean of 20 counts per ray scale = target_mean / average(ytrue) scatter_fraction = 0.1 # 10% uniform scatter for illustration scatter_mean = scatter_fraction * average(ytrue) # uniform for simplicity background = scatter_mean * ones(T,nx,nz,nview) ynoisy = rand.(Poisson.(scale * (ytrue + background))) / scale end jim(ynoisy, "$nview noisy projection views"; ncol=10) x0 = ones(T, nx, ny, nz) # initial uniform image niter = 30 if !@isdefined(xhat1) xhat1 = copy(x0) mlem!(xhat1, x0, ynoisy, background, A; niter) end; nrmse(x) = round(100 * norm(mid3(x) - mid3(xtrue)) / norm(mid3(xtrue)); digits=1) prompt() # jim(mid3(xhat1), "MLEM NRMSE=$(nrmse(xhat1))%") # display ML-EM reconstructed image cnn = Chain( Conv((3,3,3), 1 => 4, relu; stride = 1, pad = SamePad(), bias = true), Conv((3,3,3), 4 => 4, relu; stride = 1, pad = SamePad(), bias = true), Conv((3,3,3), 4 => 1; stride = 1, pad = SamePad(), bias = true), ) paramCount = sum([sum(length, params(layer)) for layer in cnn]) projectb(x) = A * x @adjoint projectb(x) = A * x, dy -> (A' * dy, ) backprojectb(y) = A' * y @adjoint backprojectb(y) = A' * y, dx -> (A * dx, ) function unsqueeze45(x) return unsqueeze(unsqueeze(x, 4), 5) end """ bregem(projectb, backprojectb, y, r, Asum, x, cnn, β; niter = 1) Backpropagatable regularized EM reconstruction with CNN regularization -`projectb`: backpropagatable forward projection -`backprojectb`: backpropagatable backward projection -`y`: projections -`r`: scatters -`Asum`: A' * 1 -`x`: current iterate -`cnn`: the CNN model -`β`: regularization parameter -`niter`: number of iteration for inner EM """ function bregem( projectb::Function, backprojectb::Function, y::AbstractArray, r::AbstractArray, Asum::AbstractArray, x::AbstractArray, cnn::Union{Chain,Function}, β::Real; niter::Int = 1, ) u = cnn(unsqueeze45(x))[:,:,:,1,1] Asumu = Asum - β * u Asumu2 = Asumu.^2 T = eltype(x) for iter = 1:niter eterm = backprojectb((y ./ (projectb(x) + r))) eterm_beta = 4 * β * (x .* eterm) x = max.(0, T(1/2β) * (-Asumu + sqrt.(Asumu2 + eterm_beta))) end return x end β = 1 Asum = A' * ones(T, nx, nz, nview) function loss(xrecon, xtrue) xiter1 = bregem(projectb, backprojectb, ynoisy, background, Asum, xrecon, cnn, β; niter = 1) xiter2 = bregem(projectb, backprojectb, ynoisy, background, Asum, xiter1, cnn, β; niter = 1) return sum(abs2, xiter2 - xtrue) end @show loss(xhat1, xtrue) if isinteractive() url = "" tmp = tempname(), tmp) cnn = BSON.load(tmp)[:cnn] else cnn = x -> x # fake "do-nothing CNN" for Literate/Documenter version end xiter1 = bregem(projectb, backprojectb, ynoisy, background, Asum, xhat1, cnn, β; niter = 1) xiter2 = bregem(projectb, backprojectb, ynoisy, background, Asum, xiter1, cnn, β; niter = 1) clim = (0,2) jim( jim(mid3(xtrue), "xtrue"; clim), jim(mid3(xhat1), "EM recon, NRMSE = $(nrmse(xhat1))%"; clim), jim(mid3(xiter1), "Iter 1, NRMSE = $(nrmse(xiter1))%"; clim), jim(mid3(xiter2), "Iter 2, NRMSE = $(nrmse(xiter2))%"; clim), )