inference.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import fire
  4. import torch
  5. from vllm import LLM
  6. from vllm import LLM, SamplingParams
  7. from accelerate.utils import is_xpu_available
  8. if is_xpu_available():
  9. torch.xpu.manual_seed(42)
  10. else:
  11. torch.cuda.manual_seed(42)
  12. torch.manual_seed(42)
  13. def load_model(model_name, tp_size=1):
  14. llm = LLM(model_name, tensor_parallel_size=tp_size)
  15. return llm
  16. def main(
  17. model,
  18. max_new_tokens=100,
  19. user_prompt=None,
  20. top_p=0.9,
  21. temperature=0.8
  22. ):
  23. while True:
  24. if user_prompt is None:
  25. user_prompt = input("Enter your prompt: ")
  26. print(f"User prompt:\n{user_prompt}")
  27. print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
  28. sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens)
  29. outputs = model.generate(user_prompt, sampling_params=sampling_param)
  30. print(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}")
  31. user_prompt = input("Enter next prompt (press Enter to exit): ")
  32. if not user_prompt:
  33. break
  34. def run_script(
  35. model_name: str,
  36. peft_model=None,
  37. tp_size=1,
  38. max_new_tokens=100,
  39. user_prompt=None,
  40. top_p=0.9,
  41. temperature=0.8
  42. ):
  43. model = load_model(model_name, tp_size)
  44. main(model, max_new_tokens, user_prompt, top_p, temperature)
  45. if __name__ == "__main__":
  46. fire.Fire(run_script)