LLAMA 7B with customized stop reasons¶
In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.
Please make sure the following permission granted before running the notebook:
- S3 bucket push access
- SageMaker access
Step 1: Let's bump up SageMaker and import stuff¶
%pip install sagemaker --upgrade --quiet
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers
role = sagemaker.get_execution_role() # execution role for the endpoint
sess = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs
region = sess._region_name # region name of the current SageMaker Studio environment
account_id = sess.account_id() # account_id of the current SageMaker Studio environment
Step 2: Start preparing model artifacts¶
In LMI contianer, we expect some artifacts to help setting up the model - serving.properties (required): Defines the model server settings - model.py (optional): A python file to define the core inference logic - requirements.txt (optional): Any additional pip wheel need to install
%%writefile serving.properties
engine=MPI
option.model_id=TheBloke/Llama-2-7B-fp16
option.tensor_parallel_degree=1
option.max_rolling_batch_size=32
option.rolling_batch=lmi-dist
option.output_formatter=jsonlines
%%sh
mkdir mymodel
mv serving.properties mymodel/
mv model.py mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel
Step 3: Start building SageMaker endpoint¶
In this step, we will build SageMaker endpoint from scratch
Getting the container image URI¶
image_uri = image_uris.retrieve(
framework="djl-deepspeed",
region=sess.boto_session.region_name,
version="0.27.0"
)
Upload artifact on S3 and create SageMaker model¶
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket() # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")
model = Model(image_uri=image_uri, model_data=code_artifact, role=role)
4.2 Create SageMaker endpoint¶
You need to specify the instance to use and endpoint names
instance_type = "ml.g5.2xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")
model.deploy(initial_instance_count=1,
instance_type=instance_type,
endpoint_name=endpoint_name,
)
Step 5: Test and benchmark the inference¶
class LineIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord('\n'):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if 'PayloadPart' not in chunk:
print('Unknown event type:' + chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk['PayloadPart']['Bytes'])
import json, io
sm_client = boto3.client("sagemaker-runtime")
body = {"inputs": "what is life", "parameters": {"max_new_tokens":25, "do_sample": True, "details": True}}
resp = sm_client.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(body), ContentType="application/json")
event_stream = resp['Body']
for line in LineIterator(event_stream):
resp = json.loads(line)
print(resp)
Let's try to make it continue generate based on finish reason
def inference(payload):
resp = sm_client.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(payload), ContentType="application/json")
event_stream = resp['Body']
text_output = []
for line in LineIterator(event_stream):
resp = json.loads(line)
token = resp['token']['text']
text_output.append(token)
print(token, end='')
if resp['details']:
finish_reason = resp['details']['finish_reason']
return payload['inputs'] + ''.join(text_output), finish_reason, len(text_output)
payload = {"inputs": "The new movie that got Oscar this year", "parameters": {"max_new_tokens":128, "do_sample": True, "top_p": 0.9,
"temperature": 0.8, "repetition_penalty": 1.2, "details": True}}
finish_reason = "length"
print(f"Ouput: {payload['inputs']}", end='')
total_tokens = 0
total_query = 0
while finish_reason == 'length':
output_text, finish_reason, out_token_len = inference(payload)
payload['inputs'] = output_text
total_tokens += out_token_len
total_query += 1
print(f"\ntotal token generated: {total_tokens} total query sent: {total_query}")
Clean up the environment¶
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()