123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464 |
- # Temp copy of Horace Flops Counter.
- # This supports distributed to avoid printing * every GPU.
- # Remove after main file is updated.
- import torch
- from torch.utils._pytree import tree_map
- from typing import List, Any, Dict, Optional, Union
- from collections import defaultdict
- from torch.utils._python_dispatch import TorchDispatchMode
- from math import prod
- __all__ = ["FlopCounterMode"]
- aten = torch.ops.aten
- def get_shape(i):
- if isinstance(i, torch.Tensor):
- return i.shape
- return i
- def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
- """
- Count flops for matmul.
- """
- # Inputs should be a list of length 2.
- # Inputs contains the shapes of two matrices.
- m, k = a_shape
- k2, n = b_shape
- assert k == k2
- # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
- return m * n * 2 * k
- def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
- """
- Count flops for addmm
- """
- return mm_flop(a_shape, b_shape)
- def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
- """
- Count flops for the bmm operation.
- """
- # Inputs should be a list of length 2.
- # Inputs contains the shapes of two tensor.
- b, m, k = a_shape
- b2, k2, n = b_shape
- assert b == b2
- assert k == k2
- # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
- flop = b * m * n * 2 * k
- return flop
- def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
- """
- Count flops for the baddbmm operation.
- """
- # Inputs should be a list of length 3.
- # Inputs contains the shapes of three tensors.
- return bmm_flop(a_shape, b_shape)
- def conv_flop_count(
- x_shape: List[int],
- w_shape: List[int],
- out_shape: List[int],
- transposed: bool = False,
- ) -> int:
- """
- Count flops for convolution. Note only multiplication is
- counted. Computation for bias are ignored.
- Flops for a transposed convolution are calculated as
- flops = (x_shape[2:] * prod(w_shape) * batch_size).
- Args:
- x_shape (list(int)): The input shape before convolution.
- w_shape (list(int)): The filter shape.
- out_shape (list(int)): The output shape after convolution.
- transposed (bool): is the convolution transposed
- Returns:
- int: the number of flops
- """
- batch_size = x_shape[0]
- conv_shape = (x_shape if transposed else out_shape)[2:]
- c_out, c_in, *dims = w_shape
- # NB(chilli): I don't think this properly accounts for padding :think:
- # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
- flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in
- return flop
- def conv_flop(
- x_shape,
- w_shape,
- _bias,
- _stride,
- _padding,
- _dilation,
- transposed,
- *args,
- out_shape=None,
- **kwargs
- ) -> int:
- """
- Count flops for convolution.
- """
- return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
- def transpose_shape(shape):
- return [shape[1], shape[0]] + list(shape[2:])
- def conv_backward_flop(
- grad_out_shape,
- x_shape,
- w_shape,
- _bias,
- _stride,
- _padding,
- _dilation,
- transposed,
- _output_padding,
- _groups,
- output_mask,
- out_shape,
- ) -> int:
- flop_count = 0
- if output_mask[0]:
- grad_input_shape = get_shape(out_shape[0])
- flop_count += conv_flop_count(
- grad_out_shape, w_shape, grad_input_shape, not transposed
- )
- if output_mask[1]:
- grad_weight_shape = get_shape(out_shape[1])
- flop_count += conv_flop_count(
- transpose_shape(x_shape), grad_out_shape, grad_weight_shape, transposed
- )
- return flop_count
- def sdpa_flop_count(query_shape, key_shape, value_shape):
- """
- Count flops for self-attention.
- NB: We can assume that value_shape == key_shape
- """
- b, h, s_q, d_q = query_shape
- _b2, _h2, s_k, _d2 = key_shape
- _b3, _h3, _s3, d_v = value_shape
- assert (
- b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
- )
- total_flops = 0
- # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
- total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
- # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
- total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
- return total_flops
- def sdpa_flop(
- query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs
- ) -> int:
- """
- Count flops for self-attention.
- """
- # NB: We aren't accounting for causal attention here
- return sdpa_flop_count(query_shape, key_shape, value_shape)
- def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
- total_flops = 0
- b, h, s_q, d_q = query_shape
- _b2, _h2, s_k, _d2 = key_shape
- _b3, _h3, _s3, d_v = value_shape
- _b4, _h4, _s4, _d4 = grad_out_shape
- assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
- assert d_v == _d4 and s_k == _s3 and s_q == _s4
- total_flops = 0
- # Step 1: We recompute the scores matrix.
- # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
- total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
- # Step 2: We propagate the gradients through the score @ v operation.
- # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
- total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
- # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
- total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
- # Step 3: We propagate th gradients through the k @ v operation
- # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
- total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
- # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
- total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
- return total_flops
- def sdpa_backward_flop(
- grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs
- ) -> int:
- """
- Count flops for self-attention backward.
- """
- return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
- flop_mapping = {
- aten.mm: mm_flop,
- aten.addmm: addmm_flop,
- aten.bmm: bmm_flop,
- aten.baddbmm: baddbmm_flop,
- aten.convolution: conv_flop,
- aten._convolution: conv_flop,
- aten.convolution_backward: conv_backward_flop,
- aten._scaled_dot_product_efficient_attention: sdpa_flop,
- aten._scaled_dot_product_flash_attention: sdpa_flop,
- aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
- aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
- }
- def normalize_tuple(x):
- if not isinstance(x, tuple):
- return (x,)
- return x
- # Define the suffixes for different orders of magnitude
- suffixes = ["", "K", "M", "B", "T"]
- # Thanks BingChat!
- def get_suffix_str(number):
- # Find the index of the appropriate suffix based on the number of digits
- # with some additional overflow.
- # i.e. 1.01B should be displayed as 1001M, not 1.001B
- index = max(0, min(len(suffixes) - 1, (len(str(number)) - 3) // 3))
- return suffixes[index]
- def convert_num_with_suffix(number, suffix):
- index = suffixes.index(suffix)
- # Divide the number by 1000^index and format it to two decimal places
- value = "{:.3f}".format(number / (1000**index))
- # Return the value and the suffix as a string
- return value + suffixes[index]
- class FlopCounterMode(TorchDispatchMode):
- """
- ``FlopCounterMode`` is a context manager that counts the number of
- flops within its context. It does this using a ``TorchDispatchMode``.
- It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
- Example usage
- .. code-block:: python
- mod = ...
- flop_counter = FlopCounterMode(mod)
- with flop_counter:
- mod.sum().backward()
- """
- def __init__(
- self,
- mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
- depth: int = 2,
- display: bool = True,
- custom_mapping: Dict[Any, Any] = None,
- rank=None,
- ):
- self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(
- lambda: defaultdict(int)
- )
- self.depth = depth
- self.parents = ["Global"]
- self.display = display
- self.rank = rank
- if custom_mapping is None:
- custom_mapping = {}
- if isinstance(mods, torch.nn.Module):
- mods = [mods]
- self.mods = mods
- if mods is not None:
- for mod in mods:
- prefix = type(mod).__name__
- for name, module in dict(mod.named_modules()).items():
- if name == "":
- name = prefix
- else:
- name = ".".join([prefix, name])
- module.register_forward_pre_hook(self._enter_module(name))
- module.register_forward_hook(self._exit_module(name))
- self.flop_mapping = {**flop_mapping, **custom_mapping}
- def _enter_module(self, name):
- def f(module, inputs):
- inputs = normalize_tuple(inputs)
- out = self._create_pre_module(name)(*inputs)
- return out
- return f
- def _exit_module(self, name):
- def f(module, inputs, outputs):
- outputs = normalize_tuple(outputs)
- return self._create_post_module(name)(*outputs)
- return f
- def _create_post_module(self, name):
- class PushState(torch.autograd.Function):
- @staticmethod
- def forward(ctx, *args):
- assert self.parents[-1] == name
- self.parents.pop()
- args = tree_map(
- lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
- )
- if len(args) == 1:
- return args[0]
- return args
- @staticmethod
- def backward(ctx, *grad_outs):
- self.parents.append(name)
- return grad_outs
- return PushState.apply
- def _create_pre_module(self, name):
- class PopState(torch.autograd.Function):
- @staticmethod
- def forward(ctx, *args):
- self.parents.append(name)
- args = tree_map(
- lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
- )
- if len(args) == 1:
- return args[0]
- return args
- @staticmethod
- def backward(ctx, *grad_outs):
- assert self.parents[-1] == name
- self.parents.pop()
- return grad_outs
- return PopState.apply
- def get_total_flops(self) -> int:
- return sum(self.flop_counts["Global"].values())
- def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
- """Returns the flop counts as a dictionary of dictionaries. The outer
- dictionary is keyed by module name, and the inner dictionary is keyed by
- operation name.
- Returns:
- Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
- """
- return dict(self.flop_counts)
- def get_table(self, depth=None):
- if depth is None:
- depth = self.depth
- if depth is None:
- depth = 999999
- import tabulate
- tabulate.PRESERVE_WHITESPACE = True
- header = ["Module", "FLOP", "% Total"]
- values = []
- global_flops = self.get_total_flops()
- global_suffix = get_suffix_str(global_flops)
- is_global_subsumed = False
- def process_mod(mod_name, depth):
- nonlocal is_global_subsumed
- total_flops = sum(self.flop_counts[mod_name].values())
- is_global_subsumed |= total_flops >= global_flops
- padding = " " * depth
- values = []
- values.append(
- [
- padding + mod_name,
- convert_num_with_suffix(total_flops, global_suffix),
- "{:.2f}%".format(total_flops / global_flops * 100),
- ]
- )
- for k, v in self.flop_counts[mod_name].items():
- values.append(
- [
- padding + " - " + str(k),
- convert_num_with_suffix(v, global_suffix),
- "{:.2f}%".format(v / global_flops * 100),
- ]
- )
- return values
- for mod in self.flop_counts.keys():
- if mod == "Global":
- continue
- mod_depth = mod.count(".") + 1
- if mod_depth > depth:
- continue
- cur_values = process_mod(mod, mod_depth - 1)
- for value in cur_values:
- values.append(value)
- # We do a bit of messing around here to only output the "Global" value
- # if there are any FLOPs in there that aren't already fully contained by
- # a module.
- if "Global" in self.flop_counts and not is_global_subsumed:
- for idx, value in enumerate(values):
- values[idx][0] = " " + values[idx][0]
- values = process_mod("Global", 0) + values
- if len(values) == 0:
- values = [["Global", "0", "0%"]]
- return tabulate.tabulate(
- values, headers=header, colalign=("left", "right", "right")
- )
- def __enter__(self):
- self.flop_counts.clear()
- super().__enter__()
- return self
- def __exit__(self, *args):
- if self.display:
- if self.rank is None or self.rank == 0:
- print(self.get_table(self.depth))
- super().__exit__(*args)
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- kwargs = kwargs if kwargs else {}
- out = func(*args, **kwargs)
- func_packet = func._overloadpacket
- if func_packet in self.flop_mapping:
- flop_count_func = self.flop_mapping[func_packet]
- args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
- flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape) # type: ignore[operator]
- for par in self.parents:
- self.flop_counts[par][func_packet] += flop_count
- return out
|