Forráskód Böngészése

fix fsdp construction on low_cpu_fsdp

lchu 1 éve
szülő
commit
c19c5c69aa
1 módosított fájl, 18 hozzáadás és 16 törlés
  1. 18 16
      llama_finetuning.py

+ 18 - 16
llama_finetuning.py

@@ -39,7 +39,7 @@ from utils.train_utils import (
     clear_gpu_cache,
     get_parameter_dtypes,
     print_model_size,
-    get_policies  
+    get_policies
 )
 
 from utils.dataset_utils import get_preprocessed_dataset
@@ -88,10 +88,10 @@ def main(**kwargs):
     if torch.distributed.is_initialized():
         torch.cuda.set_device(rank)
         setup_environ_flags(rank)
-    
+
     # Calculate gradient accumulation steps
     gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
-     
+
     # Load the pre-trained model and setup its configuration
     if train_config.enable_fsdp and train_config.low_cpu_fsdp:
         # for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
@@ -113,19 +113,20 @@ def main(**kwargs):
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
             with torch.device("meta"):
                 model = LlamaForCausalLM(llama_config)
+
     else:
         model = LlamaForCausalLM.from_pretrained(
             train_config.model_name,
             load_in_8bit=True if train_config.quantization else None,
             device_map="auto" if train_config.quantization else None,
         )
-    
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-    
+
     # Prepare the model for int8 training if quantization is enabled
     if train_config.quantization:
         model = prepare_model_for_int8_training(model)
-        
+
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
@@ -134,7 +135,7 @@ def main(**kwargs):
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
     tokenizer.add_special_tokens(
             {
-            
+
                 "pad_token": "<PAD>",
             }
         )
@@ -142,11 +143,11 @@ def main(**kwargs):
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model.print_trainable_parameters()
-    
+
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
-            
+
             freeze_transformer_layers(train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
@@ -159,8 +160,9 @@ def main(**kwargs):
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),
             limit_all_gathers=True,
-            sync_module_states=True,
-            param_init_fn=None if rank == 0 else lambda module: module.to_empty(device=torch.device("cuda"), recurse=False),
+            sync_module_states=True if train_config.low_cpu_fsdp else False,
+            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
+            if train_config.low_cpu_fsdp and rank != 0 else None,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)
@@ -168,14 +170,14 @@ def main(**kwargs):
         model.to("cuda")
 
     dataset_config = generate_dataset_config(train_config, kwargs)
-    
+
      # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
         tokenizer,
         dataset_config,
         split="train",
     )
-    
+
     if not train_config.enable_fsdp or rank == 0:
         print(f"--> Training Set Length = {len(dataset_train)}")
 
@@ -202,7 +204,7 @@ def main(**kwargs):
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
             )
-        
+
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -224,7 +226,7 @@ def main(**kwargs):
             drop_last=True,
             collate_fn=default_data_collator,
         )
-        
+
     # Initialize the optimizer and learning rate scheduler
     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
         optimizer = AnyPrecisionAdamW(
@@ -246,7 +248,7 @@ def main(**kwargs):
     results = train(
         model,
         train_dataloader,
-        eval_dataloader, 
+        eval_dataloader,
         tokenizer,
         optimizer,
         scheduler,