From c3efc6032623b04325ddde2d6762e3f1d2c5c304 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Wed, 20 Mar 2024 20:44:05 +0000 Subject: [PATCH 1/4] TP-aware optimizations draft Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../custom_modeling/flash_llama_modeling.py | 35 ++++- .../utils/gptq/shuffle.py | 65 +++++++++ server/text_generation_server/utils/layers.py | 14 +- .../text_generation_server/utils/weights.py | 136 +++++++++++++----- 4 files changed, 210 insertions(+), 40 deletions(-) create mode 100644 server/text_generation_server/utils/gptq/shuffle.py diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d56178ad0..22679d3ce 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -222,12 +222,22 @@ def __init__( weights=weights, bias=False, ) + + noshard_o_proj = False + if config.quantize == 'gptq': + from text_generation_server.utils.layers import IS_TP_AWARE_GPTQ + noshard_o_proj = IS_TP_AWARE_GPTQ + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, + noshard=noshard_o_proj, # Don't shard o_proj weight matrix if TP-aware optimization is desired ) + self.noshard_o_proj = noshard_o_proj + self.world_size = weights.process_group.size() + self.rank = weights.process_group.rank() def forward( self, @@ -285,9 +295,19 @@ def forward( 1, False, ) + attn_output = attn_output.reshape(-1, self.num_heads * self.head_size) - return self.o_proj(attn_output.reshape(-1, self.num_heads * self.head_size)) + # TP-aware Masked Matmul Optimization by zero filling the activation + # and multiply with full weight matrix in o_proj + if self.noshard_o_proj: + shard_size = attn_output.shape[1] + assert shard_size*self.world_size == self.o_proj.linear.height + zf_attn_output = torch.zeros((attn_output.shape[0], shard_size*self.world_size), dtype=attn_output.dtype, device=attn_output.device) + start_idx = self.rank * shard_size + zf_attn_output[:, start_idx:start_idx+shard_size] = attn_output + attn_output = zf_attn_output + return self.o_proj(attn_output) class LlamaMLP(nn.Module): def __init__(self, prefix, config, weights): @@ -303,6 +323,17 @@ def __init__(self, prefix, config, weights): else "none", ) ) + + # For TP-aware preshuffle optimization, load the g_idx of down_proj for computing perm + # When perm==None the original unoptimized control path is taken + perm = None + if config.quantize=="gptq": + from text_generation_server.utils.layers import IS_TP_AWARE_GPTQ + if IS_TP_AWARE_GPTQ: + down_proj_g_idx = weights.get_tensor(f"{prefix}.down_proj.g_idx") + if down_proj_g_idx is not None: + perm = torch.argsort(down_proj_g_idx) + # Fuse gate and up proj self.gate_up_proj = TensorParallelColumnLinear.load_multi( config, @@ -310,12 +341,14 @@ def __init__(self, prefix, config, weights): weights=weights, dim=0, bias=False, + col_perm=perm, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, + row_perm=perm, ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() diff --git a/server/text_generation_server/utils/gptq/shuffle.py b/server/text_generation_server/utils/gptq/shuffle.py new file mode 100644 index 000000000..a21666e43 --- /dev/null +++ b/server/text_generation_server/utils/gptq/shuffle.py @@ -0,0 +1,65 @@ +import torch + +# Shuffle columns of scales +def shuffle_and_replace_scales(state_dict, scales_name, col_perm): + scales = state_dict[scales_name] + assert len(col_perm) == scales.shape[1] + + shuffled_scales = scales[:,col_perm] + state_dict[scales_name] = shuffled_scales + +def unpack_shuffle_repack_and_replace_qzeros(state_dict, bits, qzeros_name, col_perm): + qzeros = state_dict[qzeros_name] + mask = 2**bits - 1 + pack_size = 32 // bits + assert len(col_perm) == qzeros.shape[1] * pack_size + + #unpack + unpacked_qzeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*pack_size), dtype=torch.int) + for i in range(pack_size): + unpacked_qzeros[:, i::pack_size] = (qzeros >> (i*bits)) & (mask) + + # shuffle + shuffled_qzeros = unpacked_qzeros[:,col_perm] + + # repack + packed_qzeros = torch.zeros_like(qzeros) + for i in range(pack_size): + packed_qzeros |= (shuffled_qzeros[:, i::pack_size] & mask) << (i*bits) + + state_dict[qzeros_name] = packed_qzeros + +def shuffle_and_replace_qweight(state_dict, bits, group_size, qweight_name, g_idx_name=None, next_g_idx_name=None, stable=False): + qweight = state_dict[qweight_name] + + # unpack qweight + mask = 2**bits - 1 + pack_size = 32 // bits + unpacked_qweight = torch.zeros((qweight.shape[0]*pack_size, qweight.shape[1]), dtype=torch.int) + for i in range(pack_size): + unpacked_qweight[i::pack_size] = (qweight >> (i*bits)) & (mask) + + # reorder rows conditionally + if not (g_idx_name is None): + g_idx = state_dict[g_idx_name] + row_perm = torch.argsort(g_idx, stable=stable) + unpacked_qweight = unpacked_qweight[row_perm] + + # reorder columns conditionally + if not (next_g_idx_name is None): + next_g_idx = state_dict[next_g_idx_name] + col_perm = torch.argsort(next_g_idx, stable=stable) + unpacked_qweight = unpacked_qweight[:,col_perm] + + # pack qweight + packed_qweight = torch.zeros_like(qweight) + for i in range(pack_size): + packed_qweight |= (unpacked_qweight[i::pack_size] & mask) << (i*bits) + + # replace qweight with new reordered one in state_dict + print(f'replacing {qweight_name}') + state_dict[qweight_name] = packed_qweight + + if not (g_idx_name is None): + print(f'replacing {g_idx_name}') + state_dict[g_idx_name] = torch.arange(0, len(g_idx), dtype=torch.int) // group_size \ No newline at end of file diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 312f4d5d1..c8ddc550f 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -17,6 +17,8 @@ HAS_GPTQ_CUDA = False GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower() GPTQ_CUDA_LINEAR = None +# TODO: should disable TP-aware GPTQ automatically if deployment is single GPU +IS_TP_AWARE_GPTQ = (os.getenv("DISABLE_TP_AWARE_GPTQ","False").lower() == "false") if torch.cuda.is_available(): try: @@ -279,13 +281,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class TensorParallelColumnLinear(SuperLayer): @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - return cls.load_multi(config, [prefix], weights, bias, dim=0) + def load(cls, config, prefix: str, weights, bias: bool, col_perm=None): + return cls.load_multi(config, [prefix], weights, bias, dim=0, col_perm=col_perm) @classmethod - def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int, col_perm=None): weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim + prefixes, quantize=config.quantize, dim=dim, col_perm=col_perm ) if bias: @@ -303,8 +305,8 @@ def __init__(self, linear, process_group): self.process_group = process_group @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + def load(cls, config, prefix: str, weights, bias: bool, row_perm=None, noshard=False): + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize, row_perm=row_perm, noshard=noshard) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 3a53eb360..b4e1590eb 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -10,6 +10,44 @@ QUANTIZE_CONFIG_FILENAME = "quantize_config.json" +def unpack(x, dim, bits=4): + return unpack_row(x, bits) if dim == 0 else unpack_col(x, bits) + +def unpack_col(x, bits): + mask = 2**bits - 1 + pack_size = 32 // bits + unpacked_x = torch.zeros((x.shape[0], x.shape[1]*pack_size), dtype=torch.int) + for i in range(pack_size): + unpacked_x[:, i::pack_size] = (x >> (i*bits)) & (mask) + return unpacked_x + +def unpack_row(x, bits): + mask = 2**bits - 1 + pack_size = 32 // bits + unpacked_x = torch.zeros((x.shape[0]*pack_size, x.shape[1]), dtype=torch.int) + for i in range(pack_size): + unpacked_x[i::pack_size] = (x >> (i*bits)) & (mask) + return unpacked_x + + +def pack(x, dim, bits=4): + return pack_row(x, bits) if dim == 0 else pack_col(x, bits) + +def pack_col(x, bits): + mask = 2**bits - 1 + pack_size = 32 // bits + packed_x = torch.zeros((x.shape[0], x.shape[1]//pack_size), dtype=torch.int) + for i in range(pack_size): + packed_x |= (x[:, i::pack_size] & mask) << (i*bits) + return packed_x + +def pack_row(x, bits): + mask = 2**bits - 1 + pack_size = 32 // bits + packed_x = torch.zeros((x.shape[0]//pack_size, x.shape[1]), dtype=torch.int) + for i in range(pack_size): + packed_x |= (x[i::pack_size] & mask) << (i*bits) + return packed_x class Weights: def __init__( @@ -101,7 +139,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int): tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, perm=None, packed=False): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -110,17 +148,53 @@ def get_sharded(self, tensor_name: str, dim: int): assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + if perm is None: + return self.get_partial_sharded(tensor_name, dim) + else: + return self.get_shuffle_sharded(tensor_name, dim, perm, packed) + + def get_shuffle_sharded(self, tensor_name: str, dim: int, perm, packed: bool): + filename, tensor_name = self.get_filename(tensor_name) + world_size = self.process_group.size() + rank = self.process_group.rank() + + f = self._get_handle(filename) + tensor = f.get_tensor(tensor_name) + perm = perm.to(device=tensor.device) + size = tensor.shape[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + # TODO: pack-unpack on cuda to speed up this part + if dim == 0: + if packed: + tensor = pack(unpack(tensor, dim)[perm], dim)[start:stop] + else: + tensor = tensor[perm][start:stop] + elif dim == 1: + if packed: + tensor = pack(unpack(tensor, dim)[:, perm], dim)[:, start:stop] + else: + tensor = tensor[:, perm][:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype != torch.int32: + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor - def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int, col_perm=None): if quantize == "gptq": try: - qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) + qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1, perm=col_perm, packed=False) for p in prefixes], dim=1) except RuntimeError: raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") - qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) - scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) + qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1, perm=col_perm, packed=True) for p in prefixes], dim=1) + scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1, perm=col_perm, packed=False) for p in prefixes], dim=1) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -141,39 +215,35 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): weight = torch.cat(w, dim=dim) return weight - def get_multi_weights_row(self, prefix: str, quantize: str): + def get_multi_weights_row(self, prefix: str, quantize: str, row_perm=None, noshard=False): if quantize == "gptq": bits, groupsize = self._get_gptq_params() - use_gptq_cuda = bits == 4 - - if self.process_group.size() > 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") - if g_idx is not None: - if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): - # Exllama implementation does not support row tensor parallelism with act-order, as - # it would require to reorder input activations that are split unto several GPUs - use_gptq_cuda = False - + from text_generation_server.utils.layers import HAS_GPTQ_CUDA + is_preshuffle = (row_perm != None) + is_masked_matmul = noshard + assert (is_preshuffle != is_masked_matmul + or not (is_preshuffle or is_masked_matmul)), f"TP-aware optimization can't both be enabled at the same time {is_preshuffle=}, {is_masked_matmul=}" + use_gptq_cuda = (bits == 4) and HAS_GPTQ_CUDA or (is_preshuffle or is_masked_matmul) + if self.process_group.rank == 0: + if use_gptq_cuda: + logger.info(f"Using GPTQ cuda kernels for row {prefix}") + else: + logger.warning( + "GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var," + " or not currently installed, try using BUILD_EXTENSIONS=True" + ) try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + qweight = self.get_sharded(f"{prefix}.qweight", + dim=0, + perm=row_perm if use_exllama else None, + packed=True, + ) if not is_masked_matmul else self.get_tensor(f"{prefix}.qweight") except RuntimeError: raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") - from text_generation_server.utils.layers import HAS_GPTQ_CUDA - if use_gptq_cuda: - use_gptq_cuda = HAS_GPTQ_CUDA - if self.process_group.rank == 0: - if use_gptq_cuda: - logger.info(f"Using GPTQ cuda kernels for row {prefix}") - else: - logger.warning( - "GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var," - " or not currently installed, try using BUILD_EXTENSIONS=True" - ) - if use_gptq_cuda: - if groupsize >= 0: + if groupsize >= 0 and not is_masked_matmul: # Exllama reorders the weights in advance and the activations on the fly, thus # the scales and zero-points do not need to be reordered. qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) @@ -183,7 +253,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str): scales = self.get_tensor(f"{prefix}.scales") # For tp > 1, at this point we know we do not use act-order - if self.process_group.size() == 1: + if (self.process_group.size() == 1 or is_masked_matmul) and not is_preshuffle: g_idx = self.get_tensor(f"{prefix}.g_idx") else: g_idx = None @@ -197,7 +267,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str): weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda) else: - weight = self.get_sharded(f"{prefix}.weight", dim=1) + weight = self.get_sharded(f"{prefix}.weight", dim=1) if not noshard else self.get_tensor(f"{prefix}.weight") return weight def _get_gptq_params(self) -> Tuple[int, int]: From 58903e5096250dc9c3e97ba54d31b58b0e8debe7 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Thu, 21 Mar 2024 17:28:03 +0000 Subject: [PATCH 2/4] fix use_exllama logic error Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- server/text_generation_server/utils/weights.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index b4e1590eb..4f2444e08 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -219,12 +219,13 @@ def get_multi_weights_row(self, prefix: str, quantize: str, row_perm=None, nosha if quantize == "gptq": bits, groupsize = self._get_gptq_params() - from text_generation_server.utils.layers import HAS_GPTQ_CUDA + from text_generation_server.utils.layers import HAS_GPTQ_CUDA, IS_TP_AWARE_GPTQ is_preshuffle = (row_perm != None) is_masked_matmul = noshard assert (is_preshuffle != is_masked_matmul or not (is_preshuffle or is_masked_matmul)), f"TP-aware optimization can't both be enabled at the same time {is_preshuffle=}, {is_masked_matmul=}" - use_gptq_cuda = (bits == 4) and HAS_GPTQ_CUDA or (is_preshuffle or is_masked_matmul) + + use_exllama = (bits == 4) and HAS_GPTQ_CUDA and (IS_TP_AWARE_GPTQ and (is_preshuffle or is_masked_matmul)) if self.process_group.rank == 0: if use_gptq_cuda: logger.info(f"Using GPTQ cuda kernels for row {prefix}") From 4f65abc417481b3e9cb5b5db430f5614fe8bbcd4 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Mon, 25 Mar 2024 20:59:12 +0000 Subject: [PATCH 3/4] Fix merge bugs Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../models/custom_modeling/flash_llama_modeling.py | 2 +- server/text_generation_server/utils/weights.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 22679d3ce..6aa4f22dd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -301,7 +301,7 @@ def forward( # and multiply with full weight matrix in o_proj if self.noshard_o_proj: shard_size = attn_output.shape[1] - assert shard_size*self.world_size == self.o_proj.linear.height + # assert shard_size*self.world_size == self.o_proj.linear.height zf_attn_output = torch.zeros((attn_output.shape[0], shard_size*self.world_size), dtype=attn_output.dtype, device=attn_output.device) start_idx = self.rank * shard_size zf_attn_output[:, start_idx:start_idx+shard_size] = attn_output diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 4f2444e08..b02de5b4e 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -225,7 +225,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str, row_perm=None, nosha assert (is_preshuffle != is_masked_matmul or not (is_preshuffle or is_masked_matmul)), f"TP-aware optimization can't both be enabled at the same time {is_preshuffle=}, {is_masked_matmul=}" - use_exllama = (bits == 4) and HAS_GPTQ_CUDA and (IS_TP_AWARE_GPTQ and (is_preshuffle or is_masked_matmul)) + use_gptq_cuda = (bits == 4) and HAS_GPTQ_CUDA and (IS_TP_AWARE_GPTQ and (is_preshuffle or is_masked_matmul)) if self.process_group.rank == 0: if use_gptq_cuda: logger.info(f"Using GPTQ cuda kernels for row {prefix}") @@ -237,7 +237,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str, row_perm=None, nosha try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0, - perm=row_perm if use_exllama else None, + perm=row_perm if use_gptq_cuda else None, packed=True, ) if not is_masked_matmul else self.get_tensor(f"{prefix}.qweight") except RuntimeError: From c7034dfd087872d162c7c07af4c77fdb360ce24d Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Tue, 26 Mar 2024 14:06:23 +0000 Subject: [PATCH 4/4] Change default to disabling TP-aware optimizations Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- server/text_generation_server/utils/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index c8ddc550f..5f9852876 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -18,7 +18,7 @@ GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower() GPTQ_CUDA_LINEAR = None # TODO: should disable TP-aware GPTQ automatically if deployment is single GPU -IS_TP_AWARE_GPTQ = (os.getenv("DISABLE_TP_AWARE_GPTQ","False").lower() == "false") +IS_TP_AWARE_GPTQ = (os.getenv("ENABLE_TP_AWARE_GPTQ","False").lower() not in ["false", "0"]) if torch.cuda.is_available(): try: