|
@@ -20,21 +20,21 @@ sys.path.append(parent_directory)
|
|
from model_checkpointing import load_sharded_model_single_gpu
|
|
from model_checkpointing import load_sharded_model_single_gpu
|
|
|
|
|
|
def main(
|
|
def main(
|
|
- model_name,
|
|
|
|
- save_dir="", # Path to save the HF converted model checkpoints
|
|
|
|
- model_path="" # Path/ name of the HF model that include config.json and tokenizer_config.json
|
|
|
|
|
|
+ fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints
|
|
|
|
+ consolidated_model_path="", # Path to save the HF converted model checkpoints
|
|
|
|
+ HF_model_path="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf)
|
|
):
|
|
):
|
|
#load the HF model definition from config
|
|
#load the HF model definition from config
|
|
- model_def = load_llama_from_config(model_path)
|
|
|
|
|
|
+ model_def = load_llama_from_config(HF_model_path)
|
|
print("model is loaded from config")
|
|
print("model is loaded from config")
|
|
#load the FSDP sharded checkpoints into the model
|
|
#load the FSDP sharded checkpoints into the model
|
|
- model = load_sharded_model_single_gpu(model_def, model_name)
|
|
|
|
|
|
+ model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path)
|
|
print("model is loaded from FSDP checkpoints")
|
|
print("model is loaded from FSDP checkpoints")
|
|
#loading the tokenizer form the model_path
|
|
#loading the tokenizer form the model_path
|
|
- tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
|
|
|
- tokenizer.save_pretrained(save_dir)
|
|
|
|
|
|
+ tokenizer = LlamaTokenizer.from_pretrained(HF_model_path)
|
|
|
|
+ tokenizer.save_pretrained(consolidated_model_path)
|
|
#save the FSDP sharded checkpoints in HF format
|
|
#save the FSDP sharded checkpoints in HF format
|
|
- model.save_pretrained(save_dir)
|
|
|
|
- print(f"HuggingFace model checkpoints has been saved in {save_dir}")
|
|
|
|
|
|
+ model.save_pretrained(consolidated_model_path)
|
|
|
|
+ print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}")
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
fire.Fire(main)
|
|
fire.Fire(main)
|