Parcourir la source

resolving conflicts

Hamid Shojanazeri il y a 1 an
Parent
commit
62be60355a

+ 1 - 2
README.md

@@ -39,7 +39,7 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
 To run the examples, make sure to install the requirements using
 To run the examples, make sure to install the requirements using
 
 
 ```bash
 ```bash
-
+# python 3.9 or higher recommended
 pip install -r requirements.txt
 pip install -r requirements.txt
 
 
 ```
 ```
@@ -55,7 +55,6 @@ Given that the original checkpoint resides under models/7B you can install all r
 ## Install HuggingFace Transformers from source
 ## Install HuggingFace Transformers from source
 pip freeze | grep transformers ## verify it is version 4.31.0 or higher
 pip freeze | grep transformers ## verify it is version 4.31.0 or higher
 
 
-```bash
 git clone git@github.com:huggingface/transformers.git
 git clone git@github.com:huggingface/transformers.git
 cd transformers
 cd transformers
 pip install protobuf
 pip install protobuf

+ 25 - 0
docs/inference.md

@@ -34,6 +34,31 @@ The inference folder also includes a chat completion example, that adds built-in
 python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg
 python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg
 
 
 ```
 ```
+## Loading back FSDP checkpoints
+
+In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above.
+**To convert the checkpoint use the following command**:
+
+This is helpful if you have fine-tuned you model using FSDP only as follows:
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 
+```
+Then convert your FSDP checkpoint to HuggingFace checkpoints using:
+```bash
+ python inference/checkpoint_converter_fsdp_hf.py --fsdp_checkpoint_path  PATH/to/FSDP/Checkpoints --consolidated_model_path PATH/to/save/checkpoints --HF_model_path_or_name PATH/or/HF/model_name
+
+ # --HF_model_path_or_name specifies the HF Llama model name or path where it has config.json and tokenizer.json
+ ```
+By default, training parameter are saved in `train_params.yaml` in the path where FSDP checkpoints are saved, in the converter script we frist try to find the HugingFace model name used in the fine-tuning to load the model with configs from there, if not found user need to provide it.
+
+Then run inference using:
+
+```bash
+python inference/inference.py --model_name <training_config.output_dir> --prompt_file <test_prompt_file> 
+
+```
+
 
 
 ## Other Inference Options
 ## Other Inference Options
 
 

+ 1 - 1
inference/README.md

@@ -2,7 +2,7 @@
 
 
 This folder contains inference examples for Llama 2. So far, we have provided support for three methods of inference:
 This folder contains inference examples for Llama 2. So far, we have provided support for three methods of inference:
 
 
-1. [inference script](inference.py) script provides support for Hugging Face accelerate and PEFT fine tuned models.
+1. [inference script](inference.py) script provides support for Hugging Face accelerate, PEFT and FSDP fine tuned models.
 
 
 2. [vLLM_inference.py](vLLM_inference.py) script takes advantage of vLLM's paged attention concept for low latency.
 2. [vLLM_inference.py](vLLM_inference.py) script takes advantage of vLLM's paged attention concept for low latency.
 
 

+ 63 - 0
inference/checkpoint_converter_fsdp_hf.py

@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
+
+import fire
+import torch
+import os
+import sys
+import yaml
+from transformers import LlamaTokenizer
+from model_utils import  load_llama_from_config
+# Get the current file's directory
+current_directory = os.path.dirname(os.path.abspath(__file__))
+
+# Get the parent directory
+parent_directory = os.path.dirname(current_directory)
+
+# Append the parent directory to sys.path
+sys.path.append(parent_directory)
+from model_checkpointing import load_sharded_model_single_gpu
+
+def main(
+    fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints
+    consolidated_model_path="", # Path to save the HF converted model checkpoints
+    HF_model_path_or_name="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf)
+    ):
+    
+    try:
+        file_name = 'train_params.yaml'
+        # Combine the directory and file name to create the full path
+        train_params_path = os.path.join(fsdp_checkpoint_path, file_name)
+        # Open the file
+        with open(train_params_path, 'r') as file:
+            # Load the YAML data
+            data = yaml.safe_load(file)
+
+            # Access the 'model_name' field
+            HF_model_path_or_name = data.get('model_name')
+
+            print(f"Model name: {HF_model_path_or_name}")
+    except FileNotFoundError:
+        print(f"The file {train_params_path} does not exist.")
+        HF_model_path_or_name = input("Please enter the model name: ")
+        print(f"Model name: {HF_model_path_or_name}")
+    except Exception as e:
+        print(f"An error occurred: {e}")
+        
+        
+    #load the HF model definition from config
+    model_def = load_llama_from_config(HF_model_path_or_name)
+    print("model is loaded from config")
+    #load the FSDP sharded checkpoints into the model
+    model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path)
+    print("model is loaded from FSDP checkpoints")
+    #loading the tokenizer form the  model_path
+    tokenizer = LlamaTokenizer.from_pretrained(HF_model_path_or_name)
+    tokenizer.save_pretrained(consolidated_model_path)
+    #save the FSDP sharded checkpoints in HF format
+    model.save_pretrained(consolidated_model_path)
+    print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}")
+if __name__ == "__main__":
+    fire.Fire(main)

+ 1 - 2
inference/inference.py

@@ -12,8 +12,7 @@ from typing import List
 
 
 from transformers import LlamaTokenizer
 from transformers import LlamaTokenizer
 from safety_utils import get_safety_checker
 from safety_utils import get_safety_checker
-from model_utils import load_model, load_peft_model
-
+from model_utils import load_model, load_peft_model, load_llama_from_config
 
 
 def main(
 def main(
     model_name,
     model_name,

+ 10 - 2
inference/model_utils.py

@@ -2,7 +2,7 @@
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
 
 from peft import PeftModel
 from peft import PeftModel
-from transformers import LlamaForCausalLM
+from transformers import LlamaForCausalLM, LlamaConfig
 
 
 # Function to load the main model for text generation
 # Function to load the main model for text generation
 def load_model(model_name, quantization):
 def load_model(model_name, quantization):
@@ -19,4 +19,12 @@ def load_model(model_name, quantization):
 # Function to load the PeftModel for performance optimization
 # Function to load the PeftModel for performance optimization
 def load_peft_model(model, peft_model):
 def load_peft_model(model, peft_model):
     peft_model = PeftModel.from_pretrained(model, peft_model)
     peft_model = PeftModel.from_pretrained(model, peft_model)
-    return peft_model
+    return peft_model
+
+# Loading the model from config to load FSDP checkpoints into that
+def load_llama_from_config(config_path):
+    model_config = LlamaConfig.from_pretrained(config_path) 
+    model = LlamaForCausalLM(config=model_config)
+    return model
+    
+    

+ 1 - 0
model_checkpointing/__init__.py

@@ -8,4 +8,5 @@ from .checkpoint_handler import (
     save_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,
     save_model_and_optimizer_sharded,
     load_model_sharded,
     load_model_sharded,
+    load_sharded_model_single_gpu
 )
 )

+ 18 - 0
model_checkpointing/checkpoint_handler.py

@@ -247,3 +247,21 @@ def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
 
 
     print(f"optimizer shard loaded on rank {rank}")
     print(f"optimizer shard loaded on rank {rank}")
 
 
+def load_sharded_model_single_gpu(model,model_path):
+    
+    reader = FileSystemReader(model_path)
+    
+    state_dict = {
+        "model": model.state_dict()
+    }
+    
+    dist_cp.load_state_dict(
+                state_dict=state_dict,
+                storage_reader= FileSystemReader(model_path),
+                no_dist=True,
+            )
+    
+    model.load_state_dict(state_dict["model"])
+    
+    print(f"Sharded state checkpoint loaded from {model_path}")
+    return model

+ 12 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1067,4 +1067,15 @@ chatGPT
 Llama
 Llama
 PEFT
 PEFT
 LORA
 LORA
-FSDP
+FSDP
+AuditNLG
+finetune
+fsdp
+ineference
+lora
+peft
+samsum
+vLLM
+TGI
+vLLM
+vLLM's

+ 44 - 0
utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import os
 import sys
 import sys
 from typing import List
 from typing import List
+import yaml
 
 
 import fire
 import fire
 import torch
 import torch
@@ -204,6 +205,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_loss'] = avg_eval_loss
         results['avg_eval_loss'] = avg_eval_loss
         
         
+    #saving the training params including fsdp setting for reference.
+    if train_config.enable_fsdp and fsdp_config:
+        save_train_params(train_config, fsdp_config, rank)
+        
     return results
     return results
 
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
@@ -360,3 +365,42 @@ def get_policies(cfg, rank):
             print(f"bFloat16 support not present. Using FP32, and not mixed precision")
             print(f"bFloat16 support not present. Using FP32, and not mixed precision")
     wrapping_policy = get_llama_wrapper()
     wrapping_policy = get_llama_wrapper()
     return mixed_precision_policy, wrapping_policy
     return mixed_precision_policy, wrapping_policy
+
+def save_train_params(train_config, fsdp_config, rank):
+    """
+    This function saves the train_config and FSDP config into a train_params.yaml.
+    This will be used by converter script in the inference folder to fetch the HF model name or path.
+    It also would be hepful as a log for future references.
+    """
+    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # converting all values to strings to ensure they can be serialized into a YAML file
+    train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
+    fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
+    # Merge the two dictionaries into one
+    train_params_dict = {**train_config_dict, **fsdp_config_dict}
+    # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
+    folder_name = (
+    train_config.dist_checkpoint_root_folder
+    + "/"
+    + train_config.dist_checkpoint_folder
+    + "-"
+    + train_config.model_name
+    )
+
+    save_dir = Path.cwd() / folder_name
+    # If the directory does not exist, create it
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    # Convert the dictionary to a YAML string
+    config_yaml = yaml.dump(train_params_dict, indent=4)
+    file_name = os.path.join(save_dir,'train_params.yaml')
+
+    # Check if there's a directory with the same name as the file
+    if os.path.isdir(file_name):
+        print(f"Error: {file_name} is a directory, not a file.")
+    else:
+        # Write the YAML string to the file
+        with open(file_name, 'w') as f:
+            f.write(config_yaml)
+        if rank==0:
+            print(f"training params are saved in {file_name}")