Browse Source

Support FSDP scenario

gaopengzhi 1 year ago
parent
commit
bb7c6c1e33
1 changed files with 8 additions and 2 deletions
  1. 8 2
      src/llama_recipes/utils/train_utils.py

+ 8 - 2
src/llama_recipes/utils/train_utils.py

@@ -89,7 +89,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         if train_config.gradient_clipping > 0.0:
                             scaler.unscale_(optimizer)
-                            torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
+                            if train_config.enable_fsdp:
+                                model.clip_grad_norm_(train_config.gradient_clipping)
+                            else:
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
@@ -99,7 +102,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         if train_config.gradient_clipping > 0.0:
-                            torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
+                            if train_config.enable_fsdp:
+                                model.clip_grad_norm_(train_config.gradient_clipping)
+                            else:
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping)
                         optimizer.step()
                         optimizer.zero_grad()
                         pbar.update(1)