Skip to content

Run this notebook online:Binder

Train your first model

This is the second of our beginner tutorial series that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to train an image classification model that can recognize handwritten digits.

Preparation

This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the Jupyter README.

// Add the snapshot repository to get the DJL snapshot artifacts
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

// Add the maven dependencies
%maven ai.djl:api:0.10.0
%maven ai.djl:basicdataset:0.10.0
%maven ai.djl:model-zoo:0.10.0
%maven ai.djl.mxnet:mxnet-engine:0.10.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0

// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md
// for more MXNet library selection options
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport
import java.nio.file.*;

import ai.djl.*;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.initializer.*;
import ai.djl.training.loss.*;
import ai.djl.training.listener.*;
import ai.djl.training.evaluator.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.util.*;
import ai.djl.basicmodelzoo.cv.classification.*;
import ai.djl.basicmodelzoo.basic.*;

Step 1: Prepare MNIST dataset for training

In order to train, you must create a Dataset class to contain your training data. A dataset is a collection of sample input/output pairs for the function represented by your neural network. Each single input/output is represented by a Record. Each record could have multiple arrays of inputs or outputs such as an image question and answer dataset where the input is both an image and a question about the image while the output is the answer to the question.

Because data learning is highly parallelizable, training is often done not with a single record at a time, but a Batch. This can lead to significant performance gains, especially when working with images

Sampler

Then, we must decide the parameters for loading data from the dataset. The only parameter we need for MNIST is the choice of Sampler. The sampler decides which and how many element from datasets are part of each batch when iterating through it. We will have it randomly shuffle the elements for the batch and use a batchSize of 32. The batchSize is usually the largest power of 2 that fits within memory.

int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());
Downloading: 100% |████████████████████████████████████████|

Step 2: Create your Model

Next we will build a model. A Model contains a neural network Block along with additional artifacts used for the training process. It possesses additional information about the inputs, outputs, shapes, and data types you will use. Generally, you will use the Model once you have fully completed your Block.

In this part of the tutorial, we will use the built-in Multilayer Perceptron Block from the Model Zoo. To learn how to build it from scratch, see the previous tutorial: Create Your First Network.

Because images in the MNIST dataset are 28x28 grayscale images, we will create an MLP block with 28 x 28 input. The output will be 10 because there are 10 possible classes (0 to 9) each image could be. For the hidden layers, we have chosen new int[] {128, 64} by experimenting with different values.

Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));

Step 3: Create a Trainer

Now, you can create a Trainer to train your model. The trainer is the main class to orchestrate the training process. Usually, they will be opened using a try-with-resources and closed after training is over.

The trainer takes an existing model and attempts to optimize the parameters inside the model's Block to best match the dataset. Most optimization is based upon Stochastic Gradient Descent (SGD).

Step 3.1: Setup your training configurations

Before you create your trainer, we we will need a training configuration that describes how to train your model.

The following are a few common items you may need to configure your training:

  • REQUIRED Loss function: A loss function is used to measure how well our model matches the dataset. Because the lower value of the function is better, it's called the "loss" function. The Loss is the only required argument to the model
  • Evaluator function: An evaluator function is also used to measure how well our model matches the dataset. Unlike the loss, they are only there for people to look at and are not used for optimizing the model. Since many losses are not as intuitive, adding other evaluators such as Accuracy can help to understand how your model is doing. If you know of any useful evaluators, we recommend adding them.
  • Training Listeners: The training listener adds additional functionality to the training process through a listener interface. This can include showing training progress, stopping early if training becomes undefined, or recording performance metrics. We offer several easy sets of default listeners.

You can also configure other options such as the Device, Initializer, and Optimizer. See more details.

DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
    //softmaxCrossEntropyLoss is a standard loss for classification problems
    .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
    .addTrainingListeners(TrainingListener.Defaults.logging());

// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu().
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.037 ms.

Step 5: Initialize Training

Before training your model, you have to initialize all of the parameters with starting values. You can use the trainer for this initialization by passing in the input shape.

  • The first axis of the input shape is the batch size. This won't impact the parameter initialization, so you can use 1 here.
  • The second axis of the input shape of the MLP - the number of pixels in the input image.
trainer.initialize(new Shape(1, 28 * 28));

Step 6: Train your model

Now, we can train the model.

When training, it is usually organized into epochs where each epoch trains the model on each item in the dataset once. It is slightly faster than training randomly.

Then, we will use the EasyTrain to, as the name promises, make the training easy. If you want to see more details about how the training loop works, see the EasyTrain class or read our Dive into Deep Learning book.

// Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.
int epoch = 2;

EasyTrain.fit(trainer, epoch, mnist, null);
Training:    100% |████████████████████████████████████████| 

[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 1 finished.

Training:    100% |████████████████████████████████████████| 

[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 2 finished.

Step 7: Save your model

Once your model is trained, you should save it so that it can be reloaded later. You can also add metadata to it such as training accuracy, number of epochs trained, etc that can be used when loading the model or when examining it.

Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);

model.setProperty("Epoch", String.valueOf(epoch));

model.save(modelDir, "mlp");

model
Model (
    Name: mlp
    Model location: /home/runner/work/djl/djl/jupyter/tutorial/build/mlp
    Data Type: float32
    Epoch: 2
)

Summary

Now, you've successfully trained a model that can recognize handwritten digits. You'll learn how to apply this model in the next chapter: Run image classification with your model.

You can find the complete source code for this tutorial in the examples project.