class CrossEntropyWithSoftmax(Node):
def __init__(self, x, y):
Node.__init__(self, [x, y])
def _softmax(self, x):
exp_x = np.exp(x)
probs = exp_x / np.sum(exp_x, axis=1, keepdims=True)
return probs
# omit other functions ...
def forward(self):
probs = self._softmax(self.inbound_nodes[0].value)
y = self.inbound_nodes[1].value
self.cache[0] = probs
self.cache[1] = y
n = probs.shape[0]
logprobs = -np.log(probs[range(n), y])
self.value = np.sum(logprobs) / n
# we know this is a loss so we can be a bit less generic here
# should have 0 output nodes
def backward(self):
assert len(self.outbound_nodes) == 0
self.gradients = {n: np.zeros_like(n.value) for n in self.inbound_nodes}
# combined derivative of softmax and cross entropy
gprobs = np.copy(self.cache[0])
y = self.cache[1]
n = gprobs.shape[0]
gprobs[range(n), y] -= 1
gprobs /= n
# leave the gradient for the 2nd node all 0s, we don't care about the gradient
# for the labels
self.gradients[self.inbound_nodes[0]] = gprobs
Let $z$ and $x$ be n-element vectors. $x$ is the input to the softmax function, $z$ the output.
$$ z_{i} = \frac {e^{x_{i}}} {\sum_{j=1}^n e^{x_{j}}} $$For cross entropy, following with the same $z$ and introducing $y$, where $y$ is a one hot encoding of a categorical output ($[0, 0, ....., 1, 0, 0]$) the same length as $z$. The loss is
$$ loss = -\sum_i y_i log(z_i) $$It's straightforward to see the result will be 0 except when for the index of $y$ where $y_i = 1$. In a sense it's redundant to have a one hot encoding and instead we reduce the n-element vector to a single element with the value of the index where the one hot encoding would be 1. This way we can simply index the output of the softmax.
logprobs = -np.log(probs[range(n), y])
If we take a look at the softmax function we'll note we map from n inputs to n outputs and each input plays a role in computing the output. There are two general cases, we either differentiate with respect to $x_i$ (numerator) or any of the other inputs $x_j$, part of the denominator, where $j \ne i$.
We use the quotient rule to find the derivative of the softmax function. $\sum_{j=1}^n$ is abbreviated as $\sum_j$.
$$ \frac {\partial z_{i} } {\partial x_{i} } = \frac { e^{x_{i}} \sum_j e^{x_{j}} - e^{x_{i}} e^{x_{i}} } { (\sum_j e^{x_{j}})^2 } = \frac { e^{x_{i}} } { \sum_j e^{x_{j}} } - \frac { e^{x_{i}} } { \sum_j e^{x_{j}} } \frac { e^{x_{i}} } { \sum_j e^{x_{j}} } = z_{i} - (z_{i})^2 = z_{i} (1 - z_{i}) $$$$ \frac {\partial z_{i} } {\partial x_{j} } = -\frac { e^{x_{i}} e^{x_{j}} } { (\sum_j e^{x_{j}})^2 } = -\frac { e^{x_{i}} } { \sum_j e^{x_{j}} } \frac { e^{x_{j}} } { \sum_j e^{x_{j}} } = -z_{i} z_{j} $$Then it follows from following the outputs back to the an input $x_k$
$$ \frac {\partial z}{\partial x_k} = \sum_{p=1}^n \frac {\partial z_p} {\partial x_k} $$where in out those $n$ additions, 1 of them is where $x_k$ is the numerator and the other $n-1$ are when $x_k$ is part of the denominator. In total we end up with an $n$ by $n$ matrix of derivatives known as the Jacobian.
Remember how all the values in the Cross Entropy go to 0 except where we picked based on the target label? Well, as it turns out that means the gradient of the output edge to every index of $z$ that isn't picked is now 0.
$$ \frac {\partial z}{\partial x_k} = \sum_{p=1}^n \frac {\partial z_p} {\partial x_k} $$Is now either
$$ \frac {\partial z}{\partial x_k} = \frac {\partial z_k} {\partial x_k} = z_k (1 - z_k) \hspace{0.25in} or \hspace{0.25in} \frac {\partial z}{\partial x_k} = \frac {\partial z_p} {\partial x_k} = z_P z_k $$The derivative for the loss the easiest bit
$$ \frac {\partial loss} {\partial z_i} = \frac {-1} {z_i} $$Putting it all together
$$ \frac {\partial loss} {\partial x_i} = \frac {\partial z_i} {\partial x_i} \frac {\partial loss} {\partial z_i} = z_i (1 - z_i) \frac {-1} {z_i} = z_i - 1 $$$$ \frac {\partial loss} {\partial x_j} = \frac {\partial z_i} {\partial x_j} \frac {\partial loss} {\partial z_i} = -z_i z_j \frac {-1} {z_i} = z_j $$Here it is in code
gprobs[range(n), y] -= 1
gprobs /= n
gprobs
is a cache of the computed probabilities.