こーゆことをしたくて,こーゆjuliaコードを書いたのですが,出力は正しくても,計算時間が余分にかかってしまっていて,どう効率的に書き直すか苦戦しています...。(多分,現状は行列積を過剰な回数計算している?) pic.twitter.com/Hc3siXTVJT
— yudai.jl (@physics303) August 15, 2022
using StaticArrays, LinearAlgebra, BenchmarkTools
# 3次元配列を行列の配列に変換する
"Just specify that [an element of] the argument is an array of matrices"
struct MatrixArray end
"""
R = matrixarray(S::AbstractArray{T,3})
`R[k][i, j] = S[i, k, j]`
"""
matrixarray(S) = [SMatrix{size(S[:, i, :])...}(S[:, i, :]) for i in axes(S, 2)]
#matrixarray(S) = [S[:, i, :] for i in axes(S, 2)] # without StaticArrays
matrixarray
# 要素ごとに行列積を行う(元のコードを整理した)
reconst1(G) = reconst1(MatrixArray(), matrixarray.(G))
function reconst1(::MatrixArray, G)
[tr(prod(k -> G[k][CI[k]], eachindex(G)))
for CI in CartesianIndices(Tuple(length.(G)))]
end
reconst1 (generic function with 2 methods)
を求めるとき、別個に求めると行列積が6回になるが、 $G^{(1)}(1) G^{(2)}(1) G^{(3)}(1)$ を再利用すると4回に減らせる
# 行列積の途中結果を再利用する
reconst2(G) = reconst2(MatrixArray(), matrixarray.(G))
reconst2(::MatrixArray, G) = tr.(reduce(eachprod, G))
"""
C = eachprod(A, B)
`C[i, j, k] = A[i, j] * B[k]`
`A, B, C :: Array{Matrix}`
"""
eachprod(A, B) = A .* expanddim(B, A)
"""
Bx = expanddim(B, A)
`Bx = reshape(B, (1, 1, 1, m, n))` where `ndims(A) == 3`, `size(B) == (m, n)`
"""
expanddim(B, A) = reshape(B, (ntuple(_ -> 1, ndims(A))..., size(B)...))
expanddim
トレースを取る前の行列積は(1回分だけだが)回避できる
tr(A * B) == sum(transpose(A) .* B) == dot(vec(A'), vec(B))
# 行列積の最終段を省き、トレースを直接求める
reconst3(G) = reconst3(MatrixArray(), matrixarray.(G))
reconst3(::MatrixArray, G) = _reconst3(G...)
# Gs を真ん中で分けないと遅くなる
function _reconst3(Gs...)
h = length(Gs) ÷ 2
_reconst3(reduce(eachprod, Gs[begin:h]), reduce(eachprod, Gs[h+1:end]))
end
#_reconst3(G1, G2, Gs...) = _reconst3(eachprod(G1, G2), Gs...) # 遅い
_reconst3(G1, G2) = trprod.(G1, expanddim(G2, G1))
"""
trprod(A, B)
Returns `tr(A * B)`
"""
trprod(A, B) = dot(vec(A'), vec(B))
trprod
function genG(I, r)
r2 = circshift(r, -1)
randn.(r, I, r2)
end
function mybench(K::Int = 1; I = [3, 4, 5, 6], r = [2, 2, 3, 4])
I, r = K*I, K*r
@show I, r
G = genG(I, r)
@assert reconst1(G) ≈ reconst2(G) ≈ reconst3(G)
@btime reconst1($G)
@btime reconst2($G)
@btime reconst3($G)
return
end
mybench (generic function with 2 methods)
mybench(1)
(I, r) = ([3, 4, 5, 6], [2, 2, 3, 4]) 129.500 μs (3353 allocations: 211.48 KiB) 47.000 μs (478 allocations: 47.44 KiB) 45.700 μs (484 allocations: 34.08 KiB)
mybench(2)
(I, r) = ([6, 8, 10, 12], [4, 4, 6, 8]) 1.881 ms (46986 allocations: 9.27 MiB) 190.500 μs (913 allocations: 977.22 KiB) 150.300 μs (918 allocations: 159.94 KiB)
mybench(5)
(I, r) = ([15, 20, 25, 30], [10, 10, 15, 20]) 436.854 ms (1802205 allocations: 2.08 GiB) 96.159 ms (2213 allocations: 185.85 MiB) 25.709 ms (2218 allocations: 3.60 MiB)
without StaticArrays
mybench(1)
(I, r) = ([3, 4, 5, 6], [2, 2, 3, 4])
176.700 μs (1110 allocations: 124.03 KiB)
73.400 μs (470 allocations: 52.27 KiB)
33.100 μs (806 allocations: 39.53 KiB)
mybench(2)
(I, r) = ([6, 8, 10, 12], [4, 4, 6, 8])
3.487 ms (17329 allocations: 4.36 MiB)
1.200 ms (6346 allocations: 1.32 MiB)
558.400 μs (11751 allocations: 551.25 KiB)
mybench(5)
(I, r) = ([15, 20, 25, 30], [10, 10, 15, 20])
392.522 ms (675103 allocations: 873.91 MiB)
195.947 ms (232913 allocations: 209.23 MiB)
63.163 ms (451167 allocations: 20.41 MiB)