Bladeren bron

adding profiler and flop_counter

Hamid Shojanazeri 1 jaar geleden
bovenliggende
commit
35b394e49f
3 gewijzigde bestanden met toevoegingen van 69 en 36 verwijderingen
  1. 3 1
      src/llama_recipes/configs/training.py
  2. 2 1
      src/llama_recipes/utils/__init__.py
  3. 64 34
      src/llama_recipes/utils/train_utils.py

+ 3 - 1
src/llama_recipes/configs/training.py

@@ -34,4 +34,6 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-    flop_counter: bool=True
+    flop_counter: bool=True #enable flop counter
+    profiler: bool=True #enable pytorch profiler
+    profile_output_dir: str="profile_output"

+ 2 - 1
src/llama_recipes/utils/__init__.py

@@ -4,4 +4,5 @@
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.dataset_utils import *
 from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy
-from llama_recipes.utils.train_utils import *
+from llama_recipes.utils.train_utils import *
+from llama_recipes.utils.tflop_counter import *

+ 64 - 34
src/llama_recipes/utils/train_utils.py

@@ -6,6 +6,7 @@ import time
 import yaml
 from pathlib import Path
 from pkg_resources import packaging
+import contextlib
 
 
 import torch
@@ -13,7 +14,7 @@ import torch.cuda.nccl as nccl
 import torch.distributed as dist
 from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
-from torch.utils.flop_counter import FlopCounterMode
+# from torch.utils.flop_counter import FlopCounterMode
 from tqdm import tqdm
 from transformers import LlamaTokenizer
 
@@ -21,7 +22,31 @@ from transformers import LlamaTokenizer
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
 from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
-
+from llama_recipes.utils.tflop_counter import FlopCounterMode
+
+@contextlib.contextmanager
+def maybe_run_profiler(cfg, *args, **kwargs):
+    use_profiler: bool = cfg.profiler
+
+    if use_profiler:
+        with torch.profiler.profile(
+            activities=[
+                torch.profiler.ProfilerActivity.CPU,
+                torch.profiler.ProfilerActivity.CUDA,
+            ],
+            schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
+            on_trace_ready=torch.profiler.tensorboard_trace_handler(
+                cfg.profile_output_dir
+            ),
+            profile_memory=True,
+            with_stack=False,
+            record_shapes=True,
+        ) as torch_profiler:
+            yield torch_profiler
+    else:
+        torch_profiler = contextlib.nullcontext()
+        yield None
+            
 def get_total_flops(model):
     return (sum([v for _, v in model.flop_counts["Global"].items()]))
 
@@ -73,16 +98,39 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             total_loss = 0.0
             total_length = len(train_dataloader)//gradient_accumulation_steps
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
-            for step, batch in enumerate(train_dataloader):
-                for key in batch.keys():
-                    if train_config.enable_fsdp:
-                        batch[key] = batch[key].to(local_rank)
+            with maybe_run_profiler(train_config) as torch_profiler:
+                for step, batch in enumerate(train_dataloader):
+                    for key in batch.keys():
+                        if train_config.enable_fsdp:
+                            batch[key] = batch[key].to(local_rank)
+                        else:
+                            batch[key] = batch[key].to('cuda:0') 
+                    flop_check_done = False 
+                    if train_config.flop_counter and  step == 3 and not flop_check_done:
+                        flop_counter = FlopCounterMode(rank=local_rank)
+                        with flop_counter:           
+                            loss = model(**batch).loss
+                            loss = loss / gradient_accumulation_steps
+                            total_loss += loss.detach().float()
+                            if train_config.use_fp16:
+                                # if fp16 is enabled, use gradient scaler to handle gradient update
+                                scaler.scale(loss).backward()
+                                if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                                    scaler.step(optimizer)
+                                    scaler.update()
+                                    optimizer.zero_grad()
+                                    pbar.update(1)
+                            else:
+                                # regular backpropagation when fp16 is not used
+                                loss.backward()
+                                TFlops = get_total_flops(flop_counter) / 1e12
+                                flop_check_done = True
+                                if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                                    optimizer.step()
+                                    optimizer.zero_grad()
+                                    pbar.update(1)
+                        
                     else:
-                        batch[key] = batch[key].to('cuda:0') 
-                flop_check_done = False 
-                if train_config.flop_counter and  step == 3 and not flop_check_done:
-                    flop_counter = FlopCounterMode(rank=local_rank)
-                    with flop_counter:           
                         loss = model(**batch).loss
                         loss = loss / gradient_accumulation_steps
                         total_loss += loss.detach().float()
@@ -101,29 +149,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                                 optimizer.step()
                                 optimizer.zero_grad()
                                 pbar.update(1)
-                    TFlops = get_total_flops(flop_counter) / 1e12
-                    flop_check_done = True
-                else:
-                    loss = model(**batch).loss
-                    loss = loss / gradient_accumulation_steps
-                    total_loss += loss.detach().float()
-                    if train_config.use_fp16:
-                        # if fp16 is enabled, use gradient scaler to handle gradient update
-                        scaler.scale(loss).backward()
-                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                            scaler.step(optimizer)
-                            scaler.update()
-                            optimizer.zero_grad()
-                            pbar.update(1)
-                    else:
-                        # regular backpropagation when fp16 is not used
-                        loss.backward()
-                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                            optimizer.step()
-                            optimizer.zero_grad()
-                            pbar.update(1)
-                pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
-            pbar.close()
+                    pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+                pbar.close()
                 
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)    
@@ -229,6 +256,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
+    if train_config.flop_counter:
+        results["model_flops"]= TFlops
+        
     
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft: