|
@@ -13,6 +13,7 @@ import torch.cuda.nccl as nccl
|
|
|
import torch.distributed as dist
|
|
|
from torch.distributed.fsdp import StateDictType
|
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
|
+from torch.utils.flop_counter import FlopCounterMode
|
|
|
from tqdm import tqdm
|
|
|
from transformers import LlamaTokenizer
|
|
|
|
|
@@ -21,6 +22,8 @@ from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_
|
|
|
from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
from llama_recipes.utils.memory_utils import MemoryTrace
|
|
|
|
|
|
+def get_total_flops(model):
|
|
|
+ return (sum([v for _, v in model.flop_counts["Global"].items()]))
|
|
|
|
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
tokenizer.pad_token_id = 0
|
|
@@ -75,26 +78,50 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.enable_fsdp:
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
|
else:
|
|
|
- batch[key] = batch[key].to('cuda:0')
|
|
|
- loss = model(**batch).loss
|
|
|
- loss = loss / gradient_accumulation_steps
|
|
|
- total_loss += loss.detach().float()
|
|
|
- if train_config.use_fp16:
|
|
|
- # if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
- scaler.scale(loss).backward()
|
|
|
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
- scaler.step(optimizer)
|
|
|
- scaler.update()
|
|
|
- optimizer.zero_grad()
|
|
|
- pbar.update(1)
|
|
|
+ batch[key] = batch[key].to('cuda:0')
|
|
|
+ flop_check_done = False
|
|
|
+ if train_config.flop_counter and step == 3 and not flop_check_done:
|
|
|
+ flop_counter = FlopCounterMode(rank=local_rank)
|
|
|
+ with flop_counter:
|
|
|
+ loss = model(**batch).loss
|
|
|
+ loss = loss / gradient_accumulation_steps
|
|
|
+ total_loss += loss.detach().float()
|
|
|
+ if train_config.use_fp16:
|
|
|
+ # if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
+ scaler.scale(loss).backward()
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ pbar.update(1)
|
|
|
+ else:
|
|
|
+ # regular backpropagation when fp16 is not used
|
|
|
+ loss.backward()
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ optimizer.step()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ pbar.update(1)
|
|
|
+ TFlops = get_total_flops(flop_counter) / 1e12
|
|
|
+ flop_check_done = True
|
|
|
else:
|
|
|
- # regular backpropagation when fp16 is not used
|
|
|
- loss.backward()
|
|
|
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
- optimizer.step()
|
|
|
- optimizer.zero_grad()
|
|
|
- pbar.update(1)
|
|
|
-
|
|
|
+ loss = model(**batch).loss
|
|
|
+ loss = loss / gradient_accumulation_steps
|
|
|
+ total_loss += loss.detach().float()
|
|
|
+ if train_config.use_fp16:
|
|
|
+ # if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
+ scaler.scale(loss).backward()
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ pbar.update(1)
|
|
|
+ else:
|
|
|
+ # regular backpropagation when fp16 is not used
|
|
|
+ loss.backward()
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ optimizer.step()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ pbar.update(1)
|
|
|
pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
pbar.close()
|
|
|
|