// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/ %maven ai.djl:api:0.22.1 %maven ai.djl.onnxruntime:onnxruntime-engine:0.22.1 %maven org.slf4j:slf4j-simple:1.7.32 import ai.djl.inference.*; import ai.djl.modality.*; import ai.djl.ndarray.*; import ai.djl.ndarray.types.*; import ai.djl.repository.zoo.*; import ai.djl.translate.*; import java.util.*; public static class IrisFlower { public float sepalLength; public float sepalWidth; public float petalLength; public float petalWidth; public IrisFlower(float sepalLength, float sepalWidth, float petalLength, float petalWidth) { this.sepalLength = sepalLength; this.sepalWidth = sepalWidth; this.petalLength = petalLength; this.petalWidth = petalWidth; } } public static class MyTranslator implements NoBatchifyTranslator { private final List synset; public MyTranslator() { // species name synset = Arrays.asList("setosa", "versicolor", "virginica"); } @Override public NDList processInput(TranslatorContext ctx, IrisFlower input) { float[] data = {input.sepalLength, input.sepalWidth, input.petalLength, input.petalWidth}; NDArray array = ctx.getNDManager().create(data, new Shape(1, 4)); return new NDList(array); } @Override public Classifications processOutput(TranslatorContext ctx, NDList list) { float[] data = list.get(1).toFloatArray(); List probabilities = new ArrayList<>(data.length); for (float f : data) { probabilities.add((double) f); } return new Classifications(synset, probabilities); } } String modelUrl = "https://mlrepo.djl.ai/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/0.0.1/iris_flowers.zip"; Criteria criteria = Criteria.builder() .setTypes(IrisFlower.class, Classifications.class) .optModelUrls(modelUrl) .optTranslator(new MyTranslator()) .optEngine("OnnxRuntime") // use OnnxRuntime engine by default .build(); ZooModel model = criteria.loadModel(); Predictor predictor = model.newPredictor(); IrisFlower info = new IrisFlower(1.0f, 2.0f, 3.0f, 4.0f); predictor.predict(info);