Procházet zdrojové kódy

Merge branch 'main' into ipex_feature

Abhilash Majumder před 1 rokem
rodič
revize
d5f39914e8

+ 10 - 0
README.md

@@ -125,6 +125,16 @@ torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --mode
 
 
 ```
 ```
 
 
+### Fine-tuning using FSDP on 70B Model
+
+If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.
+
+```bash
+
+torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+
+```
+
 ### Multi GPU Multi Node:
 ### Multi GPU Multi Node:
 
 
 ```bash
 ```bash

+ 2 - 1
configs/training.py

@@ -7,7 +7,8 @@ from typing import ClassVar
 @dataclass
 @dataclass
 class train_config:
 class train_config:
     model_name: str="PATH/to/LLAMA/7B"
     model_name: str="PATH/to/LLAMA/7B"
-    enable_fsdp: bool= False 
+    enable_fsdp: bool=False
+    low_cpu_fsdp: bool=False
     run_validation: bool=True
     run_validation: bool=True
     batch_size_training: int=4
     batch_size_training: int=4
     num_epochs: int=3
     num_epochs: int=3

+ 15 - 0
docs/inference.md

@@ -27,6 +27,21 @@ inference/samsum_prompt.txt
 ...
 ...
 ```
 ```
 
 
+**Note**
+Currently pad token by default in [HuggingFace Tokenizer is `None`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L110). We add the padding token as a special token to the tokenizer, which in this case requires to resize the token_embeddings as shown below:
+
+```python
+tokenizer.add_special_tokens(
+        {
+         
+            "pad_token": "<PAD>",
+        }
+    )
+model.resize_token_embeddings(model.config.vocab_size + 1) 
+```
+Padding would be required for batch inference. In this this [example](../inference/inference.py), batch size = 1 so essentially padding is not required. However,We added the code pointer as an example in case of batch inference.
+
+
 **Chat completion**
 **Chat completion**
 The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:
 The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:
 
 

+ 10 - 0
docs/multi_gpu.md

@@ -62,6 +62,16 @@ torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --mode
 
 
 ```
 ```
 
 
+### Fine-tuning using FSDP on 70B Model
+
+If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.
+
+```bash
+
+torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+
+```
+
 **Multi GPU multi node**:
 **Multi GPU multi node**:
 
 
 Here we use a slurm script to schedule a job with slurm over multiple nodes.
 Here we use a slurm script to schedule a job with slurm over multiple nodes.

+ 10 - 4
inference/inference.py

@@ -32,7 +32,8 @@ def main(
     length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. 
     length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. 
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
-    enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
+    enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
+    max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
     use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
     **kwargs
 ):
 ):
@@ -80,10 +81,11 @@ def main(
             "pad_token": "<PAD>",
             "pad_token": "<PAD>",
         }
         }
     )
     )
+    model.resize_token_embeddings(model.config.vocab_size + 1) 
     
     
     safety_checker = get_safety_checker(enable_azure_content_safety,
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_sensitive_topics,
-                                        enable_saleforce_content_safety,
+                                        enable_salesforce_content_safety,
                                         )
                                         )
 
 
     # Safety check of the user prompt
     # Safety check of the user prompt
@@ -98,10 +100,14 @@ def main(
             if not is_safe:
             if not is_safe:
                 print(method)
                 print(method)
                 print(report)
                 print(report)
-        print("Skipping the inferece as the prompt is not safe.")
+        print("Skipping the inference as the prompt is not safe.")
         sys.exit(1)  # Exit the program with an error status
         sys.exit(1)  # Exit the program with an error status
+        
+    if peft_model:
+        model = load_peft_model(model, peft_model)
 
 
-    batch = tokenizer(user_prompt, return_tensors="pt")
+    model.eval()
+    batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt")
     if is_xpu_available():
     if is_xpu_available():
         batch = {k: v.to("xpu") for k, v in batch.items()}
         batch = {k: v.to("xpu") for k, v in batch.items()}
     else:
     else:

+ 2 - 2
inference/safety_utils.py

@@ -154,14 +154,14 @@ class AzureSaftyChecker(object):
 # Function to determine which safety checker to use based on the options selected
 # Function to determine which safety checker to use based on the options selected
 def get_safety_checker(enable_azure_content_safety,
 def get_safety_checker(enable_azure_content_safety,
                        enable_sensitive_topics,
                        enable_sensitive_topics,
-                       enable_saleforce_content_safety,
+                       enable_salesforce_content_safety,
                        ):
                        ):
     safety_checker = []
     safety_checker = []
     if enable_azure_content_safety:
     if enable_azure_content_safety:
         safety_checker.append(AzureSaftyChecker())
         safety_checker.append(AzureSaftyChecker())
     if enable_sensitive_topics:
     if enable_sensitive_topics:
         safety_checker.append(AuditNLGSensitiveTopics())
         safety_checker.append(AuditNLGSensitiveTopics())
-    if enable_saleforce_content_safety:
+    if enable_salesforce_content_safety:
         safety_checker.append(SalesforceSafetyChecker())
         safety_checker.append(SalesforceSafetyChecker())
     return safety_checker
     return safety_checker
 
 

+ 71 - 65
llama_finetuning.py

@@ -2,71 +2,49 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import os
 import os
-import sys
-from typing import List, Union
 
 
 import fire
 import fire
 import torch
 import torch
-import transformers
-from datasets import load_dataset
-import os.path as osp
-from tqdm import tqdm
-
-# Unused imports removed
-from utils import fsdp_auto_wrap_policy
+import torch.distributed as dist
+import torch.optim as optim
+from peft import get_peft_model, prepare_model_for_int8_training
+from pkg_resources import packaging
+from torch.distributed.fsdp import (
+    FullyShardedDataParallel as FSDP,
+)
+from torch.optim.lr_scheduler import StepLR
+from torch.utils.data import DistributedSampler
 from transformers import (
 from transformers import (
     LlamaForCausalLM,
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaTokenizer,
-    AutoModelForCausalLM,
-    AutoModelForSeq2SeqLM,
-    AutoTokenizer,
+    LlamaConfig,
     default_data_collator,
     default_data_collator,
-    BitsAndBytesConfig
-)
-import torch.distributed as dist
-# Unused imports removed
-from utils.train_utils import (
-    set_tokenizer_params,
-    train,
-    evaluation,
-    freeze_transformer_layers,
-    check_frozen_layers_peft_model,
-    setup,
-    setup_environ_flags,
-    cleanup,
-    clear_gpu_cache,
-    get_parameter_dtypes,
-    print_model_size,
-    get_policies  
 )
 )
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 
-from utils.dataset_utils import get_preprocessed_dataset
+import policies
+from configs import fsdp_config, train_config
+from policies import AnyPrecisionAdamW
 
 
+from utils import fsdp_auto_wrap_policy
 from utils.config_utils import (
 from utils.config_utils import (
     update_config,
     update_config,
     generate_peft_config,
     generate_peft_config,
     generate_dataset_config,
     generate_dataset_config,
 )
 )
-from peft import get_peft_model, TaskType, prepare_model_for_int8_training
-import configs
-from torch.distributed.fsdp import (
-    FullyShardedDataParallel as FSDP,
-    MixedPrecision,
+from utils.dataset_utils import get_preprocessed_dataset
+
+from utils.train_utils import (
+    train,
+    freeze_transformer_layers,
+    setup,
+    setup_environ_flags,
+    clear_gpu_cache,
+    print_model_size,
+    get_policies
 )
 )
-from torch.utils.data import DistributedSampler
-import policies
-from policies import AnyPrecisionAdamW
-from configs import fsdp_config, train_config
-import torch.optim as optim
-from torch.optim.lr_scheduler import StepLR
-from pkg_resources import packaging
-import torch
-import torch.cuda.nccl as nccl
-import torch.distributed as dist
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from accelerate.utils import is_xpu_available
 from accelerate.utils import is_xpu_available
 
 
-
 def main(**kwargs):
 def main(**kwargs):
     # Update the configuration for the training and sharding process
     # Update the configuration for the training and sharding process
     update_config((train_config, fsdp_config), **kwargs)
     update_config((train_config, fsdp_config), **kwargs)
@@ -90,17 +68,42 @@ def main(**kwargs):
             torch.xpu.set_device(rank)
             torch.xpu.set_device(rank)
         else:
         else:
             torch.cuda.set_device(rank)
             torch.cuda.set_device(rank)
+        clear_gpu_cache(rank)
         setup_environ_flags(rank)
         setup_environ_flags(rank)
-    
+
     # Calculate gradient accumulation steps
     # Calculate gradient accumulation steps
     gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
     gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
-     
+
     # Load the pre-trained model and setup its configuration
     # Load the pre-trained model and setup its configuration
-    model = LlamaForCausalLM.from_pretrained(
-        train_config.model_name,
-        load_in_8bit=True if train_config.quantization else None,
-        device_map="auto" if train_config.quantization else None,
-    )
+    if train_config.enable_fsdp and train_config.low_cpu_fsdp:
+        """
+        for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
+        this avoids cpu oom when loading large models like llama 70B, in which case
+        model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
+        overhead and currently requires latest nightly.
+        """
+        v = packaging.version.parse(torch.__version__)
+        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
+        if not verify_latest_nightly:
+            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
+                            "please install latest nightly.")
+        if rank == 0:
+            model = LlamaForCausalLM.from_pretrained(
+                train_config.model_name,
+                load_in_8bit=True if train_config.quantization else None,
+                device_map="auto" if train_config.quantization else None,
+            )
+        else:
+            llama_config = LlamaConfig.from_pretrained(train_config.model_name)
+            with torch.device("meta"):
+                model = LlamaForCausalLM(llama_config)
+
+    else:
+        model = LlamaForCausalLM.from_pretrained(
+            train_config.model_name,
+            load_in_8bit=True if train_config.quantization else None,
+            device_map="auto" if train_config.quantization else None,
+        )
     if train_config.enable_fsdp and train_config.use_fast_kernels:
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
@@ -113,11 +116,11 @@ def main(**kwargs):
         except ImportError:
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-    
+
     # Prepare the model for int8 training if quantization is enabled
     # Prepare the model for int8 training if quantization is enabled
     if train_config.quantization:
     if train_config.quantization:
         model = prepare_model_for_int8_training(model)
         model = prepare_model_for_int8_training(model)
-        
+
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
         model.to(torch.bfloat16)
@@ -126,7 +129,7 @@ def main(**kwargs):
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
     tokenizer.add_special_tokens(
     tokenizer.add_special_tokens(
             {
             {
-            
+
                 "pad_token": "<PAD>",
                 "pad_token": "<PAD>",
             }
             }
         )
         )
@@ -134,16 +137,16 @@ def main(**kwargs):
         peft_config = generate_peft_config(train_config, kwargs)
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model = get_peft_model(model, peft_config)
         model.print_trainable_parameters()
         model.print_trainable_parameters()
-    
+
     #setting up FSDP if enable_fsdp is enabled
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
         if not train_config.use_peft and train_config.freeze_layers:
-            
+
             freeze_transformer_layers(train_config.num_freeze_layers)
             freeze_transformer_layers(train_config.num_freeze_layers)
 
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-   
+
         model = FSDP(
         model = FSDP(
             model,
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
@@ -151,6 +154,9 @@ def main(**kwargs):
             sharding_strategy=fsdp_config.sharding_strategy,
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
             device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
             limit_all_gathers=True,
             limit_all_gathers=True,
+            sync_module_states=train_config.low_cpu_fsdp,
+            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
+            if train_config.low_cpu_fsdp and rank != 0 else None,
         )
         )
         if fsdp_config.fsdp_activation_checkpointing:
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)
             policies.apply_fsdp_checkpointing(model)
@@ -161,14 +167,14 @@ def main(**kwargs):
             model.to("cuda")
             model.to("cuda")
 
 
     dataset_config = generate_dataset_config(train_config, kwargs)
     dataset_config = generate_dataset_config(train_config, kwargs)
-    
+
      # Load and preprocess the dataset for training and validation
      # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
     dataset_train = get_preprocessed_dataset(
         tokenizer,
         tokenizer,
         dataset_config,
         dataset_config,
         split="train",
         split="train",
     )
     )
-    
+
     if not train_config.enable_fsdp or rank == 0:
     if not train_config.enable_fsdp or rank == 0:
         print(f"--> Training Set Length = {len(dataset_train)}")
         print(f"--> Training Set Length = {len(dataset_train)}")
 
 
@@ -195,7 +201,7 @@ def main(**kwargs):
                 rank=dist.get_rank(),
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 num_replicas=dist.get_world_size(),
             )
             )
-        
+
     # Create DataLoaders for the training and validation dataset
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
         dataset_train,
@@ -217,7 +223,7 @@ def main(**kwargs):
             drop_last=True,
             drop_last=True,
             collate_fn=default_data_collator,
             collate_fn=default_data_collator,
         )
         )
-        
+
     # Initialize the optimizer and learning rate scheduler
     # Initialize the optimizer and learning rate scheduler
     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
         optimizer = AnyPrecisionAdamW(
         optimizer = AnyPrecisionAdamW(
@@ -239,7 +245,7 @@ def main(**kwargs):
     results = train(
     results = train(
         model,
         model,
         train_dataloader,
         train_dataloader,
-        eval_dataloader, 
+        eval_dataloader,
         tokenizer,
         tokenizer,
         optimizer,
         optimizer,
         scheduler,
         scheduler,

+ 2 - 2
model_checkpointing/checkpoint_handler.py

@@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg):
     reader = FileSystemReader(load_dir)
     reader = FileSystemReader(load_dir)
 
 
     with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
     with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
-        checkpoint = model.state_dict()
+        checkpoint = {"model": model.state_dict()}
         if rank == 0:
         if rank == 0:
             ck = checkpoint.keys()
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
@@ -78,7 +78,7 @@ def load_model_sharded(model, rank, cfg):
             print(f"checkpoint after load_state_dict()")
             print(f"checkpoint after load_state_dict()")
             ck = checkpoint.keys()
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
-        model.load_state_dict(checkpoint)
+        model.load_state_dict(checkpoint["model"])
     if rank == 0:
     if rank == 0:
         print(f"Sharded state checkpoint loaded from {load_dir}")
         print(f"Sharded state checkpoint loaded from {load_dir}")
 
 

+ 1 - 1
utils/train_utils.py

@@ -159,7 +159,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         lr_scheduler.step()
         lr_scheduler.step()
           
           
         if train_config.run_validation:
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
+            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             checkpoint_start_time = time.perf_counter()
             checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp: