|
@@ -86,6 +86,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.use_fp16:
|
|
|
# if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
scaler.scale(loss).backward()
|
|
|
+ if train_config.gradient_clipping > 0.0:
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
@@ -94,6 +96,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
else:
|
|
|
# regular backpropagation when fp16 is not used
|
|
|
loss.backward()
|
|
|
+ if train_config.gradient_clipping > 0.0:
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|