Custom CSV Dataset Example¶
If the provided Datasets don't meet your requirements, you can also easily extend our dataset to create your own customized dataset.
Let's take CSVDataset, which can load a csv file, for example.
Step 1: Prerequisites¶
For this example, we'll use malicious_url_data.csv.
The CSV file has the following format.
URL | isMalicious |
---|---|
sample.url.good.com | 0 |
sample.url.bad.com | 1 |
We'll also use the 3rd party Apache Commons library to read the CSV file. To use the library, include the following dependency:
api group: 'org.apache.commons', name: 'commons-csv', version: '1.7'
Step 2: Implementation¶
In order to extend the dataset, the following dependencies are required:
api "ai.djl:api:0.31.0"
api "ai.djl:basicdataset:0.31.0"
There are four parts we need to implement for CSVDataset.
- Constructor and Builder
First, we need a private field that holds the CSVRecord list from the csv file.
We create a constructor and pass the CSVRecord list from builder to the class field.
For builder, we have all we need in BaseBuilder
so we only need to include the two minimal methods as shown.
In the build() method, we take advantage of CSVParser to get the record of each CSV file and put them in CSVRecord list.
public class CSVDataset extends RandomAccessDataset {
private final List<CSVRecord> csvRecords;
private CSVDataset(Builder builder) {
super(builder);
csvRecords = builder.csvRecords;
}
...
public static final class Builder extends BaseBuilder<Builder> {
List<CSVRecord> csvRecords;
@Override
protected Builder self() {
return this;
}
CSVDataset build() throws IOException {
String csvFilePath = "path/malicious_url_data.csv";
try (Reader reader = Files.newBufferedReader(Paths.get(csvFilePath));
CSVParser csvParser =
new CSVParser(
reader,
CSVFormat.DEFAULT
.builder()
.setHeader("url", "isMalicious")
.setSkipHeaderRecord(true)
.setIgnoreHeaderCase(true)
.setTrim(true)
.build())) {
csvRecords = csvParser.getRecords();
}
return new CSVDataset(this);
}
}
}
- Prepare
As mentioned, in this example we are taking advantage of CSVParser to prepare the data for us. To prepare
the data on our own, we use the prepare()
method. Normally here we would load or create any data
for our dataset and then save it in one of the private fields previously created. This prepare()
method
is called everytime we call getData()
so in every case we want to only load the data once, we use a
boolean variable called prepared
to check if it has previously been loaded or prepared.
Since we don't have to prepare any data on our own for this example, we only have to override it.
@Override
public void prepare(Progress progress) {}
There are great examples
in our basicdataset
folder that show use cases for prepare()
.
- Getter
The getter returns a Record object which contains encoded inputs and labels. Here, we use simple encoding to transform the url String to an int array and create a NDArray on top of it. The reason why we use NDList here is that you might have multiple inputs and labels in different tasks.
@Override
public Record get(NDManager manager, long index) {
// get a CSVRecord given an index
CSVRecord record = csvRecords.get(Math.toIntExact(index));
NDArray datum = manager.create(encode(record.get("url")));
NDArray label = manager.create(Float.parseFloat(record.get("isMalicious")));
return new Record(new NDList(datum), new NDList(label));
}
- Size
The number of records available to be read in this Dataset.
Here, we can directly use the size of the List
@Override
public long availableSize() {
return csvRecords.size();
}
Done! Now, you can use the CSVDataset with the following code snippet:
CSVDataset dataset = new CSVDataset.Builder().setSampling(batchSize, false).build();
for (Batch batch : dataset.getData(model.getNDManager())) {
// use head to get first NDArray
batch.getData().head();
batch.getLabels().head();
...
// don't forget to close the batch in the end
batch.close();
}
Full example code could be found in CSVDataset.java.