瀏覽代碼

adding flop counter

Hamid Shojanazeri 1 年之前
父節點
當前提交
d56d5c469d
共有 2 個文件被更改,包括 47 次插入23 次删除
  1. 1 4
      src/llama_recipes/configs/training.py
  2. 46 19
      src/llama_recipes/utils/train_utils.py

+ 1 - 4
src/llama_recipes/configs/training.py

@@ -34,7 +34,4 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-
-    
-    
-    
+    flop_counter: bool=True

+ 46 - 19
src/llama_recipes/utils/train_utils.py

@@ -13,6 +13,7 @@ import torch.cuda.nccl as nccl
 import torch.distributed as dist
 from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.utils.flop_counter import FlopCounterMode
 from tqdm import tqdm
 from transformers import LlamaTokenizer
 
@@ -21,6 +22,8 @@ from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_
 from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 
+def get_total_flops(model):
+    return (sum([v for _, v in model.flop_counts["Global"].items()]))
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
@@ -75,26 +78,50 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda:0')              
-                loss = model(**batch).loss
-                loss = loss / gradient_accumulation_steps
-                total_loss += loss.detach().float()
-                if train_config.use_fp16:
-                    # if fp16 is enabled, use gradient scaler to handle gradient update
-                    scaler.scale(loss).backward()
-                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                        scaler.step(optimizer)
-                        scaler.update()
-                        optimizer.zero_grad()
-                        pbar.update(1)
+                        batch[key] = batch[key].to('cuda:0') 
+                flop_check_done = False 
+                if train_config.flop_counter and  step == 3 and not flop_check_done:
+                    flop_counter = FlopCounterMode(rank=local_rank)
+                    with flop_counter:           
+                        loss = model(**batch).loss
+                        loss = loss / gradient_accumulation_steps
+                        total_loss += loss.detach().float()
+                        if train_config.use_fp16:
+                            # if fp16 is enabled, use gradient scaler to handle gradient update
+                            scaler.scale(loss).backward()
+                            if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                                scaler.step(optimizer)
+                                scaler.update()
+                                optimizer.zero_grad()
+                                pbar.update(1)
+                        else:
+                            # regular backpropagation when fp16 is not used
+                            loss.backward()
+                            if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                                optimizer.step()
+                                optimizer.zero_grad()
+                                pbar.update(1)
+                    TFlops = get_total_flops(flop_counter) / 1e12
+                    flop_check_done = True
                 else:
-                    # regular backpropagation when fp16 is not used
-                    loss.backward()
-                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                        optimizer.step()
-                        optimizer.zero_grad()
-                        pbar.update(1)
-
+                    loss = model(**batch).loss
+                    loss = loss / gradient_accumulation_steps
+                    total_loss += loss.detach().float()
+                    if train_config.use_fp16:
+                        # if fp16 is enabled, use gradient scaler to handle gradient update
+                        scaler.scale(loss).backward()
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                            scaler.step(optimizer)
+                            scaler.update()
+                            optimizer.zero_grad()
+                            pbar.update(1)
+                    else:
+                        # regular backpropagation when fp16 is not used
+                        loss.backward()
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                            optimizer.step()
+                            optimizer.zero_grad()
+                            pbar.update(1)
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()