# Rank Classification using BERT on Amazon Review dataset¶

## Introduction¶

In this tutorial, you learn how to train a rank classification model using Transfer Learning. We will use a pretrained DistilBert model to train on the Amazon review dataset.

## About the dataset and model¶

Amazon Customer Review dataset consists of all different valid reviews from amazon.com. We will use the "Digital_software" category that consists of 102k valid reviews. As for the pre-trained model, use the DistilBERT[1] model. It's a light-weight BERT model already trained on Wikipedia text corpora, a much larger dataset consisting of over millions text. The DistilBERT served as a base layer and we will add some more classification layers to output as rankings (1 - 5).

Amazon Review example

We will use review body as our data input and ranking as label.

## Pre-requisites¶

This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with:

1. How to setup and run Java Kernel in Jupyter Notebook
2. Basic components of Deep Java Library, and how to train your first model.

## Getting started¶

Load the Deep Java Libarary and its dependencies from Maven. In here, you can choose between MXNet or PyTorch. MXNet is enabled by default. You can uncomment PyTorch dependencies and comment MXNet ones to switch to PyTorch.

In [ ]:
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.15.0
%maven ai.djl:basicdataset:0.15.0
%maven org.slf4j:slf4j-simple:1.7.32
%maven ai.djl.mxnet:mxnet-model-zoo:0.15.0

// PyTorch
// %maven ai.djl.pytorch:pytorch-model-zoo:0.15.0


Now let's import the necessary modules:

In [ ]:
import ai.djl.*;
import ai.djl.basicdataset.tabular.*;
import ai.djl.basicdataset.utils.*;
import ai.djl.engine.*;
import ai.djl.inference.*;
import ai.djl.metric.*;
import ai.djl.modality.*;
import ai.djl.modality.nlp.*;
import ai.djl.modality.nlp.bert.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.nn.*;
import ai.djl.nn.core.*;
import ai.djl.nn.norm.*;
import ai.djl.repository.zoo.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.evaluator.*;
import ai.djl.training.listener.*;
import ai.djl.training.loss.*;
import ai.djl.training.util.*;
import ai.djl.translate.*;
import java.io.*;
import java.nio.file.*;
import java.util.*;
import org.apache.commons.csv.*;

System.out.println("You are using: " + Engine.getInstance().getEngineName() + " Engine");


## Prepare Dataset¶

First step is to prepare the dataset for training. Since the original data was in TSV format, we can use CSVDataset to be the dataset container. We will also need to specify how do we want to preprocess the raw data. For BERT model, the input data are required to be tokenized and mapped into indices based on the inputs. In DJL, we defined an interface called Fearurizer, it is designed to allow user customize operation on each selected row/column of a dataset. In our case, we would like to clean and tokenize our sentencies. So let's try to implement it to deal with customer review sentencies.

In [ ]:
final class BertFeaturizer implements CsvDataset.Featurizer {

private final BertFullTokenizer tokenizer;
private final int maxLength; // the cut-off length

public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength) {
this.tokenizer = tokenizer;
this.maxLength = maxLength;
}

/** {@inheritDoc} */
@Override
public void featurize(DynamicBuffer buf, String input) {
Vocabulary vocab = tokenizer.getVocabulary();
// convert sentence to tokens (toLowerCase for uncased model)
List<String> tokens = tokenizer.tokenize(input.toLowerCase());
// trim the tokens to maxLength
tokens = tokens.size() > maxLength ? tokens.subList(0, maxLength) : tokens;
// BERT embedding convention "[CLS] Your Sentence [SEP]"
buf.put(vocab.getIndex("[CLS]"));
tokens.forEach(token -> buf.put(vocab.getIndex(token)));
buf.put(vocab.getIndex("[SEP]"));
}
}


Once we got this part done, we can apply the BertFeaturizer into our Dataset. We take review_body column and apply the Featurizer. We also pick star_rating as our label set. Since we go for batch input, we need to tell the dataset to pad our data if it is less than the maxLength we defined. PaddingStackBatchifier will do the work for you.

In [ ]:
CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {
String amazonReview =
"https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz";
return CsvDataset.builder()
.setSampling(batchSize, true) // make sample size and random access
.optLimit(limit)
new CsvDataset.Feature(
"review_body", new BertFeaturizer(tokenizer, maxLength)))
new CsvDataset.Feature(
"star_rating", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))
.optDataBatchifier(
.optIncludeValidLengths(false)
.build()) // define how to pad dataset to a fix length
.build();
}


We will load our pretrained model and prepare the classification. First construct the criteria to specify where to load the embedding (DistiledBERT), then call loadModel to download that embedding with pre-trained weights. Since this model is built without classification layer, we need to add a classification layer to the end of the model and train it. After you are done modifying the block, set it back to model using setBlock.

In [ ]:
// MXNet base model
String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
}

Criteria<NDList, NDList> criteria = Criteria.builder()
.optApplication(Application.NLP.WORD_EMBEDDING)
.setTypes(NDList.class, NDList.class)
.optModelUrls(modelUrls)
.optProgress(new ProgressBar())
.build();


### Create classification layers¶

Then let's build a simple MLP layer to classify the ranks. We set the output of last FullyConnected (Linear) layer to 5 to get the predictions for star 1 to 5. Then all we need to do is to load the block into the model. Before applying the classification layer, we also need to add text embedding to the front. In our case, we just create a Lambda function that do the followings:

1. batch_data (batch size, token indices) -> batch_data + max_length (size of the token indices)
2. generate embedding
In [ ]:
Predictor<NDList, NDList> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
// text embedding layer
ndList -> {
NDArray data = ndList.singletonOrThrow();
NDList inputs = new NDList();
long batchSize = data.getShape().get(0);
float maxLength = data.getShape().get(1);

if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
.toType(DataType.INT64, false)
} else {
}
// run embedding
try {
return embedder.predict(inputs);
} catch (TranslateException e) {
throw new IllegalArgumentException("embedding error", e);
}
})
// classification layer
Model model = Model.newInstance("AmazonReviewRatingClassification");
model.setBlock(classifier);


## Start Training¶

Finally, we can start building our training pipeline to train the model.

### Creating Training and Testing dataset¶

Firstly, we need to create a voabulary that is used to map token to index such as "hello" to 1121 (1121 is the index of "hello" in dictionary). Then we simply feed the vocabulary to the tokenizer that used to tokenize the sentence. Finally, we just need to split the dataset based on the ratio.

Note: we set the cut-off length to 64 which means only the first 64 tokens from the review will be used. You can increase this value to achieve better accuracy.

In [ ]:
// Prepare the vocabulary
DefaultVocabulary vocabulary = DefaultVocabulary.builder()
.optUnknownToken("[UNK]")
.build();
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
int limit = Integer.MAX_VALUE;
// int limit = 512; // uncomment for quick testing

BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];


### Setup Trainer and training config¶

Then, we need to setup our trainer. We set up the accuracy and loss function. The model training logs will be saved to build/modlel.

In [ ]:
SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(
trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
// track for accuracy and loss
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
.optDevices(Engine.getInstance().getDevices(1)) // train using single GPU


### Start training¶

We will start our training process. Training on GPU will takes approximately 10 mins. For CPU, it will take more than 2 hours to finish.

In [ ]:
int epoch = 2;

Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);
EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
System.out.println(trainer.getTrainingResult());


### Save the model¶

In [ ]:
model.save(Paths.get("build/model"), "amazon-review.param");


## Verify the model¶

We can create a predictor from the model to run inference on our customized dataset. Firstly, we can create a Translator for the model to do preprocessing and post processing. Similar to what we have done before, we need to tokenize the input sentence and get the output ranking.

In [ ]:
class MyTranslator implements Translator<String, Classifications> {

private BertFullTokenizer tokenizer;
private Vocabulary vocab;
private List<String> ranks;

public MyTranslator(BertFullTokenizer tokenizer) {
this.tokenizer = tokenizer;
vocab = tokenizer.getVocabulary();
ranks = Arrays.asList("1", "2", "3", "4", "5");
}

@Override
public Batchifier getBatchifier() { return Batchifier.STACK; }

@Override
public NDList processInput(TranslatorContext ctx, String input) {
List<String> tokens = tokenizer.tokenize(input);
float[] indices = new float[tokens.size() + 2];
indices[0] = vocab.getIndex("[CLS]");
for (int i = 0; i < tokens.size(); i++) {
indices[i+1] = vocab.getIndex(tokens.get(i));
}
indices[indices.length - 1] = vocab.getIndex("[SEP]");
return new NDList(ctx.getNDManager().create(indices));
}

@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
return new Classifications(ranks, list.singletonOrThrow().softmax(0));
}
}


Finally, we can create a Predictor to run the inference. Let's try with a random customer review:

In [ ]:
String review = "It works great, but it takes too long to update itself and slows the system";
Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));

predictor.predict(review)