Browse Source

update fore comments: use module

Dong Wang 1 year ago
parent
commit
35acf4934e

+ 4 - 4
examples/hf_llama_conversion/README.md

@@ -3,20 +3,20 @@
 This is the reverse conversion for `convert_llama_weights_to_hf.py` script from the transformer package.
 
 ## Step 0: Convert to consolidated format
-- Create an output directory for the converted weights, such as `test70Bf`.
+- Create an output directory for the converted weights, such as `test70B`.
 - Copy file params.json from the official llama download into that directory.
 - Run the conversion script. `model-path` can be a huggingface hub model or a local hf model directory.
 ```
-python convert_llama_weights_from_hf.py --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70Bf --model-size 70Bf
+python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B
 ```
 
 ## Step 1: Run inference
 Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
 ```
-torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70Bf --tokenizer_path ${llama_2_dir}/tokenizer.model
+torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model
 ```
 
 For validation, please compare the converted weights with official llama 2 weights
 ```
-python compare_llama_weights.py test70Bf ${llama_2_70b_chat_dir}
+python compare_llama_weights.py test70B ${llama_2_70b_chat_dir}
 ```

+ 12 - 45
examples/hf_llama_conversion/convert_llama_weights_from_hf.py

@@ -2,35 +2,25 @@ import json
 import os
 from typing import List, Union
 
-import click
+import fire
 import torch
 from tqdm import tqdm
 from transformers import LlamaForCausalLM  # @manual
 
 NUM_SHARDS = {
     "7B": 1,
-    "7Bf": 1,
     "13B": 2,
-    "13Bf": 2,
     "34B": 4,
     "30B": 4,
     "65B": 8,
     "70B": 8,
-    "70Bf": 8,
 }
 
 
-def read_json(path):
-    with open(path, "r") as f:
-        return json.load(f)
-
-
 def write_model(model_path, model_size, output_base_path):
     dtype = torch.bfloat16
 
-    params_path = os.path.join(output_base_path, "params.json")
-    assert os.path.isfile(params_path), f"{params_path} does not exist"
-    params = read_json(params_path)
+    params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
     num_shards = NUM_SHARDS[model_size]
     n_layers = params["n_layers"]
     n_heads = params["n_heads"]
@@ -151,41 +141,18 @@ def write_model(model_path, model_size, output_base_path):
         )
 
 
-@click.command()
-@click.option(
-    "--model-path",
-    type=str,
-    default="meta-llama/Llama-2-7b-chat-hf",
-    help="Model name or path to the model directory.",
-)
-@click.option(
-    "--model-size",
-    type=click.Choice(
-        [
-            "7B",
-            "7Bf",
-            "13B",
-            "13Bf",
-            "30B",
-            "34B",
-            "65B",
-            "70B",
-            "70Bf",
-        ]
-    ),
-    default="7Bf",
-    help="llama model size, f' models correspond to the finetuned versions.",
-)
-@click.option(
-    "--output-dir",
-    type=str,
-    required=True,
-    help="Save Llama weights. Should already contains params.json",
-)
-def main(model_path: str, model_size: str, output_dir: str):
+def main(
+    model_path: str, # Model name or path to the model directory
+    model_size: str, # llama model size.
+    output_dir: str # Save Llama weights. Should already contains params.json.
+    ):
     """Convert llama huggingface format to consolidated weights."""
+    assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
+    params_path = os.path.join(output_dir, "params.json")
+    assert os.path.isfile(params_path), f"{params_path} does not exist"
+
     write_model(model_path, model_size, output_dir)
 
 
 if __name__ == "__main__":
-    main()
+    fire.Fire(main)