model_utils.py 728 B

12345678910111213141516171819202122
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the GNU General Public License version 3.
  3. from peft import PeftModel
  4. from transformers import LlamaForCausalLM
  5. # Function to load the main model for text generation
  6. def load_model(model_name, quantization):
  7. model = LlamaForCausalLM.from_pretrained(
  8. model_name,
  9. return_dict=True,
  10. load_in_8bit=quantization,
  11. device_map="auto",
  12. low_cpu_mem_usage=True,
  13. )
  14. return model
  15. # Function to load the PeftModel for performance optimization
  16. def load_peft_model(model, peft_model):
  17. peft_model = PeftModel.from_pretrained(model, peft_model)
  18. return peft_model