|
@@ -76,7 +76,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.enable_fsdp:
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
|
else:
|
|
|
- batch[key] = batch[key].to('cuda:0')
|
|
|
+ if is_xpu_available():
|
|
|
+ batch[key] = batch[key].to('xpu:0')
|
|
|
+ else:
|
|
|
+ batch[key] = batch[key].to('cuda:0')
|
|
|
loss = model(**batch).loss
|
|
|
loss = loss / gradient_accumulation_steps
|
|
|
total_loss += loss.detach().float()
|
|
@@ -247,7 +250,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
if train_config.enable_fsdp:
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
|
else:
|
|
|
- batch[key] = batch[key].to('cuda:0')
|
|
|
+ if is_xpu_available():
|
|
|
+ batch[key] = batch[key].to('xpu:0')
|
|
|
+ else:
|
|
|
+ batch[key] = batch[key].to('cuda:0')
|
|
|
# Ensure no gradients are computed for this scope to save memory
|
|
|
with torch.no_grad():
|
|
|
# Forward pass and compute loss
|
|
@@ -261,7 +267,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
)
|
|
|
|
|
|
# If there's more than one CUDA device, reduce evaluation loss across all devices
|
|
|
- if is_xpu_available() and (torch.cuda.device_count() > 1 and train_config.enable_fsdp):
|
|
|
+ if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
|
|
|
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
|
|
|
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
|