|
@@ -4,6 +4,7 @@
|
|
|
import os
|
|
|
import time
|
|
|
import yaml
|
|
|
+from contextlib import nullcontext
|
|
|
from pathlib import Path
|
|
|
from pkg_resources import packaging
|
|
|
|
|
@@ -54,7 +55,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
if train_config.enable_fsdp:
|
|
|
- world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+ autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
|
|
|
+
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
@@ -76,7 +79,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
|
else:
|
|
|
batch[key] = batch[key].to('cuda:0')
|
|
|
- loss = model(**batch).loss
|
|
|
+ with autocast():
|
|
|
+ loss = model(**batch).loss
|
|
|
loss = loss / gradient_accumulation_steps
|
|
|
total_loss += loss.detach().float()
|
|
|
if train_config.use_fp16:
|