Browse Source

adding cuda:0 for non-fsdp situations

Hamid Shojanazeri 1 năm trước cách đây
mục cha
commit
707af7ea24
1 tập tin đã thay đổi với 1 bổ sung1 xóa
  1. 1 1
      utils/train_utils.py

+ 1 - 1
utils/train_utils.py

@@ -84,7 +84,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda')       
+                        batch[key] = batch[key].to('cuda:0')       
                 outputs = model(**batch)
                 loss = outputs.loss
                 loss = loss / gradient_accumulation_steps