Browse Source

Add missing amp context if use_fp16 is enabled

Matthias Reso 1 year ago
parent
commit
33925f71e6
1 changed files with 6 additions and 2 deletions
  1. 6 2
      src/llama_recipes/utils/train_utils.py

+ 6 - 2
src/llama_recipes/utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import time
 import yaml
+from contextlib import nullcontext
 from pathlib import Path
 from pkg_resources import packaging
 
@@ -54,7 +55,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler() 
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
+    autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
+    
     train_prep = []
     train_loss = []
     val_prep = []
@@ -76,7 +79,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         batch[key] = batch[key].to(local_rank)
                     else:
                         batch[key] = batch[key].to('cuda:0')              
-                loss = model(**batch).loss
+                with autocast():
+                    loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
                 if train_config.use_fp16: