In this post, We will cover the basic tutorial while we use PyTorch. This is the summary of lecture CS285 "Deep Reinforcement Learning" from Berkeley.
This is a PyTorch Tutorial for UC Berkeley's CS285. There's already a bunch of great tutorials that you might want to check out, and in particular this tutorial. This tutorial covers a lot of the same material. If you're familiar with PyTorch basics, you might want to skip ahead to the PyTorch Advanced section.
First, let's import some things and define a useful plotting function
import matplotlib.pyplot as plt
import torch
import numpy as np
def plot(xs, ys, xlim=(-3, 3), ylim=(-3, 3)):
fig, ax = plt.subplots()
ax.plot(xs, ys, linewidth=5)
# ax.set_aspect('equal')
ax.grid(True, which='both')
ax.axhline(y=0, color='k')
ax.axvline(x=0, color='k')
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
Numpy arrays are objects that allow you to store and manipulation matrices.
shape = (2, 3)
x = np.zeros(shape)
y = np.ones(shape)
z = x + y
print(x)
print("+")
print(y)
print("=")
print(z)
[[0. 0. 0.] [0. 0. 0.]] + [[1. 1. 1.] [1. 1. 1.]] = [[1. 1. 1.] [1. 1. 1.]]
print(z.sum())
6.0
print(z[0, 1:])
[1. 1.]
PyTorch is build around tensors, which play a similar role as numpy arrays. You can do many of the same operations in PyTorch:
x = torch.zeros(shape)
y = torch.ones(shape)
z = x + y
print(x)
print("+")
print(y)
print("=")
print(z)
tensor([[0., 0., 0.], [0., 0., 0.]]) + tensor([[1., 1., 1.], [1., 1., 1.]]) = tensor([[1., 1., 1.], [1., 1., 1.]])
Many functions have alternate syntax that accomplish the same thing
torch.add(x, y)
tensor([[1., 1., 1.], [1., 1., 1.]])
z.min()
tensor(1.)
z[1:, 0]
tensor([1.])
torch.sum(z)
tensor(6.)
Function that reduce dimenions will by default reduce all dimensions unless a dimension is specified
torch.sum(z, dim=1)
tensor([3., 3.])
torch.sum(z, dim=0)
tensor([2., 2., 2.])
Like numpy, pytorch will try to broadcast operations
x = torch.ones((3, 1))
y = torch.ones((1, 3))
z = x + y
print(x)
print("+")
print(y)
print("=")
print(z)
tensor([[1.], [1.], [1.]]) + tensor([[1., 1., 1.]]) = tensor([[2., 2., 2.], [2., 2., 2.], [2., 2., 2.]])
Operations that end with an underscore denote in-place functions. Use these sparingly as they easily lead to bugs.
print(z)
z.zero_()
print(z)
tensor([[2., 2., 2.], [2., 2., 2.], [2., 2., 2.]]) tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
z.add_(5)
print(z)
tensor([[5., 5., 5.], [5., 5., 5.], [5., 5., 5.]])
x_np = np.random.randn(*shape)
print(x_np)
[[-0.13298262 0.72862306 0.95932965] [ 2.24109395 0.53208287 -0.42554932]]
numpy -> pytorch is easy
x = torch.from_numpy(x_np)
print(x)
tensor([[-0.1330, 0.7286, 0.9593], [ 2.2411, 0.5321, -0.4255]], dtype=torch.float64)
By default, numpy arrays are float64. You'll probably want to convert arrays to float32, as most tensors in pytorch are float32.
x = torch.from_numpy(x_np).to(torch.float32)
print(x)
tensor([[-0.1330, 0.7286, 0.9593], [ 2.2411, 0.5321, -0.4255]])
pytorch -> numpy is also easy
print(x.numpy())
[[-0.13298263 0.7286231 0.95932966] [ 2.2410939 0.53208286 -0.42554933]]
torch.cuda.is_available()
True
The code below errors out because both tensors need to be on the same device.
device = torch.device("cuda")
x = torch.zeros(shape)
y = torch.ones(shape, device=device)
z = x + y
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-20-7738a425635e> in <module> 2 x = torch.zeros(shape) 3 y = torch.ones(shape, device=device) ----> 4 z = x + y RuntimeError: expected device cpu but got device cuda:0
You can move a tensor to the GPU by using the to
function.
x = x.to(device)
z = x + y
print(z)
This code also errors out, because you can't convert tensors on a GPU into numpy arrays directly.
z.numpy()
array([[5., 5., 5.], [5., 5., 5.], [5., 5., 5.]], dtype=float32)
First you need to move them to the CPU.
z_cpu = z.to('cpu')
z_cpu.numpy()
array([[5., 5., 5.], [5., 5., 5.], [5., 5., 5.]], dtype=float32)
from cs285.infrastructure import pytorch_util as ptu
ptu.from_numpy(x_np)
ptu.to_numpy(x)
--------------------------------------------------------------------------- ModuleNotFoundError Traceback (most recent call last) <ipython-input-23-303d2b6f387f> in <module> ----> 1 from cs285.infrastructure import pytorch_util as ptu 2 ptu.from_numpy(x_np) 3 ptu.to_numpy(x) ModuleNotFoundError: No module named 'cs285'
xs = torch.linspace(-3, 3, 100)
ys = torch.relu(xs)
plot(xs.numpy(), ys.numpy())
xs = torch.linspace(-3, 3, 100)
ys = torch.tanh(xs)
plot(xs.numpy(), ys.numpy())
xs = torch.linspace(-3, 3, 100)
ys = torch.selu(xs)
plot(xs.numpy(), ys.numpy())
Given some loss function $$L(\vec x, \vec y) = ||2 \vec x + \vec y||_2^2$$ we want to evaluate $$\frac{\partial L}{\partial \vec x}$$ and $$\frac{\partial L}{\partial \vec y}$$
shape = (3, )
x = torch.tensor([1., 2, 3], requires_grad=True)
y = torch.ones(shape, requires_grad=True)
PyTorch makes this easy by having tensors keep track of their data...
x.data
tensor([1., 2., 3.])
...and their gradient:
print(x.grad)
None
However, right now x
has no gradient because it does not know what loss it must be differentiated with respect to.
Below, we define the loss.
loss = ((2 * x + y)**2).sum()
print(loss)
tensor(83., grad_fn=<SumBackward0>)
And we perform back-propagation by calling backward
on it.
loss.backward()
Now we see that the gradients are populated!
print(x.grad)
print(y.grad)
tensor([12., 20., 28.]) tensor([ 6., 10., 14.])
Gradients accumulate, os if you call backwards twice...
loss = ((2 * x + y)**2).sum()
loss.backward()
...you'll get twice the gradient.
print(x.grad)
print(y.grad)
tensor([24., 40., 56.]) tensor([12., 20., 28.])
This accumulation makes it easy to add gradients from different losses, which might not even use the same parameters. For example, this loss is only a function of x
....
other_loss = (x**2).sum()
other_loss.backward()
...and so only x.grad
changes.
print(x.grad)
print(y.grad)
tensor([26., 44., 62.]) tensor([12., 20., 28.])
If you don't specify required_grad=True
, the gradient will always be None
.
shape = (3, )
x = torch.tensor([1., 2, 3], requires_grad=True)
y = torch.ones(shape)
loss = ((2 * x + y)**2).sum()
loss.backward()
print(x.grad)
print(y.grad)
tensor([12., 20., 28.]) None
You can turn required_grad
back on after initializing a tensor.
shape = (3, )
x = torch.tensor([1., 2, 3], requires_grad=True)
y = torch.ones(shape)
y.requires_grad = True
loss = ((2 * x + y)**2).sum()
loss.backward()
print(x.grad)
print(y.grad)
tensor([12., 20., 28.]) tensor([ 6., 10., 14.])
You can cut a gradient by calling y.detach()
, which will return a new tensor with required_grad=False
. Note that detach
is not an in-place operation!
shape = (3, )
x = torch.tensor([1., 2, 3], requires_grad=True)
y = torch.ones(shape, requires_grad=True)
y_detached = y.detach()
loss = ((2 * x + y)**2).sum()
loss.backward()
print(x.grad)
print(y.grad)
tensor([12., 20., 28.]) tensor([ 6., 10., 14.])
print(y_detached)
tensor([1., 1., 1.])
shape = (3, )
x = torch.tensor([1., 2, 3], requires_grad=True)
y = torch.ones(shape, requires_grad=True)
z = 2 * x
z.required_grad = True
loss = ((2 * z + y)**2).sum()
loss.backward()
print(x.grad)
print(y.grad)
tensor([ 40., 72., 104.]) tensor([10., 18., 26.])
z.grad
C:\Users\kcsgo\anaconda3\lib\site-packages\torch\tensor.py:746: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
torch.from_numpy(np.array([1,2]))
tensor([1, 2], dtype=torch.int32)
Any difference between p.data.add_(-0.001 + p.grad) and p.data+= -0.001 + p.grad?
nn.Modules
represent the building blocks of a computation graph.
For example, in typical pytorch code, each convolution block above is its own module, each fully connected block is a module, and the whole network itself is also a module.
Modules can contain modules within them.
All the classes inside of torch.nn
are instances nn.Modules
.
Below is an example definition of a module:
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, input_size, output_size):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, 32)
self.fc2 = nn.Linear(32, 32)
self.fc3 = nn.Linear(32, output_size)
def forward(self, x):
pdb.set_trace()
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def hook(self, gradient):
return 2*gradient
The main function that you need to implement is the forward
function.
Otherwise, it's a normal Python object:
net = Net(input_size=1, output_size=1)
print(net)
Net( (fc1): Linear(in_features=1, out_features=32, bias=True) (fc2): Linear(in_features=32, out_features=32, bias=True) (fc3): Linear(in_features=32, out_features=1, bias=True) )
Here we create some dummy input. The first dimension will be the batch dimension.
x = torch.linspace(-5, 5, 100).view(100, 1)
print(x.shape)
torch.Size([100, 1])
To evaluate a neural network on some input, you pass an input through a module by calling it directly. In particular, don't call net.forward(x)
.
y = net(x)
print(y.shape)
> <ipython-input-45-7c9b43d6b070>(16)forward() -> x = F.relu(self.fc1(x))
tensor([[-5.0000], [-4.8990], [-4.7980], [-4.6970], [-4.5960], [-4.4949], [-4.3939], [-4.2929], [-4.1919], [-4.0909], [-3.9899], [-3.8889], [-3.7879], [-3.6869], [-3.5859], [-3.4848], [-3.3838], [-3.2828], [-3.1818], [-3.0808], [-2.9798], [-2.8788], [-2.7778], [-2.6768], [-2.5758], [-2.4747], [-2.3737], [-2.2727], [-2.1717], [-2.0707], [-1.9697], [-1.8687], [-1.7677], [-1.6667], [-1.5657], [-1.4646], [-1.3636], [-1.2626], [-1.1616], [-1.0606], [-0.9596], [-0.8586], [-0.7576], [-0.6566], [-0.5556], [-0.4545], [-0.3535], [-0.2525], [-0.1515], [-0.0505], [ 0.0505], [ 0.1515], [ 0.2525], [ 0.3535], [ 0.4545], [ 0.5556], [ 0.6566], [ 0.7576], [ 0.8586], [ 0.9596], [ 1.0606], [ 1.1616], [ 1.2626], [ 1.3636], [ 1.4646], [ 1.5657], [ 1.6667], [ 1.7677], [ 1.8687], [ 1.9697], [ 2.0707], [ 2.1717], [ 2.2727], [ 2.3737], [ 2.4747], [ 2.5758], [ 2.6768], [ 2.7778], [ 2.8788], [ 2.9798], [ 3.0808], [ 3.1818], [ 3.2828], [ 3.3838], [ 3.4848], [ 3.5859], [ 3.6869], [ 3.7879], [ 3.8889], [ 3.9899], [ 4.0909], [ 4.1919], [ 4.2929], [ 4.3939], [ 4.4949], [ 4.5960], [ 4.6970], [ 4.7980], [ 4.8990], [ 5.0000]])
11 self.fc2 = nn.Linear(32, 32) 12 self.fc3 = nn.Linear(32, output_size) 13 14 def forward(self, x): 15 pdb.set_trace() 16 -> x = F.relu(self.fc1(x)) 17 x = F.relu(self.fc2(x)) 18 x = self.fc3(x) 19 return x 20 21 def hook(self, gradient):
torch.Size([100, 1])
> <ipython-input-45-7c9b43d6b070>(17)forward() -> x = F.relu(self.fc2(x))
> <ipython-input-45-7c9b43d6b070>(18)forward() -> x = self.fc3(x)
tensor([[1.8007, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.2302], [1.7679, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.2086], [1.7351, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.1871], ..., [0.0000, 0.0000, 1.0427, ..., 0.0000, 0.0000, 1.0551], [0.0000, 0.0000, 1.0706, ..., 0.0000, 0.0000, 1.0787], [0.0000, 0.0000, 1.0985, ..., 0.0000, 0.0000, 1.1023]], grad_fn=<ReluBackward0>)
torch.Size([100, 32])
torch.Size([100, 1])
Let's visualize what the networks looks like.
plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))
The network keeps track of all the parameters and gradients!
print(net.fc1.bias.grad)
None
In the __init__
function, any variable that you assign to self
that is also a module will be automatically added as a sub-module. The parameters of a module (and all sub-modules) can be accessed through the parameters()
function:
for p in net.parameters():
print(p.shape)
torch.Size([32, 1]) torch.Size([32]) torch.Size([32, 32]) torch.Size([32]) torch.Size([1, 32]) torch.Size([1])
WARNING: if you want to have a list of modules use
def __init__(self, network1, network2):
self.list = nn.ModuleList([network1, network 2])
and not
def __init__(self, network1, network2):
self.list = [network1, network 2]
In the later case, network1
and network2
won't be added as sub-modules.
The output of the module is just a tensor. We can perform operations on the tensor like before to automatically compute derivatives. For example, below, we minimize the sum-of-squares of the output.
loss = (y**2).sum()
loss.backward()
We can manually update the parameters by adding the gradient (times a negative learning rate) and zero'ing out the gradients to prevent gradient accumulation.
for p in net.parameters():
p.data.add_(-0.001 * p.grad)
p.grad.data.zero_()
And we can do this in a loop to train our network!
class Net2(nn.Module):
def __init__(self, input_size, output_size):
super(Net2, self).__init__()
self.fc1 = nn.Linear(input_size, 32)
self.fc2 = nn.Linear(32, 32)
self.fc3 = nn.Linear(32, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def hook(self, gradient):
return 2*gradient
net = Net2(input_size=1, output_size=1)
for _ in range(100):
y = net(x)
loss = (y**2).sum()
loss.backward()
for p in net.parameters():
p.data.add_(- 0.001 * p.grad)
p.grad.data.zero_()
Sure enough, our network learns to set everything to zero.
plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))
PyTorch has a bunch of built in loss functions, which are just other modules that you can pass your data through.
y_target = torch.sin(x)
loss_fn = nn.SmoothL1Loss()
for _ in range(1000):
y = net(x)
loss = loss_fn(y, y_target)
loss.backward()
for p in net.parameters():
p.data.add_(- 0.001 * p.grad)
p.grad.data.zero_()
plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))
We can use more fancy optimizers with the optim
package.
from torch import optim
net = Net2(input_size=1, output_size=1)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
x = torch.linspace(-5, 5, 100).view(-1, 1)
y = net(x)
y_target = torch.sin(x)
loss_fn = nn.MSELoss()
Here's the network before training
plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))
and here's how you can use the optimize to train the network.
Note that we call zero_grad
before calling loss.backward()
, and then we just call optimizer.step()
. This step
function will take care of updating all the parameters that were passed to that optimizer's constructor.
for _ in range(100):
y = net(x)
loss = loss_fn(y, y_target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
And we see that this trained a network quite well
plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))
We're going to train two separate neural networks to solve a prediction task: $$f_\theta(x) \approx y$$ First we generate the data
N = 100
d = 1
X = np.random.randn(N, 1)
Y = X * 2 + 3 + np.random.randn(N, 1)
plt.scatter(X, Y)
plt.show()
print(X.shape, Y.shape)
(100, 1) (100, 1)
Next we convert that data to pytorch
X_pt = torch.from_numpy(X).to(torch.float32)
Y_pt = torch.from_numpy(Y).to(torch.float32)
loss_fn = nn.MSELoss()
Define the training loop
def train(net: nn.Module):
optimizer = optim.SGD(net.parameters(), lr=1e-2)
losses = []
for _ in range(100):
Y_hat_pt = net(X_pt)
loss = loss_fn(Y_hat_pt, Y_pt)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.detach().numpy())
return np.array(losses)
Let's test two different networks
linear_network = nn.Linear(1, 1)
linear_losses = train(linear_network)
non_linear_network = Net2(1, 1)
non_linear_losses = train(non_linear_network)
and plot the losses and predictions.
plt.figure()
t = np.arange(len(linear_losses))
plt.plot(t, linear_losses, t, non_linear_losses)
plt.legend(['linear', 'non_linear'])
plt.show()
x_probe = torch.linspace(X.min(), X.max(), 300).reshape(-1, 1)
y_lin = linear_network(x_probe).detach().numpy()
y_non_lin = non_linear_network(x_probe).detach().numpy()
plt.scatter(x_probe.numpy(), y_lin, s=3)
plt.scatter(x_probe.numpy(), y_non_lin, s=3)
plt.scatter(X, Y)
plt.legend(['linear', 'non_linear', 'data'])
plt.show()
PyTorch has a very convenient distributions package.
from torch import distributions
You create distributions by passing the parameters of the distribution.
mean = torch.zeros(1, requires_grad=True)
std = torch.ones(1, requires_grad=True)
gaussian = distributions.Normal(mean, std)
These distributions are instances of the more general Distribution
class, which you can read more about here.
print(gaussian)
print(isinstance(gaussian, distributions.Distribution))
Normal(loc: tensor([0.], requires_grad=True), scale: tensor([1.], requires_grad=True)) True
sample = gaussian.sample((1,))
print(sample)
tensor([[0.9333]])
gaussian.log_prob(sample)
tensor([[-1.3545]], grad_fn=<SubBackward0>)
The log probability depends on the the parameters of the distribution. So, calling backward
on a loss that depends on log_prob
will back-propagate gradients into the parmaeters of the distribution.
NOTE: this won't back-propagate through the samples (the "reparatermization trick''), unless you use rsample
, which is only implemented for some distributions.
at 5:18, we can set loss to -log_prob to maximize the probability of an event. That makes sense to me because the higher prob is, the smaller loss would be. Now if we want to incorporate reward into this loss function, i think usually people do loss=-log_probreward? but that means the higher the reward is, the higher the loss is. Is this because we want to push the prob of the event to be higher with a higher loss when we have a higher reward? or should we have lower loss with higher reward (loss=log_probreward)?
loss = - gaussian.log_prob(sample).sum()
loss.backward()
mean.grad
tensor([-0.9333])
The distributions also support batch-operations. In this case, all the operations (sample
, log_prob
, etc.) are batch-wise.
mean = torch.zeros(10)
std = torch.ones(10)
gaussian = distributions.Normal(mean, std)
gaussian
Normal(loc: torch.Size([10]), scale: torch.Size([10]))
sample = gaussian.sample((1,))
print(sample)
tensor([[-1.4607, -0.5094, -1.2975, -1.6653, -0.1941, 1.3166, 0.5912, -0.6193, -0.2724, -0.4517]])
gaussian.log_prob(sample)
tensor([[-1.9857, -1.0487, -1.7606, -2.3055, -0.9378, -1.7857, -1.0937, -1.1107, -0.9560, -1.0210]])
There are other distributions
mean = torch.zeros(2)
covariance = torch.tensor(
[[1, 0.8],
[0.8, 1]]
)
gaussian = distributions.MultivariateNormal(mean, covariance)
gaussian.sample((1,))
tensor([[-0.3727, -0.0504]])
samples = gaussian.sample((500,))
plt.scatter(samples[:, 0].numpy(), samples[:, 1].numpy())
<matplotlib.collections.PathCollection at 0x232c1eeb248>
NOTE: if you want to use a batch of MultivariateNormal
distributions, you'll need to construct a batch of covariance matrices (i.e. shape [BATCH_SIZE, DIM, DIM]
).
from torch import distributions
Another useful distribution is the categorical distribution.
probs = torch.tensor([0.1, 0.2, 0.7])
dist = distributions.Categorical(probs=probs)
sample = dist.sample([20])
print(sample)
tensor([2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 1, 2, 2, 2])
dist.log_prob(sample)
tensor([-0.3567, -0.3567, -0.3567, -1.6094, -1.6094, -0.3567, -0.3567, -0.3567, -0.3567, -0.3567, -2.3026, -0.3567, -0.3567, -0.3567, -0.3567, -0.3567, -1.6094, -0.3567, -0.3567, -0.3567])
Typically, your network will output parameters of a distribution
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, input_size, output_size):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, 32)
self.fc2 = nn.Linear(32, 32)
self.fc3 = nn.Linear(32, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
mean_network = Net(1, 1)
x = torch.randn(100, 1)
mean = mean_network(x)
distribution = distributions.Normal(x, scale=1)
If you want, your nn.Module can return a distribution in the forward
function!
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, input_size, output_size):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, 32)
self.fc2 = nn.Linear(32, 32)
self.fc3 = nn.Linear(32, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return distributions.Normal(x, scale=1)
distribution_network = Net(1, 1)
x = torch.randn(100, 1)
distribution = distribution_network(x)
gather
¶This function will be useful for the DQN assignment. It allows you to index into arrays in a batch.
x = torch.arange(6).reshape(2, 3)
y = torch.tensor([0, 1]).reshape(2, 1)
print(x)
print(y)
print(torch.gather(x, 0, y))
tensor([[0, 1, 2], [3, 4, 5]]) tensor([[0], [1]])
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-114-b33e25174688> in <module> 3 print(x) 4 print(y) ----> 5 print(torch.gather(x, 0, y)) RuntimeError: Size does not match at dimension 1 get 3 vs 1
x = torch.arange(6).reshape(2, 3)
y = torch.tensor([0, 1, 0]).reshape(1, 3)
print(x)
print(y)
print(torch.gather(x, 0, y))
tensor([[0, 1, 2], [3, 4, 5]]) tensor([[0, 1, 0]]) tensor([[0, 4, 2]])
For a 3-D tensor the output is specified by::
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2