This chapter demonstrates the use of TorchServe with Amazon Inferentia hardware and the Neuron SDK. By the end of this tutorial, you will understand how TorchServe can be used to serve a model backed by EC2 Inf1 instances. We will use a pretrained BERT-Base model to determine if one sentence is a paraphrase of another.
create a new directory named torchserve
. Copy your compiled bert_neuron_b6.pt
file from the chapter 5.5 into this new directory.
cd torchserve
ls
bert_neuron_b6.pt
Install the system requirements for TorchServe.
sudo apt install openjdk-11-jdk
pip install transformers==4.2.0 torchserve==0.3.0 torch-model-archiver==0.3.0
java -version
javac -version
torchserve --version
Create config.json
with the following content:
{
"model_name": "bert-base-cased-finetuned-mrpc",
"max_length": 128,
"batch_size": 6
}
Create a python script named handler_bert.py
with the following content:
import os
import json
import sys
import logging
import torch, torch_neuron
from transformers import AutoTokenizer
from abc import ABC
from ts.torch_handler.base_handler import BaseHandler
# one core per worker
os.environ['NEURONCORE_GROUP_SIZES'] = '1'
logger = logging.getLogger(__name__)
class BertEmbeddingHandler(BaseHandler, ABC):
"""
Handler class for Bert Embedding computations.
"""
def __init__(self):
super(BertEmbeddingHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
self.device = 'cpu'
model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
model_pt_path = os.path.join(model_dir, serialized_file)
# point sys.path to our config file
with open('config.json') as fp:
config = json.load(fp)
self.max_length = config['max_length']
self.batch_size = config['batch_size']
self.classes = ['not paraphrase', 'paraphrase']
self.model = torch.jit.load(model_pt_path)
logger.debug(f'Model loaded from {model_dir}')
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
self.initialized = True
def preprocess(self, input_data):
"""
Tokenization pre-processing
"""
input_ids = []
attention_masks = []
token_type_ids = []
for row in input_data:
seq_0 = row['seq_0'].decode('utf-8')
seq_1 = row['seq_1'].decode('utf-8')
logger.debug(f'Received text: "{seq_0}", "{seq_1}"')
inputs = self.tokenizer.encode_plus(
seq_0,
seq_1,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids.append(inputs['input_ids'])
attention_masks.append(inputs['attention_mask'])
token_type_ids.append(inputs['token_type_ids'])
batch = (torch.cat(input_ids, 0),
torch.cat(attention_masks, 0),
torch.cat(token_type_ids, 0))
return batch
def inference(self, inputs):
"""
Predict the class of a text using a trained transformer model.
"""
# sanity check dimensions
assert(len(inputs) == 3)
num_inferences = len(inputs[0])
assert(num_inferences <= self.batch_size)
# insert padding if we received a partial batch
padding = self.batch_size - num_inferences
if padding > 0:
pad = torch.nn.ConstantPad1d((0, 0, 0, padding), value=0)
inputs = [pad(x) for x in inputs]
outputs = self.model(*inputs)[0]
predictions = []
for i in range(num_inferences):
prediction = self.classes[outputs[i].argmax().item()]
predictions.append([prediction])
logger.debug("Model predicted: '%s'", prediction)
return predictions
def postprocess(self, inference_output):
return inference_output
Next, we need to associate the handler script with the compiled model using torch-model-archiver. Run the following commands in your terminal:
mkdir model_store
MAX_LENGTH=$(jq '.max_length' config.json)
BATCH_SIZE=$(jq '.batch_size' config.json)
MODEL_NAME=bert-max_length$MAX_LENGTH-batch_size$BATCH_SIZE
torch-model-archiver --model-name "$MODEL_NAME" --version 1.0 --serialized-file ./bert_neuron_b6.pt --handler "./handler_bert.py" --extra-files "./config.json" --export-path model_store
If you modify your model or a dependency, you will need to rerun the archiver command with the -f
flag appended to update the archive.
The result of the above will be a mar
file inside the model_store
directory.
ls model_store
bert-max_length128-batch_size6.mar
This file is essentially an archive associated with a fixed version of your model along with its dependencies (e.g. the handler code).
The version specified in the torch-model-archiver
command can be appended to REST API requests to access a specific version of your model. For example,
if your model was hosted locally on port 8080 and named “bert”, the latest version of your model would be available at http://localhost:8080/predictions/bert
,
while version 1.0 would be accessible at http://localhost:8080/predictions/bert/1.0
. We will see how to perform inference using this API in Step 4.
Create a custom config file named torchserve.config
with the following content to set some parameters.
This file will be used to configure the server at launch when we run torchserve --start
.
# bind inference API to all network interfaces with SSL enabled
inference_address=http://0.0.0.0:8080
default_workers_per_model=1
This will cause TorchServe to bind on all interfaces. For security in real-world applications, you’ll probably want to use port 8443 and enable SSL.
It’s time to start the server. Typically we’d want to launch this in a separate console, but for this demo we’ll just redirect output to a file.
torchserve --start --ncs --model-store model_store --ts-config torchserve.config 2>&1 >torchserve.log
Verify that the server seems to have started okay.
curl http://127.0.0.1:8080/ping
{
"status": "Healthy"
}
If you get an error when trying to ping the server, you may have tried before the server was fully launched. Check torchserve.log
for details.
Use the Management API to instruct TorchServe to load our model.
MAX_BATCH_DELAY=5000 # ms timeout before a partial batch is processed
INITIAL_WORKERS=4 # number of models that will be loaded at launch
curl -X POST "http://localhost:8081/models?url=$MODEL_NAME.mar&batch_size=$BATCH_SIZE&initial_workers=$INITIAL_WORKERS&max_batch_delay=$MAX_BATCH_DELAY"
{
"status": "Model \"bert-max_length128-batch_size6\" Version: 1.0 registered with 4 initial workers"
}
Any additional attempts to configure the model after the initial curl request will cause the server to return a 409 error. You’ll need to stop/start/configure the server to realize any changes.
The MAX_BATCH_DELAY
is a timeout value that determines how long to wait before processing a partial batch.
This is why the handler code needs to check the batch dimension and potentially add padding.
TorchServe will instantiate the number of model handlers indicated by INITIAL_WORKERS
, so this value controls how many models we will load onto Inferentia in parallel.
This tutorial was performed on an inf1.2xlarge instance (one Inferentia chip), so there are four NeuronCores available.
If you want to control worker scaling more dynamically, see the docs.
If you attempt to load more models than NeuronCores available, one of two things will occur. Either the extra models will fit in device memory but performance will suffer, or you will encounter an error on your initial inference. You shouldn’t set INITIAL_WORKERS above the number of NeuronCores. However, you may want to use fewer cores if you are using the NeuronCore Pipeline feature.
It looks like everything is running successfully at this point, so it’s time for an inference.
Create the infer_bert.py
file below on your instance.
import json
import concurrent.futures
import requests
with open('config.json') as fp:
config = json.load(fp)
max_length = config['max_length']
batch_size = config['batch_size']
name = f'bert-max_length{max_length}-batch_size{batch_size}'
# dispatch requests in parallel
url = f'http://localhost:8080/predictions/{name}'
paraphrase = {'seq_0': "HuggingFace's headquarters are situated in Manhattan",
'seq_1': "The company HuggingFace is based in New York City"}
not_paraphrase = {'seq_0': paraphrase['seq_0'], 'seq_1': 'This is total nonsense.'}
with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
def worker_thread(worker_index):
# we'll send half the requests as not_paraphrase examples for sanity
data = paraphrase if worker_index < batch_size//2 else not_paraphrase
response = requests.post(url, data=data)
print(worker_index, response.json())
for worker_index in range(batch_size):
executor.submit(worker_thread, worker_index)
This script will send a batch_size
number of requests to our model.
In this example, we are using a model that estimates the probability that one sentence is a paraphrase of another.
The script sends positive examples in the first half of the batch and negative examples in the second half.
Execute the script in your terminal.
python infer_bert.py
1 ['paraphrase']
3 ['not paraphrase']
4 ['not paraphrase']
0 ['paraphrase']
5 ['not paraphrase']
2 ['paraphrase']
We can see that the first three threads (0, 1, 2) all report paraphrase, as expected. If we instead modify the script to send an incomplete batch and then wait for the timeout to expire, the excess padding results will be discarded
Congratulations! By now you should have successfully served a batched model over TorchServe.