Falcon-40b seq-scheduler rollingbatch deployment guide¶
In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.
Please make sure the following permission granted before running the notebook:
- S3 bucket push access
- SageMaker access
Step 1: Let's bump up SageMaker and import stuff¶
%pip install sagemaker --upgrade --quiet
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers
role = sagemaker.get_execution_role() # execution role for the endpoint
sess = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs
region = sess._region_name # region name of the current SageMaker Studio environment
account_id = sess.account_id() # account_id of the current SageMaker Studio environment
Step 2: Start preparing model artifacts¶
In LMI contianer, we expect some artifacts to help setting up the model
- serving.properties (required): Defines the model server settings
- model.py (optional): A python file to define the core inference logic
- requirements.txt (optional): Any additional pip wheel need to install
import os
os.environ['MODEL_ID'] = "tiiuae/falcon-40b"
os.environ['HF_TRUST_REMOTE_CODE'] = "TRUE"
model_id = os.getenv('MODEL_ID')
with open('serving.properties', 'w') as f:
f.write(f"""engine=Python
option.model_id={model_id}
option.tensor_parallel_degree=8
option.dtype=fp16
option.model_loading_timeout=3600
option.trust_remote_code=true
# rolling-batch parameters
option.max_rolling_batch_size=64
option.rolling_batch=scheduler
# seq-scheduler parameters
# limits the max_sparsity in the token sequence caused by padding
option.max_sparsity=0.33
# limits the max number of batch splits, where each split has its own inference call
option.max_splits=3
# other options: contrastive, sample
option.decoding_strategy=greedy
# default: true
option.disable_flash_attn=true
""")
%%sh
mkdir mymodel
mv serving.properties mymodel/
tar czvf mymodel-3-falcon.tar.gz mymodel/
rm -rf mymodel
image_uri = image_uris.retrieve(
framework="djl-deepspeed",
region=sess.boto_session.region_name,
version="0.27.0"
)
Upload artifact on S3 and create SageMaker model¶
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket() # bucket to house artifacts
code_artifact = sess.upload_data("mymodel-3-falcon.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")
model = Model(image_uri=image_uri, model_data=code_artifact, role=role)
Step 4: Create SageMaker endpoint¶
You need to specify the instance to use and endpoint names
instance_type = "ml.g5.48xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model-3-falcon")
model.deploy(initial_instance_count=1,
instance_type=instance_type,
endpoint_name=endpoint_name,
container_startup_health_check_timeout=3600
)
# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sess,
serializer=serializers.JSONSerializer(),
)
Step 5: Test and benchmark the inference¶
Firstly let's try to run with a wrong inputs
predictor.predict(
{"inputs": "def hello_world():", "parameters": {"max_new_tokens":128, "do_sample":"true"}}
)
Test¶
This can be done outside this notebook, in a bash shell terminal. The connection to the server is via the $SAGEMAKER url. The awscurl
here is a benchmark tool, obtainable from
wget https://github.com/frankfliu/junkyard/releases/download/v0.3.1/awscurl && chmod +x awscurl
. It can be replaced with normal curl command.
%%sh
export CONCUR=4
export SAGEMAKER=https://runtime.sagemaker.us-west-2.amazonaws
.com/endpoints/lmi-model-2-falcon-2023-11-11-03-17-14-364/invocations
TOKENIZER=tiiuae/falcon-40b ./awscurl -c $CONCUR -N 10 -n sagemaker $SAGEMAKER \
--connect-timeout 660 \
-H "Content-type: application/json" \
-d '{"inputs":"The new movie that got Oscar this year","parameters":{"max_new_tokens":256, "do_sample":true, "temperature":0.8, "top_k":5}}' \
-t -o output-3-$CONCUR.txt; \
mv output-3-$CONCUR.txt.0 /home/ec2-user/SageMaker/output_falcon
Clean up¶
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()