|
@@ -6,6 +6,7 @@ import time
|
|
|
import yaml
|
|
|
from pathlib import Path
|
|
|
from pkg_resources import packaging
|
|
|
+import contextlib
|
|
|
|
|
|
|
|
|
import torch
|
|
@@ -13,7 +14,7 @@ import torch.cuda.nccl as nccl
|
|
|
import torch.distributed as dist
|
|
|
from torch.distributed.fsdp import StateDictType
|
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
|
-from torch.utils.flop_counter import FlopCounterMode
|
|
|
+# from torch.utils.flop_counter import FlopCounterMode
|
|
|
from tqdm import tqdm
|
|
|
from transformers import LlamaTokenizer
|
|
|
|
|
@@ -21,7 +22,31 @@ from transformers import LlamaTokenizer
|
|
|
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
|
|
|
from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
from llama_recipes.utils.memory_utils import MemoryTrace
|
|
|
-
|
|
|
+from llama_recipes.utils.tflop_counter import FlopCounterMode
|
|
|
+
|
|
|
+@contextlib.contextmanager
|
|
|
+def maybe_run_profiler(cfg, *args, **kwargs):
|
|
|
+ use_profiler: bool = cfg.profiler
|
|
|
+
|
|
|
+ if use_profiler:
|
|
|
+ with torch.profiler.profile(
|
|
|
+ activities=[
|
|
|
+ torch.profiler.ProfilerActivity.CPU,
|
|
|
+ torch.profiler.ProfilerActivity.CUDA,
|
|
|
+ ],
|
|
|
+ schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
|
|
|
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
|
|
+ cfg.profile_output_dir
|
|
|
+ ),
|
|
|
+ profile_memory=True,
|
|
|
+ with_stack=False,
|
|
|
+ record_shapes=True,
|
|
|
+ ) as torch_profiler:
|
|
|
+ yield torch_profiler
|
|
|
+ else:
|
|
|
+ torch_profiler = contextlib.nullcontext()
|
|
|
+ yield None
|
|
|
+
|
|
|
def get_total_flops(model):
|
|
|
return (sum([v for _, v in model.flop_counts["Global"].items()]))
|
|
|
|
|
@@ -73,16 +98,39 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
total_loss = 0.0
|
|
|
total_length = len(train_dataloader)//gradient_accumulation_steps
|
|
|
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
|
|
|
- for step, batch in enumerate(train_dataloader):
|
|
|
- for key in batch.keys():
|
|
|
- if train_config.enable_fsdp:
|
|
|
- batch[key] = batch[key].to(local_rank)
|
|
|
+ with maybe_run_profiler(train_config) as torch_profiler:
|
|
|
+ for step, batch in enumerate(train_dataloader):
|
|
|
+ for key in batch.keys():
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ batch[key] = batch[key].to(local_rank)
|
|
|
+ else:
|
|
|
+ batch[key] = batch[key].to('cuda:0')
|
|
|
+ flop_check_done = False
|
|
|
+ if train_config.flop_counter and step == 3 and not flop_check_done:
|
|
|
+ flop_counter = FlopCounterMode(rank=local_rank)
|
|
|
+ with flop_counter:
|
|
|
+ loss = model(**batch).loss
|
|
|
+ loss = loss / gradient_accumulation_steps
|
|
|
+ total_loss += loss.detach().float()
|
|
|
+ if train_config.use_fp16:
|
|
|
+ # if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
+ scaler.scale(loss).backward()
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ pbar.update(1)
|
|
|
+ else:
|
|
|
+ # regular backpropagation when fp16 is not used
|
|
|
+ loss.backward()
|
|
|
+ TFlops = get_total_flops(flop_counter) / 1e12
|
|
|
+ flop_check_done = True
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ optimizer.step()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ pbar.update(1)
|
|
|
+
|
|
|
else:
|
|
|
- batch[key] = batch[key].to('cuda:0')
|
|
|
- flop_check_done = False
|
|
|
- if train_config.flop_counter and step == 3 and not flop_check_done:
|
|
|
- flop_counter = FlopCounterMode(rank=local_rank)
|
|
|
- with flop_counter:
|
|
|
loss = model(**batch).loss
|
|
|
loss = loss / gradient_accumulation_steps
|
|
|
total_loss += loss.detach().float()
|
|
@@ -101,29 +149,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(1)
|
|
|
- TFlops = get_total_flops(flop_counter) / 1e12
|
|
|
- flop_check_done = True
|
|
|
- else:
|
|
|
- loss = model(**batch).loss
|
|
|
- loss = loss / gradient_accumulation_steps
|
|
|
- total_loss += loss.detach().float()
|
|
|
- if train_config.use_fp16:
|
|
|
- # if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
- scaler.scale(loss).backward()
|
|
|
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
- scaler.step(optimizer)
|
|
|
- scaler.update()
|
|
|
- optimizer.zero_grad()
|
|
|
- pbar.update(1)
|
|
|
- else:
|
|
|
- # regular backpropagation when fp16 is not used
|
|
|
- loss.backward()
|
|
|
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
- optimizer.step()
|
|
|
- optimizer.zero_grad()
|
|
|
- pbar.update(1)
|
|
|
- pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
- pbar.close()
|
|
|
+ pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
+ pbar.close()
|
|
|
|
|
|
epoch_end_time = time.perf_counter()-epoch_start_time
|
|
|
epoch_times.append(epoch_end_time)
|
|
@@ -229,6 +256,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
results['avg_eval_loss'] = avg_eval_loss
|
|
|
results["avg_epoch_time"] = avg_epoch_time
|
|
|
results["avg_checkpoint_time"] = avg_checkpoint_time
|
|
|
+ if train_config.flop_counter:
|
|
|
+ results["model_flops"]= TFlops
|
|
|
+
|
|
|
|
|
|
#saving the training params including fsdp setting for reference.
|
|
|
if train_config.enable_fsdp and not train_config.use_peft:
|