|
@@ -317,18 +317,22 @@ def load_sharded_model_single_gpu(model, model_path):
|
|
|
def load_sharded_model_single_gpu(model,model_path,verbose=True):
|
|
|
|
|
|
reader = FileSystemReader(model_path)
|
|
|
-
|
|
|
- with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
|
|
-
|
|
|
- dist_cp.load_state_dict(
|
|
|
- state_dict=state_dict_to_load_to,
|
|
|
- storage_reader= FileSystemReader(path),
|
|
|
- no_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
- print(f"checkpoint after load_state_dict()")
|
|
|
- ck = checkpoint.keys()
|
|
|
- print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
|
|
|
- model.load_state_dict(checkpoint)
|
|
|
|
|
|
- print(f"Sharded state checkpoint loaded from {model_path}")
|
|
|
+ # with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
|
|
+ state_dict = {
|
|
|
+ "model": model.state_dict()
|
|
|
+ }
|
|
|
+
|
|
|
+ dist_cp.load_state_dict(
|
|
|
+ state_dict=state_dict,
|
|
|
+ storage_reader= FileSystemReader(model_path),
|
|
|
+ no_dist=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"checkpoint after load_state_dict()")
|
|
|
+ ck = state_dict["model"].keys()
|
|
|
+ print(f" checkpoint key len = {len(ck)} and \n keys = {state_dict['model']['model.embed_tokens.weight']}")
|
|
|
+ model.load_state_dict(state_dict["model"])
|
|
|
+
|
|
|
+ print(f"Sharded state checkpoint loaded from {model_path}")
|
|
|
+ return model
|