Triton Inference Server#

time expected: 10 minutes

NVIDIA Triton Inference Server is a high performance, open-source inference server for serving deep learning models. It is optimized to deploy models from multiple deep learning frameworks, including TensorRT, TensorFlow, ONNX, to various deployments target and cloud providers. Triton is also designed with optimizations to maximize hardware utilization through concurrent model execution and efficient batching strategies.

BentoML now supports running Triton Inference Server as a Runner. The following integration guide assumes that readers are familiar with BentoML architecture. Check out our tutorial should you wish to learn more about BentoML service definition.

For more information about Triton, please refer to the Triton Inference Server documentation.

The code examples in this guide can also be found in the example folder.

Why Integrating BentoML with Triton Inference Server?#

If you are an existing Triton user, the integration provides simpler ways to add custom logics in Python, deploy distributed multi-model inference graph, unify model management across different ML frameworks and workflows, and standardise model packaging format with versioning and collaboration features. If you are an existing BentoML user, the integration improves the runner efficiency and throughput under high load thanks to Triton’s efficient C++ runtime.

Prerequisites#

Make sure to have at least BentoML 1.0.16:

$ pip install -U "bentoml[triton]"

Note

Triton Inference Server is currently only available in production mode (the default mode) and will not work during development mode (--development flag).

Additonally, you will need to have Triton Inference Server installed in your system. Refer to Triton’s building documentation to setup your environment. The recommended way to run Triton is through container (Docker/Podman). To pull the latest Triton container for testing, run:

$ docker pull nvcr.io/nvidia/tritonserver:<yy>.<mm>-py3

Note

<yy>.<mm>: the version of Triton you wish to use. For example, at the time of writing, the latest version is 23.01.

Finally, The example Bento built from the example project with the YOLOv5 model will be referenced throughout this guide.

Note

To develop your own Bento with Triton, you can refer to the example folder for more usage.

Get started with Triton Inference Server#

Triton Inference Server architecture evolves around the model repository and a inference server. The model repository is a filesystem based persistent volume that contains the models file and its respective configuration that defines how the model should be loaded and served. The inference server is implemented in either HTTP/REST or gRPC protocol to serve said models with various batching strategies.

BentoML provides a simple integration with Triton via Runner:

import bentoml

triton_runner = bentoml.triton.Runner("triton_runner", model_repository="/path/to/model_repository")

The argument model_repository is the path to said model repository that Triton can use to serve the model. Note that model_repository also supports S3 path:

import bentoml

triton_runner = bentoml.triton.Runner("triton_runner",
                                      model_repository="s3://bucket/path/to/model_repository",
                                      cli_args=["--load-model=torchscrip_yolov5s", "--model-control-mode=explicit"]
)

Note

If models are saved on the file system, using the Triton runner requires setting up the model repository explicitly through the includes key in the bentofile.yaml.

Note

The cli_args argument is a list of arguments that will be passed to the tritonserver command. For example, the --load-model argument is used to load a specific model from the model repository. See tritonserver --help for all available arguments.

From a developer perspective, remote invocation of Triton runners is similar to invoking any other BentoML runners.

Note

By default, bentoml.triton.Runner will run the tritonserver with gRPC protocol. To use HTTP/REST protocol, provide tritonserver_type=''http' to the Runner constructor.

import bentoml

triton_runner = bentoml.triton.Runner("triton_runner", model_repository="/path/to/model_repository", tritonserver_type="http")

Triton Runner Signatures#

Normally in a BentoML Runner, one can access the model signatures directly from the runners attributes. For example, the model signature predict of a iris_classifier_runner (see service definition) can be accessed as iris_classifier_runner.predict.run.

However, Triton runner’s attributes represent individual models defined under the model repository. For example, if the model repository has the following structure:

model_repository
β”œβ”€β”€ onnx_mnist
β”‚Β Β  β”œβ”€β”€ 1
β”‚Β Β  β”‚Β Β  └── model.onnx
β”‚Β Β  └── config.pbtxt
β”œβ”€β”€ tensorflow_mnist
β”‚Β Β  β”œβ”€β”€ 1
β”‚Β Β  β”‚Β Β  └── model.savedmodel/
β”‚Β Β  └── config.pbtxt
└── torchscript_mnist
    β”œβ”€β”€ 1
    β”‚Β Β  └── model.pt
    └── config.pbtxt

Then each model inference can be accessed as triton_runner.onnx_mnist, triton_runner.tensorflow_mnist, or triton_runner.torchscript_mnist and invoked using either run or async_run.

An example to demonstrate how to call the Triton runner:

import bentoml
import numpy as np

@svc.api(
    input=bentoml.io.Image.from_sample("./data/0.png"), output=bentoml.io.NumpyNdarray()
)
async def bentoml_torchscript_mnist_infer(im: Image) -> NDArray[t.Any]:
    arr = np.array(im) / 255.0
    arr = np.expand_dims(arr, (0, 1)).astype("float32")
    InferResult = await triton_runner.torchscript_mnist.async_run(arr)
    return InferResult.as_numpy("OUTPUT__0")

There are a few things to note here:

  1. Triton runners should only be called within an API function. In other words, if triton_runner.torchscript_mnist.async_run is invoked in the global scope, it will not work. This is because Triton is not implemented natively in Python, and hence init_local is not supported.

    triton_runner.init_local()
    
    # TritonRunner 'triton_runner' will not be available for development mode.
    
  2. async_run and run for any Triton runner call either takes all positional arguments or keyword arguments. The arguments should be in the same order as the inputs/outputs signatures defined in config.pbtxt.

    For example, if the following config.pbtxt is used for torchscript_mnist:

    platform: "pytorch_libtorch"
    dynamic_batching {}
    input {
     name: "INPUT__0"
     data_type: TYPE_FP32
     dims: -1
     dims: 1
     dims: 28
     dims: 28
    }
    input {
     name: "INPUT__1"
     data_type: TYPE_FP32
     dims: -1
     dims: 1
     dims: 28
     dims: 28
    }
    output {
     name: "OUTPUT__0"
     data_type: TYPE_FP32
     dims: -1
     dims: 10
    }
    output {
     name: "OUTPUT__1"
     data_type: TYPE_FP32
     dims: -1
     dims: 10
    }
    

    Then run or async_run takes either two positional arguments or two keyword arugments INPUT__0 and INPUT__1:

    # Both are valid
    triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28)))
    
    await triton_runner.torchscript_mnist.async_run(
        INPUT__0=np.zeros((1, 28, 28)), INPUT__1=np.zeros((1, 28, 28))
    )
    

    Mixing positional and keyword arguments will result in an error:

    triton_runner.torchscript_mnist.run(
        np.zeros((1, 28, 28)), INPUT__1=np.zeros((1, 28, 28))
    )
    # throws errors
    
  3. run and async_run return a InferResult object. Regardless of the protocol used, the InferResult object has the following methods:

    • as_numpy(name: str) -> NDArray[T]: returns the result as a numpy array. The argument is the name of the output defined in config.pbtxt.

    • get_output(name: str) -> InferOutputTensor | dict[str, T]: Returns the results as a InferOutputTensor (gRPC) or a dictionary (HTTP). The argument is the name of the output defined in config.pbtxt.

    • get_response(self) -> ModelInferResponse | dict[str, T]: Returns the entire response as a ModelInferResponse (gRPC) or a dictionary (HTTP).

    Using the above config.pbtxt as example, the model consists of two outputs, OUTPUT__0 and OUTPUT__1.

    To get OUTPUT__0 as a numpy array:

    InferResult = triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28)))
    return InferResult.as_numpy("OUTPUT__0")
    
    InferResult = triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28)))
    return InferResult.as_numpy("OUTPUT__0")
    

    To get OUTPUT__1 as a JSON dictionary:

    InferResult = triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28)))
    return InferResult.get_output("OUTPUT__0", as_json=True)
    
    InferResult = triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28)))
    return InferResult.get_output("OUTPUT__0")
    

Additonally, the Triton runner exposes all tritonclient functions.

Supported client APIs

The list below comprises all the model management APIs from tritonclient that are supported by Triton runners:

  • get_model_config

  • get_model_metadata

  • get_model_repository_index

  • is_model_ready

  • is_server_live

  • is_server_ready

  • load_model

  • unload_model

  • infer

  • stream_infer

The following advanced client APIs are also supported:

  • get_cuda_shared_memory_status

  • get_inference_statistics

  • get_log_settings

  • get_server_metadata

  • get_system_shared_memory_status

  • get_trace_settings

  • register_cuda_shared_memory

  • register_system_shared_memory

  • unregister_cuda_shared_memory

  • unregister_system_shared_memory

  • update_log_settings

  • update_trace_settings

Important: All of the client APIs are asynchronous. To use them, make sure to use it under an async @svc.api. See Synchronous and asynchronous APIs

service.py#
@svc.api(input=bentoml.io.Text.from_sample("onnx_mnist"), output=bentoml.io.JSON())
async def unload_model(input_model: str):
    await triton_runner.unload_model(input_model)
    return {"unloaded": input_model}

Packaging BentoService with Triton Inference Server#

To build your BentoService with Triton Inference Server, add the following to your bentofile.yaml or use reference/core:bentoml.bentos.build:

bentofile.yaml#
service: service:svc
include:
  - /model_repository
  - /data/*.png
  - /*.py
exclude:
  - /__pycache__
  - /venv
  - /train.py
  - /build_bento.py
  - /containerize_bento.py
python:
  packages:
    - bentoml[triton]
docker:
  base_image: nvcr.io/nvidia/tritonserver:22.12-py3

Building this Bento with bentoml build:

$ bentoml build
build_bento.py#
if __name__ == "__main__":
    import bentoml

    bentoml.bentos.build(
        "service:svc",
        include=["/model_repository", "/data/*.png", "service.py"],
        exclude=["/__pycache__", "/venv"],
        docker={"base_image": "nvcr.io/nvidia/tritonserver:22.12-py3"},
    )

Notice that we are using nvcr.io/nvidia/tritonserver:22.12-py3 as our base image. This can be substituted with any other custom base image that has tritonserver binary available. See Triton’s documentation here to learn more about building/composing custom Triton image.

Important: The provided Triton image from NVIDIA includes Python 3.8. Therefore, if you are developing your Bento with any other Python version, make sure that your service.py is compatible with Python 3.8.

Tip

To see all available options for Triton run:

$ docker run --init --rm -p 3000:3000 triton-integration:gpu tritonserver --help

Current Caveats#

At the time of writing, there are a few caveats that you should be aware of when using TritonRunner:

Versioning Policy Limitations#

By default, model configuration version policy is set to latest(n=1), meaning the latest version of the model will be loaded into Triton server.

Currently, TritonRunner only supports the latest policy. If you have multiple versions of the same model in your BentoService, then the runner only consider the latest version.

For example, if the model repository have the following structure:

model_repository
β”œβ”€β”€ onnx_mnist
β”‚Β Β  β”œβ”€β”€ 1
β”‚Β Β  β”‚Β Β  └── model.onnx
β”‚Β Β  β”œβ”€β”€ 2
β”‚Β Β  β”‚Β Β  └── model.onnx
β”‚Β Β  └── config.pbtxt
...

Then triton_runner.onnx_mnist will reference to the latest version of the model (in this case, version 2).

To use a specific version of said model, refer to the example below:

service.py#
from __future__ import annotations

import typing as t

import numpy as np
from tritonclient.grpc.aio import InferInput
from tritonclient.grpc.aio import np_to_triton_dtype
from tritonclient.grpc.aio import InferRequestedOutput

import bentoml

if t.TYPE_CHECKING:
    from PIL.Image import Image
    from numpy.typing import NDArray

# triton runner
triton_runner = bentoml.triton.Runner(
    "triton_runner",
    "./model_repository",
    cli_args=[
        "--load-model=onnx_mnist",
        "--load-model=torchscript_yolov5s",
        "--model-control-mode=explicit",
    ],
)

svc = bentoml.Service("triton-integration", runners=[triton_runner])


@svc.api(
    input=bentoml.io.Image.from_sample("./data/0.png"), output=bentoml.io.NumpyNdarray()
)
async def predict_v1(input_data: Image) -> NDArray[t.Any]:
    arr = np.array(input_data) / 255.0
    arr = np.expand_dims(arr, (0, 1)).astype("float32")
    input_0 = InferInput("input_0", arr.shape, np_to_triton_dtype(arr.dtype))
    input_0.set_data_from_numpy(arr)
    output_0 = InferRequestedOutput("output_0")
    InferResult = await triton_runner.infer(
        "onnx_mnist", inputs=[input_0], model_version="1", outputs=[output_0]
    )
    return InferResult.as_numpy("output_0")

Inference Protocol and Metrics Server#

By default, TritonRunner uses the Inference protocol for both REST and gRPC.

HTTP/REST APIs is disabled by default, though it can be enabled when creating the runner by passing tritonserver_type to the Runner:

triton_runner = TritonRunner(
    "http_runner",
    "/path/to/model_repository",
    tritonserver_type="http"
)

Currently, TritonRunner does not support running Metrics server. If you are interested in supporting the metrics server, please open an issue on GitHub

Additionally, BentoML will allocate a random port for the gRPC/HTTP server, hence grpc-port or http-port options that is passed to Runner cli_args will be omitted.

Adaptive Batching#

Adaptive batching is a feature supported by BentoML runners that allows for efficient batch size selection during inference. However, it’s important to note that this feature is not compatible with TritonRunner.

TritonRunner is designed as a standalone Triton server, which means that the adaptive batching logic in BentoML runners is not invoked when using TritonRunner.

Fortunately, Triton supports its own solution for efficient batching called dynamic batching. Similar to adaptive batching, dynamic batching also allows for the selection of the optimal batch size during inference. To use dynamic batching in Triton, relevant settings can be specified in the model configuration file.

🚧 Help us improve the integration!

This integration is still in its early stages and we are looking for feedbacks and contributions to make it even better!

If you have any feedback or want to contribute any improvements to the Triton Inference Server integration, we would love to see your feature requests and pull request!

Check out the BentoML development guide and documentation guide to get started.