:label:sec_mlp_scratch
我们已经在数学上描述了多层感知机(MLP),现在让我们尝试自己实现一个多层感知机。为了与我们之前使用softmax回归( :numref:sec_softmax_scratch
)获得的结果进行比较,我们将继续使用Fashion-MNIST图像分类数据集( :numref:sec_fashion_mnist
)。
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/DataPoints.java
%load ../utils/Training.java
%load ../utils/Accumulator.java
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;
int batchSize = 256;
FashionMnist trainIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
FashionMnist testIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TEST)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
trainIter.prepare();
testIter.prepare();
回想一下,Fashion-MNIST中的每个图像由28×28=784个灰度像素值组成。所有图像共分为10个类别。忽略像素之间的空间结构,我们可以将每个图像视为具有784个输入特征和10个类的简单分类数据集。首先,我们将实现一个具有单隐藏层的多层感知机,它包含256个隐藏单元。注意,我们可以将这两个量都视为超参数。通常,我们选择2的若干次幂作为层的宽度。因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效。
我们用几个NDArray
来表示我们的参数。注意,对于每一层我们都要记录一个权重矩阵和一个偏置向量。跟以前一样,我们要为这些参数的损失的梯度分配内存。
int numInputs = 784;
int numOutputs = 10;
int numHiddens = 256;
NDManager manager = NDManager.newBaseManager();
NDArray W1 = manager.randomNormal(0, 0.01f, new Shape(numInputs, numHiddens), DataType.FLOAT32);
NDArray b1 = manager.zeros(new Shape(numHiddens));
NDArray W2 = manager.randomNormal(0, 0.01f, new Shape(numHiddens, numOutputs), DataType.FLOAT32);
NDArray b2 = manager.zeros(new Shape(numOutputs));
NDList params = new NDList(W1, b1, W2, b2);
for (NDArray param : params) {
param.setRequiresGradient(true);
}
为了确保我们知道一切是如何工作的,我们将使用最大值函数自己实现ReLU激活函数,而不是直接调用内置的relu
函数。
public NDArray relu(NDArray X){
return X.maximum(0f);
}
因为我们忽略了空间结构,所以我们使用reshape
将每个二维图像转换为一个长度为numInputs
的向量。我们只需几行代码就可以实现我们的模型。
public NDArray net(NDArray X) {
X = X.reshape(new Shape(-1, numInputs));
NDArray H = relu(X.dot(W1).add(b1));
return H.dot(W2).add(b2);
}
为了确保数值稳定性,同时由于我们已经从零实现过softmax函数( :numref:sec_softmax_scratch
),因此在这里我们直接使用高级API中的内置函数来计算softmax和交叉熵损失。回想一下我们之前在 :numref:subsec_softmax-implementation-revisited
中对这些复杂问题的讨论。我们鼓励感兴趣的读者查看Loss.SoftmaxCrossEntropyLoss
的源代码,以加深对实现细节的了解。
Loss loss = Loss.softmaxCrossEntropyLoss();
幸运的是,多层感知机的训练过程与softmax回归的训练过程完全相同。可以使用和第三章类似的代码来训练模型(参见 :numref:sec_softmax_scratch
),将迭代周期数设置为10,并将学习率设置为0.1.
int numEpochs = Integer.getInteger("MAX_EPOCH", 10);
float lr = 0.5f;
double[] trainLoss = new double[numEpochs];
double[] trainAccuracy = new double[numEpochs];
double[] testAccuracy = new double[numEpochs];
double[] epochCount = new double[numEpochs];
为了对学习到的模型进行评估,我们将在一些测试数据上应用这个模型。
float epochLoss = 0f;
float accuracyVal = 0f;
for (int epoch = 1; epoch <= numEpochs; epoch++) {
System.out.print("Running epoch " + epoch + "...... ");
// Iterate over dataset
for (Batch batch : trainIter.getData(manager)) {
NDArray X = batch.getData().head();
NDArray y = batch.getLabels().head();
try(GradientCollector gc = Engine.getInstance().newGradientCollector()) {
NDArray yHat = net(X); // net function call
NDArray lossValue = loss.evaluate(new NDList(y), new NDList(yHat));
NDArray l = lossValue.mul(batchSize);
accuracyVal += Training.accuracy(yHat, y);
epochLoss += l.sum().getFloat();
gc.backward(l); // gradient calculation
}
batch.close();
Training.sgd(params, lr, batchSize); // updater
}
trainLoss[epoch-1] = epochLoss/trainIter.size();
trainAccuracy[epoch-1] = accuracyVal/trainIter.size();
epochLoss = 0f;
accuracyVal = 0f;
// testing now
for (Batch batch : testIter.getData(manager)) {
NDArray X = batch.getData().head();
NDArray y = batch.getLabels().head();
NDArray yHat = net(X); // net function call
accuracyVal += Training.accuracy(yHat, y);
}
testAccuracy[epoch-1] = accuracyVal/testIter.size();
epochCount[epoch-1] = epoch;
accuracyVal = 0f;
System.out.println("Finished epoch " + epoch);
}
System.out.println("Finished training!");
String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];
Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");
Table data = Table.create("Data").addColumns(
DoubleColumn.create("epochCount", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
DoubleColumn.create("loss", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),
StringColumn.create("lossLabel", lossLabel)
);
render(LinePlot.create("", data, "epochCount", "loss", "lossLabel"),"text/html");
numHiddens
的值,并查看此超参数的变化对结果有何影响。确定此超参数的最佳值。