Browse Source

update the model load with native flash attn

Hamid Shojanazeri 1 year ago
parent
commit
db8af96ff0
1 changed files with 5 additions and 3 deletions
  1. 5 3
      src/llama_recipes/inference/model_utils.py

+ 5 - 3
src/llama_recipes/inference/model_utils.py

@@ -2,16 +2,18 @@
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
 from peft import PeftModel
-from transformers import LlamaForCausalLM, LlamaConfig
+from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
 
 # Function to load the main model for text generation
-def load_model(model_name, quantization):
-    model = LlamaForCausalLM.from_pretrained(
+def load_model(model_name, quantization, use_fast_kernels):
+    print(f"use_fast_kernels{use_fast_kernels}")
+    model = AutoModelForCausalLM.from_pretrained(
         model_name,
         return_dict=True,
         load_in_8bit=quantization,
         device_map="auto",
         low_cpu_mem_usage=True,
+        attn_implementation="eager" if use_fast_kernels else None,
     )
     return model