Browse Source

adding fsdp checkpoint loading

Hamid Shojanazeri 1 year ago
parent
commit
25b60ee835

+ 15 - 2
inference/inference.py

@@ -11,8 +11,17 @@ from typing import List
 
 from transformers import LlamaTokenizer
 from safety_utils import get_safety_checker
-from model_utils import load_model, load_peft_model
+from model_utils import load_model, load_peft_model, load_llama_from_config
 
+# Get the current file's directory
+current_directory = os.path.dirname(os.path.abspath(__file__))
+
+# Get the parent directory
+parent_directory = os.path.dirname(current_directory)
+
+# Append the parent directory to sys.path
+sys.path.append(parent_directory)
+from model_checkpointing import load_sharded_model_single_gpu
 
 def main(
     model_name,
@@ -49,7 +58,11 @@ def main(
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
-    model = load_model(model_name, quantization)
+    # model = load_model(model_name, quantization)
+    model = load_llama_from_config()
+    loaded_model = load_sharded_model_single_gpu(model, model_name)
+    
+    print("model has been loaded *******************")
 
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(

+ 10 - 2
inference/model_utils.py

@@ -2,7 +2,7 @@
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
 from peft import PeftModel
-from transformers import LlamaForCausalLM
+from transformers import LlamaForCausalLM, LlamaConfig
 
 # Function to load the main model for text generation
 def load_model(model_name, quantization):
@@ -19,4 +19,12 @@ def load_model(model_name, quantization):
 # Function to load the PeftModel for performance optimization
 def load_peft_model(model, peft_model):
     peft_model = PeftModel.from_pretrained(model, peft_model)
-    return peft_model
+    return peft_model
+
+# Loading the model from config to load FSDP checkpoints into that
+def load_llama_from_config():
+    model_config = LlamaConfig.from_pretrained("../../../hf-llama-pr/7B/") 
+    model = LlamaForCausalLM(config=model_config)
+    return model
+    
+    

+ 1 - 0
model_checkpointing/__init__.py

@@ -10,4 +10,5 @@ from .checkpoint_handler import (
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,
     load_model_sharded,
+    load_sharded_model_single_gpu
 )

+ 28 - 0
model_checkpointing/checkpoint_handler.py

@@ -304,3 +304,31 @@ def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
         save_state_dict(state_dict, writer)
 
         return
+
+def load_sharded_model_single_gpu(model, model_path):
+    
+    dcp.load_state_dict(
+                    state_dict=state_dict_to_load_to,
+                    storage_reader=FsspecReader(path),
+                    no_dist=True,
+                )
+    print(f"Sharded state checkpoint loaded from {load_dir}")
+    
+def load_sharded_model_single_gpu(model,model_path,verbose=True):
+    
+    reader = FileSystemReader(model_path)
+
+    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
+      
+        dist_cp.load_state_dict(
+                    state_dict=state_dict_to_load_to,
+                    storage_reader= FileSystemReader(path),
+                    no_dist=True,
+                )
+       
+        print(f"checkpoint after load_state_dict()")
+        ck = checkpoint.keys()
+        print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
+        model.load_state_dict(checkpoint)
+    
+    print(f"Sharded state checkpoint loaded from {model_path}")