// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/ %maven ai.djl:api:0.12.0 %maven ai.djl.pytorch:pytorch-engine:0.12.0 %maven ai.djl.pytorch:pytorch-model-zoo:0.12.0 %maven org.slf4j:slf4j-api:1.7.26 %maven org.slf4j:slf4j-simple:1.7.26 %maven net.java.dev.jna:jna:5.3.0 // See https://github.com/deepjavalibrary/djl/blob/master/pytorch/pytorch-engine/README.md // for more PyTorch library selection options %maven ai.djl.pytorch:pytorch-native-auto:1.8.1 import java.io.*; import java.nio.file.*; import java.util.*; import java.util.stream.*; import ai.djl.*; import ai.djl.ndarray.*; import ai.djl.ndarray.types.*; import ai.djl.inference.*; import ai.djl.translate.*; import ai.djl.training.util.*; import ai.djl.repository.zoo.*; import ai.djl.modality.nlp.*; import ai.djl.modality.nlp.qa.*; import ai.djl.modality.nlp.bert.*; var question = "When did BBC Japan start broadcasting?"; var resourceDocument = "BBC Japan was a general entertainment Channel.\n" + "Which operated between December 2004 and April 2006.\n" + "It ceased operations after its Japanese distributor folded."; QAInput input = new QAInput(question, resourceDocument); var tokenizer = new BertTokenizer(); List tokenQ = tokenizer.tokenize(question.toLowerCase()); List tokenA = tokenizer.tokenize(resourceDocument.toLowerCase()); System.out.println("Question Token: " + tokenQ); System.out.println("Answer Token: " + tokenA); BertToken token = tokenizer.encode(question.toLowerCase(), resourceDocument.toLowerCase()); System.out.println("Encoded tokens: " + token.getTokens()); System.out.println("Encoded token type: " + token.getTokenTypes()); System.out.println("Valid length: " + token.getValidLength()); DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/pytorch/bertqa/0.0.1/bert-base-uncased-vocab.txt.gz", "build/pytorch/bertqa/vocab.txt", new ProgressBar()); var path = Paths.get("build/pytorch/bertqa/vocab.txt"); var vocabulary = SimpleVocabulary.builder() .optMinFrequency(1) .addFromTextFile(path) .optUnknownToken("[UNK]") .build(); long index = vocabulary.getIndex("car"); String token = vocabulary.getToken(2482); System.out.println("The index of the car is " + index); System.out.println("The token of the index 2482 is " + token); public class BertTranslator implements Translator { private List tokens; private Vocabulary vocabulary; private BertTokenizer tokenizer; @Override public void prepare(NDManager manager, Model model) throws IOException { Path path = Paths.get("build/pytorch/bertqa/vocab.txt"); vocabulary = SimpleVocabulary.builder() .optMinFrequency(1) .addFromTextFile(path) .optUnknownToken("[UNK]") .build(); tokenizer = new BertTokenizer(); } @Override public NDList processInput(TranslatorContext ctx, QAInput input) { BertToken token = tokenizer.encode( input.getQuestion().toLowerCase(), input.getParagraph().toLowerCase()); // get the encoded tokens that would be used in precessOutput tokens = token.getTokens(); NDManager manager = ctx.getNDManager(); // map the tokens(String) to indices(long) long[] indices = tokens.stream().mapToLong(vocabulary::getIndex).toArray(); long[] attentionMask = token.getAttentionMask().stream().mapToLong(i -> i).toArray(); long[] tokenType = token.getTokenTypes().stream().mapToLong(i -> i).toArray(); NDArray indicesArray = manager.create(indices); NDArray attentionMaskArray = manager.create(attentionMask); NDArray tokenTypeArray = manager.create(tokenType); // The order matters return new NDList(indicesArray, attentionMaskArray, tokenTypeArray); } @Override public String processOutput(TranslatorContext ctx, NDList list) { NDArray startLogits = list.get(0); NDArray endLogits = list.get(1); int startIdx = (int) startLogits.argMax().getLong(); int endIdx = (int) endLogits.argMax().getLong(); return tokens.subList(startIdx, endIdx + 1).toString(); } @Override public Batchifier getBatchifier() { return Batchifier.STACK; } } DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/pytorch/bertqa/0.0.1/trace_bertqa.pt.gz", "build/pytorch/bertqa/bertqa.pt", new ProgressBar()); BertTranslator translator = new BertTranslator(); Criteria criteria = Criteria.builder() .setTypes(QAInput.class, String.class) .optModelPath(Paths.get("build/pytorch/bertqa/")) // search in local folder .optTranslator(translator) .optProgress(new ProgressBar()).build(); ZooModel model = criteria.loadModel(); String predictResult = null; QAInput input = new QAInput(question, resourceDocument); // Create a Predictor and use it to predict the output try (Predictor predictor = model.newPredictor(translator)) { predictResult = predictor.predict(input); } System.out.println(question); System.out.println(predictResult);