Просмотр исходного кода

enable xpu finetuning and inference

abhilash1910 1 год назад
Родитель
Сommit
ed7ba999a9
6 измененных файлов с 105 добавлено и 34 удалено
  1. 8 2
      inference/chat_completion.py
  2. 9 2
      inference/inference.py
  3. 5 1
      inference/vLLM_inference.py
  4. 14 4
      llama_finetuning.py
  5. 33 14
      utils/memory_utils.py
  6. 36 11
      utils/train_utils.py

+ 8 - 2
inference/chat_completion.py

@@ -55,7 +55,10 @@ def main(
 
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(seed)
+    else:
+        torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     model = load_model(model_name, quantization)
     if peft_model:
@@ -105,7 +108,10 @@ def main(
                 sys.exit(1)  # Exit the program with an error status
             tokens= torch.tensor(chat).long()
             tokens= tokens.unsqueeze(0)
-            tokens= tokens.to("cuda:0")
+            if is_xpu_available():
+                tokens= tokens.to("xpu:0")
+            else:
+                tokens= tokens.to("cuda:0")
             outputs = model.generate(
                 tokens,
                 max_new_tokens=max_new_tokens,

+ 9 - 2
inference/inference.py

@@ -13,6 +13,7 @@ from typing import List
 from transformers import LlamaTokenizer
 from safety_utils import get_safety_checker
 from model_utils import load_model, load_peft_model, load_llama_from_config
+from accelerate.utils import is_xpu_available
 
 def main(
     model_name,
@@ -48,7 +49,10 @@ def main(
         sys.exit(1)
     
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(seed)
+    else:
+        torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     
     model = load_model(model_name, quantization)
@@ -98,7 +102,10 @@ def main(
         sys.exit(1)  # Exit the program with an error status
 
     batch = tokenizer(user_prompt, return_tensors="pt")
-    batch = {k: v.to("cuda") for k, v in batch.items()}
+    if is_xpu_available():
+        batch = {k: v.to("xpu") for k, v in batch.items()}
+    else:
+        batch = {k: v.to("cuda") for k, v in batch.items()}
     start = time.perf_counter()
     with torch.no_grad():
         outputs = model.generate(

+ 5 - 1
inference/vLLM_inference.py

@@ -14,8 +14,12 @@ from transformers import (
 )
 from vllm import LLM
 from vllm import LLM, SamplingParams
+from accelerate.utils import is_xpu_available
 
-torch.cuda.manual_seed(42)
+if is_xpu_available():
+    torch.xpu.manual_seed(42)
+else:
+    torch.cuda.manual_seed(42)
 torch.manual_seed(42)
 
 def load_model(model_name, tp_size=1):

+ 14 - 4
llama_finetuning.py

@@ -64,6 +64,7 @@ import torch
 import torch.cuda.nccl as nccl
 import torch.distributed as dist
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from accelerate.utils import is_xpu_available
 
 
 def main(**kwargs):
@@ -71,7 +72,10 @@ def main(**kwargs):
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(train_config.seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(train_config.seed)
+    else:
+        torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
 
     if train_config.enable_fsdp:
@@ -82,7 +86,10 @@ def main(**kwargs):
         world_size = int(os.environ["WORLD_SIZE"])
 
     if torch.distributed.is_initialized():
-        torch.cuda.set_device(rank)
+        if is_xpu_available():
+            torch.xpu.set_device(rank)
+        else:
+            torch.cuda.set_device(rank)
         setup_environ_flags(rank)
     
     # Calculate gradient accumulation steps
@@ -142,13 +149,16 @@ def main(**kwargs):
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
-            device_id=torch.cuda.current_device(),
+            device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
             limit_all_gathers=True,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)
     elif not train_config.quantization and not train_config.enable_fsdp:
-        model.to("cuda")
+        if is_xpu_available():
+            model.to("xpu:0")
+        else:
+            model.to("cuda")
 
     dataset_config = generate_dataset_config(train_config, kwargs)
     

+ 33 - 14
utils/memory_utils.py

@@ -8,6 +8,7 @@ import threading
 import numpy as np
 import psutil
 import torch
+from accelerate.utils import is_xpu_available
 
 def byte2gb(x):
     return int(x / 2**30)
@@ -15,9 +16,14 @@ def byte2gb(x):
 class MemoryTrace:
     def __enter__(self):
         gc.collect()
-        torch.cuda.empty_cache()
-        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
-        self.begin = byte2gb(torch.cuda.memory_allocated())
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            torch.xpu.reset_max_memory_allocated()   # reset the peak gauge to zero
+            self.begin = byte2gb(torch.xpu.memory_allocated())
+        elif torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
+            self.begin = byte2gb(torch.cuda.memory_allocated())
         self.process = psutil.Process()
         self.cpu_begin = byte2gb(self.cpu_mem_used())
         self.peak_monitoring = True
@@ -46,17 +52,30 @@ class MemoryTrace:
         self.peak_monitoring = False
 
         gc.collect()
-        torch.cuda.empty_cache()
-        self.end = byte2gb(torch.cuda.memory_allocated())
-        self.peak = byte2gb(torch.cuda.max_memory_allocated())
-        cuda_info = torch.cuda.memory_stats()
-        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-        self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
-        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-        self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
-        self.used = byte2gb(self.end - self.begin)
-        self.peaked = byte2gb(self.peak - self.begin)
-        self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            self.end = byte2gb(torch.xpu.memory_allocated())
+            self.peak = byte2gb(torch.xpu.max_memory_allocated())
+            xpu_info = torch.xpu.memory_stats()
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
+        else:
+            torch.cuda.empty_cache()
+            self.end = byte2gb(torch.cuda.memory_allocated())
+            self.peak = byte2gb(torch.cuda.max_memory_allocated())
+            cuda_info = torch.cuda.memory_stats()
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
 
         self.cpu_end = self.cpu_mem_used()
         self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)

+ 36 - 11
utils/train_utils.py

@@ -36,6 +36,7 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from pathlib import Path
 sys.path.append(str(Path(__file__).resolve().parent.parent))
 from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from accelerate.utils import is_xpu_available, is_ccl_available
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
@@ -113,7 +114,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)    
         # Reducing total_loss across all devices if there's more than one CUDA device
-        if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
+        if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+            dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
+        elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
         train_epoch_loss = total_loss / len(train_dataloader)
         if train_config.enable_fsdp:
@@ -125,17 +128,29 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         if train_config.enable_fsdp:
             if rank==0:
+                if is_xpu_available():
+                    print(f"Max XPU memory allocated was {memtrace.peak} GB")
+                    print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
+                    print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
+                    print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}")
+                else:
+                    print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+                    print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+                    print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+                    print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        else:
+            if is_xpu_available():
+                print(f"Max XPU memory allocated was {memtrace.peak} GB")
+                print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
+                print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
+                print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}")
+            else:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
                 print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
                 print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
                 print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
                 print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        else:
-            print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-            print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-            print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-            print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-            print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
         
         # Update the learning rate as needed
         lr_scheduler.step()
@@ -259,6 +274,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             )
     
     # If there's more than one CUDA device, reduce evaluation loss across all devices
+    if is_xpu_available() and (torch.cuda.device_count() > 1 and train_config.enable_fsdp):
+        dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     
@@ -292,7 +309,11 @@ def check_frozen_layers_peft_model(model):
                 
 def setup():
     """Initialize the process group for distributed training"""
-    dist.init_process_group("nccl")
+    if is_ccl_available():
+        # distributed training on xpus
+        dist.init_process_group("ccl")
+    else:
+        dist.init_process_group("nccl")
 
 
 def setup_environ_flags(rank):
@@ -316,7 +337,10 @@ def clear_gpu_cache(rank=None):
     """Clear the GPU cache for all ranks"""
     if rank == 0:
         print(f"Clearing GPU cache for all ranks")
-    torch.cuda.empty_cache()
+    if is_xpu_available():
+        torch.xpu_empty_cache()
+    else:
+        torch.cuda.empty_cache()
 
 
 def get_parameter_dtypes(model):
@@ -348,13 +372,14 @@ def print_model_size(model, config, rank: int = 0) -> None:
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
     
-    verify_bfloat_support = (
+    verify_bfloat_support = ((
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
     and packaging.version.parse(torch.version.cuda).release >= (11, 0)
     and dist.is_nccl_available()
     and nccl.version() >= (2, 10)
-    )
+    ) or
+    (is_xpu_available()))
 
 
     mixed_precision_policy = None