#Julia言語 あと、本質的に結果は同じでも、
— 黒木玄 Gen Kuroki (@genkuroki) August 22, 2022
size.(H) = [(100, 6, 8), (8, 7, 6), (6, 5, 5), (5, 4, 100)]
の場合と
size.(K) = [(5, 4, 100), (100, 6, 8), (8, 7, 6), (6, 5, 5)]
の場合では、後者の方が7倍程度速いので、この辺も要注意かも。https://t.co/H5pnAFSw43 pic.twitter.com/MhvlCzJkHA
using TensorOperations, BenchmarkTools
function expr_ijk(N)
is = Symbol.(:i, 1:N)
js = Symbol.(:j, 1:N)
ks = circshift(is, -1)
is, js, ks
end
function expr_Gprod(G, is, js, ks)
N = length(is)
Gs = Expr.(:ref, Ref(G), 1:N)
Gsijk = Expr.(:ref, Gs, is, js, ks)
Expr(:call, :*, Gsijk...)
end
function expr_substitution(A, G, N)
is, js, ks = expr_ijk(N)
Gprod = expr_Gprod(G, is, js, ks)
Aj = Expr(:ref, A, js...)
:($Aj := $Gprod)
end
expr_substitution (generic function with 1 method)
"""
# Example
`@multr_simple(G, 4)`
"""
macro multr_simple(G, N)
A, G = esc(gensym(:A)), esc(G)
subst = expr_substitution(A, G, N)
:(@tensor $subst)
end
@multr_simple
genG(is = [100, 8, 6, 5], js = [6, 7, 5, 4]) = randn.(is, js, circshift(is, -1))
G = genG()
@macroexpand1 @multr_simple(G, 4)
:(#= In[3]:9 =# @tensor ($(Expr(:escape, Symbol("##A#312"))))[j1, j2, j3, j4] := (($(Expr(:escape, :G)))[1])[i1, j1, i2] * (($(Expr(:escape, :G)))[2])[i2, j2, i3] * (($(Expr(:escape, :G)))[3])[i3, j3, i4] * (($(Expr(:escape, :G)))[4])[i4, j4, i1])
"""
# Example
`@multr_order(G, (i3, i1, i2, i4))`
"""
macro multr_order(G, S)
A, G = esc(gensym(:A)), esc(G)
N = length(S.args)
subst = expr_substitution(A, G, N)
:(@tensor $subst order=$S)
end
@macroexpand1 @multr_order(G, (i1, i2, i3, i4))
:(#= In[5]:10 =# @tensor ($(Expr(:escape, Symbol("##A#313"))))[j1, j2, j3, j4] := (($(Expr(:escape, :G)))[1])[i1, j1, i2] * (($(Expr(:escape, :G)))[2])[i2, j2, i3] * (($(Expr(:escape, :G)))[3])[i3, j3, i4] * (($(Expr(:escape, :G)))[4])[i4, j4, i1] order = (i1, i2, i3, i4))
# それぞれの次元の大きさをコスト関数として与えることにする
# とりあえず試してみただけで、このコスト関数がベストかは不明
"""
# Example
```julia-repl
julia> expr_cost1(:i1, 100)
:(i1 => 100χ)
```
"""
expr_cost1(isym, n) = Expr.(:call, :(=>), isym, Expr.(:call, :*, n, :χ))
function expr_cost(isize, jsize)
N = length(isize)
is, js, _ = expr_ijk(N)
costs = expr_cost1.(Iterators.flatten((is, js)), Iterators.flatten((isize, jsize)))
Expr(:tuple, costs...)
end
expr_cost((100, 8, 6, 5), (6, 7, 5, 4))
:((i1 => 100χ, i2 => 8χ, i3 => 6χ, i4 => 5χ, j1 => 6χ, j2 => 7χ, j3 => 5χ, j4 => 4χ))
"""
# Example
`@multr_opt(G, ((100, 8, 6, 5), (6, 7, 5, 4)))`
"""
macro multr_opt(G, S)
#Core.println("macro: ", S)
A, G = esc(gensym(:A)), esc(G)
isize, jsize = S.args
N = length(isize.args)
cost = expr_cost(isize.args, jsize.args)
subst = expr_substitution(A, G, N)
:(@tensoropt $cost $subst)
end
@macroexpand1 @multr_opt(G, ((100, 8, 6, 5), (6, 7, 5, 4)))
:(#= In[7]:13 =# @tensoropt (i1 => 100χ, i2 => 8χ, i3 => 6χ, i4 => 5χ, j1 => 6χ, j2 => 7χ, j3 => 5χ, j4 => 4χ) ($(Expr(:escape, Symbol("##A#314"))))[j1, j2, j3, j4] := (($(Expr(:escape, :G)))[1])[i1, j1, i2] * (($(Expr(:escape, :G)))[2])[i2, j2, i3] * (($(Expr(:escape, :G)))[3])[i3, j3, i4] * (($(Expr(:escape, :G)))[4])[i4, j4, i1])
# Generated function からは S::Tuple が渡される
macro multr_opt(G, S::Tuple)
#Core.println("macro: ", S)
A, G = esc(gensym(:A)), esc(G)
isize, jsize = S
N = length(isize)
cost = expr_cost(isize, jsize)
subst = expr_substitution(A, G, N)
#Core.println("out: ", cost, subst)
:(@tensoropt $cost $subst)
end
#Expr(:macrocall, Symbol("@multr_opt"), nothing, :G, ((100, 8, 6, 5), (6, 7, 5, 4))) |> eval |> size
@multr_opt (macro with 2 methods)
# マクロを直接呼び出した場合、 S は Expr になる
# S を評価してタプルに変換した上で Generated function に処理を渡す
# (マクロからマクロに渡そうとすると Expr になってしまう?)
macro multr_opt(G, S::Expr)
:(multr($(esc(G)), Val($S)))
end
@macroexpand1 @multr_opt(G, ((100, 8, 6, 5), (6, 7, 5, 4)))
:(Main.multr(G, Main.Val(((100, 8, 6, 5), (6, 7, 5, 4)))))
function getijsize(G)
N = length(G)
isize = ntuple(i -> size(G[i], 1), N)
jsize = ntuple(i -> size(G[i], 2), N)
(isize, jsize)
end
getijsize(G)
((100, 8, 6, 5), (6, 7, 5, 4))
@generated function multr(G, ::Val{S}) where S
#Core.println("generated: ", S)
:(@multr_opt(G, $S)) # S は Expr ではなくタプルとしてマクロが呼び出される
end
multr(G) = multr(G, Val(getijsize(G)))
multr (generic function with 2 methods)
multr(G) ≈ @multr_order(G, (i3, i1, i2, i4)) ≈ @multr_simple(G, 4)
true
# optimized by @tensoropt
# cost = length of each dimension
@benchmark multr($G)
BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 32.600 μs … 4.580 ms ┊ GC (min … max): 0.00% … 98.00% Time (median): 34.700 μs ┊ GC (median): 0.00% Time (mean ± σ): 37.581 μs ± 75.503 μs ┊ GC (mean ± σ): 3.41% ± 1.69% █▆▂ ▆███▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▁▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂ 32.6 μs Histogram: frequency by time 83 μs < Memory estimate: 15.38 KiB, allocs estimate: 128.
@optimalcontractiontree((i1 => 100χ, i2 => 8χ, i3 => 6χ, i4 => 5χ, j1 => 6χ, j2 => 7χ, j3 => 5χ, j4 => 4χ),
A[j1, j2, j3, j4] := G[1][i1, j1, i2] * G[2][i2, j2, i3] * G[3][i3, j3, i4] * G[4][i4, j4, i1])
(Any[Any[3, 2], Any[4, 1]], 33600*χ^6 + 104400*χ^5 + 0*χ^4 + 0*χ^3 + 0*χ^2 + 0*χ + 0)
# the order optimized above
@benchmark @multr_order($G, (i3, i1, i2, i4))
BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 32.500 μs … 4.551 ms ┊ GC (min … max): 0.00% … 98.22% Time (median): 34.300 μs ┊ GC (median): 0.00% Time (mean ± σ): 36.858 μs ± 76.250 μs ┊ GC (mean ± σ): 3.53% ± 1.70% ▁▇█▃ ████▆▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▁▁▂▂▂▁▂▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂ ▃ 32.5 μs Histogram: frequency by time 75 μs < Memory estimate: 15.25 KiB, allocs estimate: 126.
# default order of @tensor
@benchmark @multr_simple($G, 4)
BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 214.500 μs … 5.036 ms ┊ GC (min … max): 0.00% … 93.01% Time (median): 222.400 μs ┊ GC (median): 0.00% Time (mean ± σ): 258.575 μs ± 109.763 μs ┊ GC (mean ± σ): 0.35% ± 1.30% ▇█▄▃▁ ▁▁▁ ▁ ▁▁▁ ▁ ██████▇█████▇▇▅▅▄▄▆█▇▆███▆▆▇▅▇▆▅▅▆▅▅▆▆▆▆▆▆▇██████▇▆▆▇▇▇▇▆▆▆▇▆ █ 214 μs Histogram: log(frequency) by time 571 μs < Memory estimate: 13.31 KiB, allocs estimate: 98.
# maybe the same order as @multr_simple
@benchmark @multr_order($G, (i2, i3, i4, i1))
BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 214.400 μs … 4.769 ms ┊ GC (min … max): 0.00% … 91.25% Time (median): 222.600 μs ┊ GC (median): 0.00% Time (mean ± σ): 259.347 μs ± 108.520 μs ┊ GC (mean ± σ): 0.34% ± 1.30% ▆█▄▃▁ ▁▁▁ ▁ ▁▁ ▁ ██████▇█████▇▆▅▆▅▆▆█▇▆██▇▆▇▆▆▅▆▆▅▆▅▆▆▆▅▆▆▇██████▇▇▇▇▇▇▇▆▇▆▆▆▆ █ 214 μs Histogram: log(frequency) by time 579 μs < Memory estimate: 13.31 KiB, allocs estimate: 98.