%load ../utils/djl-imports %load ../utils/plot-utils import ai.djl.basicdataset.cv.classification.*; import ai.djl.metric.*; import org.apache.commons.lang3.ArrayUtils; Engine.getInstance().setRandomSeed(1111); NDManager manager = NDManager.newBaseManager(); SequentialBlock block = new SequentialBlock(); block .add(Conv2d.builder() .setKernelShape(new Shape(5, 5)) .optPadding(new Shape(2, 2)) .optBias(false) .setFilters(6) .build()) .add(Activation::sigmoid) .add(Pool.avgPool2dBlock(new Shape(5, 5), new Shape(2, 2), new Shape(2, 2))) .add(Conv2d.builder() .setKernelShape(new Shape(5, 5)) .setFilters(16).build()) .add(Activation::sigmoid) .add(Pool.avgPool2dBlock(new Shape(5, 5), new Shape(2, 2), new Shape(2, 2))) // Blocks.batchFlattenBlock() will transform the input of the shape (batch size, channel, // height, width) into the input of the shape (batch size, // channel * height * width) .add(Blocks.batchFlattenBlock()) .add(Linear .builder() .setUnits(120) .build()) .add(Activation::sigmoid) .add(Linear .builder() .setUnits(84) .build()) .add(Activation::sigmoid) .add(Linear .builder() .setUnits(10) .build()); float lr = 0.9f; Model model = Model.newInstance("cnn"); model.setBlock(block); Loss loss = Loss.softmaxCrossEntropyLoss(); Tracker lrt = Tracker.fixed(lr); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function) .optDevices(Engine.getInstance().getDevices(1)) // Single GPU .addEvaluator(new Accuracy()) // Model Accuracy .addTrainingListeners(TrainingListener.Defaults.basic()); Trainer trainer = model.newTrainer(config); NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 28, 28)); trainer.initialize(X.getShape()); Shape currentShape = X.getShape(); for (int i = 0; i < block.getChildren().size(); i++) { Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(new Shape[]{currentShape}); currentShape = newShape[0]; System.out.println(block.getChildren().get(i).getKey() + " layer output : " + currentShape); } int batchSize = 256; int numEpochs = Integer.getInteger("MAX_EPOCH", 10); double[] trainLoss; double[] testAccuracy; double[] epochCount; double[] trainAccuracy; epochCount = new double[numEpochs]; for (int i = 0; i < epochCount.length; i++) { epochCount[i] = (i + 1); } FashionMnist trainIter = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(batchSize, true) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); FashionMnist testIter = FashionMnist.builder() .optUsage(Dataset.Usage.TEST) .setSampling(batchSize, true) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); trainIter.prepare(); testIter.prepare(); public void trainingChapter6(ArrayDataset trainIter, ArrayDataset testIter, int numEpochs, Trainer trainer) throws IOException, TranslateException { double avgTrainTimePerEpoch = 0; Map evaluatorMetrics = new HashMap<>(); trainer.setMetrics(new Metrics()); EasyTrain.fit(trainer, numEpochs, trainIter, testIter); Metrics metrics = trainer.getMetrics(); trainer.getEvaluators().stream() .forEach(evaluator -> { evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream() .mapToDouble(x -> x.getValue().doubleValue()).toArray()); evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream() .mapToDouble(x -> x.getValue().doubleValue()).toArray()); }); avgTrainTimePerEpoch = metrics.mean("epoch"); trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss"); trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy"); testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy"); System.out.printf("loss %.3f," , trainLoss[numEpochs-1]); System.out.printf(" train acc %.3f," , trainAccuracy[numEpochs-1]); System.out.printf(" test acc %.3f\n" , testAccuracy[numEpochs-1]); System.out.printf("%.1f examples/sec \n", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9))); } trainingChapter6(trainIter, testIter, numEpochs, trainer); String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length]; Arrays.fill(lossLabel, 0, trainLoss.length, "train loss"); Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc"); Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length, trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc"); Table data = Table.create("Data").addColumns( DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))), DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))), StringColumn.create("lossLabel", lossLabel) ); render(LinePlot.create("", data, "epoch", "metrics", "lossLabel"), "text/html");