convert_llama_weights_from_hf.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import json
  2. import os
  3. from typing import List, Union
  4. import click
  5. import torch
  6. from tqdm import tqdm
  7. from transformers import LlamaForCausalLM # @manual
  8. NUM_SHARDS = {
  9. "7B": 1,
  10. "7Bf": 1,
  11. "13B": 2,
  12. "13Bf": 2,
  13. "34B": 4,
  14. "30B": 4,
  15. "65B": 8,
  16. "70B": 8,
  17. "70Bf": 8,
  18. }
  19. def read_json(path):
  20. with open(path, "r") as f:
  21. return json.load(f)
  22. def write_model(model_path, model_size, output_base_path):
  23. dtype = torch.bfloat16
  24. params_path = os.path.join(output_base_path, "params.json")
  25. assert os.path.isfile(params_path), f"{params_path} does not exist"
  26. params = read_json(params_path)
  27. num_shards = NUM_SHARDS[model_size]
  28. n_layers = params["n_layers"]
  29. n_heads = params["n_heads"]
  30. n_heads_per_shard = n_heads // num_shards
  31. dim = params["dim"]
  32. dims_per_head = dim // n_heads
  33. base = 10000.0
  34. inv_freq = (
  35. 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
  36. ).to(dtype)
  37. if "n_kv_heads" in params:
  38. num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
  39. num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
  40. key_value_dim = dim // num_key_value_heads
  41. else: # compatibility with other checkpoints
  42. num_key_value_heads = n_heads
  43. num_local_key_value_heads = n_heads_per_shard
  44. key_value_dim = dim
  45. model = LlamaForCausalLM.from_pretrained(
  46. model_path,
  47. torch_dtype=dtype,
  48. low_cpu_mem_usage=True,
  49. )
  50. loaded = model.state_dict()
  51. # permute for sliced rotary
  52. def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
  53. return (
  54. w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
  55. .transpose(1, 2)
  56. .reshape(dim1, dim2)
  57. )
  58. state_dict = [{} for _ in range(num_shards)]
  59. def insert(name: str, tensor: Union[List, torch.Tensor]):
  60. for i in range(num_shards):
  61. state_dict[i][name] = (
  62. tensor[i].clone() if isinstance(tensor, list) else tensor
  63. )
  64. def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
  65. tensors = tensor.chunk(num_shards, dim=dim)
  66. for i, tensor in enumerate(tensors):
  67. state_dict[i][name] = tensor.clone()
  68. insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
  69. insert("norm.weight", loaded["model.norm.weight"])
  70. insert_chunk("output.weight", loaded["lm_head.weight"], 0)
  71. for layer_i in tqdm(range(n_layers), desc="Converting layers"):
  72. ts = (
  73. permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
  74. .view(n_heads_per_shard * num_shards, dims_per_head, dim)
  75. .chunk(num_shards, dim=0)
  76. )
  77. insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])
  78. ts = (
  79. permute(
  80. loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
  81. num_key_value_heads,
  82. key_value_dim,
  83. dim,
  84. )
  85. .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
  86. .chunk(num_shards, dim=0)
  87. )
  88. insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])
  89. ts = (
  90. loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
  91. .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
  92. .chunk(num_shards, dim=0)
  93. )
  94. insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])
  95. insert_chunk(
  96. f"layers.{layer_i}.attention.wo.weight",
  97. loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
  98. 1,
  99. )
  100. insert_chunk(
  101. f"layers.{layer_i}.feed_forward.w1.weight",
  102. loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
  103. 0,
  104. )
  105. insert_chunk(
  106. f"layers.{layer_i}.feed_forward.w2.weight",
  107. loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
  108. 1,
  109. )
  110. insert_chunk(
  111. f"layers.{layer_i}.feed_forward.w3.weight",
  112. loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
  113. 0,
  114. )
  115. insert(
  116. f"layers.{layer_i}.attention_norm.weight",
  117. loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
  118. )
  119. insert(
  120. f"layers.{layer_i}.ffn_norm.weight",
  121. loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
  122. )
  123. insert("rope.freqs", inv_freq)
  124. for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
  125. torch.save(
  126. state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
  127. )
  128. @click.command()
  129. @click.option(
  130. "--model-path",
  131. type=str,
  132. default="meta-llama/Llama-2-7b-chat-hf",
  133. help="Model name or path to the model directory.",
  134. )
  135. @click.option(
  136. "--model-size",
  137. type=click.Choice(
  138. [
  139. "7B",
  140. "7Bf",
  141. "13B",
  142. "13Bf",
  143. "30B",
  144. "34B",
  145. "65B",
  146. "70B",
  147. "70Bf",
  148. ]
  149. ),
  150. default="7Bf",
  151. help="llama model size, f' models correspond to the finetuned versions.",
  152. )
  153. @click.option(
  154. "--output-dir",
  155. type=str,
  156. required=True,
  157. help="Save Llama weights. Should already contains params.json",
  158. )
  159. def main(model_path: str, model_size: str, output_dir: str):
  160. """Convert llama huggingface format to consolidated weights."""
  161. write_model(model_path, model_size, output_dir)
  162. if __name__ == "__main__":
  163. main()