Skip to content

Commit 213f5df

Browse files
committed
support tp
Signed-off-by: Daniel Afrimi <[email protected]>
1 parent 12311e1 commit 213f5df

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ class TensorParallelMode(str, enum.Enum):
4747
def split_dim(cls, mode):
4848
return 1 if mode == cls.ROW else 0
4949

50+
# Helper to shard the corresponding per-channel activation scales
51+
# Which shard along the dimension orthogonal to the weights
52+
@classmethod
53+
def flip(cls, mode):
54+
return cls.ROW if mode == cls.COLUMN else cls.COLUMN
55+
5056

5157
def load_weight_shard(
5258
weight,
@@ -954,9 +960,16 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
954960
load_weights_vanilla_helper(module, weights)
955961

956962
device = torch.device('cuda')
957-
pre_quant_scale = load_weight_shard(weights[0]['pre_quant_scale'],
958-
module.tp_size, module.tp_rank,
959-
module.tp_mode, device)
963+
964+
pre_quant_scale = load_weight_shard(
965+
weights[0]["pre_quant_scale"],
966+
module.tp_size,
967+
module.tp_rank,
968+
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
969+
TensorParallelMode.flip(module.tp_mode),
970+
device,
971+
)
972+
960973
module.pre_quant_scale = Parameter(
961974
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
962975
requires_grad=False).to(device=device)
@@ -1128,9 +1141,14 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
11281141
load_weights_vanilla_helper(module, weights)
11291142

11301143
device = torch.device('cuda')
1131-
pre_quant_scale = load_weight_shard(weights[0]['pre_quant_scale'],
1132-
module.tp_size, module.tp_rank,
1133-
module.tp_mode, device)
1144+
pre_quant_scale = load_weight_shard(
1145+
weights[0]["pre_quant_scale"],
1146+
module.tp_size,
1147+
module.tp_rank,
1148+
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
1149+
TensorParallelMode.flip(module.tp_mode),
1150+
device,
1151+
)
11341152

11351153
assert pre_quant_scale.dtype == module.dtype
11361154

@@ -1185,11 +1203,15 @@ def load_weights_fused_qkv_linear(self, module: Linear,
11851203
# NOTE: pre_quant_scale is the same for q,k,v since modelopt checks which layer shared the same input and create an avg pre_quant_scale
11861204
# Usually when modelopt exports the quantized model, pre_quant_Scale is fused in the layer norm (this case relevant if fused is disabled - modelopt internal)
11871205
if "pre_quant_scale" in weights[0].keys():
1188-
pre_quant_scale = load_weight_shard(weights[0]['pre_quant_scale'],
1189-
module.tp_size,
1190-
module.tp_rank,
1191-
module.tp_mode,
1192-
device=torch.device('cuda'))
1206+
1207+
pre_quant_scale = load_weight_shard(
1208+
weights[0]["pre_quant_scale"],
1209+
module.tp_size,
1210+
module.tp_rank,
1211+
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
1212+
TensorParallelMode.flip(module.tp_mode),
1213+
torch.device('cuda'),
1214+
)
11931215

11941216
module.pre_quant_scale = Parameter(
11951217
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
@@ -1223,11 +1245,14 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
12231245
copy_weight(module.alpha, alpha)
12241246

12251247
if "pre_quant_scale" in weights[0].keys():
1226-
pre_quant_scale = load_weight_shard(weights[0]['pre_quant_scale'],
1227-
module.tp_size,
1228-
module.tp_rank,
1229-
module.tp_mode,
1230-
device=torch.device('cuda'))
1248+
pre_quant_scale = load_weight_shard(
1249+
weights[0]["pre_quant_scale"],
1250+
module.tp_size,
1251+
module.tp_rank,
1252+
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
1253+
TensorParallelMode.flip(module.tp_mode),
1254+
torch.device('cuda'),
1255+
)
12311256

12321257
# NOTE:Create this tensor in load_weights, since not all layer have this tensor and memory is not allocated for it (same as W4A16)
12331258
module.pre_quant_scale = Parameter(

0 commit comments

Comments
 (0)