using DFTK
using LinearAlgebra

function fixed_point_iteration(F, ρ0, info0; maxiter, tol=1e-10)
    # F:        The SCF step function
    # ρ0:       The initial guess density
    # info0:    The initial metadata
    # maxiter:  The maximal number of iterations to be performed

    ρ = ρ0
    info = info0
    for n = 1:maxiter
        Fρ, info = F(ρ, info)
        # If the change is less than the tolerance, break iteration.
        if norm(Fρ - ρ) < tol
            break
        end
        ρ = Fρ
    end

    # Return some stuff DFTK needs ...
    (; fixpoint=ρ, info)
end;

using AtomsBuilder
using PseudoPotentialData

function aluminium_setup(repeat=1; Ecut=13.0, kgrid=[2, 2, 2])
    al_supercell  = bulk(:Al; cubic=true) * (repeat, 1, 1)
    pd_pbe_family = PseudoFamily("dojo.nc.sr.pbe.v0_4_1.standard.upf")
    model = model_DFT(al_supercell;
                      functionals=PBE(), temperature=1e-3, symmetries=false,
                      pseudopotentials=pd_pbe_family)
    PlaneWaveBasis(model; Ecut, kgrid)
end;

self_consistent_field(aluminium_setup(1); solver=fixed_point_iteration, damping=1.0,
                                          maxiter=30, mixing=SimpleMixing());

function anderson_iteration(F, ρ0, info0; maxiter)
    # F:        The SCF step function
    # ρ0:       The initial guess density
    # info0:    The initial metadata
    # maxiter:  The maximal number of iterations to be performed

    info = info0
    ρ  = ρ0
    ρs = []
    Rs = []
    for n = 1:maxiter
        Fρ, info = F(ρ, info)
        if info.converged
            break
        end
        Rρ = Fρ - ρ

        ρnext = vec(ρ) .+ vec(Rρ)
        if !isempty(Rs)
            M = hcat(Rs...) .- vec(Rρ)
            βs = -(M \ vec(Rρ))

            for (iβ, β) in enumerate(βs)
                ρnext .+= β .* (ρs[iβ] .- vec(ρ) .+ Rs[iβ] .- vec(Rρ))
            end
        end

        push!(ρs, vec(ρ))
        push!(Rs, vec(Rρ))
        ρ = reshape(ρnext, size(ρ0)...)
    end

    # Return some stuff DFTK needs ...
    (; fixpoint=ρ, info)
end;

using Plots

χ0_metal(q, kTF=1) = -kTF^2 / 4π
χ0_dielectric(q, εr, C₀=1-εr, kTF=1) = C₀*q^2 / (4π * (1 - C₀*q^2/kTF^2))
χ0_GaAs(q) = χ0_dielectric(q, 14.9)
χ0_SiO2(q) = χ0_dielectric(q, 1.5)
ε(χ0, q) = (1 - 4π/q^2 * χ0(q))

p = plot(xlims=(1e-2, 1.5), ylims=(0, 16), xlabel="q", ylabel="ε(q)", lw=4)
plot!(p, x -> ε(χ0_metal, x), label="aluminium (Al)")
plot!(p, x -> ε(χ0_GaAs, x),  label="gallium arsenide (GaAs)", ls=:dash)
plot!(p, x -> ε(χ0_SiO2, x),  label="silica (SiO₂)", ls=:dashdot)