Selaa lähdekoodia

fixing scaler for both fsdp and non fsdp (#34)

Geeta Chauhan 1 vuosi sitten
vanhempi
commit
1e0f8a1fb7
1 muutettua tiedostoa jossa 5 lisäystä ja 7 poistoa
  1. 5 7
      utils/train_utils.py

+ 5 - 7
utils/train_utils.py

@@ -35,11 +35,6 @@ from pathlib import Path
 sys.path.append(str(Path(__file__).resolve().parent.parent))
 from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
 
-scaler = ShardedGradScaler()
-
-
-
-
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
@@ -67,8 +62,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     Returns: results dictionary containing average training and validation perplexity and loss
     """
     # Create a gradient scaler for fp16
-    scaler = torch.cuda.amp.GradScaler() if train_config.use_fp16 else None
-
+    if train_config.use_fp16 and train_config.enable_fsdp:
+        scaler = ShardedGradScaler()
+    elif train_config.use_fp16 and not train_config.enable_fsdp:
+        scaler = torch.cuda.amp.GradScaler() 
+        
     train_prep = []
     train_loss = []
     val_prep = []