Browse Source

Revert "Flop counter, profiling and GC (#357)"

This reverts commit 4530d543f8f543d12ee23229e17921dd52fd21e2, reversing
changes made to 98b122e57a8df44f5b88fa9fdab8818cc6e4969f.
Hamid Shojanazeri 1 year ago
parent
commit
162be4c045

+ 0 - 3
src/llama_recipes/configs/training.py

@@ -38,7 +38,4 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-    flop_counter: bool=True #enable flop counter
-    profiler: bool=True #enable pytorch profiler
-    profile_output_dir: str="profile_output"
     save_metrics: bool = False # saves training metrics to a json file for later plotting
     save_metrics: bool = False # saves training metrics to a json file for later plotting

+ 6 - 5
src/llama_recipes/finetuning.py

@@ -3,10 +3,9 @@
 
 
 import os
 import os
 from pkg_resources import packaging
 from pkg_resources import packaging
-import gc
+
 import fire
 import fire
 import random
 import random
-
 import torch
 import torch
 import torch.optim as optim
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from peft import get_peft_model, prepare_model_for_int8_training
@@ -45,12 +44,9 @@ from llama_recipes.utils.train_utils import (
     print_model_size,
     print_model_size,
     get_policies
     get_policies
 )
 )
-
 from accelerate.utils import is_xpu_available
 from accelerate.utils import is_xpu_available
 
 
 def main(**kwargs):
 def main(**kwargs):
-    gc.disable()
-    gc.collect(1)
     # Update the configuration for the training and sharding process
     # Update the configuration for the training and sharding process
     train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
     update_config((train_config, fsdp_config), **kwargs)
@@ -87,6 +83,11 @@ def main(**kwargs):
         model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
         model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
         overhead and currently requires latest nightly.
         overhead and currently requires latest nightly.
         """
         """
+        v = packaging.version.parse(torch.__version__)
+        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
+        if not verify_latest_nightly:
+            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
+                            "please install latest nightly.")
         if rank == 0:
         if rank == 0:
             model = LlamaForCausalLM.from_pretrained(
             model = LlamaForCausalLM.from_pretrained(
                 train_config.model_name,
                 train_config.model_name,

+ 1 - 2
src/llama_recipes/utils/__init__.py

@@ -4,5 +4,4 @@
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.dataset_utils import *
 from llama_recipes.utils.dataset_utils import *
 from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy
 from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy
-from llama_recipes.utils.train_utils import *
-from llama_recipes.utils.tflop_counter import *
+from llama_recipes.utils.train_utils import *

+ 0 - 464
src/llama_recipes/utils/tflop_counter.py

@@ -1,464 +0,0 @@
-# Temp copy of Horace Flops Counter.
-# This supports distributed to avoid printing * every GPU.
-# Remove after main file is updated.
-
-import torch
-from torch.utils._pytree import tree_map
-from typing import List, Any, Dict, Optional, Union
-from collections import defaultdict
-from torch.utils._python_dispatch import TorchDispatchMode
-from math import prod
-
-__all__ = ["FlopCounterMode"]
-
-aten = torch.ops.aten
-
-
-def get_shape(i):
-    if isinstance(i, torch.Tensor):
-        return i.shape
-    return i
-
-
-def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
-    """
-    Count flops for matmul.
-    """
-    # Inputs should be a list of length 2.
-    # Inputs contains the shapes of two matrices.
-    m, k = a_shape
-    k2, n = b_shape
-    assert k == k2
-    # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
-    return m * n * 2 * k
-
-
-def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
-    """
-    Count flops for addmm
-    """
-    return mm_flop(a_shape, b_shape)
-
-
-def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
-    """
-    Count flops for the bmm operation.
-    """
-    # Inputs should be a list of length 2.
-    # Inputs contains the shapes of two tensor.
-    b, m, k = a_shape
-    b2, k2, n = b_shape
-    assert b == b2
-    assert k == k2
-    # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
-    flop = b * m * n * 2 * k
-    return flop
-
-
-def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
-    """
-    Count flops for the baddbmm operation.
-    """
-    # Inputs should be a list of length 3.
-    # Inputs contains the shapes of three tensors.
-    return bmm_flop(a_shape, b_shape)
-
-
-def conv_flop_count(
-    x_shape: List[int],
-    w_shape: List[int],
-    out_shape: List[int],
-    transposed: bool = False,
-) -> int:
-    """
-    Count flops for convolution. Note only multiplication is
-    counted. Computation for bias are ignored.
-    Flops for a transposed convolution are calculated as
-    flops = (x_shape[2:] * prod(w_shape) * batch_size).
-    Args:
-        x_shape (list(int)): The input shape before convolution.
-        w_shape (list(int)): The filter shape.
-        out_shape (list(int)): The output shape after convolution.
-        transposed (bool): is the convolution transposed
-    Returns:
-        int: the number of flops
-    """
-    batch_size = x_shape[0]
-    conv_shape = (x_shape if transposed else out_shape)[2:]
-    c_out, c_in, *dims = w_shape
-
-    # NB(chilli): I don't think this properly accounts for padding :think:
-    # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
-    flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in
-    return flop
-
-
-def conv_flop(
-    x_shape,
-    w_shape,
-    _bias,
-    _stride,
-    _padding,
-    _dilation,
-    transposed,
-    *args,
-    out_shape=None,
-    **kwargs
-) -> int:
-    """
-    Count flops for convolution.
-    """
-    return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
-
-
-def transpose_shape(shape):
-    return [shape[1], shape[0]] + list(shape[2:])
-
-
-def conv_backward_flop(
-    grad_out_shape,
-    x_shape,
-    w_shape,
-    _bias,
-    _stride,
-    _padding,
-    _dilation,
-    transposed,
-    _output_padding,
-    _groups,
-    output_mask,
-    out_shape,
-) -> int:
-    flop_count = 0
-
-    if output_mask[0]:
-        grad_input_shape = get_shape(out_shape[0])
-        flop_count += conv_flop_count(
-            grad_out_shape, w_shape, grad_input_shape, not transposed
-        )
-    if output_mask[1]:
-        grad_weight_shape = get_shape(out_shape[1])
-        flop_count += conv_flop_count(
-            transpose_shape(x_shape), grad_out_shape, grad_weight_shape, transposed
-        )
-
-    return flop_count
-
-
-def sdpa_flop_count(query_shape, key_shape, value_shape):
-    """
-    Count flops for self-attention.
-    NB: We can assume that value_shape == key_shape
-    """
-    b, h, s_q, d_q = query_shape
-    _b2, _h2, s_k, _d2 = key_shape
-    _b3, _h3, _s3, d_v = value_shape
-    assert (
-        b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
-    )
-    total_flops = 0
-    # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
-    total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
-    # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
-    total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
-    return total_flops
-
-
-def sdpa_flop(
-    query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs
-) -> int:
-    """
-    Count flops for self-attention.
-    """
-    # NB: We aren't accounting for causal attention here
-    return sdpa_flop_count(query_shape, key_shape, value_shape)
-
-
-def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
-    total_flops = 0
-    b, h, s_q, d_q = query_shape
-    _b2, _h2, s_k, _d2 = key_shape
-    _b3, _h3, _s3, d_v = value_shape
-    _b4, _h4, _s4, _d4 = grad_out_shape
-    assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
-    assert d_v == _d4 and s_k == _s3 and s_q == _s4
-    total_flops = 0
-    # Step 1: We recompute the scores matrix.
-    # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
-    total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
-
-    # Step 2: We propagate the gradients through the score @ v operation.
-    # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
-    total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
-    # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
-    total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
-
-    # Step 3: We propagate th gradients through the k @ v operation
-    # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
-    total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
-    # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
-    total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
-    return total_flops
-
-
-def sdpa_backward_flop(
-    grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs
-) -> int:
-    """
-    Count flops for self-attention backward.
-    """
-    return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
-
-
-flop_mapping = {
-    aten.mm: mm_flop,
-    aten.addmm: addmm_flop,
-    aten.bmm: bmm_flop,
-    aten.baddbmm: baddbmm_flop,
-    aten.convolution: conv_flop,
-    aten._convolution: conv_flop,
-    aten.convolution_backward: conv_backward_flop,
-    aten._scaled_dot_product_efficient_attention: sdpa_flop,
-    aten._scaled_dot_product_flash_attention: sdpa_flop,
-    aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
-    aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
-}
-
-
-def normalize_tuple(x):
-    if not isinstance(x, tuple):
-        return (x,)
-    return x
-
-
-# Define the suffixes for different orders of magnitude
-suffixes = ["", "K", "M", "B", "T"]
-
-
-# Thanks BingChat!
-def get_suffix_str(number):
-    # Find the index of the appropriate suffix based on the number of digits
-    # with some additional overflow.
-    # i.e. 1.01B should be displayed as 1001M, not 1.001B
-    index = max(0, min(len(suffixes) - 1, (len(str(number)) - 3) // 3))
-    return suffixes[index]
-
-
-def convert_num_with_suffix(number, suffix):
-    index = suffixes.index(suffix)
-    # Divide the number by 1000^index and format it to two decimal places
-    value = "{:.3f}".format(number / (1000**index))
-    # Return the value and the suffix as a string
-    return value + suffixes[index]
-
-
-class FlopCounterMode(TorchDispatchMode):
-    """
-    ``FlopCounterMode`` is a context manager that counts the number of
-    flops within its context. It does this using a ``TorchDispatchMode``.
-
-    It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
-
-    Example usage
-
-    .. code-block:: python
-
-        mod = ...
-        flop_counter = FlopCounterMode(mod)
-        with flop_counter:
-            mod.sum().backward()
-
-    """
-
-    def __init__(
-        self,
-        mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
-        depth: int = 2,
-        display: bool = True,
-        custom_mapping: Dict[Any, Any] = None,
-        rank=None,
-    ):
-        self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(
-            lambda: defaultdict(int)
-        )
-        self.depth = depth
-        self.parents = ["Global"]
-        self.display = display
-        self.rank = rank
-
-        if custom_mapping is None:
-            custom_mapping = {}
-        if isinstance(mods, torch.nn.Module):
-            mods = [mods]
-        self.mods = mods
-        if mods is not None:
-            for mod in mods:
-                prefix = type(mod).__name__
-                for name, module in dict(mod.named_modules()).items():
-                    if name == "":
-                        name = prefix
-                    else:
-                        name = ".".join([prefix, name])
-                    module.register_forward_pre_hook(self._enter_module(name))
-                    module.register_forward_hook(self._exit_module(name))
-        self.flop_mapping = {**flop_mapping, **custom_mapping}
-
-    def _enter_module(self, name):
-        def f(module, inputs):
-            inputs = normalize_tuple(inputs)
-            out = self._create_pre_module(name)(*inputs)
-            return out
-
-        return f
-
-    def _exit_module(self, name):
-        def f(module, inputs, outputs):
-            outputs = normalize_tuple(outputs)
-            return self._create_post_module(name)(*outputs)
-
-        return f
-
-    def _create_post_module(self, name):
-        class PushState(torch.autograd.Function):
-            @staticmethod
-            def forward(ctx, *args):
-                assert self.parents[-1] == name
-                self.parents.pop()
-                args = tree_map(
-                    lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
-                )
-                if len(args) == 1:
-                    return args[0]
-                return args
-
-            @staticmethod
-            def backward(ctx, *grad_outs):
-                self.parents.append(name)
-                return grad_outs
-
-        return PushState.apply
-
-    def _create_pre_module(self, name):
-        class PopState(torch.autograd.Function):
-            @staticmethod
-            def forward(ctx, *args):
-                self.parents.append(name)
-                args = tree_map(
-                    lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
-                )
-                if len(args) == 1:
-                    return args[0]
-                return args
-
-            @staticmethod
-            def backward(ctx, *grad_outs):
-                assert self.parents[-1] == name
-                self.parents.pop()
-                return grad_outs
-
-        return PopState.apply
-
-    def get_total_flops(self) -> int:
-        return sum(self.flop_counts["Global"].values())
-
-    def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
-        """Returns the flop counts as a dictionary of dictionaries. The outer
-        dictionary is keyed by module name, and the inner dictionary is keyed by
-        operation name.
-
-        Returns:
-            Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
-        """
-        return dict(self.flop_counts)
-
-    def get_table(self, depth=None):
-        if depth is None:
-            depth = self.depth
-        if depth is None:
-            depth = 999999
-
-        import tabulate
-
-        tabulate.PRESERVE_WHITESPACE = True
-        header = ["Module", "FLOP", "% Total"]
-        values = []
-        global_flops = self.get_total_flops()
-        global_suffix = get_suffix_str(global_flops)
-        is_global_subsumed = False
-
-        def process_mod(mod_name, depth):
-            nonlocal is_global_subsumed
-
-            total_flops = sum(self.flop_counts[mod_name].values())
-
-            is_global_subsumed |= total_flops >= global_flops
-
-            padding = " " * depth
-            values = []
-            values.append(
-                [
-                    padding + mod_name,
-                    convert_num_with_suffix(total_flops, global_suffix),
-                    "{:.2f}%".format(total_flops / global_flops * 100),
-                ]
-            )
-            for k, v in self.flop_counts[mod_name].items():
-                values.append(
-                    [
-                        padding + " - " + str(k),
-                        convert_num_with_suffix(v, global_suffix),
-                        "{:.2f}%".format(v / global_flops * 100),
-                    ]
-                )
-            return values
-
-        for mod in self.flop_counts.keys():
-            if mod == "Global":
-                continue
-            mod_depth = mod.count(".") + 1
-            if mod_depth > depth:
-                continue
-
-            cur_values = process_mod(mod, mod_depth - 1)
-            for value in cur_values:
-                values.append(value)
-
-        # We do a bit of messing around here to only output the "Global" value
-        # if there are any FLOPs in there that aren't already fully contained by
-        # a module.
-        if "Global" in self.flop_counts and not is_global_subsumed:
-            for idx, value in enumerate(values):
-                values[idx][0] = " " + values[idx][0]
-
-            values = process_mod("Global", 0) + values
-
-        if len(values) == 0:
-            values = [["Global", "0", "0%"]]
-
-        return tabulate.tabulate(
-            values, headers=header, colalign=("left", "right", "right")
-        )
-
-    def __enter__(self):
-        self.flop_counts.clear()
-        super().__enter__()
-        return self
-
-    def __exit__(self, *args):
-        if self.display:
-            if self.rank is None or self.rank == 0:
-                print(self.get_table(self.depth))
-        super().__exit__(*args)
-
-    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
-        kwargs = kwargs if kwargs else {}
-        out = func(*args, **kwargs)
-        func_packet = func._overloadpacket
-        if func_packet in self.flop_mapping:
-            flop_count_func = self.flop_mapping[func_packet]
-            args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
-            flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape)  # type: ignore[operator]
-            for par in self.parents:
-                self.flop_counts[par][func_packet] += flop_count
-
-        return out

+ 2 - 96
src/llama_recipes/utils/train_utils.py

@@ -7,10 +7,9 @@ import yaml
 from contextlib import nullcontext
 from contextlib import nullcontext
 from pathlib import Path
 from pathlib import Path
 from pkg_resources import packaging
 from pkg_resources import packaging
-import contextlib
-import gc
 from datetime import datetime
 from datetime import datetime
 
 
+
 import torch
 import torch
 import torch.cuda.nccl as nccl
 import torch.cuda.nccl as nccl
 import torch.distributed as dist
 import torch.distributed as dist
@@ -24,39 +23,8 @@ import json
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.memory_utils import MemoryTrace
-
-from llama_recipes.utils.tflop_counter import FlopCounterMode
-
-@contextlib.contextmanager
-def maybe_run_profiler(cfg, *args, **kwargs):
-    use_profiler: bool = cfg.profiler
-    
-    if use_profiler:
-        print(f"profiling is activated and results will be saved in {cfg.profile_output_dir}")
-        with torch.profiler.profile(
-            activities=[
-                torch.profiler.ProfilerActivity.CPU,
-                torch.profiler.ProfilerActivity.CUDA,
-            ],
-            schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
-            on_trace_ready=torch.profiler.tensorboard_trace_handler(
-                cfg.profile_output_dir
-            ),
-            profile_memory=True,
-            with_stack=False,
-            record_shapes=True,
-        ) as torch_profiler:
-            yield torch_profiler
-    else:
-        torch_profiler = contextlib.nullcontext()
-        yield None
-            
-def get_total_flops(model):
-    return (sum([v for _, v in model.flop_counts["Global"].items()]))
-
 from accelerate.utils import is_xpu_available, is_ccl_available
 from accelerate.utils import is_xpu_available, is_ccl_available
 
 
-
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
     tokenizer.padding_side = "left"
@@ -118,62 +86,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             total_loss = 0.0
             total_loss = 0.0
             total_length = len(train_dataloader)//gradient_accumulation_steps
             total_length = len(train_dataloader)//gradient_accumulation_steps
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
-
-            with maybe_run_profiler(train_config) as torch_profiler:
-                for step, batch in enumerate(train_dataloader):
-                    gc.collect(1)
-                    for key in batch.keys():
-                        if train_config.enable_fsdp:
-                            batch[key] = batch[key].to(local_rank)
-                        else:
-                            batch[key] = batch[key].to('cuda:0') 
-                    flop_check_done = False 
-                    if train_config.flop_counter and  step == 3 and not flop_check_done:
-                        flop_counter = FlopCounterMode(rank=local_rank)
-                        with flop_counter:           
-                            loss = model(**batch).loss
-                            loss = loss / gradient_accumulation_steps
-                            total_loss += loss.detach().float()
-                            if train_config.use_fp16:
-                                # if fp16 is enabled, use gradient scaler to handle gradient update
-                                scaler.scale(loss).backward()
-                                if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                                    scaler.step(optimizer)
-                                    scaler.update()
-                                    optimizer.zero_grad()
-                                    pbar.update(1)
-                            else:
-                                # regular backpropagation when fp16 is not used
-                                loss.backward()
-                                TFlops = get_total_flops(flop_counter) / 1e12
-                                flop_check_done = True
-                                if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                                    optimizer.step()
-                                    optimizer.zero_grad()
-                                    pbar.update(1)
-                        
-                    else:
-                        loss = model(**batch).loss
-                        loss = loss / gradient_accumulation_steps
-                        total_loss += loss.detach().float()
-                        if train_config.use_fp16:
-                            # if fp16 is enabled, use gradient scaler to handle gradient update
-                            scaler.scale(loss).backward()
-                            if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                                scaler.step(optimizer)
-                                scaler.update()
-                                optimizer.zero_grad()
-                                pbar.update(1)
-                        else:
-                            # regular backpropagation when fp16 is not used
-                            loss.backward()
-                            if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                                optimizer.step()
-                                optimizer.zero_grad()
-                                pbar.update(1)
-                    pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
-                pbar.close()
-
             for step, batch in enumerate(train_dataloader):
             for step, batch in enumerate(train_dataloader):
                 for key in batch.keys():
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
@@ -227,7 +139,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
                     save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
             pbar.close()
             pbar.close()
 
 
-
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)
         epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
         # Reducing total_loss across all devices if there's more than one CUDA device
@@ -355,10 +266,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
-
-    if train_config.flop_counter:
-        results["model_flops"]= TFlops
-       
     if train_config.save_metrics:
     if train_config.save_metrics:
         results["metrics_filename"] = metrics_filename
         results["metrics_filename"] = metrics_filename
 
 
@@ -389,7 +296,6 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     eval_loss = 0.0  # Initialize evaluation loss
     eval_loss = 0.0  # Initialize evaluation loss
     with MemoryTrace() as memtrace:
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
-            gc.collect(1)
             for key in batch.keys():
             for key in batch.keys():
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                     batch[key] = batch[key].to(local_rank)
@@ -595,4 +501,4 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
         "val_epoch_perplexity": val_epoch_ppl
         "val_epoch_perplexity": val_epoch_ppl
     }
     }
     with open(output_filename, "w") as f:
     with open(output_filename, "w") as f:
-        json.dump(metrics_data, f)
+        json.dump(metrics_data, f)