@@ -85,7 +85,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
for key in batch.keys():
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
- elif not train_config.quantization:
+ else:
batch[key] = batch[key].to('cuda')
outputs = model(**batch)
loss = outputs.loss