|
@@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg):
|
|
|
reader = FileSystemReader(load_dir)
|
|
|
|
|
|
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
|
|
|
- checkpoint = model.state_dict()
|
|
|
+ checkpoint = {"model": model.state_dict()}
|
|
|
if rank == 0:
|
|
|
ck = checkpoint.keys()
|
|
|
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
|
|
@@ -78,7 +78,7 @@ def load_model_sharded(model, rank, cfg):
|
|
|
print(f"checkpoint after load_state_dict()")
|
|
|
ck = checkpoint.keys()
|
|
|
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
|
|
|
- model.load_state_dict(checkpoint)
|
|
|
+ model.load_state_dict(checkpoint["model"])
|
|
|
if rank == 0:
|
|
|
print(f"Sharded state checkpoint loaded from {load_dir}")
|
|
|
|