これ,深さ2の2つのテンソルGとT間の Multilinear Product っていう積の計算なんですけど,こんなに for loop 重ねるしかないのでしょうか.なんかすごい遅くなるようなことをやっている気がして不安.
— yudai.jl (@physics303) August 5, 2022
定義はこの論文のdef5.を参照した.https://t.co/svPibRSlQ7 pic.twitter.com/IqrlkkGYyZ
function mp1(G, T)
@assert size(G, 3) == size(T, 1) "rank error"
R, I, Q = size(G)
Q, J, M = size(T)
S = zeros(R, I*J, M)
for r = 1:R
for i = 1:I
for j = 1:J
for m = 1:M
s = 0.0
for q = 1:Q
s += G[r, i, q] * T[q, j, m]
end
S[r, (i-1)*J+j, m] = s
end
end
end
end
S
end
mp1 (generic function with 1 method)
function mp2(G, T)
@assert size(G, 3) == size(T, 1) "rank error"
R, I, Q = size(G)
Q, J, M = size(T)
G = reshape(G, (R*I, Q))
T = reshape(T, (Q, J*M))
S = G * T
S = reshape(S, (R, I, J, M))
S = permutedims(S, (1, 3, 2, 4))
reshape(S, (R, J*I, M))
end
mp2 (generic function with 1 method)
using BenchmarkTools
R, I, Q, J, M = 30, 40, 50, 60, 70
G, T = rand(R, I, Q), rand(Q, J, M)
@btime mp1($G, $T)
@btime mp2($G, $T)
mp1(G, T) ≈ mp2(G, T)
265.991 ms (2 allocations: 38.45 MiB) 22.683 ms (13 allocations: 76.90 MiB)
true