Browse Source

Add gradient_clipping and gradient_clipping_threshold parameters

gaopengzhi 1 year ago
parent
commit
e2797abe9b
2 changed files with 8 additions and 7 deletions
  1. 2 1
      src/llama_recipes/configs/training.py
  2. 6 6
      src/llama_recipes/utils/train_utils.py

+ 2 - 1
src/llama_recipes/configs/training.py

@@ -14,7 +14,8 @@ class train_config:
     batching_strategy: str="packing" #alternative: padding
     context_length: int=4096
     gradient_accumulation_steps: int=1
-    gradient_clipping: float = 1.0
+    gradient_clipping: bool = False
+    gradient_clipping_threshold: float = 1.0
     num_epochs: int=3
     num_workers_dataloader: int=1
     lr: float=1e-4

+ 6 - 6
src/llama_recipes/utils/train_utils.py

@@ -87,12 +87,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                        if train_config.gradient_clipping > 0.0:
+                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
                             scaler.unscale_(optimizer)
                             if train_config.enable_fsdp:
-                                model.clip_grad_norm_(train_config.gradient_clipping)
+                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
                             else:
-                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
@@ -101,11 +101,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                        if train_config.gradient_clipping > 0.0:
+                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
                             if train_config.enable_fsdp:
-                                model.clip_grad_norm_(train_config.gradient_clipping)
+                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
                             else:
-                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
                         optimizer.step()
                         optimizer.zero_grad()
                         pbar.update(1)