|
@@ -137,7 +137,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
if train_config.run_validation:
|
|
if train_config.run_validation:
|
|
eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
|
|
eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
- dist.barrier()
|
|
|
|
|
|
+ if train_config.enable_fsdp:
|
|
|
|
+ dist.barrier()
|
|
if train_config.use_peft:
|
|
if train_config.use_peft:
|
|
if train_config.enable_fsdp:
|
|
if train_config.enable_fsdp:
|
|
if rank==0:
|
|
if rank==0:
|
|
@@ -173,7 +174,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
)
|
|
)
|
|
print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
|
|
print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
|
|
print("=====================================================")
|
|
print("=====================================================")
|
|
- dist.barrier()
|
|
|
|
|
|
+ if train_config.enable_fsdp:
|
|
|
|
+ dist.barrier()
|
|
|
|
|
|
if eval_epoch_loss < best_val_loss:
|
|
if eval_epoch_loss < best_val_loss:
|
|
best_val_loss = eval_epoch_loss
|
|
best_val_loss = eval_epoch_loss
|
|
@@ -205,7 +207,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
results['avg_eval_prep'] = avg_eval_prep
|
|
results['avg_eval_prep'] = avg_eval_prep
|
|
results['avg_eval_loss'] = avg_eval_loss
|
|
results['avg_eval_loss'] = avg_eval_loss
|
|
|
|
|
|
- dist.barrier()
|
|
|
|
return results
|
|
return results
|
|
|
|
|
|
def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
@@ -285,8 +286,10 @@ def setup_environ_flags(rank):
|
|
"""Set environment flags for debugging purposes"""
|
|
"""Set environment flags for debugging purposes"""
|
|
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
|
|
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
|
|
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
|
|
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
|
|
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
|
|
|
- os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
|
|
|
|
|
|
+ # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
|
|
|
+ # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
|
|
|
|
+ # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
|
|
|
|
+ # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
|
|
if rank == 0:
|
|
if rank == 0:
|
|
print(f"--> Running with torch dist debug set to detail")
|
|
print(f"--> Running with torch dist debug set to detail")
|
|
|
|
|