Explorar el Código

Feature : Enable Intel GPU/XPU finetuning and inference (#116)

Hamid Shojanazeri hace 1 año
padre
commit
dbfea484c6

+ 9 - 3
examples/chat_completion/chat_completion.py

@@ -13,7 +13,7 @@ from transformers import LlamaTokenizer
 from llama_recipes.inference.chat_utils import read_dialogs_from_file, format_tokens
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 from llama_recipes.inference.safety_utils import get_safety_checker
-
+from accelerate.utils import is_xpu_available
 
 def main(
     model_name,
@@ -55,7 +55,10 @@ def main(
 
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(seed)
+    else:
+        torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     model = load_model(model_name, quantization)
     if peft_model:
@@ -105,7 +108,10 @@ def main(
                 sys.exit(1)  # Exit the program with an error status
             tokens= torch.tensor(chat).long()
             tokens= tokens.unsqueeze(0)
-            tokens= tokens.to("cuda:0")
+            if is_xpu_available():
+                tokens= tokens.to("xpu:0")
+            else:
+                tokens= tokens.to("cuda:0")
             outputs = model.generate(
                 input_ids=tokens,
                 max_new_tokens=max_new_tokens,

+ 9 - 2
examples/inference.py

@@ -14,6 +14,7 @@ from transformers import LlamaTokenizer
 from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 
+from accelerate.utils import is_xpu_available
 
 def main(
     model_name,
@@ -72,7 +73,10 @@ def main(
         sys.exit(1)  # Exit the program with an error status
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(seed)
+    else:
+        torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     
     model = load_model(model_name, quantization)
@@ -97,8 +101,11 @@ def main(
     tokenizer.pad_token = tokenizer.eos_token
     
     batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
+    if is_xpu_available():
+        batch = {k: v.to("xpu") for k, v in batch.items()}
+    else:
+        batch = {k: v.to("cuda") for k, v in batch.items()}
 
-    batch = {k: v.to("cuda") for k, v in batch.items()}
     start = time.perf_counter()
     with torch.no_grad():
         outputs = model.generate(

+ 5 - 1
examples/vllm/inference.py

@@ -6,9 +6,13 @@ import fire
 import torch
 from vllm import LLM
 from vllm import LLM, SamplingParams
+from accelerate.utils import is_xpu_available
 
+if is_xpu_available():
+    torch.xpu.manual_seed(42)
+else:
+    torch.cuda.manual_seed(42)
 
-torch.cuda.manual_seed(42)
 torch.manual_seed(42)
 
 def load_model(model_name, tp_size=1):

+ 14 - 5
src/llama_recipes/finetuning.py

@@ -44,7 +44,7 @@ from llama_recipes.utils.train_utils import (
     print_model_size,
     get_policies
 )
-
+from accelerate.utils import is_xpu_available
 
 def main(**kwargs):
     # Update the configuration for the training and sharding process
@@ -52,7 +52,10 @@ def main(**kwargs):
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(train_config.seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(train_config.seed)
+    else:
+        torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
     random.seed(train_config.seed)
 
@@ -64,7 +67,10 @@ def main(**kwargs):
         world_size = int(os.environ["WORLD_SIZE"])
 
     if torch.distributed.is_initialized():
-        torch.cuda.set_device(local_rank)
+        if is_xpu_available():
+            torch.xpu.set_device(local_rank)
+        else:
+            torch.cuda.set_device(local_rank)
         clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
 
@@ -148,7 +154,7 @@ def main(**kwargs):
             cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
-            device_id=torch.cuda.current_device(),
+            device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,
             param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
@@ -157,7 +163,10 @@ def main(**kwargs):
         if fsdp_config.fsdp_activation_checkpointing:
             apply_fsdp_checkpointing(model)
     elif not train_config.quantization and not train_config.enable_fsdp:
-        model.to("cuda")
+        if is_xpu_available():
+            model.to("xpu:0")
+        else:
+            model.to("cuda")
 
     dataset_config = generate_dataset_config(train_config, kwargs)
 

+ 33 - 14
src/llama_recipes/utils/memory_utils.py

@@ -6,6 +6,7 @@ import psutil
 import threading
 
 import torch
+from accelerate.utils import is_xpu_available
 
 def byte2gb(x):
     return int(x / 2**30)
@@ -13,9 +14,14 @@ def byte2gb(x):
 class MemoryTrace:
     def __enter__(self):
         gc.collect()
-        torch.cuda.empty_cache()
-        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
-        self.begin = byte2gb(torch.cuda.memory_allocated())
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            torch.xpu.reset_max_memory_allocated()   # reset the peak gauge to zero
+            self.begin = byte2gb(torch.xpu.memory_allocated())
+        elif torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
+            self.begin = byte2gb(torch.cuda.memory_allocated())
         self.process = psutil.Process()
         self.cpu_begin = byte2gb(self.cpu_mem_used())
         self.peak_monitoring = True
@@ -44,17 +50,30 @@ class MemoryTrace:
         self.peak_monitoring = False
 
         gc.collect()
-        torch.cuda.empty_cache()
-        self.end = byte2gb(torch.cuda.memory_allocated())
-        self.peak = byte2gb(torch.cuda.max_memory_allocated())
-        cuda_info = torch.cuda.memory_stats()
-        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-        self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
-        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-        self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
-        self.used = byte2gb(self.end - self.begin)
-        self.peaked = byte2gb(self.peak - self.begin)
-        self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            self.end = byte2gb(torch.xpu.memory_allocated())
+            self.peak = byte2gb(torch.xpu.max_memory_allocated())
+            xpu_info = torch.xpu.memory_stats()
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
+        else:
+            torch.cuda.empty_cache()
+            self.end = byte2gb(torch.cuda.memory_allocated())
+            self.peak = byte2gb(torch.cuda.max_memory_allocated())
+            cuda_info = torch.cuda.memory_stats()
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
 
         self.cpu_end = self.cpu_mem_used()
         self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)

+ 50 - 15
src/llama_recipes/utils/train_utils.py

@@ -23,7 +23,7 @@ import json
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
-
+from accelerate.utils import is_xpu_available, is_ccl_available
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
@@ -89,9 +89,16 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             for step, batch in enumerate(train_dataloader):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
-                        batch[key] = batch[key].to(local_rank)
+                        if is_xpu_available():
+                            batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+                        else:
+                            batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda:0')
+
+                        if is_xpu_available():
+                            batch[key] = batch[key].to('xpu:0')
+                        else:
+                            batch[key] = batch[key].to('cuda:0')              
                 with autocast():
                     loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
@@ -135,7 +142,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
-        if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
+        if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+            dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
+        elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
         train_epoch_loss = total_loss / len(train_dataloader)
         if train_config.enable_fsdp:
@@ -147,16 +156,28 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         if train_config.enable_fsdp:
             if rank==0:
+                if is_xpu_available():
+                    print(f"Max XPU memory allocated was {memtrace.peak} GB")
+                    print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
+                    print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
+                    print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
+                else:
+                    print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+                    print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+                    print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+                    print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        else:
+            if is_xpu_available():
+                print(f"Max XPU memory allocated was {memtrace.peak} GB")
+                print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
+                print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
+                print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
+            else:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
                 print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
                 print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
                 print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        else:
-            print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-            print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-            print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-            print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
 
         # Update the learning rate as needed
@@ -279,7 +300,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                 else:
-                    batch[key] = batch[key].to('cuda:0')
+                    if is_xpu_available():
+                        batch[key] = batch[key].to('xpu:0')
+                    else:
+                        batch[key] = batch[key].to('cuda:0')  
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss
@@ -297,6 +321,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             )
 
     # If there's more than one CUDA device, reduce evaluation loss across all devices
+    if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+        dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
 
@@ -330,7 +356,11 @@ def check_frozen_layers_peft_model(model):
 
 def setup():
     """Initialize the process group for distributed training"""
-    dist.init_process_group("nccl")
+    if is_ccl_available():
+        # distributed training on xpus
+        dist.init_process_group("ccl")
+    else:
+        dist.init_process_group("nccl")
 
 
 def setup_environ_flags(rank):
@@ -354,7 +384,10 @@ def clear_gpu_cache(rank=None):
     """Clear the GPU cache for all ranks"""
     if rank == 0:
         print(f"Clearing GPU cache for all ranks")
-    torch.cuda.empty_cache()
+    if is_xpu_available():
+        torch.xpu_empty_cache()
+    else:
+        torch.cuda.empty_cache()
 
 
 def get_parameter_dtypes(model):
@@ -386,13 +419,15 @@ def print_model_size(model, config, rank: int = 0) -> None:
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
 
-    verify_bfloat_support = (
+    
+    verify_bfloat_support = ((
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
     and packaging.version.parse(torch.version.cuda).release >= (11, 0)
     and dist.is_nccl_available()
     and nccl.version() >= (2, 10)
-    )
+    ) or
+    (is_xpu_available()))
 
 
     mixed_precision_policy = None

+ 83 - 0
utils/memory_utils.py

@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+import gc
+import os
+import sys
+import threading
+
+import numpy as np
+import psutil
+import torch
+from accelerate.utils import is_xpu_available
+
+def byte2gb(x):
+    return int(x / 2**30)
+# This context manager is used to track the peak memory usage of the process
+class MemoryTrace:
+    def __enter__(self):
+        gc.collect()
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            torch.xpu.reset_max_memory_allocated()   # reset the peak gauge to zero
+            self.begin = byte2gb(torch.xpu.memory_allocated())
+        elif torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
+            self.begin = byte2gb(torch.cuda.memory_allocated())
+        self.process = psutil.Process()
+        self.cpu_begin = byte2gb(self.cpu_mem_used())
+        self.peak_monitoring = True
+        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
+        peak_monitor_thread.daemon = True
+        peak_monitor_thread.start()
+        return self
+
+    def cpu_mem_used(self):
+        """get resident set size memory for the current process"""
+        return self.process.memory_info().rss
+
+    def peak_monitor_func(self):
+        self.cpu_peak = -1
+
+        while True:
+            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
+
+            # can't sleep or will not catch the peak right (this comment is here on purpose)
+            # time.sleep(0.001) # 1msec
+
+            if not self.peak_monitoring:
+                break
+
+    def __exit__(self, *exc):
+        self.peak_monitoring = False
+
+        gc.collect()
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            self.end = byte2gb(torch.xpu.memory_allocated())
+            self.peak = byte2gb(torch.xpu.max_memory_allocated())
+            xpu_info = torch.xpu.memory_stats()
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
+        else:
+            torch.cuda.empty_cache()
+            self.end = byte2gb(torch.cuda.memory_allocated())
+            self.peak = byte2gb(torch.cuda.max_memory_allocated())
+            cuda_info = torch.cuda.memory_stats()
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
+
+        self.cpu_end = self.cpu_mem_used()
+        self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
+        self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
+        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")