Notebook
def index_select_ND(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor: """Selects the message features from source corresponding to the atom or bond indices in index. """ index_size = index.size() # (num_atoms/num_bonds, max_num_bonds) suffix_dim = source.size()[1:] # (hidden_size,) final_size = index_size + suffix_dim # (num_atoms/num_bonds, max_num_bonds, hidden_size) target = source.index_select(dim=0, index=index.view(-1)) # (num_atoms/num_bonds * max_num_bonds, hidden_size) target = target.view(final_size) # (num_atoms/num_bonds, max_num_bonds, hidden_size) return target for depth in range(3): # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND(message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden message = W_h(message) message = act_func(input + message) # num_bonds x hidden_size