|
@@ -87,6 +87,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
# if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
scaler.scale(loss).backward()
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
|
|
|
+ scaler.unscale_(optimizer)
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ model.clip_grad_norm_(train_config.gradient_clipping_threshold)
|
|
|
+ else:
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
optimizer.zero_grad()
|
|
@@ -95,6 +101,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
# regular backpropagation when fp16 is not used
|
|
|
loss.backward()
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ model.clip_grad_norm_(train_config.gradient_clipping_threshold)
|
|
|
+ else:
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(1)
|