Skip to content

Commit de7da4a

Browse files
authored
Hugging Face Transformer Deployment Tutorial (#49)
* Initial Commit * Mount model repo so changes reflect, parameter tweaking, README file * Image name error * Incorporating review comments. Separate docker and model repo builds, add README, restructure repo * Tutorial restructuring. Using static model configurations * Bump triton container and update README * Remove client script * Incorporating review comments * Modify WIP line in vLLM tutorial * Remove trust_remote_code parameter from falcon model * Removing Mistral * Incorporating Feedback * Change input/output names * Pre-commit format * Different perf_analyzer example, config file format fixes * Deep dive changes to Triton tools section * Remove unused variable
1 parent af67595 commit de7da4a

File tree

7 files changed

+667
-2
lines changed

7 files changed

+667
-2
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
FROM nvcr.io/nvidia/tritonserver:23.09-py3
27+
RUN pip install transformers==4.34.0 protobuf==3.20.3 sentencepiece==0.1.99 accelerate==0.23.0 einops==0.6.1

Quick_Deploy/HuggingFaceTransformers/README.md

Lines changed: 355 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
import os
27+
28+
os.environ[
29+
"TRANSFORMERS_CACHE"
30+
] = "/opt/tritonserver/model_repository/falcon7b/hf_cache"
31+
import json
32+
33+
import numpy as np
34+
import torch
35+
import transformers
36+
import triton_python_backend_utils as pb_utils
37+
38+
39+
class TritonPythonModel:
40+
def initialize(self, args):
41+
self.logger = pb_utils.Logger
42+
self.model_config = json.loads(args["model_config"])
43+
self.model_params = self.model_config.get("parameters", {})
44+
default_hf_model = "tiiuae/falcon-7b"
45+
default_max_gen_length = "15"
46+
# Check for user-specified model name in model config parameters
47+
hf_model = self.model_params.get("huggingface_model", {}).get(
48+
"string_value", default_hf_model
49+
)
50+
# Check for user-specified max length in model config parameters
51+
self.max_output_length = int(
52+
self.model_params.get("max_output_length", {}).get(
53+
"string_value", default_max_gen_length
54+
)
55+
)
56+
57+
self.logger.log_info(f"Max sequence length: {self.max_output_length}")
58+
self.logger.log_info(f"Loading HuggingFace model: {hf_model}...")
59+
# Assume tokenizer available for same model
60+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model)
61+
self.pipeline = transformers.pipeline(
62+
"text-generation",
63+
model=hf_model,
64+
torch_dtype=torch.float16,
65+
tokenizer=self.tokenizer,
66+
device_map="auto",
67+
)
68+
self.pipeline.tokenizer.pad_token_id = self.tokenizer.eos_token_id
69+
70+
def execute(self, requests):
71+
prompts = []
72+
for request in requests:
73+
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
74+
multi_dim = input_tensor.as_numpy().ndim > 1
75+
if not multi_dim:
76+
prompt = input_tensor.as_numpy()[0].decode("utf-8")
77+
self.logger.log_info(f"Generating sequences for text_input: {prompt}")
78+
prompts.append(prompt)
79+
else:
80+
# Implementation to accept dynamically batched inputs
81+
num_prompts = input_tensor.as_numpy().shape[0]
82+
for prompt_index in range(0, num_prompts):
83+
prompt = input_tensor.as_numpy()[prompt_index][0].decode("utf-8")
84+
prompts.append(prompt)
85+
86+
batch_size = len(prompts)
87+
return self.generate(prompts, batch_size)
88+
89+
def generate(self, prompts, batch_size):
90+
sequences = self.pipeline(
91+
prompts,
92+
max_length=self.max_output_length,
93+
pad_token_id=self.tokenizer.eos_token_id,
94+
batch_size=batch_size,
95+
)
96+
responses = []
97+
texts = []
98+
for i, seq in enumerate(sequences):
99+
output_tensors = []
100+
text = seq[0]["generated_text"]
101+
texts.append(text)
102+
tensor = pb_utils.Tensor("text_output", np.array(texts, dtype=np.object_))
103+
output_tensors.append(tensor)
104+
responses.append(pb_utils.InferenceResponse(output_tensors=output_tensors))
105+
106+
return responses
107+
108+
def finalize(self):
109+
print("Cleaning up...")
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Triton backend to use
2+
backend: "python"
3+
4+
# Hugging face model path. Parameters must follow this
5+
# key/value structure
6+
parameters: {
7+
key: "huggingface_model",
8+
value: {string_value: "tiiuae/falcon-7b"}
9+
}
10+
11+
# The maximum number of tokens to generate in response
12+
# to our input
13+
parameters: {
14+
key: "max_output_length",
15+
value: {string_value: "15"}
16+
}
17+
18+
# Triton should expect as input a single string of set
19+
# length named 'text_input'
20+
input [
21+
{
22+
name: "text_input"
23+
data_type: TYPE_STRING
24+
dims: [ 1 ]
25+
}
26+
]
27+
28+
# Triton should expect to respond with a single string
29+
# output of variable length named 'text_output'
30+
output [
31+
{
32+
name: "text_output"
33+
data_type: TYPE_STRING
34+
dims: [ -1 ]
35+
}
36+
]
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
import os
27+
28+
os.environ[
29+
"TRANSFORMERS_CACHE"
30+
] = "/opt/tritonserver/model_repository/persimmon8b/hf_cache"
31+
32+
import json
33+
34+
import numpy as np
35+
import torch
36+
import transformers
37+
import triton_python_backend_utils as pb_utils
38+
39+
40+
class TritonPythonModel:
41+
def initialize(self, args):
42+
self.logger = pb_utils.Logger
43+
self.model_config = json.loads(args["model_config"])
44+
self.model_params = self.model_config.get("parameters", {})
45+
default_hf_model = "adept/persimmon-8b-base"
46+
default_max_gen_length = "15"
47+
# Check for user-specified model name in model config parameters
48+
hf_model = self.model_params.get("huggingface_model", {}).get(
49+
"string_value", default_hf_model
50+
)
51+
# Check for user-specified max length in model config parameters
52+
self.max_output_length = int(
53+
self.model_params.get("max_output_length", {}).get(
54+
"string_value", default_max_gen_length
55+
)
56+
)
57+
58+
self.logger.log_info(f"Max output length: {self.max_output_length}")
59+
self.logger.log_info(f"Loading HuggingFace model: {hf_model}...")
60+
# Assume tokenizer available for same model
61+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model)
62+
self.pipeline = transformers.pipeline(
63+
"text-generation",
64+
model=hf_model,
65+
torch_dtype=torch.float16,
66+
tokenizer=self.tokenizer,
67+
device_map="auto",
68+
)
69+
70+
def execute(self, requests):
71+
responses = []
72+
for request in requests:
73+
# Assume input named "prompt", specified in autocomplete above
74+
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
75+
prompt = input_tensor.as_numpy()[0].decode("utf-8")
76+
77+
self.logger.log_info(f"Generating sequences for text_input: {prompt}")
78+
response = self.generate(prompt)
79+
responses.append(response)
80+
81+
return responses
82+
83+
def generate(self, prompt):
84+
sequences = self.pipeline(
85+
prompt,
86+
max_length=self.max_output_length,
87+
pad_token_id=self.tokenizer.eos_token_id,
88+
)
89+
90+
output_tensors = []
91+
texts = []
92+
for i, seq in enumerate(sequences):
93+
text = seq["generated_text"]
94+
self.logger.log_info(f"Sequence {i+1}: {text}")
95+
texts.append(text)
96+
97+
tensor = pb_utils.Tensor("text_output", np.array(texts, dtype=np.object_))
98+
output_tensors.append(tensor)
99+
response = pb_utils.InferenceResponse(output_tensors=output_tensors)
100+
return response
101+
102+
def finalize(self):
103+
print("Cleaning up...")
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Triton backend to use
2+
backend: "python"
3+
4+
# Hugging face model path. Parameters must follow this
5+
# key/value structure
6+
parameters: {
7+
key: "huggingface_model",
8+
value: {string_value: "adept/persimmon-8b-base"}
9+
}
10+
11+
# The maximum number of tokens to generate in response
12+
# to our input
13+
parameters: {
14+
key: "max_output_length",
15+
value: {string_value: "15"}
16+
}
17+
18+
# Triton should expect as input a single string of set
19+
# length named 'text_input'
20+
input [
21+
{
22+
name: "text_input"
23+
data_type: TYPE_STRING
24+
dims: [ 1 ]
25+
}
26+
]
27+
28+
# Triton should expect to respond with a single string
29+
# output of variable length named 'text_output'
30+
output [
31+
{
32+
name: "text_output"
33+
data_type: TYPE_STRING
34+
dims: [ -1 ]
35+
}
36+
]

Quick_Deploy/vLLM/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ The following tutorial demonstrates how to deploy a simple
3434
Triton Inference Server using Triton's [Python backend](https://github.com/triton-inference-server/python_backend) and the
3535
[vLLM](https://github.com/vllm-project/vllm) library.
3636

37-
*NOTE*: The tutorial is intended to be a reference example only. It is a work in progress with
38-
[known limitations](#limitations).
37+
*NOTE*: The tutorial is intended to be a reference example only and has [known limitations](#limitations).
3938

4039

4140
## Step 1: Build a Triton Container Image with vLLM

0 commit comments

Comments
 (0)