#Julia言語 generated function 版です。マクロ版よりも使い易いと思います。
— 黒木玄 Gen Kuroki (@genkuroki) August 22, 2022
ただし、こういうことに慣れているわけでもないので、不適切な書き方をしている可能性があります。
もしも不適切な書き方をしていたら教えて下さい>Juliaの偉い人達https://t.co/H5pnAFSw43 pic.twitter.com/dGWsIEOOZo
Generated function やマクロ (hygiene など) に間違いがあるかもしれない。
using TensorOperations, BenchmarkTools
function expr_Aj_Gprod(A, G, N)
is = Symbol.(:i, 1:N)
js = Symbol.(:j, 1:N)
ks = circshift(is, -1)
Gs = Expr.(:ref, Ref(G), 1:N)
Gsijk = Expr.(:ref, Gs, is, js, ks)
Gprod = Expr(:call, :*, Gsijk...)
Aj = Expr(:ref, A, js...)
Aj, Gprod
end
expr_Aj_Gprod(:A, :G, 4)
(:(A[j1, j2, j3, j4]), :((G[1])[i1, j1, i2] * (G[2])[i2, j2, i3] * (G[3])[i3, j3, i4] * (G[4])[i4, j4, i1]))
macro multr(G, N)
A, G = esc(gensym(:A)), esc(G)
#Core.println("macro: ", G) # Generated function では初回実行時にのみ呼び出されている
Aj, Gprod = expr_Aj_Gprod(A, G, N)
:(@tensor $Aj := $Gprod)
end
@macroexpand1 @multr(H, 4)
:(#= In[3]:5 =# @tensor ($(Expr(:escape, Symbol("##A#312"))))[j1, j2, j3, j4] := (($(Expr(:escape, :H)))[1])[i1, j1, i2] * (($(Expr(:escape, :H)))[2])[i2, j2, i3] * (($(Expr(:escape, :H)))[3])[i3, j3, i4] * (($(Expr(:escape, :H)))[4])[i4, j4, i1])
# 分岐があらかじめ列挙できる程度の数ならば良い選択肢だと思う
for N in 2:10
@eval multr_pre(G, ::Val{$N}) = @multr(G, $N)
end
multr_pre(G) = multr_pre(G, Val(length(G)))
multr_pre (generic function with 10 methods)
# 呼ばれてから構文を評価する
function multr_live(G, ::Val{N}) where N
eval(
quote
let G = $G # 配列を変数に入れ直したほうがわずかに速い
@multr(G, $N)
end
end
)
end
multr_live(G) = multr_live(G, Val(length(G)))
multr_live (generic function with 2 methods)
# @multr が gensym を含むため pure とは言えない?(引数の型からコンパイル済みのコードへの対応に注目した場合)
# TensorOperations.jl でも gensym を使っているようなので、 generated function を使うなら気にしても仕方がない?
@generated function multr_gen(G, ::Val{N}) where N
quote
#Core.println("generated: ", G) # シンボルではなく、配列が渡されている
@multr(G, $N)
end
end
multr_gen(G) = multr_gen(G, Val(length(G)))
multr_gen (generic function with 2 methods)
genG(is = [3, 4, 5, 6], js = [2, 3, 4, 5]) = randn.(is, js, circshift(is, -1))
# ミス検出のため変数名を G ではなく H にしている
H = genG();
@multr(H, 4) == multr_pre(H) == multr_live(H) == multr_gen(H)
true
@benchmark @multr($H, 4)
BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 16.800 μs … 98.900 μs ┊ GC (min … max): 0.00% … 0.00% Time (median): 17.800 μs ┊ GC (median): 0.00% Time (mean ± σ): 18.360 μs ± 2.097 μs ┊ GC (mean ± σ): 0.00% ± 0.00% ▃▇██▇▅▃▁ ▂▂▃▅█████████▇▆▆▅▆▆▆▅▅▄▄▄▃▄▃▃▁▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂ ▄ 16.8 μs Histogram: frequency by time 22.5 μs < Memory estimate: 8.16 KiB, allocs estimate: 104.
@benchmark multr_pre($H)
BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 17.000 μs … 3.966 ms ┊ GC (min … max): 0.00% … 98.55% Time (median): 18.000 μs ┊ GC (median): 0.00% Time (mean ± σ): 19.059 μs ± 39.600 μs ┊ GC (mean ± σ): 2.05% ± 0.99% ▁▅▇██▇▅▂▁ ▂▃▆█████████▆▅▅▅▆▆▅█▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁ ▃ 17 μs Histogram: frequency by time 23.2 μs < Memory estimate: 7.69 KiB, allocs estimate: 98.
@benchmark multr_live($H) # 遅い
BenchmarkTools.Trial: 2167 samples with 1 evaluation. Range (min … max): 1.825 ms … 9.972 ms ┊ GC (min … max): 0.00% … 75.46% Time (median): 2.075 ms ┊ GC (median): 0.00% Time (mean ± σ): 2.299 ms ± 888.111 μs ┊ GC (mean ± σ): 1.35% ± 5.01% ▆▆██▆▅▅▃▂▁ ▁ ██████████▇▆▆▄▄▄▁▄▄▁▄▄▁▁▁▄▁▁▁▁▄▅▄▅▄▁▁▅▆▆▆▆▄▆▄▇▇▆▆▅▅▄▆▆▆▅▅▆▄ █ 1.82 ms Histogram: log(frequency) by time 6.09 ms < Memory estimate: 243.52 KiB, allocs estimate: 4890.
@benchmark multr_gen($H)
BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 17.100 μs … 4.916 ms ┊ GC (min … max): 0.00% … 98.70% Time (median): 18.200 μs ┊ GC (median): 0.00% Time (mean ± σ): 19.288 μs ± 49.047 μs ┊ GC (mean ± σ): 2.52% ± 0.99% ▂▅▇██▆▃▂ ▂ ▁▂▃▆█████████▆█▅▅▅▅▄▄▄▃▃▃▂▂▂▂▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃ 17.1 μs Histogram: frequency by time 23.4 μs < Memory estimate: 7.69 KiB, allocs estimate: 98.