|
@@ -314,11 +314,10 @@ def load_sharded_model_single_gpu(model, model_path):
|
|
|
)
|
|
|
print(f"Sharded state checkpoint loaded from {load_dir}")
|
|
|
|
|
|
-def load_sharded_model_single_gpu(model,model_path,verbose=True):
|
|
|
+def load_sharded_model_single_gpu(model,model_path):
|
|
|
|
|
|
reader = FileSystemReader(model_path)
|
|
|
|
|
|
- # with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
|
|
state_dict = {
|
|
|
"model": model.state_dict()
|
|
|
}
|
|
@@ -329,9 +328,8 @@ def load_sharded_model_single_gpu(model,model_path,verbose=True):
|
|
|
no_dist=True,
|
|
|
)
|
|
|
|
|
|
- print(f"checkpoint after load_state_dict()")
|
|
|
ck = state_dict["model"].keys()
|
|
|
- print(f" checkpoint key len = {len(ck)} and \n keys = {state_dict['model']['model.embed_tokens.weight']}")
|
|
|
+ print(f" checkpoint key len = {len(ck)} and \n keys = {state_dict.keys()}")
|
|
|
model.load_state_dict(state_dict["model"])
|
|
|
|
|
|
print(f"Sharded state checkpoint loaded from {model_path}")
|