convert_hf_weights_to_llama.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import json
  4. import os
  5. from typing import List, Union
  6. import fire
  7. import torch
  8. from tqdm import tqdm
  9. from transformers import LlamaForCausalLM # @manual
  10. NUM_SHARDS = {
  11. "7B": 1,
  12. "13B": 2,
  13. "34B": 4,
  14. "30B": 4,
  15. "65B": 8,
  16. "70B": 8,
  17. }
  18. def write_model(model_path, model_size, output_base_path):
  19. dtype = torch.bfloat16
  20. params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
  21. num_shards = NUM_SHARDS[model_size]
  22. n_layers = params["n_layers"]
  23. n_heads = params["n_heads"]
  24. n_heads_per_shard = n_heads // num_shards
  25. dim = params["dim"]
  26. dims_per_head = dim // n_heads
  27. base = 10000.0
  28. inv_freq = (
  29. 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
  30. ).to(dtype)
  31. if "n_kv_heads" in params:
  32. num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
  33. num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
  34. key_value_dim = dim // num_key_value_heads
  35. else: # compatibility with other checkpoints
  36. num_key_value_heads = n_heads
  37. num_local_key_value_heads = n_heads_per_shard
  38. key_value_dim = dim
  39. model = LlamaForCausalLM.from_pretrained(
  40. model_path,
  41. torch_dtype=dtype,
  42. low_cpu_mem_usage=True,
  43. )
  44. loaded = model.state_dict()
  45. # permute for sliced rotary
  46. def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
  47. return (
  48. w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
  49. .transpose(1, 2)
  50. .reshape(dim1, dim2)
  51. )
  52. state_dict = [{} for _ in range(num_shards)]
  53. def insert(name: str, tensor: Union[List, torch.Tensor]):
  54. for i in range(num_shards):
  55. state_dict[i][name] = (
  56. tensor[i].clone() if isinstance(tensor, list) else tensor
  57. )
  58. def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
  59. tensors = tensor.chunk(num_shards, dim=dim)
  60. for i, tensor in enumerate(tensors):
  61. state_dict[i][name] = tensor.clone()
  62. insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
  63. insert("norm.weight", loaded["model.norm.weight"])
  64. insert_chunk("output.weight", loaded["lm_head.weight"], 0)
  65. for layer_i in tqdm(range(n_layers), desc="Converting layers"):
  66. ts = (
  67. permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
  68. .view(n_heads_per_shard * num_shards, dims_per_head, dim)
  69. .chunk(num_shards, dim=0)
  70. )
  71. insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])
  72. ts = (
  73. permute(
  74. loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
  75. num_key_value_heads,
  76. key_value_dim,
  77. dim,
  78. )
  79. .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
  80. .chunk(num_shards, dim=0)
  81. )
  82. insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])
  83. ts = (
  84. loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
  85. .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
  86. .chunk(num_shards, dim=0)
  87. )
  88. insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])
  89. insert_chunk(
  90. f"layers.{layer_i}.attention.wo.weight",
  91. loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
  92. 1,
  93. )
  94. insert_chunk(
  95. f"layers.{layer_i}.feed_forward.w1.weight",
  96. loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
  97. 0,
  98. )
  99. insert_chunk(
  100. f"layers.{layer_i}.feed_forward.w2.weight",
  101. loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
  102. 1,
  103. )
  104. insert_chunk(
  105. f"layers.{layer_i}.feed_forward.w3.weight",
  106. loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
  107. 0,
  108. )
  109. insert(
  110. f"layers.{layer_i}.attention_norm.weight",
  111. loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
  112. )
  113. insert(
  114. f"layers.{layer_i}.ffn_norm.weight",
  115. loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
  116. )
  117. insert("rope.freqs", inv_freq)
  118. for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
  119. torch.save(
  120. state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
  121. )
  122. def main(
  123. model_path: str,
  124. model_size: str,
  125. output_dir: str,
  126. ):
  127. """Convert llama weights from huggingface format to consolidated format.
  128. params:
  129. model_path: model name or path to the model directory.
  130. model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B.
  131. output_dir: directory to save Llama weights, should contains params.json.
  132. """
  133. assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
  134. params_path = os.path.join(output_dir, "params.json")
  135. assert os.path.isfile(params_path), f"{params_path} does not exist"
  136. write_model(model_path, model_size, output_dir)
  137. if __name__ == "__main__":
  138. fire.Fire(main)