Forráskód Böngészése

Merging with latest from main

Beto 1 éve
szülő
commit
56937a68d7

+ 1 - 1
README.md

@@ -2,7 +2,7 @@
 
 **[Update Dec. 15, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md).**
 
-**[Update Nov. 16, 2023] We recently released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama (locally, in the cloud, or on-prem), how to ask Llama questions in general or about custom data (PDF, DB, or live), how to integrate Llama with WhatsApp, and how to implement an end-to-end chatbot with RAG (Retrieval Augmented Generation).**
+**[Update Dec 14, 2023] We recently released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama (locally, in the cloud, or on-prem), how to ask Llama questions in general or about custom data (PDF, DB, or live), how to integrate Llama with WhatsApp and Messenger, and how to implement an end-to-end chatbot with RAG (Retrieval Augmented Generation).**
 
 The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face).
 

+ 5 - 1
demo_apps/README.md

@@ -13,6 +13,7 @@ This folder contains a series of Llama 2-powered apps:
 2. Ask Llama questions about structured data in a DB
 3. Ask Llama questions about live data on the web
 4. Build a Llama-enabled WhatsApp chatbot
+5. Build a Llama-enabled Messenger chatbot
 
 We also show how to build quick web UI for Llama 2 demo apps using Streamlit and Gradio.
 
@@ -77,7 +78,10 @@ This demo app shows how to use LangChain and Llama2 to let users ask questions a
 This demo app shows how to perform live data augmented generation tasks with Llama2 and [LlamaIndex](https://github.com/run-llama/llama_index), another leading open-source framework for building LLM apps: it uses the [You.com search API](https://documentation.you.com/quickstart) to get live search result and ask Llama2 about them.
 
 ## [WhatsApp Chatbot](whatsapp_llama2.md): Building a Llama-enabled WhatsApp Chatbot
-This step-by-step tutorial shows how to use the [WhatsApp Business API](https://developers.facebook.com/docs/whatsapp/cloud-api/overview), LangChain and Replicate to build a Llama-enabled WhatsApp chatbot.
+This step-by-step tutorial shows how to use the [WhatsApp Business API](https://developers.facebook.com/docs/whatsapp/cloud-api/overview) to build a Llama-enabled WhatsApp chatbot.
+
+## [Messenger Chatbot](messenger_llama2.md): Building a Llama-enabled Messenger Chatbot
+This step-by-step tutorial shows how to use the [Messenger Platform](https://developers.facebook.com/docs/messenger-platform/overview) to build a Llama-enabled Messenger chatbot.
 
 ## Quick Web UI for Llama2 Chat
 If you prefer to see Llama2 in action in a web UI, instead of the notebooks above, you can try one of the two methods:

+ 45 - 0
demo_apps/llama_messenger.py

@@ -0,0 +1,45 @@
+import langchain
+from langchain.llms import Replicate
+
+from flask import Flask
+from flask import request
+import os
+import requests
+import json
+
+os.environ["REPLICATE_API_TOKEN"] = "<your replicate api token>"
+llama2_13b_chat = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+
+llm = Replicate(
+    model=llama2_13b_chat,
+    model_kwargs={"temperature": 0.01, "top_p": 1, "max_new_tokens":500}
+)
+
+app = Flask(__name__)
+
+@app.route('/msgrcvd_pager', methods=['POST', 'GET'])
+def msgrcvd_pager():    
+    message = request.args.get('message')
+    sender = request.args.get('sender')
+    recipient = request.args.get('recipient')
+
+    answer = llm(message)
+    print(message)
+    print(answer)
+
+    url = f"https://graph.facebook.com/v18.0/{recipient}/messages"
+    params = {
+        'recipient': '{"id": ' + sender + '}',
+        'message': json.dumps({'text': answer}),
+        'messaging_type': 'RESPONSE',
+        'access_token': "<your page access token>"
+    }
+    headers = {
+        'Content-Type': 'application/json'
+    }
+    response = requests.post(url, params=params, headers=headers)
+    print(response.status_code)
+    print(response.text)
+
+    return message + "<p/>" + answer
+

BIN
demo_apps/messenger_api_settings.png


A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 194 - 0
demo_apps/messenger_llama2.md


BIN
demo_apps/messenger_llama_arch.jpg


A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 2 - 0
demo_apps/whatsapp_llama2.md


+ 22 - 0
examples/hf_llama_conversion/README.md

@@ -0,0 +1,22 @@
+# Convert Hugging Face llama weights to official llama consolidated format
+
+This is the reverse conversion for `convert_llama_weights_to_hf.py` script from the transformer package.
+
+## Step 0: Convert to consolidated format
+- Create an output directory for the converted weights, such as `test70B`.
+- Copy file params.json from the official llama download into that directory.
+- Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory.
+```
+python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B
+```
+
+## Step 1: Run inference
+Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
+```
+torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model
+```
+
+For validation, please compare the converted weights with official llama 2 weights
+```
+python compare_llama_weights.py test70B ${llama_2_70b_chat_dir}
+```

+ 48 - 0
examples/hf_llama_conversion/compare_llama_weights.py

@@ -0,0 +1,48 @@
+import gc
+import glob
+import os
+import sys
+
+import torch
+import tqdm
+
+
+def main() -> None:
+    """Compare two llama checkpoint directories"""
+
+    one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth")))
+    two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth")))
+    assert len(one_files) == len(
+        two_files
+    ), "One directory has {} files while another has {} files.".format(
+        len(one_files), len(two_files)
+    )
+
+    deltas = []
+    for i in tqdm.trange(len(one_files), desc="Comparing shards"):
+        one = torch.load(one_files[i])
+        two = torch.load(two_files[i])
+        assert len(one) == len(
+            two
+        ), "shard should have the same length: {} != {}".format(len(one), len(two))
+
+        for _, (v, w) in enumerate(zip(one.items(), two.items())):
+            assert v[0] == w[0], "{} != {}".format(v[0], w[0])
+            assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
+                v[0], v[1].shape, w[1].shape
+            )
+
+            delta = (v[1] - w[1]).abs().max().item()
+            deltas.append((i, v[0], delta))
+        del one
+        del two
+        gc.collect()
+
+    deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
+    print("Top 10 largest deltas:")
+    for i, k, v in deltas[:10]:
+        print(f"  shard {i} {k}: {v}")
+
+
+if __name__ == "__main__":
+    main()

+ 3 - 0
scripts/spellcheck_conf/wordlist.txt

@@ -1209,6 +1209,9 @@ venv
 webhook
 webhook's
 whatsapp
+business
+js
+webhooks
 Anyscale
 ADDR
 ckpt

+ 1 - 1
src/llama_recipes/datasets/alpaca_dataset.py

@@ -27,7 +27,7 @@ class InstructionDataset(Dataset):
     def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         if partition == "train":
-            self.ann = self.ann
+            self.ann = self.ann[200:]
         else:
             self.ann = self.ann[:200]
 

+ 163 - 0
src/llama_recipes/tools/convert_hf_weights_to_llama.py

@@ -0,0 +1,163 @@
+import json
+import os
+from typing import List, Union
+
+import fire
+import torch
+from tqdm import tqdm
+from transformers import LlamaForCausalLM  # @manual
+
+NUM_SHARDS = {
+    "7B": 1,
+    "13B": 2,
+    "34B": 4,
+    "30B": 4,
+    "65B": 8,
+    "70B": 8,
+}
+
+
+def write_model(model_path, model_size, output_base_path):
+    dtype = torch.bfloat16
+
+    params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
+    num_shards = NUM_SHARDS[model_size]
+    n_layers = params["n_layers"]
+    n_heads = params["n_heads"]
+    n_heads_per_shard = n_heads // num_shards
+    dim = params["dim"]
+    dims_per_head = dim // n_heads
+    base = 10000.0
+    inv_freq = (
+        1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
+    ).to(dtype)
+
+    if "n_kv_heads" in params:
+        num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
+        num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
+        key_value_dim = dim // num_key_value_heads
+    else:  # compatibility with other checkpoints
+        num_key_value_heads = n_heads
+        num_local_key_value_heads = n_heads_per_shard
+        key_value_dim = dim
+
+    model = LlamaForCausalLM.from_pretrained(
+        model_path,
+        torch_dtype=dtype,
+        low_cpu_mem_usage=True,
+    )
+    loaded = model.state_dict()
+
+    # permute for sliced rotary
+    def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
+        return (
+            w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
+            .transpose(1, 2)
+            .reshape(dim1, dim2)
+        )
+
+    state_dict = [{} for _ in range(num_shards)]
+
+    def insert(name: str, tensor: Union[List, torch.Tensor]):
+        for i in range(num_shards):
+            state_dict[i][name] = (
+                tensor[i].clone() if isinstance(tensor, list) else tensor
+            )
+
+    def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
+        tensors = tensor.chunk(num_shards, dim=dim)
+        for i, tensor in enumerate(tensors):
+            state_dict[i][name] = tensor.clone()
+
+    insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
+    insert("norm.weight", loaded["model.norm.weight"])
+    insert_chunk("output.weight", loaded["lm_head.weight"], 0)
+
+    for layer_i in tqdm(range(n_layers), desc="Converting layers"):
+
+        ts = (
+            permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
+            .view(n_heads_per_shard * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])
+
+        ts = (
+            permute(
+                loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
+                num_key_value_heads,
+                key_value_dim,
+                dim,
+            )
+            .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])
+
+        ts = (
+            loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
+            .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])
+
+        insert_chunk(
+            f"layers.{layer_i}.attention.wo.weight",
+            loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
+            1,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w1.weight",
+            loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
+            0,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w2.weight",
+            loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
+            1,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w3.weight",
+            loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
+            0,
+        )
+
+        insert(
+            f"layers.{layer_i}.attention_norm.weight",
+            loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
+        )
+        insert(
+            f"layers.{layer_i}.ffn_norm.weight",
+            loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
+        )
+    insert("rope.freqs", inv_freq)
+
+    for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
+        torch.save(
+            state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
+        )
+
+
+def main(
+    model_path: str,
+    model_size: str,
+    output_dir: str,
+):
+    """Convert llama weights from huggingface format to consolidated format.
+    params:
+    model_path: model name or path to the model directory.
+    model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B.
+    output_dir: directory to save Llama weights, should contains params.json.
+    """
+    assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
+    params_path = os.path.join(output_dir, "params.json")
+    assert os.path.isfile(params_path), f"{params_path} does not exist"
+
+    write_model(model_path, model_size, output_dir)
+
+
+if __name__ == "__main__":
+    fire.Fire(main)