Hamid Shojanazeri 1 ano atrás
pai
commit
76a187c4d2

+ 2 - 2
inference/model_utils.py

@@ -22,8 +22,8 @@ def load_peft_model(model, peft_model):
     return peft_model
 
 # Loading the model from config to load FSDP checkpoints into that
-def load_llama_from_config():
-    model_config = LlamaConfig.from_pretrained("../../../hf-llama-pr/7B/") 
+def load_llama_from_config(config_path):
+    model_config = LlamaConfig.from_pretrained(config_path) 
     model = LlamaForCausalLM(config=model_config)
     return model
     

+ 2 - 4
model_checkpointing/checkpoint_handler.py

@@ -314,11 +314,10 @@ def load_sharded_model_single_gpu(model, model_path):
                 )
     print(f"Sharded state checkpoint loaded from {load_dir}")
     
-def load_sharded_model_single_gpu(model,model_path,verbose=True):
+def load_sharded_model_single_gpu(model,model_path):
     
     reader = FileSystemReader(model_path)
     
-    # with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
     state_dict = {
         "model": model.state_dict()
     }
@@ -329,9 +328,8 @@ def load_sharded_model_single_gpu(model,model_path,verbose=True):
                 no_dist=True,
             )
     
-    print(f"checkpoint after load_state_dict()")
     ck = state_dict["model"].keys()
-    print(f" checkpoint key len = {len(ck)} and \n keys =  {state_dict['model']['model.embed_tokens.weight']}")
+    print(f" checkpoint key len = {len(ck)} and \n keys =  {state_dict.keys()}")
     model.load_state_dict(state_dict["model"])
     
     print(f"Sharded state checkpoint loaded from {model_path}")