model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import math
  4. from dataclasses import dataclass
  5. from typing import Optional, Tuple
  6. import fairscale.nn.model_parallel.initialize as fs_init
  7. import torch
  8. import torch.nn.functional as F
  9. from fairscale.nn.model_parallel.layers import (
  10. ColumnParallelLinear,
  11. ParallelEmbedding,
  12. RowParallelLinear,
  13. )
  14. from torch import nn
  15. @dataclass
  16. class ModelArgs:
  17. dim: int = 4096
  18. n_layers: int = 32
  19. n_heads: int = 32
  20. n_kv_heads: Optional[int] = None
  21. vocab_size: int = -1 # defined later by tokenizer
  22. multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
  23. ffn_dim_multiplier: Optional[float] = None
  24. norm_eps: float = 1e-5
  25. max_batch_size: int = 32
  26. max_seq_len: int = 2048
  27. class RMSNorm(torch.nn.Module):
  28. def __init__(self, dim: int, eps: float = 1e-6):
  29. """
  30. Initialize the RMSNorm normalization layer.
  31. Args:
  32. dim (int): The dimension of the input tensor.
  33. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
  34. Attributes:
  35. eps (float): A small value added to the denominator for numerical stability.
  36. weight (nn.Parameter): Learnable scaling parameter.
  37. """
  38. super().__init__()
  39. self.eps = eps
  40. self.weight = nn.Parameter(torch.ones(dim))
  41. def _norm(self, x):
  42. """
  43. Apply the RMSNorm normalization to the input tensor.
  44. Args:
  45. x (torch.Tensor): The input tensor.
  46. Returns:
  47. torch.Tensor: The normalized tensor.
  48. """
  49. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  50. def forward(self, x):
  51. """
  52. Forward pass through the RMSNorm layer.
  53. Args:
  54. x (torch.Tensor): The input tensor.
  55. Returns:
  56. torch.Tensor: The output tensor after applying RMSNorm.
  57. """
  58. output = self._norm(x.float()).type_as(x)
  59. return output * self.weight
  60. def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
  61. """
  62. Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
  63. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
  64. and the end index 'end'. The 'theta' parameter scales the frequencies.
  65. The returned tensor contains complex values in complex64 data type.
  66. Args:
  67. dim (int): Dimension of the frequency tensor.
  68. end (int): End index for precomputing frequencies.
  69. theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
  70. Returns:
  71. torch.Tensor: Precomputed frequency tensor with complex exponentials.
  72. """
  73. freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
  74. t = torch.arange(end, device=freqs.device) # type: ignore
  75. freqs = torch.outer(t, freqs).float() # type: ignore
  76. freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
  77. return freqs_cis
  78. def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
  79. """
  80. Reshape frequency tensor for broadcasting it with another tensor.
  81. This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
  82. for the purpose of broadcasting the frequency tensor during element-wise operations.
  83. Args:
  84. freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
  85. x (torch.Tensor): Target tensor for broadcasting compatibility.
  86. Returns:
  87. torch.Tensor: Reshaped frequency tensor.
  88. Raises:
  89. AssertionError: If the frequency tensor doesn't match the expected shape.
  90. AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
  91. """
  92. ndim = x.ndim
  93. assert 0 <= 1 < ndim
  94. assert freqs_cis.shape == (x.shape[1], x.shape[-1])
  95. shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
  96. return freqs_cis.view(*shape)
  97. def apply_rotary_emb(
  98. xq: torch.Tensor,
  99. xk: torch.Tensor,
  100. freqs_cis: torch.Tensor,
  101. ) -> Tuple[torch.Tensor, torch.Tensor]:
  102. """
  103. Apply rotary embeddings to input tensors using the given frequency tensor.
  104. This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
  105. frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
  106. is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
  107. returned as real tensors.
  108. Args:
  109. xq (torch.Tensor): Query tensor to apply rotary embeddings.
  110. xk (torch.Tensor): Key tensor to apply rotary embeddings.
  111. freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
  112. Returns:
  113. Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
  114. """
  115. xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
  116. xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
  117. freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
  118. xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
  119. xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
  120. return xq_out.type_as(xq), xk_out.type_as(xk)
  121. def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
  122. """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
  123. bs, slen, n_kv_heads, head_dim = x.shape
  124. if n_rep == 1:
  125. return x
  126. return (
  127. x[:, :, :, None, :]
  128. .expand(bs, slen, n_kv_heads, n_rep, head_dim)
  129. .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
  130. )
  131. class Attention(nn.Module):
  132. """Multi-head attention module."""
  133. def __init__(self, args: ModelArgs):
  134. """
  135. Initialize the Attention module.
  136. Args:
  137. args (ModelArgs): Model configuration parameters.
  138. Attributes:
  139. n_kv_heads (int): Number of key and value heads.
  140. n_local_heads (int): Number of local query heads.
  141. n_local_kv_heads (int): Number of local key and value heads.
  142. n_rep (int): Number of repetitions for local heads.
  143. head_dim (int): Dimension size of each attention head.
  144. wq (ColumnParallelLinear): Linear transformation for queries.
  145. wk (ColumnParallelLinear): Linear transformation for keys.
  146. wv (ColumnParallelLinear): Linear transformation for values.
  147. wo (RowParallelLinear): Linear transformation for output.
  148. cache_k (torch.Tensor): Cached keys for attention.
  149. cache_v (torch.Tensor): Cached values for attention.
  150. """
  151. super().__init__()
  152. self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
  153. model_parallel_size = fs_init.get_model_parallel_world_size()
  154. self.n_local_heads = args.n_heads // model_parallel_size
  155. self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
  156. self.n_rep = self.n_local_heads // self.n_local_kv_heads
  157. self.head_dim = args.dim // args.n_heads
  158. self.wq = ColumnParallelLinear(
  159. args.dim,
  160. args.n_heads * self.head_dim,
  161. bias=False,
  162. gather_output=False,
  163. init_method=lambda x: x,
  164. )
  165. self.wk = ColumnParallelLinear(
  166. args.dim,
  167. self.n_kv_heads * self.head_dim,
  168. bias=False,
  169. gather_output=False,
  170. init_method=lambda x: x,
  171. )
  172. self.wv = ColumnParallelLinear(
  173. args.dim,
  174. self.n_kv_heads * self.head_dim,
  175. bias=False,
  176. gather_output=False,
  177. init_method=lambda x: x,
  178. )
  179. self.wo = RowParallelLinear(
  180. args.n_heads * self.head_dim,
  181. args.dim,
  182. bias=False,
  183. input_is_parallel=True,
  184. init_method=lambda x: x,
  185. )
  186. self.cache_k = torch.zeros(
  187. (
  188. args.max_batch_size,
  189. args.max_seq_len,
  190. self.n_local_kv_heads,
  191. self.head_dim,
  192. )
  193. ).cuda()
  194. self.cache_v = torch.zeros(
  195. (
  196. args.max_batch_size,
  197. args.max_seq_len,
  198. self.n_local_kv_heads,
  199. self.head_dim,
  200. )
  201. ).cuda()
  202. def forward(
  203. self,
  204. x: torch.Tensor,
  205. start_pos: int,
  206. freqs_cis: torch.Tensor,
  207. mask: Optional[torch.Tensor],
  208. ):
  209. """
  210. Forward pass of the attention module.
  211. Args:
  212. x (torch.Tensor): Input tensor.
  213. start_pos (int): Starting position for caching.
  214. freqs_cis (torch.Tensor): Precomputed frequency tensor.
  215. mask (torch.Tensor, optional): Attention mask tensor.
  216. Returns:
  217. torch.Tensor: Output tensor after attention.
  218. """
  219. bsz, seqlen, _ = x.shape
  220. xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
  221. xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  222. xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
  223. xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
  224. xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
  225. self.cache_k = self.cache_k.to(xq)
  226. self.cache_v = self.cache_v.to(xq)
  227. self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
  228. self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
  229. keys = self.cache_k[:bsz, : start_pos + seqlen]
  230. values = self.cache_v[:bsz, : start_pos + seqlen]
  231. # repeat k/v heads if n_kv_heads < n_heads
  232. keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
  233. values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
  234. xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
  235. keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
  236. values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
  237. scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
  238. if mask is not None:
  239. scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
  240. scores = F.softmax(scores.float(), dim=-1).type_as(xq)
  241. output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
  242. output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
  243. return self.wo(output)
  244. class FeedForward(nn.Module):
  245. def __init__(
  246. self,
  247. dim: int,
  248. hidden_dim: int,
  249. multiple_of: int,
  250. ffn_dim_multiplier: Optional[float],
  251. ):
  252. """
  253. Initialize the FeedForward module.
  254. Args:
  255. dim (int): Input dimension.
  256. hidden_dim (int): Hidden dimension of the feedforward layer.
  257. multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
  258. ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
  259. Attributes:
  260. w1 (ColumnParallelLinear): Linear transformation for the first layer.
  261. w2 (RowParallelLinear): Linear transformation for the second layer.
  262. w3 (ColumnParallelLinear): Linear transformation for the third layer.
  263. """
  264. super().__init__()
  265. hidden_dim = int(2 * hidden_dim / 3)
  266. # custom dim factor multiplier
  267. if ffn_dim_multiplier is not None:
  268. hidden_dim = int(ffn_dim_multiplier * hidden_dim)
  269. hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  270. self.w1 = ColumnParallelLinear(
  271. dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
  272. )
  273. self.w2 = RowParallelLinear(
  274. hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
  275. )
  276. self.w3 = ColumnParallelLinear(
  277. dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
  278. )
  279. def forward(self, x):
  280. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  281. class TransformerBlock(nn.Module):
  282. def __init__(self, layer_id: int, args: ModelArgs):
  283. """
  284. Initialize a TransformerBlock.
  285. Args:
  286. layer_id (int): Identifier for the layer.
  287. args (ModelArgs): Model configuration parameters.
  288. Attributes:
  289. n_heads (int): Number of attention heads.
  290. dim (int): Dimension size of the model.
  291. head_dim (int): Dimension size of each attention head.
  292. attention (Attention): Attention module.
  293. feed_forward (FeedForward): FeedForward module.
  294. layer_id (int): Identifier for the layer.
  295. attention_norm (RMSNorm): Layer normalization for attention output.
  296. ffn_norm (RMSNorm): Layer normalization for feedforward output.
  297. """
  298. super().__init__()
  299. self.n_heads = args.n_heads
  300. self.dim = args.dim
  301. self.head_dim = args.dim // args.n_heads
  302. self.attention = Attention(args)
  303. self.feed_forward = FeedForward(
  304. dim=args.dim,
  305. hidden_dim=4 * args.dim,
  306. multiple_of=args.multiple_of,
  307. ffn_dim_multiplier=args.ffn_dim_multiplier,
  308. )
  309. self.layer_id = layer_id
  310. self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
  311. self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
  312. def forward(
  313. self,
  314. x: torch.Tensor,
  315. start_pos: int,
  316. freqs_cis: torch.Tensor,
  317. mask: Optional[torch.Tensor],
  318. ):
  319. """
  320. Perform a forward pass through the TransformerBlock.
  321. Args:
  322. x (torch.Tensor): Input tensor.
  323. start_pos (int): Starting position for attention caching.
  324. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
  325. mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
  326. Returns:
  327. torch.Tensor: Output tensor after applying attention and feedforward layers.
  328. """
  329. h = x + self.attention.forward(
  330. self.attention_norm(x), start_pos, freqs_cis, mask
  331. )
  332. out = h + self.feed_forward.forward(self.ffn_norm(h))
  333. return out
  334. class Transformer(nn.Module):
  335. def __init__(self, params: ModelArgs):
  336. """
  337. Initialize a Transformer model.
  338. Args:
  339. params (ModelArgs): Model configuration parameters.
  340. Attributes:
  341. params (ModelArgs): Model configuration parameters.
  342. vocab_size (int): Vocabulary size.
  343. n_layers (int): Number of layers in the model.
  344. tok_embeddings (ParallelEmbedding): Token embeddings.
  345. layers (torch.nn.ModuleList): List of Transformer blocks.
  346. norm (RMSNorm): Layer normalization for the model output.
  347. output (ColumnParallelLinear): Linear layer for final output.
  348. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
  349. """
  350. super().__init__()
  351. self.params = params
  352. self.vocab_size = params.vocab_size
  353. self.n_layers = params.n_layers
  354. self.tok_embeddings = ParallelEmbedding(
  355. params.vocab_size, params.dim, init_method=lambda x: x
  356. )
  357. self.layers = torch.nn.ModuleList()
  358. for layer_id in range(params.n_layers):
  359. self.layers.append(TransformerBlock(layer_id, params))
  360. self.norm = RMSNorm(params.dim, eps=params.norm_eps)
  361. self.output = ColumnParallelLinear(
  362. params.dim, params.vocab_size, bias=False, init_method=lambda x: x
  363. )
  364. self.freqs_cis = precompute_freqs_cis(
  365. # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
  366. # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
  367. self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
  368. )
  369. @torch.inference_mode()
  370. def forward(self, tokens: torch.Tensor, start_pos: int):
  371. """
  372. Perform a forward pass through the Transformer model.
  373. Args:
  374. tokens (torch.Tensor): Input token indices.
  375. start_pos (int): Starting position for attention caching.
  376. Returns:
  377. torch.Tensor: Output logits after applying the Transformer model.
  378. """
  379. _bsz, seqlen = tokens.shape
  380. h = self.tok_embeddings(tokens)
  381. self.freqs_cis = self.freqs_cis.to(h.device)
  382. freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
  383. mask = None
  384. if seqlen > 1:
  385. mask = torch.full(
  386. (seqlen, seqlen), float("-inf"), device=tokens.device
  387. )
  388. mask = torch.triu(mask, diagonal=1)
  389. # When performing key-value caching, we compute the attention scores
  390. # only for the new sequence. Thus, the matrix of scores is of size
  391. # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
  392. # j > cache_len + i, since row i corresponds to token cache_len + i.
  393. mask = torch.hstack([
  394. torch.zeros((seqlen, start_pos), device=tokens.device),
  395. mask
  396. ]).type_as(h)
  397. for layer in self.layers:
  398. h = layer(h, start_pos, freqs_cis, mask)
  399. h = self.norm(h)
  400. output = self.output(h).float()
  401. return output