tflop_counter.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. # Temp copy of Horace Flops Counter.
  2. # This supports distributed to avoid printing * every GPU.
  3. # Remove after main file is updated.
  4. import torch
  5. from torch.utils._pytree import tree_map
  6. from typing import List, Any, Dict, Optional, Union
  7. from collections import defaultdict
  8. from torch.utils._python_dispatch import TorchDispatchMode
  9. from math import prod
  10. __all__ = ["FlopCounterMode"]
  11. aten = torch.ops.aten
  12. def get_shape(i):
  13. if isinstance(i, torch.Tensor):
  14. return i.shape
  15. return i
  16. def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
  17. """
  18. Count flops for matmul.
  19. """
  20. # Inputs should be a list of length 2.
  21. # Inputs contains the shapes of two matrices.
  22. m, k = a_shape
  23. k2, n = b_shape
  24. assert k == k2
  25. # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
  26. return m * n * 2 * k
  27. def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
  28. """
  29. Count flops for addmm
  30. """
  31. return mm_flop(a_shape, b_shape)
  32. def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
  33. """
  34. Count flops for the bmm operation.
  35. """
  36. # Inputs should be a list of length 2.
  37. # Inputs contains the shapes of two tensor.
  38. b, m, k = a_shape
  39. b2, k2, n = b_shape
  40. assert b == b2
  41. assert k == k2
  42. # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
  43. flop = b * m * n * 2 * k
  44. return flop
  45. def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
  46. """
  47. Count flops for the baddbmm operation.
  48. """
  49. # Inputs should be a list of length 3.
  50. # Inputs contains the shapes of three tensors.
  51. return bmm_flop(a_shape, b_shape)
  52. def conv_flop_count(
  53. x_shape: List[int],
  54. w_shape: List[int],
  55. out_shape: List[int],
  56. transposed: bool = False,
  57. ) -> int:
  58. """
  59. Count flops for convolution. Note only multiplication is
  60. counted. Computation for bias are ignored.
  61. Flops for a transposed convolution are calculated as
  62. flops = (x_shape[2:] * prod(w_shape) * batch_size).
  63. Args:
  64. x_shape (list(int)): The input shape before convolution.
  65. w_shape (list(int)): The filter shape.
  66. out_shape (list(int)): The output shape after convolution.
  67. transposed (bool): is the convolution transposed
  68. Returns:
  69. int: the number of flops
  70. """
  71. batch_size = x_shape[0]
  72. conv_shape = (x_shape if transposed else out_shape)[2:]
  73. c_out, c_in, *dims = w_shape
  74. # NB(chilli): I don't think this properly accounts for padding :think:
  75. # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
  76. flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in
  77. return flop
  78. def conv_flop(
  79. x_shape,
  80. w_shape,
  81. _bias,
  82. _stride,
  83. _padding,
  84. _dilation,
  85. transposed,
  86. *args,
  87. out_shape=None,
  88. **kwargs
  89. ) -> int:
  90. """
  91. Count flops for convolution.
  92. """
  93. return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
  94. def transpose_shape(shape):
  95. return [shape[1], shape[0]] + list(shape[2:])
  96. def conv_backward_flop(
  97. grad_out_shape,
  98. x_shape,
  99. w_shape,
  100. _bias,
  101. _stride,
  102. _padding,
  103. _dilation,
  104. transposed,
  105. _output_padding,
  106. _groups,
  107. output_mask,
  108. out_shape,
  109. ) -> int:
  110. flop_count = 0
  111. if output_mask[0]:
  112. grad_input_shape = get_shape(out_shape[0])
  113. flop_count += conv_flop_count(
  114. grad_out_shape, w_shape, grad_input_shape, not transposed
  115. )
  116. if output_mask[1]:
  117. grad_weight_shape = get_shape(out_shape[1])
  118. flop_count += conv_flop_count(
  119. transpose_shape(x_shape), grad_out_shape, grad_weight_shape, transposed
  120. )
  121. return flop_count
  122. def sdpa_flop_count(query_shape, key_shape, value_shape):
  123. """
  124. Count flops for self-attention.
  125. NB: We can assume that value_shape == key_shape
  126. """
  127. b, h, s_q, d_q = query_shape
  128. _b2, _h2, s_k, _d2 = key_shape
  129. _b3, _h3, _s3, d_v = value_shape
  130. assert (
  131. b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
  132. )
  133. total_flops = 0
  134. # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
  135. total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
  136. # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
  137. total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
  138. return total_flops
  139. def sdpa_flop(
  140. query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs
  141. ) -> int:
  142. """
  143. Count flops for self-attention.
  144. """
  145. # NB: We aren't accounting for causal attention here
  146. return sdpa_flop_count(query_shape, key_shape, value_shape)
  147. def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
  148. total_flops = 0
  149. b, h, s_q, d_q = query_shape
  150. _b2, _h2, s_k, _d2 = key_shape
  151. _b3, _h3, _s3, d_v = value_shape
  152. _b4, _h4, _s4, _d4 = grad_out_shape
  153. assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
  154. assert d_v == _d4 and s_k == _s3 and s_q == _s4
  155. total_flops = 0
  156. # Step 1: We recompute the scores matrix.
  157. # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
  158. total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
  159. # Step 2: We propagate the gradients through the score @ v operation.
  160. # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
  161. total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
  162. # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
  163. total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
  164. # Step 3: We propagate th gradients through the k @ v operation
  165. # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
  166. total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
  167. # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
  168. total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
  169. return total_flops
  170. def sdpa_backward_flop(
  171. grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs
  172. ) -> int:
  173. """
  174. Count flops for self-attention backward.
  175. """
  176. return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  177. flop_mapping = {
  178. aten.mm: mm_flop,
  179. aten.addmm: addmm_flop,
  180. aten.bmm: bmm_flop,
  181. aten.baddbmm: baddbmm_flop,
  182. aten.convolution: conv_flop,
  183. aten._convolution: conv_flop,
  184. aten.convolution_backward: conv_backward_flop,
  185. aten._scaled_dot_product_efficient_attention: sdpa_flop,
  186. aten._scaled_dot_product_flash_attention: sdpa_flop,
  187. aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
  188. aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
  189. }
  190. def normalize_tuple(x):
  191. if not isinstance(x, tuple):
  192. return (x,)
  193. return x
  194. # Define the suffixes for different orders of magnitude
  195. suffixes = ["", "K", "M", "B", "T"]
  196. # Thanks BingChat!
  197. def get_suffix_str(number):
  198. # Find the index of the appropriate suffix based on the number of digits
  199. # with some additional overflow.
  200. # i.e. 1.01B should be displayed as 1001M, not 1.001B
  201. index = max(0, min(len(suffixes) - 1, (len(str(number)) - 3) // 3))
  202. return suffixes[index]
  203. def convert_num_with_suffix(number, suffix):
  204. index = suffixes.index(suffix)
  205. # Divide the number by 1000^index and format it to two decimal places
  206. value = "{:.3f}".format(number / (1000**index))
  207. # Return the value and the suffix as a string
  208. return value + suffixes[index]
  209. class FlopCounterMode(TorchDispatchMode):
  210. """
  211. ``FlopCounterMode`` is a context manager that counts the number of
  212. flops within its context. It does this using a ``TorchDispatchMode``.
  213. It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
  214. Example usage
  215. .. code-block:: python
  216. mod = ...
  217. flop_counter = FlopCounterMode(mod)
  218. with flop_counter:
  219. mod.sum().backward()
  220. """
  221. def __init__(
  222. self,
  223. mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
  224. depth: int = 2,
  225. display: bool = True,
  226. custom_mapping: Dict[Any, Any] = None,
  227. rank=None,
  228. ):
  229. self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(
  230. lambda: defaultdict(int)
  231. )
  232. self.depth = depth
  233. self.parents = ["Global"]
  234. self.display = display
  235. self.rank = rank
  236. if custom_mapping is None:
  237. custom_mapping = {}
  238. if isinstance(mods, torch.nn.Module):
  239. mods = [mods]
  240. self.mods = mods
  241. if mods is not None:
  242. for mod in mods:
  243. prefix = type(mod).__name__
  244. for name, module in dict(mod.named_modules()).items():
  245. if name == "":
  246. name = prefix
  247. else:
  248. name = ".".join([prefix, name])
  249. module.register_forward_pre_hook(self._enter_module(name))
  250. module.register_forward_hook(self._exit_module(name))
  251. self.flop_mapping = {**flop_mapping, **custom_mapping}
  252. def _enter_module(self, name):
  253. def f(module, inputs):
  254. inputs = normalize_tuple(inputs)
  255. out = self._create_pre_module(name)(*inputs)
  256. return out
  257. return f
  258. def _exit_module(self, name):
  259. def f(module, inputs, outputs):
  260. outputs = normalize_tuple(outputs)
  261. return self._create_post_module(name)(*outputs)
  262. return f
  263. def _create_post_module(self, name):
  264. class PushState(torch.autograd.Function):
  265. @staticmethod
  266. def forward(ctx, *args):
  267. assert self.parents[-1] == name
  268. self.parents.pop()
  269. args = tree_map(
  270. lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
  271. )
  272. if len(args) == 1:
  273. return args[0]
  274. return args
  275. @staticmethod
  276. def backward(ctx, *grad_outs):
  277. self.parents.append(name)
  278. return grad_outs
  279. return PushState.apply
  280. def _create_pre_module(self, name):
  281. class PopState(torch.autograd.Function):
  282. @staticmethod
  283. def forward(ctx, *args):
  284. self.parents.append(name)
  285. args = tree_map(
  286. lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
  287. )
  288. if len(args) == 1:
  289. return args[0]
  290. return args
  291. @staticmethod
  292. def backward(ctx, *grad_outs):
  293. assert self.parents[-1] == name
  294. self.parents.pop()
  295. return grad_outs
  296. return PopState.apply
  297. def get_total_flops(self) -> int:
  298. return sum(self.flop_counts["Global"].values())
  299. def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
  300. """Returns the flop counts as a dictionary of dictionaries. The outer
  301. dictionary is keyed by module name, and the inner dictionary is keyed by
  302. operation name.
  303. Returns:
  304. Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
  305. """
  306. return dict(self.flop_counts)
  307. def get_table(self, depth=None):
  308. if depth is None:
  309. depth = self.depth
  310. if depth is None:
  311. depth = 999999
  312. import tabulate
  313. tabulate.PRESERVE_WHITESPACE = True
  314. header = ["Module", "FLOP", "% Total"]
  315. values = []
  316. global_flops = self.get_total_flops()
  317. global_suffix = get_suffix_str(global_flops)
  318. is_global_subsumed = False
  319. def process_mod(mod_name, depth):
  320. nonlocal is_global_subsumed
  321. total_flops = sum(self.flop_counts[mod_name].values())
  322. is_global_subsumed |= total_flops >= global_flops
  323. padding = " " * depth
  324. values = []
  325. values.append(
  326. [
  327. padding + mod_name,
  328. convert_num_with_suffix(total_flops, global_suffix),
  329. "{:.2f}%".format(total_flops / global_flops * 100),
  330. ]
  331. )
  332. for k, v in self.flop_counts[mod_name].items():
  333. values.append(
  334. [
  335. padding + " - " + str(k),
  336. convert_num_with_suffix(v, global_suffix),
  337. "{:.2f}%".format(v / global_flops * 100),
  338. ]
  339. )
  340. return values
  341. for mod in self.flop_counts.keys():
  342. if mod == "Global":
  343. continue
  344. mod_depth = mod.count(".") + 1
  345. if mod_depth > depth:
  346. continue
  347. cur_values = process_mod(mod, mod_depth - 1)
  348. for value in cur_values:
  349. values.append(value)
  350. # We do a bit of messing around here to only output the "Global" value
  351. # if there are any FLOPs in there that aren't already fully contained by
  352. # a module.
  353. if "Global" in self.flop_counts and not is_global_subsumed:
  354. for idx, value in enumerate(values):
  355. values[idx][0] = " " + values[idx][0]
  356. values = process_mod("Global", 0) + values
  357. if len(values) == 0:
  358. values = [["Global", "0", "0%"]]
  359. return tabulate.tabulate(
  360. values, headers=header, colalign=("left", "right", "right")
  361. )
  362. def __enter__(self):
  363. self.flop_counts.clear()
  364. super().__enter__()
  365. return self
  366. def __exit__(self, *args):
  367. if self.display:
  368. if self.rank is None or self.rank == 0:
  369. print(self.get_table(self.depth))
  370. super().__exit__(*args)
  371. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  372. kwargs = kwargs if kwargs else {}
  373. out = func(*args, **kwargs)
  374. func_packet = func._overloadpacket
  375. if func_packet in self.flop_mapping:
  376. flop_count_func = self.flop_mapping[func_packet]
  377. args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
  378. flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape) # type: ignore[operator]
  379. for par in self.parents:
  380. self.flop_counts[par][func_packet] += flop_count
  381. return out