Przeglądaj źródła

Fixing tokenizer used for llama 3. Changing quantization configs on safety_utils.

Beto 11 miesięcy temu
rodzic
commit
f63ba19827

+ 3 - 3
recipes/inference/local_inference/inference.py

@@ -10,7 +10,7 @@ import time
 import gradio as gr
 
 import torch
-from transformers import LlamaTokenizer
+from transformers import LlamaTokenizer, AutoTokenizer
 
 from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
@@ -77,9 +77,9 @@ def main(
     model.eval()
     
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.pad_token = tokenizer.eos_token
-    
+
     batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
     if is_xpu_available():
         batch = {k: v.to("xpu") for k, v in batch.items()}

+ 5 - 3
src/llama_recipes/inference/safety_utils.py

@@ -157,13 +157,15 @@ class AzureSaftyChecker(object):
 class LlamaGuardSafetyChecker(object):
 
     def __init__(self):
-        from transformers import AutoModelForCausalLM, AutoTokenizer
+        from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
+        from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
 
         model_id = "meta-llama/LlamaGuard-7b"
 
+        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
         self.tokenizer = AutoTokenizer.from_pretrained(model_id)
-        self.model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
-        pass
+        self.model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
 
     def __call__(self, output_text, **kwargs):