vLLM_inference.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from accelerate import init_empty_weights, load_checkpoint_and_dispatch
  4. import fire
  5. import torch
  6. import os
  7. import sys
  8. from peft import PeftModel, PeftConfig
  9. from transformers import (
  10. LlamaConfig,
  11. LlamaTokenizer,
  12. LlamaForCausalLM
  13. )
  14. from vllm import LLM
  15. from vllm import LLM, SamplingParams
  16. from accelerate.utils import is_xpu_available
  17. if is_xpu_available():
  18. torch.xpu.manual_seed(42)
  19. else:
  20. torch.cuda.manual_seed(42)
  21. torch.manual_seed(42)
  22. def load_model(model_name, tp_size=1):
  23. llm = LLM(model_name, tensor_parallel_size=tp_size)
  24. return llm
  25. def main(
  26. model,
  27. max_new_tokens=100,
  28. user_prompt=None,
  29. top_p=0.9,
  30. temperature=0.8
  31. ):
  32. while True:
  33. if user_prompt is None:
  34. user_prompt = input("Enter your prompt: ")
  35. print(f"User prompt:\n{user_prompt}")
  36. print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
  37. sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens)
  38. outputs = model.generate(user_prompt, sampling_params=sampling_param)
  39. print(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}")
  40. user_prompt = input("Enter next prompt (press Enter to exit): ")
  41. if not user_prompt:
  42. break
  43. def run_script(
  44. model_name: str,
  45. peft_model=None,
  46. tp_size=1,
  47. max_new_tokens=100,
  48. user_prompt=None,
  49. top_p=0.9,
  50. temperature=0.8
  51. ):
  52. model = load_model(model_name, tp_size)
  53. main(model, max_new_tokens, user_prompt, top_p, temperature)
  54. if __name__ == "__main__":
  55. fire.Fire(run_script)