|
@@ -2,16 +2,18 @@
|
|
|
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
|
|
|
|
|
|
from peft import PeftModel
|
|
|
-from transformers import LlamaForCausalLM, LlamaConfig
|
|
|
+from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
|
|
|
|
|
|
# Function to load the main model for text generation
|
|
|
-def load_model(model_name, quantization):
|
|
|
- model = LlamaForCausalLM.from_pretrained(
|
|
|
+def load_model(model_name, quantization, use_fast_kernels):
|
|
|
+ print(f"use_fast_kernels{use_fast_kernels}")
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_name,
|
|
|
return_dict=True,
|
|
|
load_in_8bit=quantization,
|
|
|
device_map="auto",
|
|
|
low_cpu_mem_usage=True,
|
|
|
+ attn_implementation="eager" if use_fast_kernels else None,
|
|
|
)
|
|
|
return model
|
|
|
|