|
@@ -57,9 +57,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
if train_config.enable_fsdp:
|
|
|
- world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+
|
|
|
|
|
|
-
|
|
|
|
|
|
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
|
|
|
|
|
@@ -74,12 +74,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
train_step_loss = []
|
|
|
val_step_loss = []
|
|
|
val_step_perplexity = []
|
|
|
-
|
|
|
+
|
|
|
epoch_times = []
|
|
|
checkpoint_times = []
|
|
|
results = {}
|
|
|
best_val_loss = float("inf")
|
|
|
+ total_train_steps = 0
|
|
|
for epoch in range(train_config.num_epochs):
|
|
|
+ # stop when the maximum number of training steps is reached
|
|
|
+ if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
|
|
|
+ if not train_config.enable_fsdp or local_rank==0:
|
|
|
+ print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
|
|
|
+ break
|
|
|
epoch_start_time = time.perf_counter()
|
|
|
with MemoryTrace() as memtrace: # track the memory usage
|
|
|
model.train()
|
|
@@ -87,6 +93,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
total_length = len(train_dataloader)//gradient_accumulation_steps
|
|
|
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
+ total_train_steps += 1
|
|
|
+ # stop when the maximum number of training steps is reached
|
|
|
+ if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
|
|
|
+ break
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
if is_xpu_available():
|
|
@@ -98,7 +108,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if is_xpu_available():
|
|
|
batch[key] = batch[key].to('xpu:0')
|
|
|
else:
|
|
|
- batch[key] = batch[key].to('cuda:0')
|
|
|
+ batch[key] = batch[key].to('cuda:0')
|
|
|
with autocast():
|
|
|
loss = model(**batch).loss
|
|
|
loss = loss / gradient_accumulation_steps
|
|
@@ -133,7 +143,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(1)
|
|
|
|
|
|
- if wandb_run:
|
|
|
+ if wandb_run:
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
wandb_run.log({
|
|
|
'train/epoch': epoch + 1,
|
|
@@ -158,10 +168,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.enable_fsdp:
|
|
|
train_epoch_loss = train_epoch_loss/world_size
|
|
|
train_perplexity = torch.exp(train_epoch_loss)
|
|
|
-
|
|
|
+
|
|
|
train_prep.append(float(train_perplexity))
|
|
|
train_loss.append(float(train_epoch_loss))
|
|
|
-
|
|
|
+
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
memtrace.print_stats()
|
|
|
|
|
@@ -231,7 +241,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
else:
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
-
|
|
|
+
|
|
|
# Saving the results every epoch to plot later
|
|
|
if train_config.save_metrics:
|
|
|
save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
@@ -279,8 +289,15 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
|
|
|
val_step_loss = []
|
|
|
val_step_perplexity = []
|
|
|
eval_loss = 0.0 # Initialize evaluation loss
|
|
|
+ total_eval_steps = 0
|
|
|
with MemoryTrace() as memtrace:
|
|
|
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
|
|
|
+ total_eval_steps += 1
|
|
|
+ # stop when the maximum number of eval steps is reached
|
|
|
+ if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
|
|
|
+ if not train_config.enable_fsdp or local_rank==0:
|
|
|
+ print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
|
|
|
+ break
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
@@ -288,7 +305,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
|
|
|
if is_xpu_available():
|
|
|
batch[key] = batch[key].to('xpu:0')
|
|
|
else:
|
|
|
- batch[key] = batch[key].to('cuda:0')
|
|
|
+ batch[key] = batch[key].to('cuda:0')
|
|
|
# Ensure no gradients are computed for this scope to save memory
|
|
|
with torch.no_grad():
|
|
|
# Forward pass and compute loss
|
|
@@ -296,7 +313,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
|
|
|
loss = outputs.loss
|
|
|
if train_config.save_metrics:
|
|
|
val_step_loss.append(loss.detach().float().item())
|
|
|
- val_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
+ val_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
|
|
|
eval_loss += loss.detach().float()
|
|
|
# Decode predictions and add to evaluation predictions list
|
|
@@ -324,12 +341,12 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
|
|
|
else:
|
|
|
print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
|
|
|
- if wandb_run:
|
|
|
+ if wandb_run:
|
|
|
wandb_run.log({
|
|
|
'eval/perplexity': eval_ppl,
|
|
|
'eval/loss': eval_epoch_loss,
|
|
|
}, commit=False)
|
|
|
-
|
|
|
+
|
|
|
return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
|
|
|
|
|
|
def freeze_transformer_layers(model, num_layer):
|
|
@@ -410,7 +427,7 @@ def print_model_size(model, config, rank: int = 0) -> None:
|
|
|
def get_policies(cfg, rank):
|
|
|
"""Get the policies for mixed precision and fsdp wrapping"""
|
|
|
|
|
|
-
|
|
|
+
|
|
|
verify_bfloat_support = ((
|
|
|
torch.version.cuda
|
|
|
and torch.cuda.is_bf16_supported()
|