|
@@ -94,6 +94,7 @@ def main(**kwargs):
|
|
|
load_in_8bit=True if train_config.quantization else None,
|
|
|
device_map="auto" if train_config.quantization else None,
|
|
|
use_cache=use_cache,
|
|
|
+ attn_implementation="eager" if train_config.use_fast_kernels else None,
|
|
|
)
|
|
|
else:
|
|
|
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
@@ -107,18 +108,8 @@ def main(**kwargs):
|
|
|
load_in_8bit=True if train_config.quantization else None,
|
|
|
device_map="auto" if train_config.quantization else None,
|
|
|
use_cache=use_cache,
|
|
|
+ attn_implementation="eager" if train_config.use_fast_kernels else None,
|
|
|
)
|
|
|
- if train_config.enable_fsdp and train_config.use_fast_kernels:
|
|
|
- """
|
|
|
- For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
|
|
|
- using of Flash Attention or Xformer memory-efficient kernels
|
|
|
- based on the hardware being used. This would speed up fine-tuning.
|
|
|
- """
|
|
|
- try:
|
|
|
- from optimum.bettertransformer import BetterTransformer
|
|
|
- model = BetterTransformer.transform(model)
|
|
|
- except ImportError:
|
|
|
- print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
|
|
|
|
|
|
# Load the tokenizer and add special tokens
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
|