소스 검색

Fix bugs in data loading

abhilash1910 1 년 전
부모
커밋
82d3ca6e06
1개의 변경된 파일9개의 추가작업 그리고 3개의 파일을 삭제
  1. 9 3
      utils/train_utils.py

+ 9 - 3
utils/train_utils.py

@@ -89,7 +89,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda:0')              
+                        if is_xpu_available():
+                            batch[key] = batch[key].to('xpu:0')
+                        else:
+                            batch[key] = batch[key].to('cuda:0')              
                 loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
@@ -260,7 +263,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                 else:
-                    batch[key] = batch[key].to('cuda:0')
+                    if is_xpu_available():
+                        batch[key] = batch[key].to('xpu:0')
+                    else:
+                        batch[key] = batch[key].to('cuda:0')  
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss
@@ -274,7 +280,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             )
     
     # If there's more than one CUDA device, reduce evaluation loss across all devices
-    if is_xpu_available() and (torch.cuda.device_count() > 1 and train_config.enable_fsdp):
+    if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)