Browse Source

Add gradient clipping feature

gaopengzhi 1 year ago
parent
commit
04befdef69
2 changed files with 5 additions and 0 deletions
  1. 1 0
      src/llama_recipes/configs/training.py
  2. 4 0
      src/llama_recipes/utils/train_utils.py

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -14,6 +14,7 @@ class train_config:
     batching_strategy: str="packing" #alternative: padding
     context_length: int=4096
     gradient_accumulation_steps: int=1
+    gradient_clipping: float = 1.0
     num_epochs: int=3
     num_workers_dataloader: int=1
     lr: float=1e-4

+ 4 - 0
src/llama_recipes/utils/train_utils.py

@@ -86,6 +86,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 if train_config.use_fp16:
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()
+                    if train_config.gradient_clipping > 0.0:
+                        torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         scaler.step(optimizer)
                         scaler.update()
@@ -94,6 +96,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 else:
                     # regular backpropagation when fp16 is not used
                     loss.backward()
+                    if train_config.gradient_clipping > 0.0:
+                        torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.zero_grad()