Browse Source

Refactor gradient clipping feature

gaopengzhi 1 year ago
parent
commit
b1d9efd155
1 changed files with 5 additions and 4 deletions
  1. 5 4
      src/llama_recipes/utils/train_utils.py

+ 5 - 4
src/llama_recipes/utils/train_utils.py

@@ -86,9 +86,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 if train_config.use_fp16:
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()
-                    if train_config.gradient_clipping > 0.0:
-                        torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        if train_config.gradient_clipping > 0.0:
+                            scaler.unscale_(optimizer)
+                            torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
@@ -96,9 +97,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 else:
                     # regular backpropagation when fp16 is not used
                     loss.backward()
-                    if train_config.gradient_clipping > 0.0:
-                        torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        if train_config.gradient_clipping > 0.0:
+                            torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
                         optimizer.step()
                         optimizer.zero_grad()
                         pbar.update(1)