Browse Source

update with the HF flash attention native

Hamid Shojanazeri 1 year ago
parent
commit
8776ceb833
1 changed files with 1 additions and 12 deletions
  1. 1 12
      examples/inference.py

+ 1 - 12
examples/inference.py

@@ -79,23 +79,12 @@ def main(
         torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     
-    model = load_model(model_name, quantization)
+    model = load_model(model_name, quantization, use_fast_kernels)
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     model.eval()
     
-    if use_fast_kernels:
-        """
-        Setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
-        based on the hardware being used. This would speed up inference when used for batched inputs.
-        """
-        try:
-            from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model)    
-        except ImportError:
-            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
 
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.pad_token = tokenizer.eos_token