Skip to content

Run this notebook online:Binder

Train your first model

This is the second of our beginner tutorial series that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to train an image classification model that can recognize handwritten digits.


This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the Jupyter README.

// Add the snapshot repository to get the DJL snapshot artifacts
// %mavenRepo snapshots

// Add the maven dependencies
%maven ai.djl:api:0.9.0
%maven ai.djl:basicdataset:0.9.0
%maven ai.djl:model-zoo:0.9.0
%maven ai.djl.mxnet:mxnet-engine:0.9.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26

// See
// for more MXNet library selection options
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport
import java.nio.file.*;

import ai.djl.*;
import ai.djl.basicdataset.*;
import ai.djl.ndarray.types.*;
import ai.djl.basicmodelzoo.basic.*;

Step 1: Prepare MNIST dataset for training

When training a deep learning network, it is important to first understand the dataset.


A Dataset is a collection of sample input/output pairs for the function represented by your neural network. Each single input/output is represented by a Record. Each record could have multiple arrays of inputs or outputs such as an image question and answer dataset where the input is both an image and a question about the image while the output is the answer to the question.

Because data learning is highly parallelizable, training is often done not with a single record at a time, but a Batch of records at a time. This can lead to significant performance gains, especially when working with images.


The dataset we will be using is MNIST, a database of handwritten digits. Each image contains a black and white digit from 0-9 in a 28x28 image. It is commonly used when getting started with deep learning because it is small and fast to train.

Mnist Image

Once you understand your dataset, you should create an implementation of the Dataset class. In this case, we provide the MNIST dataset built-in to make it easy for you to use it.


Then, we must decide the parameters for loading data from the dataset. The only parameter we need for MNIST is the choice of Sampler. The sampler decides which and how many element from datasets are part of each batch when iterating through it. We will have it randomly shuffle the elements for the batch and use a batchSize of 32. The batchSize is usually the largest power of 2 that fits within memory.

int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());
Downloading: 100% |████████████████████████████████████████|

Step 2: Create your Model

A Model contains a neural network Block along with additional artifacts used for the training process. It possesses additional information about the inputs, outputs, shapes, and data types you will use. Generally, you will use Model once you have fully completed your Block.

In this part of the tutorial, we will use the built-in Multilayer Perceptron Block from the Model Zoo. To learn more, see the previous tutorial: Create Your First Network.

Because images in the MNIST dataset are 28x28 grayscale images, we will create an MLP block with 28 x 28 input. The output will be 10 because there are 10 possible classes (0 to 9) each image could be. For the hidden layers, we have chosen new int[] {128, 64} by experimenting with different values.

Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));

Step 3: Create a Trainer

Now, you can create a Trainer to train your model. The trainer is the main class to orchestrate the training process. Usually, they will be opened using a try-with-resources and closed after training is over.

The trainer takes an existing model and attempts to optimize the parameters inside the model's Block to best match the dataset. Most optimization is based upon Stochastic Gradient Descent (SGD).

Step 3.1: Setup your training configurations

Before you create your trainer, we we will need a training configuration that describes how to train your model.

The following are a few common items you may need to configure your training: * REQUIRED Loss function: A loss function is used to measure how well our model matches the dataset. Because the lower value of the function is better, it's called the "loss" function. The Loss is the only required argument to the model * Evaluator function: An evaluator function is also used to measure how well our model matches the dataset. Unlike the loss, they are only there for people to look at and are not used for optimizing the model. Since many losses are not as intuitive, adding other evaluators such as Accuracy can help to understand how your model is doing. If you know of any useful evaluators, we recommend adding them. * Device: The device is what hardware should be used to train your model on. Typically, this is either CPU or GPU. DJL can automatically detect whether a GPU is available. If GPUs are available, it will run on a single GPU by default. If you need to train with multiple GPUs, you need to set devices as : config.setDevices(Devices.getDevices(maxNumberOfGPUs)). * Initializer: An Initializer is used to set the initial values of the model's parameters before training. This can usually be left as the default initializer. * Optimizer: The optimizer is the code that updates the model parameters to minimize the loss function. There are a variety of optimizers, most of which offer improvements upon the basic SGD. When just starting, you can use the default optimizer. Later on, customizing the optimizer can result in faster training. * Training Listeners: The training listener adds additional functionality to the training process through a listener interface. This can include showing training progress, stopping early if training becomes undefined, or recording performance metrics. We offer several easy sets of default listeners.

DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
    //softmaxCrossEntropyLoss is a standard loss for classification problems
    .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is

// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);
[IJava-executor-0] INFO - Training on: cpu().
[IJava-executor-0] INFO - Load MXNet Engine Version 1.7.0 in 0.054 ms.

Step 5: Initialize Training

Before training your model, you have to initialize all of the parameters with starting values. You can use the trainer for this initialization by passing in the input shape.

  • The first axis of the input shape is the batch size. This won't impact the parameter initialization, so you can use 1 here.
  • The second axis of the input shape of the MLP - the number of pixels in the input image.
trainer.initialize(new Shape(1, 28 * 28));

Step 6: Train your model

Now, we can train the model.

// Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.
int epoch = 2;

for (int i = 0; i < epoch; ++i) {
    int index = 0;

    // We iterate through the dataset once during each epoch
    for (Batch batch : trainer.iterateDataset(mnist)) {

        // During trainBatch, we update the loss and evaluators with the results for the training batch.
        EasyTrain.trainBatch(trainer, batch);

        // Now, we update the model parameters based on the results of the latest trainBatch

        // We must make sure to close the batch to ensure all the memory associated with the batch is cleared quickly.
        // If the memory isn't closed after each batch, you will very quickly run out of memory on your GPU

    // Call the end epoch event for the training listeners now that we are done
    trainer.notifyListeners(listener -> listener.onEpoch(trainer));
Training:    100% |████████████████████████████████████████| 

[IJava-executor-0] INFO - Epoch 1 finished.

Training:    100% |████████████████████████████████████████| 

[IJava-executor-0] INFO - Epoch 2 finished.

Step 7: Save your model

Once your model is trained, you should save it so that it can be reloaded later. You can also add metadata to it such as training accuracy, number of epochs trained, etc that can be used when loading the model or when examining it.

Path modelDir = Paths.get("build/mlp");

model.setProperty("Epoch", String.valueOf(epoch));, "mlp");

Model (
    Name: mlp
    Model location: /home/runner/work/djl/djl/djl_tmp/jupyter/tutorial/build/mlp
    Data Type: float32
    Epoch: 2


Now, you've successfully trained a model that can recognize handwritten digits. You'll learn how to apply this model in the next chapter: Run image classification with your model.

You can find the complete source code for this tutorial in the examples project.