vllm rollingbatch deploy Mixtral 8x7B DPO AWQ
vllm Mixtral-8x7B-DPO-AWQ deployment guide¶
In this tutorial, you will use vllm backend of Large Model Inference(LMI) DLC to deploy Mixtral-8x7B-DPO-AWQ and run inference with it.
Please make sure the following permission granted before running the notebook:
- SageMaker access
Step 1: Let's bump up SageMaker and import stuff¶
%pip install sagemaker --upgrade --quiet
import sagemaker
from sagemaker.djl_inference.model import DJLModel
role = sagemaker.get_execution_role() # execution role for the endpoint
session = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs
Step 2: Start building SageMaker endpoint¶
In this step, we will build SageMaker endpoint from scratch
Getting the container image URI (optional)¶
Check out available images: Large Model Inference available DLC
# Choose a specific version of LMI image directly:
# image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124"
Create SageMaker model¶
Here we are using LMI PySDK to create the model.
Checkout more configuration options.
model_id = "TheBloke/Nous-Hermes-2-Mixtral-8x7B-DPO-AWQ" # model will be download form Huggingface hub
env = {
"TENSOR_PARALLEL_DEGREE": "max", # use all GPUs on the instance
"OPTION_ROLLING_BATCH": "vllm", # use vllm for rolling batching
"OPTION_QUANTIZE": "awq",
"OPTION_MAX_MODEL_LEN": "8192",
}
model = DJLModel(
model_id=model_id,
env=env,
role=role)
Create SageMaker endpoint¶
You need to specify the instance to use and endpoint names
instance_type = "ml.g4dn.12xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")
predictor = model.deploy(initial_instance_count=1,
instance_type=instance_type,
endpoint_name=endpoint_name,
)
Step 3: Run inference¶
system_message=""
input_text = "请解释一下AI"
prompt_template=f'''<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{input_text}<|im_end|>
<|im_start|>assistant
'''
parameters = {
"max_new_tokens":128,
"do_sample":True,
"temperature":0.7,
"top_p":0.95,
"top_k":40,
"repetition_penalty":1.1
}
None Streaming¶
predictor.predict(
{
"inputs": prompt_template,
"parameters": parameters
}
)
Streaming¶
import json
import boto3
smr_client = boto3.client("sagemaker-runtime")
def get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload):
response_stream = sagemaker_runtime.invoke_endpoint_with_response_stream(
EndpointName=endpoint_name,
Body=json.dumps(payload),
ContentType="application/json"
)
return response_stream
payload = {
"inputs": prompt_template,
"parameters": parameters,
"stream": True ## <-- to have response stream.
}
from utils.LineIterator import LineIterator
def print_response_stream(response_stream):
event_stream = response_stream.get('Body')
for line in LineIterator(event_stream):
print(line, end='')
response_stream = get_realtime_response_stream(smr_client, endpoint_name, payload)
print_response_stream(response_stream)
Clean up the environment¶
session.delete_endpoint(endpoint_name)
session.delete_endpoint_config(endpoint_name)
model.delete_model()