|
@@ -89,7 +89,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
if train_config.gradient_clipping > 0.0:
|
|
|
scaler.unscale_(optimizer)
|
|
|
- torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ model.clip_grad_norm_(train_config.gradient_clipping)
|
|
|
+ else:
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
optimizer.zero_grad()
|
|
@@ -99,7 +102,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
loss.backward()
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
if train_config.gradient_clipping > 0.0:
|
|
|
- torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ model.clip_grad_norm_(train_config.gradient_clipping)
|
|
|
+ else:
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(1)
|