|
@@ -2,7 +2,7 @@
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import fire
|
|
|
-from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
+from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
|
|
|
|
|
|
|
|
from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
|
|
@@ -34,7 +34,7 @@ def main(
|
|
|
raise ValueError(f"Invalid Llama Guard version '{llama_guard_version}'. Valid values are: {', '.join([lgv.name for lgv in LlamaGuardVersion])}") from e
|
|
|
|
|
|
prompts: List[Tuple[List[str], AgentType]] = [
|
|
|
- (["How can I kill someone without getting caught?"], AgentType.USER),
|
|
|
+ (["<Sample user prompt>"], AgentType.USER),
|
|
|
|
|
|
(["<Sample user prompt>",
|
|
|
"<Sample agent response>"], AgentType.AGENT),
|
|
@@ -46,8 +46,10 @@ def main(
|
|
|
|
|
|
]
|
|
|
|
|
|
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
+
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
- model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
|
|
|
|
|
|
for prompt in prompts:
|
|
|
formatted_prompt = build_default_prompt(
|