Browse Source

Fsdp inference checkpoints (#39)

Geeta Chauhan 1 year ago
parent
commit
0cd5694a14

+ 25 - 0
docs/inference.md

@@ -34,6 +34,31 @@ The inference folder also includes a chat completion example, that adds built-in
 python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg
 
 ```
+## Loading back FSDP checkpoints
+
+In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above.
+**To convert the checkpoint use the following command**:
+
+This is helpful if you have fine-tuned you model using FSDP only as follows:
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 
+```
+Then convert your FSDP checkpoint to HuggingFace checkpoints using:
+```bash
+ python inference/checkpoint_converter_fsdp_hf.py --fsdp_checkpoint_path  PATH/to/FSDP/Checkpoints --consolidated_model_path PATH/to/save/checkpoints --HF_model_path_or_name PATH/or/HF/model_name
+
+ # --HF_model_path_or_name specifies the HF Llama model name or path where it has config.json and tokenizer.json
+ ```
+By default, training parameter are saved in `train_params.yaml` in the path where FSDP checkpoints are saved, in the converter script we frist try to find the HugingFace model name used in the fine-tuning to load the model with configs from there, if not found user need to provide it.
+
+Then run inference using:
+
+```bash
+python inference/inference.py --model_name <training_config.output_dir> --prompt_file <test_prompt_file> 
+
+```
+
 
 ## Other Inference Options
 

+ 1 - 1
inference/README.md

@@ -2,7 +2,7 @@
 
 This folder contains inference examples for Llama 2. So far, we have provided support for three methods of inference:
 
-1. [inference script](inference.py) script provides support for Hugging Face accelerate and PEFT fine tuned models.
+1. [inference script](inference.py) script provides support for Hugging Face accelerate, PEFT and FSDP fine tuned models.
 
 2. [vLLM_inference.py](vLLM_inference.py) script takes advantage of vLLM's paged attention concept for low latency.
 

+ 63 - 0
inference/checkpoint_converter_fsdp_hf.py

@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
+
+import fire
+import torch
+import os
+import sys
+import yaml
+from transformers import LlamaTokenizer
+from model_utils import  load_llama_from_config
+# Get the current file's directory
+current_directory = os.path.dirname(os.path.abspath(__file__))
+
+# Get the parent directory
+parent_directory = os.path.dirname(current_directory)
+
+# Append the parent directory to sys.path
+sys.path.append(parent_directory)
+from model_checkpointing import load_sharded_model_single_gpu
+
+def main(
+    fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints
+    consolidated_model_path="", # Path to save the HF converted model checkpoints
+    HF_model_path_or_name="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf)
+    ):
+    
+    try:
+        file_name = 'train_params.yaml'
+        # Combine the directory and file name to create the full path
+        train_params_path = os.path.join(fsdp_checkpoint_path, file_name)
+        # Open the file
+        with open(train_params_path, 'r') as file:
+            # Load the YAML data
+            data = yaml.safe_load(file)
+
+            # Access the 'model_name' field
+            HF_model_path_or_name = data.get('model_name')
+
+            print(f"Model name: {HF_model_path_or_name}")
+    except FileNotFoundError:
+        print(f"The file {train_params_path} does not exist.")
+        HF_model_path_or_name = input("Please enter the model name: ")
+        print(f"Model name: {HF_model_path_or_name}")
+    except Exception as e:
+        print(f"An error occurred: {e}")
+        
+        
+    #load the HF model definition from config
+    model_def = load_llama_from_config(HF_model_path_or_name)
+    print("model is loaded from config")
+    #load the FSDP sharded checkpoints into the model
+    model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path)
+    print("model is loaded from FSDP checkpoints")
+    #loading the tokenizer form the  model_path
+    tokenizer = LlamaTokenizer.from_pretrained(HF_model_path_or_name)
+    tokenizer.save_pretrained(consolidated_model_path)
+    #save the FSDP sharded checkpoints in HF format
+    model.save_pretrained(consolidated_model_path)
+    print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}")
+if __name__ == "__main__":
+    fire.Fire(main)

+ 1 - 2
inference/inference.py

@@ -12,8 +12,7 @@ from typing import List
 
 from transformers import LlamaTokenizer
 from safety_utils import get_safety_checker
-from model_utils import load_model, load_peft_model
-
+from model_utils import load_model, load_peft_model, load_llama_from_config
 
 def main(
     model_name,

+ 10 - 2
inference/model_utils.py

@@ -2,7 +2,7 @@
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
 from peft import PeftModel
-from transformers import LlamaForCausalLM
+from transformers import LlamaForCausalLM, LlamaConfig
 
 # Function to load the main model for text generation
 def load_model(model_name, quantization):
@@ -19,4 +19,12 @@ def load_model(model_name, quantization):
 # Function to load the PeftModel for performance optimization
 def load_peft_model(model, peft_model):
     peft_model = PeftModel.from_pretrained(model, peft_model)
-    return peft_model
+    return peft_model
+
+# Loading the model from config to load FSDP checkpoints into that
+def load_llama_from_config(config_path):
+    model_config = LlamaConfig.from_pretrained(config_path) 
+    model = LlamaForCausalLM(config=model_config)
+    return model
+    
+    

+ 1 - 0
model_checkpointing/__init__.py

@@ -10,4 +10,5 @@ from .checkpoint_handler import (
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,
     load_model_sharded,
+    load_sharded_model_single_gpu
 )

+ 28 - 0
model_checkpointing/checkpoint_handler.py

@@ -304,3 +304,31 @@ def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
         save_state_dict(state_dict, writer)
 
         return
+
+def load_sharded_model_single_gpu(model, model_path):
+    
+    dcp.load_state_dict(
+                    state_dict=state_dict_to_load_to,
+                    storage_reader=FsspecReader(path),
+                    no_dist=True,
+                )
+    print(f"Sharded state checkpoint loaded from {load_dir}")
+    
+def load_sharded_model_single_gpu(model,model_path):
+    
+    reader = FileSystemReader(model_path)
+    
+    state_dict = {
+        "model": model.state_dict()
+    }
+    
+    dist_cp.load_state_dict(
+                state_dict=state_dict,
+                storage_reader= FileSystemReader(model_path),
+                no_dist=True,
+            )
+    
+    model.load_state_dict(state_dict["model"])
+    
+    print(f"Sharded state checkpoint loaded from {model_path}")
+    return model

+ 12 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1067,4 +1067,15 @@ chatGPT
 Llama
 PEFT
 LORA
-FSDP
+FSDP
+AuditNLG
+finetune
+fsdp
+ineference
+lora
+peft
+samsum
+vLLM
+TGI
+vLLM
+vLLM's

+ 44 - 1
utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import sys
 from typing import List
+import yaml
 
 import fire
 import torch
@@ -173,7 +174,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_loss'] = avg_eval_loss
         
-
+    #saving the training params including fsdp setting for reference.
+    if train_config.enable_fsdp and fsdp_config:
+        save_train_params(train_config, fsdp_config, rank)
+        
     return results
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
@@ -322,3 +326,42 @@ def get_policies(cfg, rank):
             print(f"bFloat16 support not present. Using FP32, and not mixed precision")
     wrapping_policy = get_llama_wrapper()
     return mixed_precision_policy, wrapping_policy
+
+def save_train_params(train_config, fsdp_config, rank):
+    """
+    This function saves the train_config and FSDP config into a train_params.yaml.
+    This will be used by converter script in the inference folder to fetch the HF model name or path.
+    It also would be hepful as a log for future references.
+    """
+    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # converting all values to strings to ensure they can be serialized into a YAML file
+    train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
+    fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
+    # Merge the two dictionaries into one
+    train_params_dict = {**train_config_dict, **fsdp_config_dict}
+    # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
+    folder_name = (
+    train_config.dist_checkpoint_root_folder
+    + "/"
+    + train_config.dist_checkpoint_folder
+    + "-"
+    + train_config.model_name
+    )
+
+    save_dir = Path.cwd() / folder_name
+    # If the directory does not exist, create it
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    # Convert the dictionary to a YAML string
+    config_yaml = yaml.dump(train_params_dict, indent=4)
+    file_name = os.path.join(save_dir,'train_params.yaml')
+
+    # Check if there's a directory with the same name as the file
+    if os.path.isdir(file_name):
+        print(f"Error: {file_name} is a directory, not a file.")
+    else:
+        # Write the YAML string to the file
+        with open(file_name, 'w') as f:
+            f.write(config_yaml)
+        if rank==0:
+            print(f"training params are saved in {file_name}")