ONNX#

ONNX is an open format built to represent machine learning models. ONNX provides high interoperability among various frameworks, as well as enable machine learning practitioners to maximize models’ performance across different hardware.

Due to its high interoperability among frameworks, we recommend you to check out the framework integration with ONNX as it will contain specific recommendation and requirements for that given framework.

Compatibility#

BentoML currently only support ONNX Runtime as ONNX backend. BentoML requires either onnxruntime>=1.9 or onnxruntime-gpu>=1.9 to be installed.

Converting model frameworks to ONNX format#

First, let’s create a SuperResolution model in PyTorch.

train.py#
import torch.nn as nn
import torch.nn.init as init

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

torch_model = SuperResolutionNet(upscale_factor=3)

For this tutorial, we will use pre-trained weights provided by the PyTorch team. Note that the model was only partially trained and being used for demonstration purposes.

train.py#
import torch.utils.model_zoo as model_zoo

# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

# set the model to inference mode
torch_model.eval()

Exporting a model to ONNX in PyTorch works via tracing or scripting (read more at official PyTorch documentation). In this tutorial we will export the model using tracing techniques:

train.py#
batch_size = 1
# Tracing input to the model
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

# Export the model
torch.onnx.export(
   torch_model,
   x,
   "super_resolution.onnx",  # where to save the model (can be a file or file-like object)
   export_params=True,  # store the trained parameter weights inside the model file
   opset_version=10,  # the ONNX version to export the model to
   do_constant_folding=True,  # whether to execute constant folding for optimization
   input_names=["input"],  # the model's input names
   output_names=["output"],  # the model's output names
   dynamic_axes={
      "input": {0: "batch_size"},  # variable length axes
      "output": {0: "batch_size"},
   },
)

Notice from the arguments of torch.onnx.export(), even though we are exporting the model with an input of batch_size=1, the first dimension is still specified as dynamic in dynamic_axes parameter. By doing so, the exported model will accept inputs of size [batch_size, 1, 224, 224] where batch_size can vary among inferences.

We can now compute the output using ONNX Runtime’s Python APIs:

import onnxruntime

ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
# ONNX Runtime will return a list of outputs
ort_outs = ort_session.run(None, ort_inputs)
print(ort_outs[0])

First let’s install tf2onnx

pip install tf2onnx

For this tutorial we will download a pretrained ResNet-50 model:

train.py#
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50

model = ResNet50(weights='imagenet')

Notice that we use None in TensorSpec to denote the first input dimension as dynamic batch axies, which means this dimension can accept any arbitrary input size:

train.py#
spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13)

First let’s install sklearn-onnx

pip install skl2onnx

For this tutorial we will train a random forest classifier on Iris Data set:

train.py#
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)

Then we can use skl2onnx to export a scikit-learn model to ONNX format:

train.py#
import skl2onnx

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
initial_type = [('float_input', FloatTensorType([None, 4]))]
model_proto = convert_sklearn(clr, initial_types=initial_type)

Notice that we use None in initial_type to denote the first input dimension as dynamic batch axies, which means this dimension can accept arbitrary input size:

Saving ONNX model with BentoML#

To quickly save any given ONNX model to BentoML’s Model Store, use onnx.load to load the exported ONNX model back into the Python session, then call BentoML’s save_model():

train.py#
signatures = {
    "run": {"batchable": True},
}
bentoml.onnx.save_model("onnx_super_resolution", onnx_model, signatures=signatures)

which will result:

Model(tag="onnx_super_resolution:lwqr7ah5ocv3rea3", path="~/bentoml/models/onnx_super_resolution/lwqr7ah5ocv3rea3/")
train.py#
signatures = {
    "run": {"batchable": True},
}
bentoml.onnx.save_model("onnx_resnet50", onnx_model, signatures=signatures)

which will result:

Model(tag="onnx_resnet50:zavavxh6w2v3rea3", path="~/bentoml/models/onnx_resnet50/zavavxh6w2v3rea3/")
train.py#
signatures = {
    "run": {"batchable": True},
}
bentoml.onnx.save_model("onnx_iris", model_proto, signatures=signatures)

which will result:

Model(tag="onnx_iris:sqixlaqf76vv7ea3", path="~/bentoml/models/onnx_iris/sqixlaqf76vv7ea3/")

The default signature for save_model() is set to {"run": {"batchable": False}}.

This means by default, BentoML’s Adaptive Batching is disabled when saving ONNX model. If you want to enable adaptive batching, provide a signature similar to the aboved example.

Refer to Model signatures and Batching behaviour for more information.

Note

BentoML internally uses onnxruntime.InferenceSession to run inference. When the original model is converted to ONNX format and loaded by onnxruntime.InferenceSession, the inference method of the original model is converted to the run method of the onnxruntime.InferenceSession. signatures in above codes refers to the predict method of onnxruntime.InferenceSession, hence the only allowed method name in signatures is run.

Building a Service for ONNX#

See also

Building a Service for how to create a prediction service with BentoML.

service.py#
import bentoml

import numpy as np
from PIL import Image as PIL_Image
from PIL import ImageOps
from bentoml.io import Image

runner = bentoml.onnx.get("onnx_super_resolution:latest").to_runner()

svc = bentoml.Service("onnx_super_resolution", runners=[runner])

# for output, we set image io descriptor's pilmode to "L" to denote
# the output is a gray scale image
@svc.api(input=Image(), output=Image(pilmode="L"))
async def sr(img) -> np.ndarray:
    img = img.resize((224, 224))
    gray_img = ImageOps.grayscale(img)
    arr = np.array(gray_img) / 255.0  # convert from 0-255 range to 0.0-1.0 range
    arr = np.expand_dims(arr, (0, 1))  # add batch_size, color_channel dims
    sr_arr = await runner.run.async_run(arr)
    sr_arr = np.squeeze(sr_arr)  # remove batch_size, color_channel dims
    sr_arr = np.uint8(sr_arr * 255)
    return sr_arr
service.py#
import bentoml

import numpy as np
from bentoml.io import Image
from bentoml.io import JSON

runner = bentoml.onnx.get("onnx_resnet50:latest").to_runner()

svc = bentoml.Service("onnx_resnet50", runners=[runner])

@svc.api(input=Image(), output=JSON())
async def predict(img):

    from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions

    img = img.resize((224, 224))
    arr = np.array(img)
    arr = np.expand_dims(arr, axis=0)
    arr = preprocess_input(arr)
    preds = await runner.run.async_run(arr)
    return decode_predictions(preds, top=1)[0]
service.py#
import bentoml

from bentoml.io import JSON
from bentoml.io import NumpyNdarray

runner = bentoml.onnx.get("onnx_iris:latest").to_runner()

svc = bentoml.Service("onnx_iris", runners=[runner])

@svc.api(input=NumpyNdarray(), output=JSON())
async def classify(input_array):
    return await runner.run.async_run(input_array)

Note

In the aboved example, notice there are both run and async_run in runner.run.async_run(input_data) inside inference code. The distinction between run and async_run is as follow:

  1. The run refers to onnxruntime.InferenceSession’s run method, which is the ONNX Runtime API to run inference.

  2. The async_run refers to BentoML’s runner inference API for invoking a model’s signature. In the case of ONNX, it happens to have a similar name like the InferenceSession endpoint.

When constructing a bentofile.yaml, there are two ways to include ONNX as a dependency, via python (if using pip) or conda:

python:
  packages:
    - onnx
    - onnxruntime
conda:
  channels:
  - conda-forge
  dependencies:
  - onnx
  - onnxruntime

Using Runners#

See also

Runners for more information on what is a Runner and how to use it.

To test ONNX Runner locally, access the model via get and convert it to a runner object:

test_input = np.random.randn(2, 1, 244, 244)

runner = bentoml.onnx.get("onnx_super_resolution").to_runner()

runner.init_local()

runner.run.run(test_input)

Note

You don’t need to cast your input ndarray to np.float32 for runner input.

Similar to load_model, you can customize providers and session_options when creating a runner:

providers=["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]

bento_model = bentoml.onnx.get("onnx_super_resolution")

runner = bento_model.with_options(providers=providers).to_runner()

runner.init_local()

Loading an ONNX model with BentoML for local testing#

Use load_model to verify that the saved model can be loaded properly:

ort_session = bentoml.onnx.load_model("onnx_super_resolution")

Note

BentoML will load an ONNX model back as an onnxruntime.InferenceSession object which is ready to do inference

test_input = np.random.randn(2, 1, 244, 244) # can accept arbitrary batch size
ort_session.run(None, {"input": test_input.astype(np.float32)})

Note

In the above snippet, we need explicitly convert input ndarray to float32 since onnxruntime.InferenceSession expects only single floats.

However, BentoML will automatically cast the input data automatically via Runners.

Dynamic Batch Size#

See also

Adaptive Batching: a general introduction to adaptive batching in BentoML.

When Adaptive Batching is enabled, the exported ONNX model is REQUIRED to accept dynamic batch size.

Therefore, dynamic batch axes needs to be specified when the model is exported to the ONNX format.

For PyTorch models, you can achieve this by specifying dynamic_axes when using torch.onnx.export

torch.onnx.export(
   torch_model,
   x,
   "super_resolution.onnx",  # where to save the model (can be a file or file-like object)
   export_params=True,  # store the trained parameter weights inside the model file
   opset_version=10,  # the ONNX version to export the model to
   do_constant_folding=True,  # whether to execute constant folding for optimization
   input_names=["input"],  # the model's input names
   output_names=["output"],  # the model's output names
   dynamic_axes={
      "input": {0: "batch_size"},  # variable length axes
      "output": {0: "batch_size"},
   },
)

For TensorFlow models, you can achieve this by using None to denote a dynamic batch axis in TensorSpec when using either tf2onnx.convert.from_keras or tf2onnx.convert.from_function

spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),) # batch_axis = 0
model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13)

For scikit-learn models, you can achieve this by using None in initial_type to denote the a dynamic batch axis when using skl2onnx.convert_sklearn

initial_type = [('float_input', FloatTensorType([None, 4]))]
model_proto = convert_sklearn(clr, initial_types=initial_type)

Default Execution Providers Settings#

When a CUDA-compatible GPU is available, BentoML runner will use ["CUDAExecutionProvider", "CPUExecutionProvider"] as the de facto execution providers.

Otherwise, Runner will use ["CPUExecutionProvider"] as the default providers.

If onnxruntime-gpu is installed, using TensorrtExecutionProvider may improve inference runtime. You can override the default setting using with_options when creating a runner:

providers = ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]

bento_model = bentoml.onnx.get("onnx_super_resolution")

runner = bento_model.with_options(providers=providers).to_runner()