Browse Source

Formatting and updating Readme files, removing test methods and unnecesary prints.

Beto 1 year ago
parent
commit
5f83e6d30b

+ 2 - 0
README.md

@@ -1,5 +1,7 @@
 # Llama 2 Fine-tuning / Inference Recipes, Examples and Demo Apps
 
+**[Update Dec. 15, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md).**
+
 **[Update Nov. 16, 2023] We recently released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama (locally, in the cloud, or on-prem), how to ask Llama questions in general or about custom data (PDF, DB, or live), how to integrate Llama with WhatsApp, and how to implement an end-to-end chatbot with RAG (Retrieval Augmented Generation).**
 
 The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face).

File diff suppressed because it is too large
+ 32 - 6
examples/llama_guard/README.md


+ 0 - 6
examples/llama_guard/__init__.py

@@ -1,6 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-from .generation import Llama, Dialog
-from .model import ModelArgs, Transformer
-from .tokenizer import Tokenizer

+ 10 - 30
examples/llamaguard_inference.py

@@ -2,7 +2,7 @@ import fire
 from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
-from llama_recipes.inference.prompt_format import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY, create_hf_chat
+from llama_recipes.inference.prompt_format import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY
 from typing import List, Tuple
 from enum import Enum
 
@@ -10,13 +10,7 @@ class AgentType(Enum):
     AGENT = "Agent"
     USER = "User"
 
-def main(
-    temperature: float = 0.6,
-    top_p: float = 0.9,
-    max_seq_len: int = 128,
-    max_gen_len: int = 64,
-    max_batch_size: int = 4,
-):
+def main():
     """
     Entry point of the program for generating text using a pretrained model.
     Args:
@@ -36,13 +30,16 @@ def main(
 
         (["<Sample user prompt>",
         "<Sample agent response>"], AgentType.AGENT),
+        
+        (["<Sample user prompt>",
+        "<Sample agent response>",
+        "<Sample user reply>",
+        "<Sample agent response>",], AgentType.AGENT),
 
     ]
 
     model_id = "meta-llama/LlamaGuard-7b"
-    device = "cuda"
-    # dtype = torch.bfloat16
-
+    
     tokenizer = AutoTokenizer.from_pretrained(model_id)
     model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
 
@@ -59,27 +56,10 @@ def main(
         output = model.generate(**input, max_new_tokens=100, pad_token_id=0)
         results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
        
-        # print("\nprompt template ==================================\n")
-        # print(formatted_prompt)
-        print("\n==================================\n")
+        
+        print(prompt[0])
         print(f"> {results}")
         print("\n==================================\n")
 
-    
-        print(create_hf_chat(prompt[0]))
-        input_ids_hf = tokenizer.apply_chat_template(create_hf_chat(prompt[0]), return_tensors="pt").to("cuda")
-        prompt_len_hf = input_ids_hf.shape[-1]
-        output_hf = model.generate(input_ids=input_ids_hf, max_new_tokens=100, pad_token_id=0)
-        result_hf = tokenizer.decode(output_hf[0][prompt_len_hf:], skip_special_tokens=True)
-
-        formatted_prompt_hf = tokenizer.decode(input_ids_hf[0], skip_special_tokens=True)
-
-        # print("\nHF template ==================================\n")
-        # print(formatted_prompt_hf)
-        print("\n==================================\n")
-        print(f"> HF {result_hf}")
-        print("\n==================================\n")
-
-
 if __name__ == "__main__":
     fire.Fire(main)

+ 0 - 7
src/llama_recipes/inference/prompt_format.py

@@ -107,13 +107,6 @@ def create_conversation(messges: List[str]) -> List[ConversationTurn]:
 
     return conversations
 
-def create_hf_chat(messges: List[str]) -> List[any]:
-    conversations = []
-    for i, messge in enumerate(messges):
-        conversations.append({"role": "user" if i % 2 == 0 else "assistant", "content": messge})
-
-    return conversations
-
 def build_prompt(
         agent_type: AgentType, 
         categories: List[SafetyCategory],