Skip to content

Trtllm rollingbatch deploy mixtral 8x7b

Run this notebook online:Binder

TensorRT-LLM(TRT-LLM) rollingbatch Mixtral-8x7B deployment guide

In this tutorial, you will use TensorRT-LLM SageMaker Large Model Inference(LMI) DLC to deploy Mixtral-8x7B and run inference with it.

Please make sure the following permission granted before running the notebook:

  • S3 bucket push access
  • SageMaker access

Step 1: Let's bump up SageMaker and import stuff

%pip install sagemaker --upgrade  --quiet
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id() 

Step 2: Start preparing model artifacts

In LMI container, we expect some artifacts to help setting up the model

  • serving.properties (required): Defines the model server settings
  • model.py (optional): A python file to define the core inference logic
  • requirements.txt (optional): Any additional pip wheel need to install
%%writefile serving.properties
engine=MPI
option.model_id=mistralai/Mixtral-8x7B-v0.1
option.tensor_parallel_degree=8
option.max_rolling_batch_size=32
%%sh
mkdir mymodel
mv serving.properties mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

Step 3: Start building SageMaker endpoint

Getting the container image URI

See available Large Model Inference DLC's here

image_uri = image_uris.retrieve(
        framework="djl-tensorrtllm",
        region=sess.boto_session.region_name,
        version="0.27.0"
    )

Upload artifact on S3 and create SageMaker model

s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

Create SageMaker endpoint with a specified instance type

instance_type = "ml.g5.48xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")
print(f"endpoint_name: {endpoint_name}")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             container_startup_health_check_timeout=1800
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
)

Step 4: Run inference

predictor.predict(
    {"inputs": "The future of Artificial Intelligence is", "parameters": {"max_new_tokens":128, "do_sample":True}}
)