Forráskód Böngészése

Add example conversion script to convert hf to consolidated weight format

Dong Wang 1 éve
szülő
commit
e755ed1d8f

+ 22 - 0
examples/hf_llama_conversion/README.md

@@ -0,0 +1,22 @@
+# Convert huggingface llama weights to official llama consolidated format
+
+This is the reverse conversion for `convert_llama_weights_to_hf.py` script from the transformer package.
+
+## Step 0: Convert to consolidated format
+- Create an output directory for the converted weights, eg `test70Bf`.
+- Copy file params.json from the official llama download into that directory.
+- Run the conversion script. `model-path` can be a huggingface hub model or a local hf model directory.
+```
+python convert_llama_weights_from_hf.py --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70Bf --model-size 70Bf
+```
+
+## Step 1: Run inference
+Checkout the offical llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
+```
+torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70Bf --tokenizer_path ${llama_2_dir}/tokenizer.model
+```
+
+For validation, please compare the converted weights with official llama 2 weights
+```
+python compare_llama_weights.py test70Bf ${llama_2_70b_chat_dir}
+```

+ 48 - 0
examples/hf_llama_conversion/compare_llama_weights.py

@@ -0,0 +1,48 @@
+import gc
+import glob
+import os
+import sys
+
+import torch
+import tqdm
+
+
+def main() -> None:
+    """Compare two llama checkpoint directories"""
+
+    one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth")))
+    two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth")))
+    assert len(one_files) == len(
+        two_files
+    ), "One directory has {} files while another has {} files.".format(
+        len(one_files), len(two_files)
+    )
+
+    deltas = []
+    for i in tqdm.trange(len(one_files), desc="Comparing shards"):
+        one = torch.load(one_files[i])
+        two = torch.load(two_files[i])
+        assert len(one) == len(
+            two
+        ), "shard should have the same length: {} != {}".format(len(one), len(two))
+
+        for _, (v, w) in enumerate(zip(one.items(), two.items())):
+            assert v[0] == w[0], "{} != {}".format(v[0], w[0])
+            assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
+                v[0], v[1].shape, w[1].shape
+            )
+
+            delta = (v[1] - w[1]).abs().max().item()
+            deltas.append((i, v[0], delta))
+        del one
+        del two
+        gc.collect()
+
+    deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
+    print("Top 10 largest deltas:")
+    for i, k, v in deltas[:10]:
+        print(f"  shard {i} {k}: {v}")
+
+
+if __name__ == "__main__":
+    main()

+ 191 - 0
examples/hf_llama_conversion/convert_llama_weights_from_hf.py

@@ -0,0 +1,191 @@
+import json
+import os
+from typing import List, Union
+
+import click
+import torch
+from tqdm import tqdm
+from transformers import LlamaForCausalLM  # @manual
+
+NUM_SHARDS = {
+    "7B": 1,
+    "7Bf": 1,
+    "13B": 2,
+    "13Bf": 2,
+    "34B": 4,
+    "30B": 4,
+    "65B": 8,
+    "70B": 8,
+    "70Bf": 8,
+}
+
+
+def read_json(path):
+    with open(path, "r") as f:
+        return json.load(f)
+
+
+def write_model(model_path, model_size, output_base_path):
+    dtype = torch.bfloat16
+
+    params_path = os.path.join(output_base_path, "params.json")
+    assert os.path.isfile(params_path), f"{params_path} does not exist"
+    params = read_json(params_path)
+    num_shards = NUM_SHARDS[model_size]
+    n_layers = params["n_layers"]
+    n_heads = params["n_heads"]
+    n_heads_per_shard = n_heads // num_shards
+    dim = params["dim"]
+    dims_per_head = dim // n_heads
+    base = 10000.0
+    inv_freq = (
+        1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
+    ).to(dtype)
+
+    if "n_kv_heads" in params:
+        num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
+        num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
+        key_value_dim = dim // num_key_value_heads
+    else:  # compatibility with other checkpoints
+        num_key_value_heads = n_heads
+        num_local_key_value_heads = n_heads_per_shard
+        key_value_dim = dim
+
+    model = LlamaForCausalLM.from_pretrained(
+        model_path,
+        torch_dtype=dtype,
+        low_cpu_mem_usage=True,
+    )
+    loaded = model.state_dict()
+
+    # permute for sliced rotary
+    def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
+        return (
+            w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
+            .transpose(1, 2)
+            .reshape(dim1, dim2)
+        )
+
+    state_dict = [{} for _ in range(num_shards)]
+
+    def insert(name: str, tensor: Union[List, torch.Tensor]):
+        for i in range(num_shards):
+            state_dict[i][name] = (
+                tensor[i].clone() if isinstance(tensor, list) else tensor
+            )
+
+    def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
+        tensors = tensor.chunk(num_shards, dim=dim)
+        for i, tensor in enumerate(tensors):
+            state_dict[i][name] = tensor.clone()
+
+    insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
+    insert("norm.weight", loaded["model.norm.weight"])
+    insert_chunk("output.weight", loaded["lm_head.weight"], 0)
+
+    for layer_i in tqdm(range(n_layers), desc="Converting layers"):
+
+        ts = (
+            permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
+            .view(n_heads_per_shard * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])
+
+        ts = (
+            permute(
+                loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
+                num_key_value_heads,
+                key_value_dim,
+                dim,
+            )
+            .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])
+
+        ts = (
+            loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
+            .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])
+
+        insert_chunk(
+            f"layers.{layer_i}.attention.wo.weight",
+            loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
+            1,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w1.weight",
+            loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
+            0,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w2.weight",
+            loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
+            1,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w3.weight",
+            loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
+            0,
+        )
+
+        insert(
+            f"layers.{layer_i}.attention_norm.weight",
+            loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
+        )
+        insert(
+            f"layers.{layer_i}.ffn_norm.weight",
+            loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
+        )
+    insert("rope.freqs", inv_freq)
+
+    for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
+        torch.save(
+            state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
+        )
+
+
+@click.command()
+@click.option(
+    "--model-path",
+    type=str,
+    default="meta-llama/Llama-2-7b-chat-hf",
+    help="Model name or path to the model directory.",
+)
+@click.option(
+    "--model-size",
+    type=click.Choice(
+        [
+            "7B",
+            "7Bf",
+            "13B",
+            "13Bf",
+            "30B",
+            "34B",
+            "65B",
+            "70B",
+            "70Bf",
+        ]
+    ),
+    default="7Bf",
+    help="llama model size, f' models correspond to the finetuned versions.",
+)
+@click.option(
+    "--output-dir",
+    type=str,
+    required=True,
+    help="Save Llama weights. Should already contains params.json",
+)
+def main(model_path: str, model_size: str, output_dir: str):
+    """Convert llama huggingface format to consolidated weights."""
+    write_model(model_path, model_size, output_dir)
+
+
+if __name__ == "__main__":
+    main()