Skip to content

Classification on Iris dataset with sklearn and DJL

In this notebook, you will try to use a pre-trained sklearn model to run on DJL for a general classification task. The model was trained with Iris flower dataset.

Background

Iris Dataset

The dataset contains a set of 150 records under five attributes - sepal length, sepal width, petal length, petal width and species.

Iris setosa Iris versicolor Iris virginica

The chart above shows three different kinds of the Iris flowers.

We will use sepal length, sepal width, petal length, petal width as the feature and species as the label to train the model.

Sklearn Model

You can find more information here. You can use the sklearn built-in iris dataset to load the data. Then we defined a RandomForestClassifer to train the model. After that, we convert the model to onnx format for DJL to run inference. The following code is a sample classification setup using sklearn:

# Train a model.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)

Preparation

This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the README.

These are dependencies we will use. To enhance the NDArray operation capability, we are importing ONNX Runtime and PyTorch Engine at the same time. Please find more information here.

// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.27.0
%maven ai.djl.onnxruntime:onnxruntime-engine:0.27.0
%maven org.slf4j:slf4j-simple:1.7.36
import ai.djl.inference.*;
import ai.djl.modality.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.repository.zoo.*;
import ai.djl.translate.*;
import java.util.*;

Step 1 create a Translator

Inference in machine learning is the process of predicting the output for a given input based on a pre-defined model. DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:

https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true

The Translator interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing component converts the user-defined input objects into an NDList, so that the Predictor in DJL can understand the input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the Predictor. The post-processing block allows you to convert the output from the Predictor to the desired output format.

In our use case, we use a class namely IrisFlower as our input class type. We will use Classifications as our output class type.

public static class IrisFlower {

    public float sepalLength;
    public float sepalWidth;
    public float petalLength;
    public float petalWidth;

    public IrisFlower(float sepalLength, float sepalWidth, float petalLength, float petalWidth) {
        this.sepalLength = sepalLength;
        this.sepalWidth = sepalWidth;
        this.petalLength = petalLength;
        this.petalWidth = petalWidth;
    }
}

Let's create a translator

public static class MyTranslator implements NoBatchifyTranslator<irisflower, classifications=""> {

    private final List<string> synset;

    public MyTranslator() {
        // species name
        synset = Arrays.asList("setosa", "versicolor", "virginica");
    }

    @Override
    public NDList processInput(TranslatorContext ctx, IrisFlower input) {
        float[] data = {input.sepalLength, input.sepalWidth, input.petalLength, input.petalWidth};
        NDArray array = ctx.getNDManager().create(data, new Shape(1, 4));
        return new NDList(array);
    }

    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        float[] data = list.get(1).toFloatArray();
        List<double> probabilities = new ArrayList&lt;&gt;(data.length);
        for (float f : data) {
            probabilities.add((double) f);
        }
        return new Classifications(synset, probabilities);
    }
}

Step 2 Prepare your model

We will load a pretrained sklearn model into DJL. We defined a ModelZoo concept to allow user load model from varity of locations, such as remote URL, local files or DJL pretrained model zoo. We need to define Criteria class to help the modelzoo locate the model and attach translator. In this example, we download a compressed ONNX model from S3.

String modelUrl = "https://mlrepo.djl.ai/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/0.0.1/iris_flowers.zip";
Criteria<irisflower, classifications=""> criteria = Criteria.builder()
        .setTypes(IrisFlower.class, Classifications.class)
        .optModelUrls(modelUrl)
        .optTranslator(new MyTranslator())
        .optEngine("OnnxRuntime") // use OnnxRuntime engine by default
        .build();
ZooModel<irisflower, classifications=""> model = criteria.loadModel();

Step 3 Run inference

User will just need to create a Predictor from model to run the inference.

Predictor<irisflower, classifications=""> predictor = model.newPredictor();
IrisFlower info = new IrisFlower(1.0f, 2.0f, 3.0f, 4.0f);
predictor.predict(info);
[
    {"class": "virginica", "probability": 0.73000}
    {"class": "versicolor", "probability": 0.26999}
    {"class": "setosa", "probability": 0.0e+00}
]