|
@@ -157,13 +157,15 @@ class AzureSaftyChecker(object):
|
|
|
class LlamaGuardSafetyChecker(object):
|
|
|
|
|
|
def __init__(self):
|
|
|
- from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
|
+ from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
|
|
|
|
|
|
model_id = "meta-llama/LlamaGuard-7b"
|
|
|
|
|
|
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
+
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
- self.model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
|
|
- pass
|
|
|
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
|
|
|
|
|
|
def __call__(self, output_text, **kwargs):
|
|
|
|