This is the third and final tutorial of our beginner tutorial series that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to execute your image classification model for a production system.
In the previous tutorial, you successfully trained your model. Now, we will learn how to implement a Translator
to convert between POJO and NDArray
as well as a Predictor
to run inference.
This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the Jupyter README.
// Add the snapshot repository to get the DJL snapshot artifacts
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/
// Add the maven dependencies
%maven ai.djl:api:0.4.0
%maven ai.djl:repository:0.4.0
%maven ai.djl:model-zoo:0.4.0
%maven ai.djl.mxnet:mxnet-engine:0.4.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.4.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0
// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md
// for more MXNet library selection options
%maven ai.djl.mxnet:mxnet-native-auto:1.6.0
import java.awt.image.*;
import java.nio.file.*;
import java.util.*;
import java.util.stream.*;
import ai.djl.*;
import ai.djl.basicmodelzoo.basic.*;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.util.*;
import ai.djl.translate.*;
We will start by loading the image that we want to run our model to classify.
var img = BufferedImageUtils.fromUrl("https://djl-ai.s3.amazonaws.com/resources/images/0.png");
img
Next, we need to load the model to run inference with. This model should have been saved to the build/mlp
directory when running the previous tutorial.
TODO: Mention model zoo? List models in model zoo? TODO: Key Concept ZooModel TODO: Link to Model javadoc
Path modelDir = Paths.get("build/mlp");
Model model = Model.newInstance();
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
model.load(modelDir);
Translator
¶The Translator
is used to encapsulate the pre-processing and post-processing functionality of your application. The input to the processInput and processOutput should be single data items, not batches.
Translator<BufferedImage, Classifications> translator = new Translator<BufferedImage, Classifications>() {
@Override
public NDList processInput(TranslatorContext ctx, BufferedImage input) {
// Convert BufferedImage to NDArray
NDArray array = BufferedImageUtils.toNDArray(ctx.getNDManager(), input, NDImageUtils.Flag.GRAYSCALE);
return new NDList(NDImageUtils.toTensor(array));
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
NDArray probabilities = list.singletonOrThrow().softmax(0);
List<String> indices = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
return new Classifications(indices, probabilities);
}
};
Using the translator, we will create a new Predictor
. The predictor is the main class to orchestrate the inference process. During inference, a trained model is used to predict values, often for production use cases. The predictor is not thread-safe so you should create multiple predictors on the same model with one predictor in each thread.
var predictor = model.newPredictor(translator);
With our predictor, we can simply call the predict method to run inference. Afterwards, the same predictor should be used for further inference calls.
var classifications = predictor.predict(img);
classifications
Now, you've successfully built a model, trained it, and run inference. Congratulations on finishing the beginner tutorial series. After this, you should read our other examples and jupyter notebooks to learn more about DJL.
You can find the complete source code for this tutorial in the examples project.