merge_lora_weights.py 985 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import fire
  4. import torch
  5. from peft import PeftModel
  6. from transformers import LlamaForCausalLM, LlamaTokenizer
  7. def main(base_model: str,
  8. peft_model: str,
  9. output_dir: str):
  10. model = LlamaForCausalLM.from_pretrained(
  11. base_model,
  12. load_in_8bit=False,
  13. torch_dtype=torch.float16,
  14. device_map="auto",
  15. offload_folder="tmp",
  16. )
  17. tokenizer = LlamaTokenizer.from_pretrained(
  18. base_model
  19. )
  20. model = PeftModel.from_pretrained(
  21. model,
  22. peft_model,
  23. torch_dtype=torch.float16,
  24. device_map="auto",
  25. offload_folder="tmp",
  26. )
  27. model = model.merge_and_unload()
  28. model.save_pretrained(output_dir)
  29. tokenizer.save_pretrained(output_dir)
  30. if __name__ == "__main__":
  31. fire.Fire(main)