|
@@ -172,14 +172,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
if train_config.save_optimizer:
|
|
|
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
- print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
|
|
|
+ print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
|
|
|
print("=====================================================")
|
|
|
|
|
|
if not train_config.use_peft and train_config.save_optimizer:
|
|
|
model_checkpointing.save_optimizer_checkpoint(
|
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
|
)
|
|
|
- print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
|
|
|
+ print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
|
|
|
print("=====================================================")
|
|
|
if train_config.enable_fsdp:
|
|
|
dist.barrier()
|