Browse Source

adding dist checkpoint handler for signle gpu

Hamid Shojanazeri 1 year ago
parent
commit
15053268b4
2 changed files with 22 additions and 18 deletions
  1. 4 4
      inference/inference.py
  2. 18 14
      model_checkpointing/checkpoint_handler.py

+ 4 - 4
inference/inference.py

@@ -59,12 +59,12 @@ def main(
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     # model = load_model(model_name, quantization)
-    model = load_llama_from_config()
-    loaded_model = load_sharded_model_single_gpu(model, model_name)
+    model_config = load_llama_from_config()
+    model = load_sharded_model_single_gpu(model_config, model_name)
     
     print("model has been loaded *******************")
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+    tokenizer = LlamaTokenizer.from_pretrained("../../../hf-llama-pr/7B/")
     tokenizer.add_special_tokens(
         {
             "eos_token": "</s>",
@@ -97,7 +97,7 @@ def main(
     if peft_model:
         model = load_peft_model(model, peft_model)
 
-    model.eval()
+    # model.eval()
 
     batch = tokenizer(user_prompt, return_tensors="pt")
     batch = {k: v.to("cuda") for k, v in batch.items()}

+ 18 - 14
model_checkpointing/checkpoint_handler.py

@@ -317,18 +317,22 @@ def load_sharded_model_single_gpu(model, model_path):
 def load_sharded_model_single_gpu(model,model_path,verbose=True):
     
     reader = FileSystemReader(model_path)
-
-    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
-      
-        dist_cp.load_state_dict(
-                    state_dict=state_dict_to_load_to,
-                    storage_reader= FileSystemReader(path),
-                    no_dist=True,
-                )
-       
-        print(f"checkpoint after load_state_dict()")
-        ck = checkpoint.keys()
-        print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
-        model.load_state_dict(checkpoint)
     
-    print(f"Sharded state checkpoint loaded from {model_path}")
+    # with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
+    state_dict = {
+        "model": model.state_dict()
+    }
+    
+    dist_cp.load_state_dict(
+                state_dict=state_dict,
+                storage_reader= FileSystemReader(model_path),
+                no_dist=True,
+            )
+    
+    print(f"checkpoint after load_state_dict()")
+    ck = state_dict["model"].keys()
+    print(f" checkpoint key len = {len(ck)} and \n keys =  {state_dict['model']['model.embed_tokens.weight']}")
+    model.load_state_dict(state_dict["model"])
+    
+    print(f"Sharded state checkpoint loaded from {model_path}")
+    return model