Browse Source

update the fine-tuning script

Hamid Shojanazeri 1 year ago
parent
commit
9a2434f408
1 changed files with 2 additions and 11 deletions
  1. 2 11
      src/llama_recipes/finetuning.py

+ 2 - 11
src/llama_recipes/finetuning.py

@@ -94,6 +94,7 @@ def main(**kwargs):
                 load_in_8bit=True if train_config.quantization else None,
                 device_map="auto" if train_config.quantization else None,
                 use_cache=use_cache,
+                attn_implementation="eager" if train_config.use_fast_kernels else None,
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
@@ -107,18 +108,8 @@ def main(**kwargs):
             load_in_8bit=True if train_config.quantization else None,
             device_map="auto" if train_config.quantization else None,
             use_cache=use_cache,
+            attn_implementation="eager" if train_config.use_fast_kernels else None,
         )
-    if train_config.enable_fsdp and train_config.use_fast_kernels:
-        """
-        For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels
-        based on the hardware being used. This would speed up fine-tuning.
-        """
-        try:
-            from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model)
-        except ImportError:
-            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
 
     # Load the tokenizer and add special tokens
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)