diff --git a/examples/llama3_acc.sh b/examples/llama3_acc.sh index c29061f..530e4cf 100755 --- a/examples/llama3_acc.sh +++ b/examples/llama3_acc.sh @@ -3,4 +3,4 @@ set -ex # FSDP # note: this need transformers>=4.41.0 -./examples/run.sh --model ./hf_models/config/llama-3-1b --accelerator acc --gc --mbs 2 --fsdp 8 --max_seq_length 4096 --no_fa +./examples/run.sh --model ./hf_models/config/llama-3-1b --accelerator acc --gc --mbs 2 --fsdp 8 --max_seq_length 4096 --use_flash_attn diff --git a/examples/llama_acc.sh b/examples/llama_acc.sh index 3080336..a416f7b 100755 --- a/examples/llama_acc.sh +++ b/examples/llama_acc.sh @@ -2,7 +2,7 @@ set -ex # FSDP -./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 +./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 --use_flash_attn # TP # ./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 24 --tp 4 diff --git a/examples/run.sh b/examples/run.sh index ab396df..57cedcb 100755 --- a/examples/run.sh +++ b/examples/run.sh @@ -18,9 +18,9 @@ DP_NUM=1 # data parallelism number PP_NUM=1 # pipeline parallelism number TP_NUM=1 # tensor parallelism number FSDP_NUM=1 # fsdp number -FLASH_ATTN=1 # enable flash-attn-2 DATA=./data/wikitext-2-raw-v1.json # data name or path MODEL_NAME_OR_PATH="./hf_models/config/llama-1b" # model name or path +USE_FLASH_ATTN=1 OTHER_ARGS="" @@ -28,7 +28,7 @@ OTHER_ARGS="" HELP_STR=("Usage: bash examples/run.sh [-h|--help] [--accelerator {acc, cuda}] [--model MODEL_NAME_OR_PATH] \n" "\t[--data DATASET_NAME_OR_PATH] [--mbs MICRO_BATCH_SIZE] [--max_seq_length MAX_SEQ_LENGTH] \n" "\t[--num_train_epochs NUM_TRAIN_EPOCHS] [--max_steps MAX_TRAIN_STEPS] [--pp PP_NUM] [--tp TP_NUM] [--fsdp FSDP_NUM] \n" - "\t[--ga GRADIENT_ACCUMULATION_STEPS] [--gc] [--bf16] [--fp16] [--fp32] [--no_fa] [--log_interval LOG_INTERVAL] \n" + "\t[--ga GRADIENT_ACCUMULATION_STEPS] [--gc] [--bf16] [--fp16] [--fp32] [--use_flash_attn] [--log_interval LOG_INTERVAL] \n" "\t[other args for apps/train.py] \n" "Examples: \n" "\tbash examples/run.sh --accelerator cuda --model ./hf_models/config/llama-7b\n" @@ -125,8 +125,8 @@ while [[ $# -gt 0 ]]; do BF16=0 shift ;; - --no_fa) - FLASH_ATTN=0 + --use_flash_attn) + ACC_FLASH_ATTN=1 shift ;; --log_interval) @@ -150,6 +150,11 @@ OPTION_ARGS="" [[ "$BF16" -eq 1 ]] && OPTION_ARGS+="--bf16 " [[ "$FP16" -eq 1 ]] && OPTION_ARGS+="--fp16 " +if [[ "$ACC_FLASH_ATTN" == 1 && ( "$FP16" -eq 1 || "$BF16" -eq 1 ) ]]; then + OPTION_ARGS+="--use_flash_attn " + export ACC_FLASH_ATTN=1 +fi + if [ "$ACCELERATOR" == "cuda" ]; then [ "$PP_NUM" -gt 1 ] && echo "Error: Pipeline Parallelism is not supported for cuda accelerator." && exit 1 [ "$TP_NUM" -gt 1 ] && echo "Error: Tensor Parallelism is not supported for cuda accelerator." && exit 1 @@ -160,11 +165,6 @@ if [ "$TP_NUM" -gt "1" ]; then export XLA_USE_SPMD=1 fi - -if [[ "$ACCELERATOR" == "acc" && "FLASH_ATTN" -eq 1 && ( "$FP16" -eq 1 || "$BF16" -eq 1 ) ]]; then - export ACC_FLASH_ATTN=1 -fi - export XLA_PERSISTENT_CACHE_PATH=./compiled_cache/ MODEL_NAME=$(basename $MODEL_NAME_OR_PATH) diff --git a/flashmodels/accelerators/acc_baichuan_accelerator.py b/flashmodels/accelerators/acc_baichuan_accelerator.py index 935e333..14d497d 100644 --- a/flashmodels/accelerators/acc_baichuan_accelerator.py +++ b/flashmodels/accelerators/acc_baichuan_accelerator.py @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader def get_config(self, model): diff --git a/flashmodels/accelerators/acc_gemma_accelerator.py b/flashmodels/accelerators/acc_gemma_accelerator.py index f062105..db12458 100644 --- a/flashmodels/accelerators/acc_gemma_accelerator.py +++ b/flashmodels/accelerators/acc_gemma_accelerator.py @@ -12,7 +12,7 @@ def accelerate(self, model, loader): def accelerate_internal(self, model, loader): config = self.get_config() - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader def get_config(self): diff --git a/flashmodels/accelerators/acc_glm_accelerator.py b/flashmodels/accelerators/acc_glm_accelerator.py index 6c9801c..7695db0 100644 --- a/flashmodels/accelerators/acc_glm_accelerator.py +++ b/flashmodels/accelerators/acc_glm_accelerator.py @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader def get_config(self, model): diff --git a/flashmodels/accelerators/acc_gpt_accelerator.py b/flashmodels/accelerators/acc_gpt_accelerator.py index 95a3231..971c680 100644 --- a/flashmodels/accelerators/acc_gpt_accelerator.py +++ b/flashmodels/accelerators/acc_gpt_accelerator.py @@ -20,7 +20,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader device = lazy_device() diff --git a/flashmodels/accelerators/acc_llama_accelerator.py b/flashmodels/accelerators/acc_llama_accelerator.py index c2e7951..459f5e5 100644 --- a/flashmodels/accelerators/acc_llama_accelerator.py +++ b/flashmodels/accelerators/acc_llama_accelerator.py @@ -85,9 +85,7 @@ def accelerate_internal(self, model, loader): model = self.tensor_parallel(model) return model, loader - if self.args.pp_num > 1: - # Prevent unnecessary model outputs - model.model.config.use_cache = False + model.model.config.use_cache = False # TODO: support this in torchacc if self.args.resume_from_checkpoint: assert self.args.fsdp_num == self.args.world_size, \ @@ -101,7 +99,7 @@ def accelerate_internal(self, model, loader): self.args.sp) config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) if self.args.tp_num > 1 and self.args.pp_num > 1: self.parallel_3d(model._get_underlay_model()) diff --git a/flashmodels/accelerators/acc_olmo_accelerator.py b/flashmodels/accelerators/acc_olmo_accelerator.py index 8009209..0af0250 100644 --- a/flashmodels/accelerators/acc_olmo_accelerator.py +++ b/flashmodels/accelerators/acc_olmo_accelerator.py @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader else: raise NotImplementedError("Currently, only FSDP is supported.") diff --git a/flashmodels/accelerators/acc_qwen_accelerator.py b/flashmodels/accelerators/acc_qwen_accelerator.py index 1a64edf..c1acc0e 100644 --- a/flashmodels/accelerators/acc_qwen_accelerator.py +++ b/flashmodels/accelerators/acc_qwen_accelerator.py @@ -37,7 +37,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader def get_config(self, model): diff --git a/flashmodels/arguments.py b/flashmodels/arguments.py index 3aac6bf..8ad9eb7 100644 --- a/flashmodels/arguments.py +++ b/flashmodels/arguments.py @@ -4,7 +4,7 @@ import torch from flashmodels.logger import logger -from flashmodels.patch import patch_amp, patch_gemma, patch_llama, patch_peft +from flashmodels.patch import patch_gemma, patch_llama, patch_peft def print_args(args): @@ -16,10 +16,9 @@ def parse(): parser = argparse.ArgumentParser(description="Flash Models Arguments") # model args - parser.add_argument( - "--model_name_or_path", - type=str, - default="decapoda-research/llama-7b-hf") + parser.add_argument("--model_name_or_path", + type=str, + default="decapoda-research/llama-7b-hf") parser.add_argument("--cache_dir", type=str, default="./models/") parser.add_argument("--max_seq_length", type=int, default=1024) parser.add_argument( @@ -29,97 +28,95 @@ def parse(): choices=["gpt", "llama", "glm", "baichuan", "qwen", "olmo"]) # dataset args - parser.add_argument( - "--dataset_name_or_path", - type=str, - default="./data/wikitext-2-raw-v1.json") + parser.add_argument("--dataset_name_or_path", + type=str, + default="./data/wikitext-2-raw-v1.json") parser.add_argument("--dataset_config", type=str, default="") parser.add_argument("--micro_batch_size", type=int, default=8) parser.add_argument("--padding_side", type=str, default="right") - parser.add_argument( - "--disable_train_sampler", - action="store_true", - help="Disable Train Sampler") + parser.add_argument("--disable_train_sampler", + action="store_true", + help="Disable Train Sampler") # accelerator args - parser.add_argument( - "--accelerator", - type=str, - default="acc", - choices=["cuda", "acc", "megatron"], - help="accelerator name") - parser.add_argument( - "--fsdp_num", - type=int, - default=1, - help="Full sharded data parallel Number") - parser.add_argument( - "--gc", - action="store_true", - default=False, - help="Use gradients checkpoint") + parser.add_argument("--accelerator", + type=str, + default="acc", + choices=["cuda", "acc", "megatron"], + help="accelerator name") + parser.add_argument("--fsdp_num", + type=int, + default=1, + help="Full sharded data parallel Number") + parser.add_argument("--gc", + action="store_true", + default=False, + help="Use gradients checkpoint") parser.add_argument( "--gc_cnt", type=int, default=None, help="Number of decoder layers for gradient checkpointing") - parser.add_argument( - "--tp_num", type=int, default=1, help="Tensor Parallel Number") - parser.add_argument( - "--sp", - action="store_true", - default=False, - help="Use Sequence Parallelism.") + parser.add_argument("--tp_num", + type=int, + default=1, + help="Tensor Parallel Number") + parser.add_argument("--sp", + action="store_true", + default=False, + help="Use Sequence Parallelism.") parser.add_argument( "--sp_reshard_after_forward", action="store_true", default=False, help="To reduce memory usage, reshard weight after forward in TP-SP, \ and perform an extra all-gather in the backward pass") - parser.add_argument( - "--sp_num", - type=int, - default=1, - help="DeepSpeed Ulysses Sequence \ + parser.add_argument("--sp_num", + type=int, + default=1, + help="DeepSpeed Ulysses Sequence \ Parallel Number. ") - parser.add_argument( - "--dp_num", type=int, default=1, help="Data Parallel Number") - parser.add_argument( - "--pp_num", type=int, default=1, help="Pipeline Parallel Number") - parser.add_argument( - "--fp16", action="store_true", help="Run model in fp16 mode.") - parser.add_argument( - "--bf16", action="store_true", help="Run model in bfloat16 mode.") - parser.add_argument( - "--force_use_syncfree_adam", - action="store_true", - help="Force to use \ + parser.add_argument("--dp_num", + type=int, + default=1, + help="Data Parallel Number") + parser.add_argument("--pp_num", + type=int, + default=1, + help="Pipeline Parallel Number") + parser.add_argument("--fp16", + action="store_true", + help="Run model in fp16 mode.") + parser.add_argument("--bf16", + action="store_true", + help="Run model in bfloat16 mode.") + parser.add_argument("--force_use_syncfree_adam", + action="store_true", + help="Force to use \ syncfree.Adam/AdamW for better tracing peformance.") - parser.add_argument( - "--use_zero2", - action="store_true", - help="Use \ + parser.add_argument("--use_zero2", + action="store_true", + help="Use \ distributed optimizer(ZeRO2) for SPMD-DP.") - parser.add_argument( - "--use_zero3", - action="store_true", - help="Use \ + parser.add_argument("--use_zero3", + action="store_true", + help="Use \ ZeRO3 for SPMD-DP.") # lora parser.add_argument("--lora", action="store_true", help="Use lora") - parser.add_argument( - "--lora_r", type=int, default=8, help="lora attention dimension") - parser.add_argument( - "--lora_alpha", - type=int, - default=8, - help="lora scaling alpha parameter") - parser.add_argument( - "--lora_dropout", - type=float, - default=0.0, - help="The dropout probability \ + parser.add_argument("--lora_r", + type=int, + default=8, + help="lora attention dimension") + parser.add_argument("--lora_alpha", + type=int, + default=8, + help="lora scaling alpha parameter") + parser.add_argument("--lora_dropout", + type=float, + default=0.0, + help="The dropout probability \ for Lora layers") parser.add_argument( "--lora_target_modules", @@ -131,55 +128,50 @@ def parse(): # training args parser.add_argument("--global_rank", type=int, default=0) - parser.add_argument( - "--resume_from_checkpoint", - action="store_true", - help="Resume from checkpoint, if true," - " load checkpoint from ckpt_dir") + parser.add_argument("--resume_from_checkpoint", + action="store_true", + help="Resume from checkpoint, if true," + " load checkpoint from ckpt_dir") parser.add_argument("--ckpt_dir", type=str, default="") - parser.add_argument( - "--ckpt_freq", - type=int, - default=100, - help="The checkpoint frequency of local steps.") - parser.add_argument( - "--profile", action="store_true", help="Open pytorch profiler") + parser.add_argument("--ckpt_freq", + type=int, + default=100, + help="The checkpoint frequency of local steps.") + parser.add_argument("--profile", + action="store_true", + help="Open pytorch profiler") parser.add_argument("--profile_dir", type=str, default="./profile/") - parser.add_argument( - "--profile_stop_step", - type=int, - default=10, - help="Maximum profiling steps") + parser.add_argument("--profile_stop_step", + type=int, + default=10, + help="Maximum profiling steps") parser.add_argument("--log_interval", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--max_step", type=int, default=-1) - parser.add_argument( - "--learning_rate", - type=float, - default=2e-5, - help="The initial learning rate for AdamW.") - parser.add_argument( - "--weight_decay", - type=float, - default=0.03, - help="Weight decay for AdamW if we apply some.") - parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="Beta1 for AdamW optimizer") - parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="Beta2 for AdamW optimizer") - parser.add_argument( - "--adam_epsilon", - type=float, - default=1e-8, - help="Epsilon for AdamW optimizer.") - parser.add_argument( - "--max_grad_norm", type=float, default=1.0, help="Max gradient norm.") + parser.add_argument("--learning_rate", + type=float, + default=2e-5, + help="The initial learning rate for AdamW.") + parser.add_argument("--weight_decay", + type=float, + default=0.03, + help="Weight decay for AdamW if we apply some.") + parser.add_argument("--adam_beta1", + type=float, + default=0.9, + help="Beta1 for AdamW optimizer") + parser.add_argument("--adam_beta2", + type=float, + default=0.999, + help="Beta2 for AdamW optimizer") + parser.add_argument("--adam_epsilon", + type=float, + default=1e-8, + help="Epsilon for AdamW optimizer.") + parser.add_argument("--max_grad_norm", + type=float, + default=1.0, + help="Max gradient norm.") parser.add_argument( "--lr_scheduler_type", type=str, @@ -195,34 +187,34 @@ def parse(): type=float, default=0.0, help="Linear warmup over warmup_ratio fraction of total steps.") - parser.add_argument( - "--warmup_steps", - type=int, - default=0, - help="Linear warmup over warmup_steps.") + parser.add_argument("--warmup_steps", + type=int, + default=0, + help="Linear warmup over warmup_steps.") parser.add_argument("--num_train_epochs", type=int, default=1) - parser.add_argument( - "--padding_strategy", - type=str, - default="max_length", - help="tokenizer padding strategy", - choices=["max_length", "longest"]) - parser.add_argument( - "--max_train_steps", - type=int, - default=-1, - help="Maximum training steps") - parser.add_argument( - "--log_loss", action="store_true", help="Print loss when logging steps") + parser.add_argument("--padding_strategy", + type=str, + default="max_length", + help="tokenizer padding strategy", + choices=["max_length", "longest"]) + parser.add_argument("--max_train_steps", + type=int, + default=-1, + help="Maximum training steps") + parser.add_argument("--use_flash_attn", + action="store_true", + default=False, + help="Use TriDao FlashAttention2") + parser.add_argument("--log_loss", + action="store_true", + help="Print loss when logging steps") args = parser.parse_args() if args.lora: patch_peft() - if args.accelerator == "acc": - patch_amp() - else: + if args.accelerator == "cuda": torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) args.global_rank = int(os.getenv("RANK", 0)) @@ -271,7 +263,7 @@ def parse(): if args.model_type == "llama" and args.accelerator == 'acc' and ( args.fp16 or args.bf16): - patch_llama() + patch_llama(args.use_flash_attn) if args.model_type == "gemma" and args.accelerator == 'acc': patch_gemma() diff --git a/flashmodels/builder.py b/flashmodels/builder.py index aaa62bd..b5aca36 100644 --- a/flashmodels/builder.py +++ b/flashmodels/builder.py @@ -4,6 +4,7 @@ import torch import torchacc as ta + from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler) @@ -62,7 +63,9 @@ def build_model_from_ckpt(self): config = AutoConfig.from_pretrained( self.args.model_name_or_path, trust_remote_code=True) return self._init_fn( - AutoModelForCausalLM.from_config, config, trust_remote_code=True) + AutoModelForCausalLM.from_config, config, + attn_implementation="flash_attention_2" if self.args.use_flash_attn else "eager", + trust_remote_code=True) def build_model_from_pretrain(self): has_weight = False @@ -78,6 +81,7 @@ def build_model_from_pretrain(self): return self._init_fn( AutoModelForCausalLM.from_pretrained, self.args.model_name_or_path, + attn_implementation="flash_attention_2" if self.args.use_flash_attn else "eager", cache_dir=self.args.cache_dir, trust_remote_code=True) if self.args.local_rank == 0: diff --git a/flashmodels/patch/__init__.py b/flashmodels/patch/__init__.py index baf734f..452a833 100644 --- a/flashmodels/patch/__init__.py +++ b/flashmodels/patch/__init__.py @@ -1,10 +1,5 @@ from flashmodels.patch.patch import patch_gemma, patch_llama, patch_lora -def patch_amp(): - import torchacc as ta - ta.patch_amp() - - def patch_peft(): patch_lora() diff --git a/flashmodels/patch/patch.py b/flashmodels/patch/patch.py index f334ffb..4d6d36f 100644 --- a/flashmodels/patch/patch.py +++ b/flashmodels/patch/patch.py @@ -5,13 +5,10 @@ from typing import Any import torch +import torchacc.utils.patch as patch import transformers from flashmodels.logger import logger -from flashmodels.patch.llama_model import (LlamaAttention, LlamaDecoderLayer, - LlamaMLP, flash_attn_fwd, - flash_attn_prep_mask, - make_causal_mask) def rewrite_load(): @@ -34,12 +31,11 @@ def rewrite_load(): exec(modified, transformers.modeling_utils.__dict__) -def patch_llama(): - transformers.models.llama.modeling_llama._make_causal_mask = make_causal_mask - if os.getenv("ACC_FLASH_ATTN", "0") == "1": - transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = flash_attn_prep_mask - transformers.models.llama.modeling_llama.LlamaAttention.forward = flash_attn_fwd - elif os.environ.get("ACC_LLAMA_TP") == "1": +def patch_llama(use_flash_attn): + patch.patch_llama(use_flash_attn) + from flashmodels.patch.llama_model import (LlamaAttention, + LlamaDecoderLayer, LlamaMLP) + if os.environ.get("ACC_LLAMA_TP") == "1": transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP if os.getenv("XLA_USE_SPMD") == "1": # use einsum in linear for SPMD TP/Ulysses. @@ -50,25 +46,10 @@ def patch_llama(): if bool(int(os.environ.get("LOW_CPU_MEM_USAGE", "0"))): rewrite_load() - # Set the attention_mask in LlamaAttention to None to match the pattern of FlashAttentionRewriter. - def wrap_for_flash_attention(func): - - def wrapper(*args, **kwargs): - kwargs["attention_mask"] = None - return func(*args, **kwargs) - - return wrapper - - # always attention_mask=None - transformers.models.llama.modeling_llama.LlamaAttention.forward = wrap_for_flash_attention( - transformers.models.llama.modeling_llama.LlamaAttention. - forward) - def patch_gemma(): # Set the attention_mask in GemmaAttention to None to match the pattern of FlashAttentionRewriter. def wrap_for_flash_attention(func): - def wrapper(*args, **kwargs): kwargs["attention_mask"] = None return func(*args, **kwargs) diff --git a/hf_models/config/llama-1b/config.json b/hf_models/config/llama-1b/config.json index 04d6487..03476f1 100644 --- a/hf_models/config/llama-1b/config.json +++ b/hf_models/config/llama-1b/config.json @@ -1 +1 @@ -{"architectures": ["LLaMAForCausalLM"], "bos_token_id": 0, "eos_token_id": 1, "hidden_act": "silu", "hidden_size": 4096, "intermediate_size": 11008, "initializer_range": 0.02, "max_sequence_length": 2048, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 4, "pad_token_id": -1, "rms_norm_eps": 1e-06, "torch_dtype": "float16", "transformers_version": "4.27.0.dev0", "use_cache": true, "vocab_size": 32000} \ No newline at end of file +{"architectures": ["LLaMAForCausalLM"], "bos_token_id": 0, "eos_token_id": 1, "hidden_act": "silu", "hidden_size": 4096, "intermediate_size": 11008, "initializer_range": 0.02, "max_sequence_length": 2048, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 4, "pad_token_id": -1, "rms_norm_eps": 1e-06, "torch_dtype": "float16", "transformers_version": "4.27.0.dev0", "use_cache": true, "vocab_size": 32000}