@@ -47,6 +47,12 @@ class TensorParallelMode(str, enum.Enum):
4747 def split_dim (cls , mode ):
4848 return 1 if mode == cls .ROW else 0
4949
50+ # Helper to shard the corresponding per-channel activation scales
51+ # Which shard along the dimension orthogonal to the weights
52+ @classmethod
53+ def flip (cls , mode ):
54+ return cls .ROW if mode == cls .COLUMN else cls .COLUMN
55+
5056
5157def load_weight_shard (
5258 weight ,
@@ -954,9 +960,16 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
954960 load_weights_vanilla_helper (module , weights )
955961
956962 device = torch .device ('cuda' )
957- pre_quant_scale = load_weight_shard (weights [0 ]['pre_quant_scale' ],
958- module .tp_size , module .tp_rank ,
959- module .tp_mode , device )
963+
964+ pre_quant_scale = load_weight_shard (
965+ weights [0 ]["pre_quant_scale" ],
966+ module .tp_size ,
967+ module .tp_rank ,
968+ # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
969+ TensorParallelMode .flip (module .tp_mode ),
970+ device ,
971+ )
972+
960973 module .pre_quant_scale = Parameter (
961974 torch .ones ((module .in_features , ), dtype = pre_quant_scale .dtype ),
962975 requires_grad = False ).to (device = device )
@@ -1128,9 +1141,14 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
11281141 load_weights_vanilla_helper (module , weights )
11291142
11301143 device = torch .device ('cuda' )
1131- pre_quant_scale = load_weight_shard (weights [0 ]['pre_quant_scale' ],
1132- module .tp_size , module .tp_rank ,
1133- module .tp_mode , device )
1144+ pre_quant_scale = load_weight_shard (
1145+ weights [0 ]["pre_quant_scale" ],
1146+ module .tp_size ,
1147+ module .tp_rank ,
1148+ # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
1149+ TensorParallelMode .flip (module .tp_mode ),
1150+ device ,
1151+ )
11341152
11351153 assert pre_quant_scale .dtype == module .dtype
11361154
@@ -1185,11 +1203,15 @@ def load_weights_fused_qkv_linear(self, module: Linear,
11851203 # NOTE: pre_quant_scale is the same for q,k,v since modelopt checks which layer shared the same input and create an avg pre_quant_scale
11861204 # Usually when modelopt exports the quantized model, pre_quant_Scale is fused in the layer norm (this case relevant if fused is disabled - modelopt internal)
11871205 if "pre_quant_scale" in weights [0 ].keys ():
1188- pre_quant_scale = load_weight_shard (weights [0 ]['pre_quant_scale' ],
1189- module .tp_size ,
1190- module .tp_rank ,
1191- module .tp_mode ,
1192- device = torch .device ('cuda' ))
1206+
1207+ pre_quant_scale = load_weight_shard (
1208+ weights [0 ]["pre_quant_scale" ],
1209+ module .tp_size ,
1210+ module .tp_rank ,
1211+ # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
1212+ TensorParallelMode .flip (module .tp_mode ),
1213+ torch .device ('cuda' ),
1214+ )
11931215
11941216 module .pre_quant_scale = Parameter (
11951217 torch .ones ((module .in_features , ), dtype = pre_quant_scale .dtype ),
@@ -1223,11 +1245,14 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
12231245 copy_weight (module .alpha , alpha )
12241246
12251247 if "pre_quant_scale" in weights [0 ].keys ():
1226- pre_quant_scale = load_weight_shard (weights [0 ]['pre_quant_scale' ],
1227- module .tp_size ,
1228- module .tp_rank ,
1229- module .tp_mode ,
1230- device = torch .device ('cuda' ))
1248+ pre_quant_scale = load_weight_shard (
1249+ weights [0 ]["pre_quant_scale" ],
1250+ module .tp_size ,
1251+ module .tp_rank ,
1252+ # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
1253+ TensorParallelMode .flip (module .tp_mode ),
1254+ torch .device ('cuda' ),
1255+ )
12311256
12321257 # NOTE:Create this tensor in load_weights, since not all layer have this tensor and memory is not allocated for it (same as W4A16)
12331258 module .pre_quant_scale = Parameter (
0 commit comments