Jelajahi Sumber

fixing the cuda id

Hamid Shojanazeri 1 tahun lalu
induk
melakukan
a7156dfb5d
1 mengubah file dengan 1 tambahan dan 1 penghapusan
  1. 1 1
      utils/train_utils.py

+ 1 - 1
utils/train_utils.py

@@ -199,7 +199,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                     batch[key] = batch[key].to(local_rank)
                 else:
                 else:
-                    batch[key] = batch[key].to('cuda')
+                    batch[key] = batch[key].to('cuda:0')
             # Ensure no gradients are computed for this scope to save memory
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
             with torch.no_grad():
                 # Forward pass and compute loss
                 # Forward pass and compute loss