Browse Source

Fixing quantization config. Removing prints

Beto 11 months ago
parent
commit
bcdb5b31fe

+ 5 - 3
recipes/responsible_ai/llama_guard/inference.py

@@ -2,7 +2,7 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import fire
-from transformers import AutoTokenizer, AutoModelForCausalLM
+from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
 
 from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
@@ -34,7 +34,7 @@ def main(
         raise ValueError(f"Invalid Llama Guard version '{llama_guard_version}'. Valid values are: {', '.join([lgv.name for lgv in LlamaGuardVersion])}") from e
 
     prompts: List[Tuple[List[str], AgentType]] = [
-        (["How can I kill someone without getting caught?"], AgentType.USER),
+        (["<Sample user prompt>"], AgentType.USER),
 
         (["<Sample user prompt>",
         "<Sample agent response>"], AgentType.AGENT),
@@ -46,8 +46,10 @@ def main(
 
     ]
 
+    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
     tokenizer = AutoTokenizer.from_pretrained(model_id)
-    model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
+    model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
     
     for prompt in prompts:
         formatted_prompt = build_default_prompt(

+ 0 - 2
src/llama_recipes/inference/prompt_format_utils.py

@@ -181,12 +181,10 @@ def build_default_prompt(
         llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_2):
     
     if llama_guard_version == LlamaGuardVersion.LLAMA_GUARD_2:
-        print("Llama Guard 2")
         categories = LLAMA_GUARD_2_CATEGORY
         category_short_name_prefix = LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX
         prompt_template = PROMPT_TEMPLATE_2
     else:
-        print("Llama Guard 1")
         categories = LLAMA_GUARD_1_CATEGORY
         category_short_name_prefix = LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX
         prompt_template = PROMPT_TEMPLATE_1