llamaguard_inference.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import fire
  2. from transformers import AutoTokenizer, AutoModelForCausalLM
  3. from llama_recipes.inference.prompt_format import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY, create_hf_chat
  4. from typing import List, Tuple
  5. from enum import Enum
  6. class AgentType(Enum):
  7. AGENT = "Agent"
  8. USER = "User"
  9. def main(
  10. temperature: float = 0.6,
  11. top_p: float = 0.9,
  12. max_seq_len: int = 128,
  13. max_gen_len: int = 64,
  14. max_batch_size: int = 4,
  15. ):
  16. """
  17. Entry point of the program for generating text using a pretrained model.
  18. Args:
  19. ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
  20. tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
  21. temperature (float, optional): The temperature value for controlling randomness in generation.
  22. Defaults to 0.6.
  23. top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
  24. Defaults to 0.9.
  25. max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
  26. max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
  27. max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
  28. """
  29. prompts: List[Tuple[List[str], AgentType]] = [
  30. (["<Sample user prompt>"], AgentType.USER),
  31. (["<Sample user prompt>",
  32. "<Sample agent response>"], AgentType.AGENT),
  33. ]
  34. model_id = "meta-llama/LlamaGuard-7b"
  35. device = "cuda"
  36. # dtype = torch.bfloat16
  37. tokenizer = AutoTokenizer.from_pretrained(model_id)
  38. model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
  39. for prompt in prompts:
  40. formatted_prompt = build_prompt(
  41. prompt[1],
  42. LLAMA_GUARD_CATEGORY,
  43. create_conversation(prompt[0]))
  44. input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
  45. prompt_len = input["input_ids"].shape[-1]
  46. output = model.generate(**input, max_new_tokens=100, pad_token_id=0)
  47. results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
  48. # print("\nprompt template ==================================\n")
  49. # print(formatted_prompt)
  50. print("\n==================================\n")
  51. print(f"> {results}")
  52. print("\n==================================\n")
  53. print(create_hf_chat(prompt[0]))
  54. input_ids_hf = tokenizer.apply_chat_template(create_hf_chat(prompt[0]), return_tensors="pt").to("cuda")
  55. prompt_len_hf = input_ids_hf.shape[-1]
  56. output_hf = model.generate(input_ids=input_ids_hf, max_new_tokens=100, pad_token_id=0)
  57. result_hf = tokenizer.decode(output_hf[0][prompt_len_hf:], skip_special_tokens=True)
  58. formatted_prompt_hf = tokenizer.decode(input_ids_hf[0], skip_special_tokens=True)
  59. # print("\nHF template ==================================\n")
  60. # print(formatted_prompt_hf)
  61. print("\n==================================\n")
  62. print(f"> HF {result_hf}")
  63. print("\n==================================\n")
  64. if __name__ == "__main__":
  65. fire.Fire(main)