In a previous tutorial, you successfully trained your model.
Now, you learn how to use your model to classify a handwritten digit image. You will learn how to implement a Translator
interface to convert between POJO and NDArray
.
This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the README.
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/
%maven ai.djl:api:0.2.1
%maven ai.djl:repository:0.2.1
%maven ai.djl:model-zoo:0.2.1
%maven ai.djl.mxnet:mxnet-engine:0.2.1
%maven ai.djl.mxnet:mxnet-model-zoo:0.2.1
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0
This tutorial uses MXNet engine as its backend. MXNet has different build flavor and it is platform specific. Please read here for how to select MXNet engine flavor.
String osName = System.getProperty("os.name");
String classifier = osName.startsWith("Mac") ? "osx-x86_64" : osName.startsWith("Win") ? "win-x86_64" : "linux-x86_64";
%maven ai.djl.mxnet:mxnet-native-mkl:jar:${classifier}:1.6.0-b
import java.awt.image.*;
import java.nio.file.*;
import java.util.*;
import java.util.stream.*;
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.index.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.modality.cv.util.NDImageUtils.Flag;
import ai.djl.mxnet.zoo.*;
import ai.djl.translate.*;
import ai.djl.util.*;
import ai.djl.zoo.cv.classification.*;
var img = BufferedImageUtils.fromUrl("https://djl-ai.s3.amazonaws.com/resources/images/0.png");
img
Path modelDir = Paths.get("build/mlp");
Model model = Model.newInstance();
model.setBlock(new Mlp(28, 28));
model.load(modelDir);
Translator
¶Translator<BufferedImage, Classifications> translator = new Translator<BufferedImage, Classifications>() {
@Override
public NDList processInput(TranslatorContext ctx, BufferedImage input) {
// Convert BufferedImage to NDArray
NDArray array = BufferedImageUtils.toNDArray(ctx.getNDManager(), input, NDImageUtils.Flag.GRAYSCALE);
return new NDList(NDImageUtils.toTensor(array));
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
NDArray probabilities = list.singletonOrThrow().softmax(0);
List<String> indices = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
return new Classifications(indices, probabilities);
}
};
var predictor = model.newPredictor(translator);
var classifications = predictor.predict(img);
classifications
You can find the full source in the examples project.