Browse Source

Save memory and fix typos

Andrew Gu 1 year atrás
parent
commit
71fdc4920a

+ 1 - 1
llama_finetuning.py

@@ -134,7 +134,7 @@ def main(**kwargs):
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),
-            limit_all_gathers=False,
+            limit_all_gathers=True,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)

+ 1 - 1
model_checkpointing/checkpoint_handler.py

@@ -212,7 +212,7 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
 
 
 def load_optimizer_checkpoint(model, optimizer, rank, cfg):
-    """load an fdsp optimizer full_state checkpoint using scatter method
+    """load an fsdp optimizer full_state checkpoint using scatter method
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     """
 

+ 1 - 1
policies/activation_checkpointing_functions.py

@@ -26,7 +26,7 @@ def apply_fsdp_checkpointing(model):
     """apply activation checkpointing to model
     returns None as model is updated directly
     """
-    print(f"--> applying fdsp activation checkpointing...")
+    print(f"--> applying fsdp activation checkpointing...")
 
     apply_activation_checkpointing(
         model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn

+ 2 - 3
utils/train_utils.py

@@ -85,8 +85,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         batch[key] = batch[key].to(local_rank)
                     else:
                         batch[key] = batch[key].to('cuda')       
-                outputs = model(**batch)
-                loss = outputs.loss
+                loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
                 first_key = next(iter(batch))
@@ -105,7 +104,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         optimizer.step()
                         optimizer.zero_grad()
                         
-                print(f"\n step {step} is completed and loss is {loss.detach().float()}")        
+                print(f"\n step {step} is completed and loss is {loss.detach().float()}")
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)