Skip to content

Rank Classification using BERT on Amazon Review dataset


In this tutorial, you learn how to train a rank classification model using Transfer Learning. We will use a pretrained DistilBert model to train on the Amazon review dataset.

About the dataset and model

Amazon Customer Review dataset consists of all different valid reviews from We will use the "Digital_software" category that consists of 102k valid reviews. As for the pre-trained model, use the DistilBERT[1] model. It's a light-weight BERT model already trained on Wikipedia text corpora, a much larger dataset consisting of over millions text. The DistilBERT served as a base layer and we will add some more classification layers to output as rankings (1 - 5).

No description has been provided for this image

Amazon Review example

We will use review body as our data input and ranking as label.


This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with: 1. How to setup and run Java Kernel in Jupyter Notebook 2. Basic components of Deep Java Library, and how to train your first model.

Getting started

Load the Deep Java Libarary and its dependencies from Maven. In here, you can choose between MXNet or PyTorch. MXNet is enabled by default. You can uncomment PyTorch dependencies and comment MXNet ones to switch to PyTorch.

// %mavenRepo snapshots

%maven ai.djl:api:0.27.0
%maven ai.djl:basicdataset:0.27.0
%maven org.slf4j:slf4j-simple:1.7.36
%maven ai.djl.mxnet:mxnet-model-zoo:0.27.0

// PyTorch
// %maven ai.djl.pytorch:pytorch-model-zoo:0.27.0

Now let's import the necessary modules:

import ai.djl.*;
import ai.djl.basicdataset.tabular.*;
import ai.djl.basicdataset.tabular.utils.*;
import ai.djl.basicdataset.utils.*;
import ai.djl.engine.*;
import ai.djl.inference.*;
import ai.djl.metric.*;
import ai.djl.modality.*;
import ai.djl.modality.nlp.*;
import ai.djl.modality.nlp.bert.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.nn.*;
import ai.djl.nn.core.*;
import ai.djl.nn.norm.*;
import ai.djl.repository.zoo.*;
import ai.djl.translate.*;
import java.nio.file.*;
import java.util.*;
import org.apache.commons.csv.*;

System.out.println("You are using: " + Engine.getInstance().getEngineName() + " Engine");

Prepare Dataset

First step is to prepare the dataset for training. Since the original data was in TSV format, we can use CSVDataset to be the dataset container. We will also need to specify how do we want to preprocess the raw data. For BERT model, the input data are required to be tokenized and mapped into indices based on the inputs. In DJL, we defined an interface called Fearurizer, it is designed to allow user customize operation on each selected row/column of a dataset. In our case, we would like to clean and tokenize our sentencies. So let's try to implement it to deal with customer review sentencies.

final class BertFeaturizer implements Featurizer {

    private final BertFullTokenizer tokenizer;
    private final int maxLength; // the cut-off length

    public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength) {
        this.tokenizer = tokenizer;
        this.maxLength = maxLength;

    /** {@inheritDoc} */
    public void featurize(DynamicBuffer buf, String input) {
        Vocabulary vocab = tokenizer.getVocabulary();
        // convert sentence to tokens (toLowerCase for uncased model)
        List<string> tokens = tokenizer.tokenize(input.toLowerCase());
        // trim the tokens to maxLength
        tokens = tokens.size() &gt; maxLength ? tokens.subList(0, maxLength) : tokens;
        // BERT embedding convention "[CLS] Your Sentence [SEP]"
        tokens.forEach(token -&gt; buf.put(vocab.getIndex(token)));

    /** {@inheritDoc} */
    public int dataRequired() {
        throw new IllegalStateException("BertFeaturizer only support featurize, not deFeaturize");

    /** {@inheritDoc} */
    public Object deFeaturize(float[] data) {
        throw new IllegalStateException("BertFeaturizer only support featurize, not deFeaturize");

Once we got this part done, we can apply the BertFeaturizer into our Dataset. We take review_body column and apply the Featurizer. We also pick star_rating as our label set. Since we go for batch input, we need to tell the dataset to pad our data if it is less than the maxLength we defined. PaddingStackBatchifier will do the work for you.

CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {
    String amazonReview =
    float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
    return CsvDataset.builder()
            .optCsvUrl(amazonReview) // load from Url
            .setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format
            .setSampling(batchSize, true) // make sample size and random access
                    new Feature(
                            "review_body", new BertFeaturizer(tokenizer, maxLength)))
                    new Feature(
                            "star_rating", (buf, data) -&gt; buf.put(Float.parseFloat(data) - 1.0f)))
                            .addPad(0, 0, (m) -&gt; m.ones(new Shape(1)).mul(paddingToken))
                            .build()) // define how to pad dataset to a fix length

Construct your model

We will load our pretrained model and prepare the classification. First construct the criteria to specify where to load the embedding (DistiledBERT), then call loadModel to download that embedding with pre-trained weights. Since this model is built without classification layer, we need to add a classification layer to the end of the model and train it. After you are done modifying the block, set it back to model using setBlock.

Load the word embedding

We will download our word embedding and load it to memory (this may take a while)

// MXNet base model
String modelUrls = "";
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
    modelUrls = "";

Criteria<ndlist, ndlist=""> criteria = Criteria.builder()
        .setTypes(NDList.class, NDList.class)
        .optProgress(new ProgressBar())
ZooModel<ndlist, ndlist=""> embedding = criteria.loadModel();

Create classification layers

Then let's build a simple MLP layer to classify the ranks. We set the output of last FullyConnected (Linear) layer to 5 to get the predictions for star 1 to 5. Then all we need to do is to load the block into the model. Before applying the classification layer, we also need to add text embedding to the front. In our case, we just create a Lambda function that do the followings:

  1. batch_data (batch size, token indices) -> batch_data + max_length (size of the token indices)
  2. generate embedding
Predictor<ndlist, ndlist=""> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
        // text embedding layer
            ndList -&gt; {
                NDArray data = ndList.singletonOrThrow();
                NDList inputs = new NDList();
                long batchSize = data.getShape().get(0);
                float maxLength = data.getShape().get(1);

                if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
                    inputs.add(data.toType(DataType.INT64, false));
                    inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
                               .toType(DataType.INT64, false)
                } else {
                    inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
                // run embedding
                try {
                    return embedder.predict(inputs);
                } catch (TranslateException e) {
                    throw new IllegalArgumentException("embedding error", e);
        // classification layer
        .add(Linear.builder().setUnits(768).build()) // pre classifier
        .add(Linear.builder().setUnits(5).build()) // 5 star rating
        .addSingleton(nd -&gt; nd.get(":,0")); // Take [CLS] as the head
Model model = Model.newInstance("AmazonReviewRatingClassification");

Start Training

Finally, we can start building our training pipeline to train the model.

Creating Training and Testing dataset

Firstly, we need to create a voabulary that is used to map token to index such as "hello" to 1121 (1121 is the index of "hello" in dictionary). Then we simply feed the vocabulary to the tokenizer that used to tokenize the sentence. Finally, we just need to split the dataset based on the ratio.

Note: we set the cut-off length to 64 which means only the first 64 tokens from the review will be used. You can increase this value to achieve better accuracy.

// Prepare the vocabulary
DefaultVocabulary vocabulary = DefaultVocabulary.builder()
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
int limit = Integer.MAX_VALUE;
// int limit = 512; // uncomment for quick testing

BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];

Setup Trainer and training config

Then, we need to setup our trainer. We set up the accuracy and loss function. The model training logs will be saved to build/modlel.

SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
            trainer -&gt; {
                TrainingResult result = trainer.getTrainingResult();
                Model model = trainer.getModel();
                // track for accuracy and loss
                float accuracy = result.getValidateEvaluation("Accuracy");
                model.setProperty("Accuracy", String.format("%.5f", accuracy));
                model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
        .addEvaluator(new Accuracy())
        .optDevices(Engine.getInstance().getDevices(1)) // train using single GPU

Start training

We will start our training process. Training on GPU will takes approximately 10 mins. For CPU, it will take more than 2 hours to finish.

int epoch = 2;

Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);, epoch, trainingSet, validationSet);

Save the model"build/model"), "amazon-review.param");

Verify the model

We can create a predictor from the model to run inference on our customized dataset. Firstly, we can create a Translator for the model to do preprocessing and post processing. Similar to what we have done before, we need to tokenize the input sentence and get the output ranking.

class MyTranslator implements Translator<string, classifications=""> {

    private BertFullTokenizer tokenizer;
    private Vocabulary vocab;
    private List<string> ranks;

    public MyTranslator(BertFullTokenizer tokenizer) {
        this.tokenizer = tokenizer;
        vocab = tokenizer.getVocabulary();
        ranks = Arrays.asList("1", "2", "3", "4", "5");

    public Batchifier getBatchifier() { return Batchifier.STACK; }

    public NDList processInput(TranslatorContext ctx, String input) {
        List<string> tokens = tokenizer.tokenize(input);
        float[] indices = new float[tokens.size() + 2];
        indices[0] = vocab.getIndex("[CLS]");
        for (int i = 0; i &lt; tokens.size(); i++) {
            indices[i+1] = vocab.getIndex(tokens.get(i));
        indices[indices.length - 1] = vocab.getIndex("[SEP]");
        return new NDList(ctx.getNDManager().create(indices));

    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        return new Classifications(ranks, list.singletonOrThrow().softmax(0));

Finally, we can create a Predictor to run the inference. Let's try with a random customer review:

String review = "It works great, but it takes too long to update itself and slows the system";
Predictor<string, classifications=""> predictor = model.newPredictor(new MyTranslator(tokenizer));