浏览代码

adding flash attention and xformer memory efficient through PT SDPA (#97)

Geeta Chauhan 1 年之前
父节点
当前提交
3f1fef7a00
共有 10 个文件被更改,包括 120 次插入17 次删除
  1. 9 1
      README.md
  2. 1 0
      configs/training.py
  3. 12 0
      docs/inference.md
  4. 8 1
      docs/multi_gpu.md
  5. 13 0
      inference/chat_completion.py
  6. 18 5
      inference/inference.py
  7. 11 2
      llama_finetuning.py
  8. 1 1
      requirements.txt
  9. 30 1
      scripts/spellcheck_conf/wordlist.txt
  10. 17 6
      utils/train_utils.py

+ 9 - 1
README.md

@@ -107,13 +107,21 @@ torchrun --nnodes 1 --nproc_per_node 4  llama_finetuning.py --enable_fsdp --use_
 
 Here we use FSDP as discussed in the next section which can be used along with PEFT methods. To make use of PEFT methods with FSDP make sure to pass `use_peft` and `peft_method` args along with `enable_fsdp`. Here we are using `BF16` for training.
 
+## Flash Attention and Xformer Memory Efficient Kernels
+
+Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 4  llama_finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
+```
+
 ### Fine-tuning using FSDP Only
 
 If you are interested in running full parameter fine-tuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs.
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --use_fast_kernels
 
 ```
 

+ 1 - 0
configs/training.py

@@ -32,6 +32,7 @@ class train_config:
     dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
+    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
 
     
     

+ 12 - 0
docs/inference.md

@@ -34,6 +34,18 @@ The inference folder also includes a chat completion example, that adds built-in
 python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg
 
 ```
+
+## Flash Attention and Xformer Memory Efficient Kernels
+
+Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up inference when used for batched inputs. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
+
+```bash
+python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg --use_fast_kernels
+
+python inference/inference.py --model_name <training_config.output_dir> --peft_model <training_config.output_dir> --prompt_file <test_prompt_file> --use_auditnlg --use_fast_kernels
+
+```
+
 ## Loading back FSDP checkpoints
 
 In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above.

+ 8 - 1
docs/multi_gpu.md

@@ -44,6 +44,13 @@ The args used in the command above are:
 
 We use `torchrun` here to spawn multiple processes for FSDP.
 
+## Flash Attention and Xformer Memory Efficient Kernels
+
+Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 4  ../llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels
+```
 
 ### Fine-tuning using FSDP Only
 
@@ -51,7 +58,7 @@ If interested in running full parameter finetuning without making use of PEFT me
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 
+torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 --use_fast_kernels
 
 ```
 

+ 13 - 0
inference/chat_completion.py

@@ -34,6 +34,7 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
+    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
     if prompt_file is not None:
@@ -59,6 +60,18 @@ def main(
     model = load_model(model_name, quantization)
     if peft_model:
         model = load_peft_model(model, peft_model)
+    if use_fast_kernels:
+        """
+        Setting 'use_fast_kernels' will enable
+        using of Flash Attention or Xformer memory-efficient kernels 
+        based on the hardware being used. This would speed up inference when used for batched inputs.
+        """
+        try:
+            from optimum.bettertransformer import BetterTransformer
+            model = BetterTransformer.transform(model)   
+        except ImportError:
+            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {

+ 18 - 5
inference/inference.py

@@ -32,6 +32,7 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
+    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
     if prompt_file is not None:
@@ -51,6 +52,23 @@ def main(
     torch.manual_seed(seed)
     
     model = load_model(model_name, quantization)
+    if peft_model:
+        model = load_peft_model(model, peft_model)
+
+    model.eval()
+    
+    if use_fast_kernels:
+        """
+        Setting 'use_fast_kernels' will enable
+        using of Flash Attention or Xformer memory-efficient kernels 
+        based on the hardware being used. This would speed up inference when used for batched inputs.
+        """
+        try:
+            from optimum.bettertransformer import BetterTransformer
+            model = BetterTransformer.transform(model)    
+        except ImportError:
+            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
@@ -79,11 +97,6 @@ def main(
         print("Skipping the inferece as the prompt is not safe.")
         sys.exit(1)  # Exit the program with an error status
 
-    if peft_model:
-        model = load_peft_model(model, peft_model)
-
-    model.eval()
-
     batch = tokenizer(user_prompt, return_tensors="pt")
     batch = {k: v.to("cuda") for k, v in batch.items()}
     start = time.perf_counter()

+ 11 - 2
llama_finetuning.py

@@ -24,7 +24,6 @@ from transformers import (
     BitsAndBytesConfig
 )
 import torch.distributed as dist
-
 # Unused imports removed
 from utils.train_utils import (
     set_tokenizer_params,
@@ -95,7 +94,17 @@ def main(**kwargs):
         load_in_8bit=True if train_config.quantization else None,
         device_map="auto" if train_config.quantization else None,
     )
-    
+    if train_config.enable_fsdp and train_config.use_fast_kernels:
+        """
+        For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
+        using of Flash Attention or Xformer memory-efficient kernels 
+        based on the hardware being used. This would speed up fine-tuning.
+        """
+        try:
+            from optimum.bettertransformer import BetterTransformer
+            model = BetterTransformer.transform(model) 
+        except ImportError:
+            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
     
     # Prepare the model for int8 training if quantization is enabled

+ 1 - 1
requirements.txt

@@ -13,4 +13,4 @@ transformers>=4.31.0
 sentencepiece
 py7zr
 scipy
-
+optimum

+ 30 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1090,5 +1090,34 @@ intra
 nightlies
 recenly
 uncomment
+BFloat
+DDP
+LLM
+Xformer
+accuracies
+activations
+anyprecision
+aplaca
+assembels
+boolean
+checkpoining
+defatults
+gradinets
+itermediate
+recommond
+scaler
+sharding
+slurm
+summarization
+theJfleg
+xA
+Jupyter
+LLM
+Xformer
+dataset's
+jupyter
+mutli
+summarization
+xA
 Sanitization
-tokenization
+tokenization

+ 17 - 6
utils/train_utils.py

@@ -5,6 +5,7 @@ import os
 import sys
 from typing import List
 import yaml
+import time
 
 import fire
 import torch
@@ -73,9 +74,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     train_loss = []
     val_prep = []
     val_loss =[]
+    epoch_times = []
+    checkpoint_times = []
     results = {}
     best_val_loss = float("inf")
     for epoch in range(train_config.num_epochs):
+        epoch_start_time = time.perf_counter()
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             total_loss = 0.0
@@ -106,7 +110,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print(f"\n step {step} is completed and loss is {loss.detach().float()}")
                 else:
                     print(f"\n step {step} is completed and loss is {loss.detach().float()}")
-                    
+        epoch_end_time = time.perf_counter()-epoch_start_time
+        epoch_times.append(epoch_end_time)    
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -117,6 +122,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
+        
         if train_config.enable_fsdp:
             if rank==0:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
@@ -136,6 +142,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
           
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
+            checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp:
                     dist.barrier()
@@ -176,7 +183,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print("=====================================================")                     
                 if train_config.enable_fsdp:
                     dist.barrier()
-            
+            checkpoint_end_time = time.perf_counter() - checkpoint_start_time
+            checkpoint_times.append(checkpoint_end_time)
             if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
                 if train_config.enable_fsdp:
@@ -189,10 +197,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         if train_config.enable_fsdp:
             if rank==0:
-                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
+                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
         else:
-            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
-            
+            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
+    avg_epoch_time = sum(epoch_times)/ len(epoch_times) 
+    avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)   
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
@@ -204,7 +213,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     if train_config.run_validation:
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_loss'] = avg_eval_loss
-        
+    results["avg_epoch_time"] = avg_epoch_time
+    results["avg_checkpoint_time"] = avg_checkpoint_time
+    
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)