Просмотр исходного кода

fixing the full state path in checkpoint handler+loss report calculation (#51)

Geeta Chauhan 1 год назад
Родитель
Сommit
1387b76e11

Разница между файлами не показана из-за своего большого размера
+ 13 - 0
docs/FAQ.md


+ 0 - 2
model_checkpointing/__init__.py

@@ -4,8 +4,6 @@
 from .checkpoint_handler import (
     load_model_checkpoint,
     save_model_checkpoint,
-    save_distributed_model_checkpoint,
-    load_distributed_model_checkpoint,
     load_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,

+ 31 - 98
model_checkpointing/checkpoint_handler.py

@@ -44,7 +44,7 @@ def get_date_of_run():
 fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
 
 
-def load_model_sharded(model, rank, cfg, verbose=True):
+def load_model_sharded(model, rank, cfg):
     # torch.manual_seed(103)
     folder_name = (
         cfg.dist_checkpoint_root_folder
@@ -83,7 +83,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
         print(f"Sharded state checkpoint loaded from {load_dir}")
 
 
-def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
+def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
     """save model and optimizer via sharded_state_dict to save_dir"""
     
     folder_name = (
@@ -142,7 +142,14 @@ def save_model_checkpoint(
     if rank == 0:
         print(f"--> saving model ...")
         # create save path
-        save_dir = Path.cwd() / cfg.checkpoint_folder
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
         save_name = cfg.model_name + "-" + str(epoch) + ".pt"
         save_full_path = str(save_dir) + "/" + save_name
@@ -150,12 +157,12 @@ def save_model_checkpoint(
         # save model
         torch.save(cpu_state, save_full_path)
 
-        if cfg.verbose:
-            print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
+        
+        print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
       
 
 
-def load_model_checkpoint(model, rank, cfg, verbose=True):
+def load_model_checkpoint(model, rank, cfg):
     """load local checkpoint to rank0 cpu
     must be called * before * passing to FSDP"""
 
@@ -178,8 +185,8 @@ def load_model_checkpoint(model, rank, cfg, verbose=True):
     # integrate into loaded model
     model.load_state_dict(model_checkpoint)
 
-    if cfg.verbose:
-        print(f"model checkpoint loaded to rank0 cpu")
+    
+    print(f"model checkpoint loaded to rank0 cpu")
 
 
 def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
@@ -192,15 +199,22 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
 
     optim_state = FSDP.full_optim_state_dict(model, optimizer)
 
-    if cfg.verbose:
-        print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
+    
+    print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
 
     if rank == 0:
-        save_dir = Path.cwd() / cfg.checkpoint_folder
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
 
         opt_save_name = (
-            cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
+            "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
         )
         opt_save_full_path = save_dir / opt_save_name
 
@@ -211,109 +225,28 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
         print(f"--> saved {opt_save_full_path} to disk")
 
 
-def load_optimizer_checkpoint(model, optimizer, rank, cfg):
+def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
     """load an fsdp optimizer full_state checkpoint using scatter method
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     """
 
-    opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file
 
-    if not opt_file_path.is_file():
+    if not optimizer_checkpoint_path.is_file():
         print(
-            f"warning - optimizer checkpoint not present {opt_file_path}. Returning. "
+            f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
         )
         return
 
     full_osd = None
 
     if rank == 0:
-        full_osd = torch.load(opt_file_path)
-
-        if cfg.verbose:
-            print(f"loaded full osd on rank 0")
+        full_osd = torch.load(optimizer_checkpoint_path)
 
     # called from all ranks, though only rank0 has a valid param for full_osd
     sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
 
-    if cfg.verbose:
-        print(f"optimizer shard loaded on rank {rank}")
-
+    print(f"optimizer shard loaded on rank {rank}")
 
-
-def load_distributed_model_checkpoint(model, rank, cfg):
-    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
-        print(f"loading distributed checkpoint, rank {rank}...")
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-
-        checkdir = Path.cwd() / folder_name
-
-        if not checkdir.exists():
-            if rank == 0:
-                print(f"No checkpoint directory found...skipping")
-            return
-
-
-        reader = FileSystemReader(checkdir)
-
-        with FSDP.state_dict_type(
-            model,
-            StateDictType.LOCAL_STATE_DICT,
-        ):
-            state_dict = model.state_dict()
-            load_state_dict(state_dict, reader)
-            model.load_state_dict(state_dict)
-
-        print(f"--> local state loaded on rank {rank}")
-
-        return
-
-
-def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
-    # distributed checkpoint saving
-
-    # confirm type of checkpoint and save
-    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
-        # create writer to current path
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-        save_dir = Path.cwd() / folder_name
-
-        writer = FileSystemWriter(
-            save_dir,
-        )
-
-        with FSDP.state_dict_type(
-            model,
-            StateDictType.LOCAL_STATE_DICT,
-        ):
-            state_dict = model.state_dict()
-       
-
-        # write out distributed checkpoint
-        save_state_dict(state_dict, writer)
-
-        return
-
-def load_sharded_model_single_gpu(model, model_path):
-    
-    dcp.load_state_dict(
-                    state_dict=state_dict_to_load_to,
-                    storage_reader=FsspecReader(path),
-                    no_dist=True,
-                )
-    print(f"Sharded state checkpoint loaded from {load_dir}")
-    
 def load_sharded_model_single_gpu(model,model_path):
     
     reader = FileSystemReader(model_path)

+ 12 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1078,4 +1078,15 @@ samsum
 vLLM
 TGI
 vLLM
-vLLM's
+vLLM's
+OOM
+RTX
+SKU
+TPUs
+checkpointing
+enviroment
+fragmentations
+intra
+nightlies
+recenly
+uncomment

+ 1 - 0
utils/memory_utils.py

@@ -52,6 +52,7 @@ class MemoryTrace:
         cuda_info = torch.cuda.memory_stats()
         self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
         self.used = byte2gb(self.end - self.begin)
         self.peaked = byte2gb(self.peak - self.begin)

+ 78 - 39
utils/train_utils.py

@@ -67,7 +67,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler() 
-        
+    if train_config.enable_fsdp:
+        world_size = int(os.environ["WORLD_SIZE"]) 
     train_prep = []
     train_loss = []
     val_prep = []
@@ -78,7 +79,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             total_loss = 0.0
-            data_set_len = 0
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
@@ -88,8 +88,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
-                first_key = next(iter(batch))
-                data_set_len += len(batch[first_key])
                 if train_config.use_fp16:
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()
@@ -103,22 +101,35 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.zero_grad()
-                        
-                print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+                if train_config.enable_fsdp:
+                    if rank==0:       
+                        print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+                else:
+                    print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+                    
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
-        train_epoch_loss = total_loss / data_set_len
+        train_epoch_loss = total_loss / len(train_dataloader)
+        if train_config.enable_fsdp:
+            train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
         
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
-        
-        print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-        print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-        print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-        print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-        print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        if train_config.enable_fsdp:
+            if rank==0:
+                print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+                print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+                print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+                print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        else:
+            print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+            print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+            print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+            print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+            print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
         
         # Update the learning rate as needed
         lr_scheduler.step()
@@ -126,42 +137,62 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
             if train_config.save_model and eval_epoch_loss < best_val_loss:
-                
-                if  train_config.use_peft:
-                    
-                    print(f"we are in the saving the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)   
-                    print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                    
+                if train_config.enable_fsdp:
+                    dist.barrier()
+                if train_config.use_peft:
+                    if train_config.enable_fsdp:
+                        if rank==0:
+                            print(f"we are about to save the PEFT modules")
+                    else:
+                        print(f"we are about to save the PEFT modules")
+                    model.save_pretrained(train_config.output_dir)  
+                    if train_config.enable_fsdp:
+                        if rank==0: 
+                            print(f"PEFT modules are saved in {train_config.output_dir} directory")
+                    else:
+                        print(f"PEFT modules are saved in {train_config.output_dir} directory")
+                        
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                         
                         model_checkpointing.save_model_checkpoint(
-                            model, optimizer, rank, train_config, epoch=1
+                            model, optimizer, rank, train_config, epoch=epoch
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
-                        print(" we are about to save the models *******")
+                        print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                        print("=====================================================")
                         
                         model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                             model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                            print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
+                            print("=====================================================")
 
                     if not train_config.use_peft and  train_config.save_optimizer:
                         model_checkpointing.save_optimizer_checkpoint(
-                            model, optimizer, rank, train_config, epoch=1
-                        )   
-                                
+                            model, optimizer, rank, train_config, epoch=epoch
+                        )
+                        print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
+                        print("=====================================================")                     
+                if train_config.enable_fsdp:
+                    dist.barrier()
             
-            if local_rank == 0 and eval_epoch_loss < best_val_loss:
+            if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
-                print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+                if train_config.enable_fsdp:
+                    if rank==0:
+                        print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+                else:
+                    print(f"best eval loss on epoch {epoch} is {best_val_loss}")
             val_loss.append(best_val_loss)
             val_prep.append(eval_ppl)
         
-        
-        print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
-        lr_scheduler.step()
-
+        if train_config.enable_fsdp:
+            if rank==0:
+                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
+        else:
+            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
+            
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
@@ -175,7 +206,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
         
     #saving the training params including fsdp setting for reference.
-    if train_config.enable_fsdp and fsdp_config:
+    if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)
         
     return results
@@ -192,10 +223,11 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     
     Returns: eval_ppl, eval_epoch_loss
     """
+    if train_config.enable_fsdp:
+        world_size = int(os.environ["WORLD_SIZE"]) 
     model.eval()
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
-    eval_dataset_len = 0
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
             for key in batch.keys():
@@ -209,9 +241,6 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 outputs = model(**batch)
                 loss = outputs.loss
                 eval_loss += loss.detach().float()
-                first_key = next(iter(batch))
-                eval_dataset_len+= len(batch[first_key])
-                
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)
             eval_preds.extend(
@@ -223,11 +252,18 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     
     # Compute average loss and perplexity
-    eval_epoch_loss = eval_loss / eval_dataset_len
+    eval_epoch_loss = eval_loss / len(eval_dataloader)
+    if train_config.enable_fsdp:
+        eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
     
     # Print evaluation metrics
-    print(f" {eval_ppl=} {eval_epoch_loss=}")
+    if train_config.enable_fsdp:
+        if local_rank==0:
+            print(f" {eval_ppl=} {eval_epoch_loss=}")
+    else:
+        print(f" {eval_ppl=} {eval_epoch_loss=}")
+        
     return eval_ppl, eval_epoch_loss
 
 def freeze_transformer_layers(model, num_layer):
@@ -252,7 +288,10 @@ def setup_environ_flags(rank):
     """Set environment flags for debugging purposes"""
     os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
     os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
-    os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
+    # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")