@@ -328,8 +328,6 @@ def load_sharded_model_single_gpu(model,model_path):
no_dist=True,
)
- ck = state_dict["model"].keys()
- print(f" checkpoint key len = {len(ck)} and \n keys = {state_dict.keys()}")
model.load_state_dict(state_dict["model"])
print(f"Sharded state checkpoint loaded from {model_path}")