|
@@ -2,35 +2,25 @@ import json
|
|
|
import os
|
|
|
from typing import List, Union
|
|
|
|
|
|
-import click
|
|
|
+import fire
|
|
|
import torch
|
|
|
from tqdm import tqdm
|
|
|
from transformers import LlamaForCausalLM # @manual
|
|
|
|
|
|
NUM_SHARDS = {
|
|
|
"7B": 1,
|
|
|
- "7Bf": 1,
|
|
|
"13B": 2,
|
|
|
- "13Bf": 2,
|
|
|
"34B": 4,
|
|
|
"30B": 4,
|
|
|
"65B": 8,
|
|
|
"70B": 8,
|
|
|
- "70Bf": 8,
|
|
|
}
|
|
|
|
|
|
|
|
|
-def read_json(path):
|
|
|
- with open(path, "r") as f:
|
|
|
- return json.load(f)
|
|
|
-
|
|
|
-
|
|
|
def write_model(model_path, model_size, output_base_path):
|
|
|
dtype = torch.bfloat16
|
|
|
|
|
|
- params_path = os.path.join(output_base_path, "params.json")
|
|
|
- assert os.path.isfile(params_path), f"{params_path} does not exist"
|
|
|
- params = read_json(params_path)
|
|
|
+ params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
|
|
|
num_shards = NUM_SHARDS[model_size]
|
|
|
n_layers = params["n_layers"]
|
|
|
n_heads = params["n_heads"]
|
|
@@ -151,41 +141,18 @@ def write_model(model_path, model_size, output_base_path):
|
|
|
)
|
|
|
|
|
|
|
|
|
-@click.command()
|
|
|
-@click.option(
|
|
|
- "--model-path",
|
|
|
- type=str,
|
|
|
- default="meta-llama/Llama-2-7b-chat-hf",
|
|
|
- help="Model name or path to the model directory.",
|
|
|
-)
|
|
|
-@click.option(
|
|
|
- "--model-size",
|
|
|
- type=click.Choice(
|
|
|
- [
|
|
|
- "7B",
|
|
|
- "7Bf",
|
|
|
- "13B",
|
|
|
- "13Bf",
|
|
|
- "30B",
|
|
|
- "34B",
|
|
|
- "65B",
|
|
|
- "70B",
|
|
|
- "70Bf",
|
|
|
- ]
|
|
|
- ),
|
|
|
- default="7Bf",
|
|
|
- help="llama model size, f' models correspond to the finetuned versions.",
|
|
|
-)
|
|
|
-@click.option(
|
|
|
- "--output-dir",
|
|
|
- type=str,
|
|
|
- required=True,
|
|
|
- help="Save Llama weights. Should already contains params.json",
|
|
|
-)
|
|
|
-def main(model_path: str, model_size: str, output_dir: str):
|
|
|
+def main(
|
|
|
+ model_path: str, # Model name or path to the model directory
|
|
|
+ model_size: str, # llama model size.
|
|
|
+ output_dir: str # Save Llama weights. Should already contains params.json.
|
|
|
+ ):
|
|
|
"""Convert llama huggingface format to consolidated weights."""
|
|
|
+ assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
|
|
|
+ params_path = os.path.join(output_dir, "params.json")
|
|
|
+ assert os.path.isfile(params_path), f"{params_path} does not exist"
|
|
|
+
|
|
|
write_model(model_path, model_size, output_dir)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- main()
|
|
|
+ fire.Fire(main)
|