model_utils.py 988 B

123456789101112131415161718192021222324252627282930
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the GNU General Public License version 3.
  3. from peft import PeftModel
  4. from transformers import LlamaForCausalLM, LlamaConfig
  5. # Function to load the main model for text generation
  6. def load_model(model_name, quantization):
  7. model = LlamaForCausalLM.from_pretrained(
  8. model_name,
  9. return_dict=True,
  10. load_in_8bit=quantization,
  11. device_map="auto",
  12. low_cpu_mem_usage=True,
  13. )
  14. return model
  15. # Function to load the PeftModel for performance optimization
  16. def load_peft_model(model, peft_model):
  17. peft_model = PeftModel.from_pretrained(model, peft_model)
  18. return peft_model
  19. # Loading the model from config to load FSDP checkpoints into that
  20. def load_llama_from_config(config_path):
  21. model_config = LlamaConfig.from_pretrained(config_path)
  22. model = LlamaForCausalLM(config=model_config)
  23. return model