Browse Source

Merge remote-tracking branch 'upstream/main'

Jeff Tang 1 year ago
parent
commit
514655d8d3
2 changed files with 13 additions and 0 deletions
  1. 2 0
      src/llama_recipes/configs/training.py
  2. 11 0
      src/llama_recipes/utils/train_utils.py

+ 2 - 0
src/llama_recipes/configs/training.py

@@ -14,6 +14,8 @@ class train_config:
     batching_strategy: str="packing" #alternative: padding
     context_length: int=4096
     gradient_accumulation_steps: int=1
+    gradient_clipping: bool = False
+    gradient_clipping_threshold: float = 1.0
     num_epochs: int=3
     num_workers_dataloader: int=1
     lr: float=1e-4

+ 11 - 0
src/llama_recipes/utils/train_utils.py

@@ -87,6 +87,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                            scaler.unscale_(optimizer)
+                            if train_config.enable_fsdp:
+                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                            else:
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
@@ -95,6 +101,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                            if train_config.enable_fsdp:
+                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                            else:
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
                         optimizer.step()
                         optimizer.zero_grad()
                         pbar.update(1)