@@ -84,7 +84,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
else:
- batch[key] = batch[key].to('cuda')
+ batch[key] = batch[key].to('cuda:0')
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps