Browse Source

adding cuda:0 for non-fsdp situations

Hamid Shojanazeri 1 year ago
parent
commit
707af7ea24
1 changed files with 1 additions and 1 deletions
  1. 1 1
      utils/train_utils.py

+ 1 - 1
utils/train_utils.py

@@ -84,7 +84,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda')       
+                        batch[key] = batch[key].to('cuda:0')       
                 outputs = model(**batch)
                 loss = outputs.loss
                 loss = loss / gradient_accumulation_steps