|
@@ -2,7 +2,7 @@ import fire
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
|
|
|
-from llama_recipes.inference.prompt_format import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY, create_hf_chat
|
|
|
+from llama_recipes.inference.prompt_format import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY
|
|
|
from typing import List, Tuple
|
|
|
from enum import Enum
|
|
|
|
|
@@ -10,13 +10,7 @@ class AgentType(Enum):
|
|
|
AGENT = "Agent"
|
|
|
USER = "User"
|
|
|
|
|
|
-def main(
|
|
|
- temperature: float = 0.6,
|
|
|
- top_p: float = 0.9,
|
|
|
- max_seq_len: int = 128,
|
|
|
- max_gen_len: int = 64,
|
|
|
- max_batch_size: int = 4,
|
|
|
-):
|
|
|
+def main():
|
|
|
"""
|
|
|
Entry point of the program for generating text using a pretrained model.
|
|
|
Args:
|
|
@@ -36,13 +30,16 @@ def main(
|
|
|
|
|
|
(["<Sample user prompt>",
|
|
|
"<Sample agent response>"], AgentType.AGENT),
|
|
|
+
|
|
|
+ (["<Sample user prompt>",
|
|
|
+ "<Sample agent response>",
|
|
|
+ "<Sample user reply>",
|
|
|
+ "<Sample agent response>",], AgentType.AGENT),
|
|
|
|
|
|
]
|
|
|
|
|
|
model_id = "meta-llama/LlamaGuard-7b"
|
|
|
- device = "cuda"
|
|
|
- # dtype = torch.bfloat16
|
|
|
-
|
|
|
+
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
|
|
|
|
@@ -59,27 +56,10 @@ def main(
|
|
|
output = model.generate(**input, max_new_tokens=100, pad_token_id=0)
|
|
|
results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
|
|
|
|
|
|
- # print("\nprompt template ==================================\n")
|
|
|
- # print(formatted_prompt)
|
|
|
- print("\n==================================\n")
|
|
|
+
|
|
|
+ print(prompt[0])
|
|
|
print(f"> {results}")
|
|
|
print("\n==================================\n")
|
|
|
|
|
|
-
|
|
|
- print(create_hf_chat(prompt[0]))
|
|
|
- input_ids_hf = tokenizer.apply_chat_template(create_hf_chat(prompt[0]), return_tensors="pt").to("cuda")
|
|
|
- prompt_len_hf = input_ids_hf.shape[-1]
|
|
|
- output_hf = model.generate(input_ids=input_ids_hf, max_new_tokens=100, pad_token_id=0)
|
|
|
- result_hf = tokenizer.decode(output_hf[0][prompt_len_hf:], skip_special_tokens=True)
|
|
|
-
|
|
|
- formatted_prompt_hf = tokenizer.decode(input_ids_hf[0], skip_special_tokens=True)
|
|
|
-
|
|
|
- # print("\nHF template ==================================\n")
|
|
|
- # print(formatted_prompt_hf)
|
|
|
- print("\n==================================\n")
|
|
|
- print(f"> HF {result_hf}")
|
|
|
- print("\n==================================\n")
|
|
|
-
|
|
|
-
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(main)
|