|
@@ -4,6 +4,7 @@
|
|
|
import os
|
|
|
import time
|
|
|
import yaml
|
|
|
+from contextlib import nullcontext
|
|
|
from pathlib import Path
|
|
|
from pkg_resources import packaging
|
|
|
|
|
@@ -56,7 +57,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
if train_config.enable_fsdp:
|
|
|
- world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+ autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
|
|
|
+
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
@@ -71,17 +74,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
|
total_length = len(train_dataloader)//gradient_accumulation_steps
|
|
|
- pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length)
|
|
|
+ pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
- batch[key] = batch[key].to(local_rank)
|
|
|
+ if is_xpu_available():
|
|
|
+ batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
|
|
|
+ else:
|
|
|
+ batch[key] = batch[key].to(local_rank)
|
|
|
else:
|
|
|
if is_xpu_available():
|
|
|
batch[key] = batch[key].to('xpu:0')
|
|
|
else:
|
|
|
batch[key] = batch[key].to('cuda:0')
|
|
|
- loss = model(**batch).loss
|
|
|
+ with autocast():
|
|
|
+ loss = model(**batch).loss
|
|
|
loss = loss / gradient_accumulation_steps
|
|
|
total_loss += loss.detach().float()
|
|
|
if train_config.use_fp16:
|
|
@@ -91,16 +98,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
optimizer.zero_grad()
|
|
|
- pbar.update(step//gradient_accumulation_steps)
|
|
|
+ pbar.update(1)
|
|
|
else:
|
|
|
# regular backpropagation when fp16 is not used
|
|
|
loss.backward()
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
- pbar.update(step//gradient_accumulation_steps)
|
|
|
-
|
|
|
- pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
+ pbar.update(1)
|
|
|
+
|
|
|
+ pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
+ pbar.close()
|
|
|
|
|
|
epoch_end_time = time.perf_counter()-epoch_start_time
|
|
|
epoch_times.append(epoch_end_time)
|
|
@@ -195,16 +203,16 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
best_val_loss = eval_epoch_loss
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
- print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
+ print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
else:
|
|
|
- print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
+ print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
val_loss.append(best_val_loss)
|
|
|
val_prep.append(eval_ppl)
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
- print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
|
+ print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
else:
|
|
|
- print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
|
+ print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
|
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
@@ -245,7 +253,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
eval_preds = []
|
|
|
eval_loss = 0.0 # Initialize evaluation loss
|
|
|
with MemoryTrace() as memtrace:
|
|
|
- for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
|
|
|
+ for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
batch[key] = batch[key].to(local_rank)
|