Hamid Shojanazeri 1 year ago
parent
commit
19089269d3
2 changed files with 11 additions and 3 deletions
  1. 4 2
      src/llama_recipes/finetuning.py
  2. 7 1
      src/llama_recipes/utils/train_utils.py

+ 4 - 2
src/llama_recipes/finetuning.py

@@ -3,8 +3,9 @@
 
 import os
 from pkg_resources import packaging
-
+import gc
 import fire
+
 import torch
 import torch.distributed as dist
 import torch.optim as optim
@@ -44,8 +45,9 @@ from llama_recipes.utils.train_utils import (
     get_policies
 )
 
-
+import gc
 def main(**kwargs):
+    gc.disable()
     # Update the configuration for the training and sharding process
     update_config((train_config, fsdp_config), **kwargs)
 

+ 7 - 1
src/llama_recipes/utils/train_utils.py

@@ -7,7 +7,7 @@ import yaml
 from pathlib import Path
 from pkg_resources import packaging
 import contextlib
-
+import gc
 
 import torch
 import torch.cuda.nccl as nccl
@@ -100,6 +100,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
             with maybe_run_profiler(train_config) as torch_profiler:
                 for step, batch in enumerate(train_dataloader):
+                    if step > 5:
+                        break
+                    gc.collect(1)
                     for key in batch.keys():
                         if train_config.enable_fsdp:
                             batch[key] = batch[key].to(local_rank)
@@ -285,6 +288,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     eval_loss = 0.0  # Initialize evaluation loss
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
+            if step > 5:
+                break
+            gc.collect(1)
             for key in batch.keys():
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)