|
@@ -66,7 +66,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
scaler = ShardedGradScaler()
|
|
|
elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
-
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
@@ -102,12 +103,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
-
|
|
|
- print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ if rank==0:
|
|
|
+ print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
+ else:
|
|
|
+ print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
+
|
|
|
# Reducing total_loss across all devices if there's more than one CUDA device
|
|
|
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
|
|
|
- train_epoch_loss = total_loss / data_set_len
|
|
|
+ train_epoch_loss = total_loss / len(train_dataloader)
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ train_epoch_loss = train_epoch_loss/world_size
|
|
|
train_perplexity = torch.exp(train_epoch_loss)
|
|
|
|
|
|
train_prep.append(train_perplexity)
|
|
@@ -127,11 +134,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
|
dist.barrier()
|
|
|
if train_config.use_peft:
|
|
|
-
|
|
|
print(f"we are in the saving the PEFT modules")
|
|
|
model.save_pretrained(train_config.output_dir)
|
|
|
- print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
-
|
|
|
+ print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
else:
|
|
|
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
|
|
|
|
|
@@ -139,16 +144,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
|
)
|
|
|
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
|
|
|
- print(" we are about to save the models *******")
|
|
|
+ print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
|
|
|
+ print("=====================================================")
|
|
|
|
|
|
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
if train_config.save_optimizer:
|
|
|
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
+ print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
|
|
|
+ print("=====================================================")
|
|
|
|
|
|
if not train_config.use_peft and train_config.save_optimizer:
|
|
|
model_checkpointing.save_optimizer_checkpoint(
|
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
|
- )
|
|
|
+ )
|
|
|
+ print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
|
|
|
+ print("=====================================================")
|
|
|
dist.barrier()
|
|
|
|
|
|
if eval_epoch_loss < best_val_loss:
|
|
@@ -192,6 +202,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
|
|
|
Returns: eval_ppl, eval_epoch_loss
|
|
|
"""
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
model.eval()
|
|
|
eval_preds = []
|
|
|
eval_loss = 0.0 # Initialize evaluation loss
|
|
@@ -223,7 +235,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
|
|
|
|
|
|
# Compute average loss and perplexity
|
|
|
- eval_epoch_loss = eval_loss / eval_dataset_len
|
|
|
+ eval_epoch_loss = eval_loss / len(eval_dataloader)
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ eval_epoch_loss = eval_epoch_loss/world_size
|
|
|
eval_ppl = torch.exp(eval_epoch_loss)
|
|
|
|
|
|
# Print evaluation metrics
|