|
@@ -35,11 +35,6 @@ from pathlib import Path
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
|
from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
|
|
|
-scaler = ShardedGradScaler()
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
tokenizer.pad_token_id = 0
|
|
|
tokenizer.padding_side = "left"
|
|
@@ -67,8 +62,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
Returns: results dictionary containing average training and validation perplexity and loss
|
|
|
"""
|
|
|
# Create a gradient scaler for fp16
|
|
|
- scaler = torch.cuda.amp.GradScaler() if train_config.use_fp16 else None
|
|
|
-
|
|
|
+ if train_config.use_fp16 and train_config.enable_fsdp:
|
|
|
+ scaler = ShardedGradScaler()
|
|
|
+ elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
|
+ scaler = torch.cuda.amp.GradScaler()
|
|
|
+
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
@@ -85,7 +83,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
|
- elif not train_config.quantization:
|
|
|
+ else:
|
|
|
batch[key] = batch[key].to('cuda')
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
@@ -105,11 +103,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
loss.backward()
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
optimizer.step()
|
|
|
- lr_scheduler.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
- print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
-
|
|
|
+ print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
# Reducing total_loss across all devices if there's more than one CUDA device
|
|
|
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
|
|
@@ -123,7 +119,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
|
|
|
print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
|
|
|
print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
|
|
|
-
|
|
|
+
|
|
|
+ # Update the learning rate as needed
|
|
|
+ lr_scheduler.step()
|
|
|
+
|
|
|
if train_config.run_validation:
|
|
|
eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
|
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
@@ -159,7 +158,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
val_loss.append(best_val_loss)
|
|
|
val_prep.append(eval_ppl)
|
|
|
|
|
|
+
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
|
|
|
+ lr_scheduler.step()
|
|
|
|
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
|
avg_train_loss = sum(train_loss)/len(train_loss)
|