inference.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import fire
  2. from transformers import AutoTokenizer, AutoModelForCausalLM
  3. from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY
  4. from typing import List, Tuple
  5. from enum import Enum
  6. class AgentType(Enum):
  7. AGENT = "Agent"
  8. USER = "User"
  9. def main():
  10. """
  11. Entry point of the program for generating text using a pretrained model.
  12. Args:
  13. ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
  14. tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
  15. temperature (float, optional): The temperature value for controlling randomness in generation.
  16. Defaults to 0.6.
  17. top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
  18. Defaults to 0.9.
  19. max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
  20. max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
  21. max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
  22. """
  23. prompts: List[Tuple[List[str], AgentType]] = [
  24. (["<Sample user prompt>"], AgentType.USER),
  25. (["<Sample user prompt>",
  26. "<Sample agent response>"], AgentType.AGENT),
  27. (["<Sample user prompt>",
  28. "<Sample agent response>",
  29. "<Sample user reply>",
  30. "<Sample agent response>",], AgentType.AGENT),
  31. ]
  32. model_id = "meta-llama/LlamaGuard-7b"
  33. tokenizer = AutoTokenizer.from_pretrained(model_id)
  34. model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
  35. for prompt in prompts:
  36. formatted_prompt = build_prompt(
  37. prompt[1],
  38. LLAMA_GUARD_CATEGORY,
  39. create_conversation(prompt[0]))
  40. input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
  41. prompt_len = input["input_ids"].shape[-1]
  42. output = model.generate(**input, max_new_tokens=100, pad_token_id=0)
  43. results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
  44. print(prompt[0])
  45. print(f"> {results}")
  46. print("\n==================================\n")
  47. if __name__ == "__main__":
  48. fire.Fire(main)