فهرست منبع

fixed arg names

Hamid Shojanazeri 1 سال پیش
والد
کامیت
9e3b1b7f01
1فایلهای تغییر یافته به همراه9 افزوده شده و 9 حذف شده
  1. 9 9
      inference/checkpoint_converter_fsdp_hf.py

+ 9 - 9
inference/checkpoint_converter_fsdp_hf.py

@@ -20,21 +20,21 @@ sys.path.append(parent_directory)
 from model_checkpointing import load_sharded_model_single_gpu
 
 def main(
-    model_name,
-    save_dir="", # Path to save the HF converted model checkpoints
-    model_path="" # Path/ name of the HF model that include config.json and tokenizer_config.json
+    fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints
+    consolidated_model_path="", # Path to save the HF converted model checkpoints
+    HF_model_path="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf)
     ):
     #load the HF model definition from config
-    model_def = load_llama_from_config(model_path)
+    model_def = load_llama_from_config(HF_model_path)
     print("model is loaded from config")
     #load the FSDP sharded checkpoints into the model
-    model = load_sharded_model_single_gpu(model_def, model_name)
+    model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path)
     print("model is loaded from FSDP checkpoints")
     #loading the tokenizer form the  model_path
-    tokenizer = LlamaTokenizer.from_pretrained(model_path)
-    tokenizer.save_pretrained(save_dir)
+    tokenizer = LlamaTokenizer.from_pretrained(HF_model_path)
+    tokenizer.save_pretrained(consolidated_model_path)
     #save the FSDP sharded checkpoints in HF format
-    model.save_pretrained(save_dir)
-    print(f"HuggingFace model checkpoints has been saved in {save_dir}")
+    model.save_pretrained(consolidated_model_path)
+    print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}")
 if __name__ == "__main__":
     fire.Fire(main)