Skip to content

vllm rollingbatch deploy Mixtral 8x7B DPO AWQ

Run this notebook online:Binder

vllm Mixtral-8x7B-DPO-AWQ deployment guide

In this tutorial, you will use vllm backend of Large Model Inference(LMI) DLC to deploy Mixtral-8x7B-DPO-AWQ and run inference with it.

Please make sure the following permission granted before running the notebook:

  • SageMaker access

Step 1: Let's bump up SageMaker and import stuff

%pip install sagemaker --upgrade  --quiet
import sagemaker
from sagemaker.djl_inference.model import DJLModel

role = sagemaker.get_execution_role()  # execution role for the endpoint
session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs

Step 2: Start building SageMaker endpoint

In this step, we will build SageMaker endpoint from scratch

Getting the container image URI (optional)

Check out available images: Large Model Inference available DLC

# Choose a specific version of LMI image directly:
# image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124"

Create SageMaker model

Here we are using LMI PySDK to create the model.

Checkout more configuration options.

model_id = "TheBloke/Nous-Hermes-2-Mixtral-8x7B-DPO-AWQ" # model will be download form Huggingface hub

env = {
    "TENSOR_PARALLEL_DEGREE": "max",          # use all GPUs on the instance
    "OPTION_ROLLING_BATCH": "vllm",           # use vllm for rolling batching
    "OPTION_QUANTIZE": "awq",
    "OPTION_MAX_MODEL_LEN": "8192",
}

model = DJLModel(
            model_id=model_id,
            env=env,
            role=role)

Create SageMaker endpoint

You need to specify the instance to use and endpoint names

instance_type = "ml.g4dn.12xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")

predictor = model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
            )

Step 3: Run inference

system_message=""
input_text = "请解释一下AI"

prompt_template=f'''<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{input_text}<|im_end|>
<|im_start|>assistant
'''
parameters = {
    "max_new_tokens":128,
    "do_sample":True,
    "temperature":0.7,
    "top_p":0.95,
    "top_k":40,
    "repetition_penalty":1.1
}

None Streaming

predictor.predict(
    {
        "inputs": prompt_template, 
         "parameters": parameters
    }
)

Streaming

import json
import boto3

smr_client = boto3.client("sagemaker-runtime")
def get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload):
    response_stream = sagemaker_runtime.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(payload), 
        ContentType="application/json"
    )
    return response_stream
payload = {
    "inputs":  prompt_template,
    "parameters": parameters,
    "stream": True ## <-- to have response stream.
}
from utils.LineIterator import LineIterator

def print_response_stream(response_stream):
    event_stream = response_stream.get('Body')
    for line in LineIterator(event_stream):
        print(line, end='')
response_stream = get_realtime_response_stream(smr_client, endpoint_name, payload)
print_response_stream(response_stream)

Clean up the environment

session.delete_endpoint(endpoint_name)
session.delete_endpoint_config(endpoint_name)
model.delete_model()