pretrained_vllm_benchmark.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import csv
  4. import json
  5. import time
  6. import random
  7. import threading
  8. import numpy as np
  9. import requests
  10. import transformers
  11. import torch
  12. #imports for Azure content safety
  13. from azure.ai.contentsafety import ContentSafetyClient
  14. from azure.core.credentials import AzureKeyCredential
  15. from azure.core.exceptions import HttpResponseError
  16. from azure.ai.contentsafety.models import AnalyzeTextOptions
  17. from concurrent.futures import ThreadPoolExecutor, as_completed
  18. from typing import Dict, Tuple, List
  19. # Predefined inputs
  20. with open('input.jsonl') as input:
  21. prompt_data = json.load(input)
  22. with open('parameters.json') as parameters:
  23. params = json.load(parameters)
  24. MAX_NEW_TOKENS = params["MAX_NEW_TOKENS"]
  25. CONCURRENT_LEVELS = params["CONCURRENT_LEVELS"]
  26. # Replace with your own deployment
  27. MODEL_PATH = params["MODEL_PATH"]
  28. MODEL_HEADERS = params["MODEL_HEADERS"]
  29. SAFE_CHECK = params["SAFE_CHECK"]
  30. # Threshold for tokens per second below which we deem the query to be slow
  31. THRESHOLD_TPS = params["THRESHOLD_TPS"]
  32. # Replace with your own tokenizer
  33. TOKENIZER_PATH = params["TOKENIZER_PATH"]
  34. RANDOM_PROMPT_LENGTH = params["RANDOM_PROMPT_LENGTH"]
  35. TEMPERATURE = params["TEMPERATURE"]
  36. TOP_P = params["TOP_P"]
  37. # Add your model endpoints here, specify the port number. You can acquire the endpoint when creating a on-prem server like vLLM.
  38. # Group of model endpoints - Send balanced requests to each endpoint for batch maximization.
  39. MODEL_ENDPOINTS = params["MODEL_ENDPOINTS"]
  40. #Get number of GPUs on this instance
  41. if torch.cuda.is_available():
  42. NUM_GPU = torch.cuda.device_count()
  43. else:
  44. print("No available GPUs")
  45. # This tokenizer is downloaded from Azure model catalog for each specific models. The main purpose is to decode the reponses for token calculation
  46. tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)
  47. # Select vocabulary that is longer than 2 tokens (closer to real words) and close to the English (not foolproof)
  48. vocab = [token for token in tokenizer.get_vocab().keys() if len(token) > 2 and all(ord(c) < 128 for c in token)]
  49. def generate_random_prompt(num_tokens):
  50. generated_tokens_count = 0
  51. selected_tokens = ""
  52. while generated_tokens_count < num_tokens:
  53. selected_tokens += random.choice(vocab)
  54. selected_tokens += " "
  55. generated_tokens_count = len(tokenizer.encode(selected_tokens))
  56. return selected_tokens
  57. PROMPT = generate_random_prompt(RANDOM_PROMPT_LENGTH)
  58. num_token_input_prompt = len(tokenizer.encode(PROMPT))
  59. print(f"Number of token for input prompt: {num_token_input_prompt}")
  60. # Azure content safety analysis
  61. def analyze_prompt(input):
  62. start_time = time.time()
  63. # Obtain credentials
  64. key = "" #Add your AZURE_CONTENT_SAFETY_KEY
  65. endpoint = "" #Add your AZURE_CONTENT_SAFETY_ENDPOINT
  66. # Create a content safety client
  67. client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
  68. # Create request
  69. request = AnalyzeTextOptions(text=input)
  70. # Analyze prompt
  71. try:
  72. response = client.analyze_text(request)
  73. except HttpResponseError as e:
  74. print("prompt failed due to content safety filtering.")
  75. if e.error:
  76. print(f"Error code: {e.error.code}")
  77. print(f"Error message: {e.error.message}")
  78. raise
  79. print(e)
  80. raise
  81. analyze_end_time = time.time()
  82. # The round trip latency for using Azure content safety check
  83. analyze_latency = (analyze_end_time - start_time) * 1000
  84. # Simple round-robin to dispatch requests into different containers
  85. executor_id = 0
  86. lock = threading.Lock()
  87. def generate_text() -> Tuple[int, int]:
  88. headers = MODEL_HEADERS
  89. payload = {
  90. "model" : MODEL_PATH,
  91. "messages" : [
  92. {
  93. "role": "user",
  94. "content": PROMPT
  95. }
  96. ],
  97. "stream" : False,
  98. "temperature" : TEMPERATURE,
  99. "top_p" : TOP_P,
  100. "max_tokens" : MAX_NEW_TOKENS
  101. }
  102. start_time = time.time()
  103. if(SAFE_CHECK):
  104. # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
  105. # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
  106. analyze_prompt(PROMPT)
  107. # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
  108. lock.acquire()
  109. global executor_id
  110. if executor_id != len(MODEL_ENDPOINTS)-1:
  111. executor_id += 1
  112. endpoint_id = executor_id
  113. else:
  114. executor_id = 0
  115. endpoint_id = executor_id
  116. lock.release()
  117. response = requests.post(MODEL_ENDPOINTS[endpoint_id], headers=headers, json=payload)
  118. if(SAFE_CHECK):
  119. # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
  120. # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
  121. analyze_prompt(PROMPT)
  122. # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
  123. end_time = time.time()
  124. # Convert to ms
  125. latency = (end_time - start_time) * 1000
  126. if response.status_code != 200:
  127. raise ValueError(f"Error: {response.content}")
  128. output = json.loads(response.content)["choices"][0]["message"]["content"]
  129. token_count = len(tokenizer.encode(output))
  130. return latency, token_count
  131. def evaluate_performance(concurrent_requests: int) -> Tuple[float, float, float, float, float, float, float, List[float]]:
  132. latencies = []
  133. total_output_tokens = 0
  134. output_tokens_per_second_each_request = []
  135. start_time = time.time()
  136. # Init multi-thread execution
  137. with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
  138. future_to_req = {executor.submit(generate_text): i for i in range(concurrent_requests)}
  139. for future in as_completed(future_to_req):
  140. latency, token_count = future.result()
  141. latencies.append(latency)
  142. total_output_tokens += token_count
  143. # Calculate tokens per second for this request
  144. tokens_per_sec = token_count / (latency / 1000)
  145. output_tokens_per_second_each_request.append(tokens_per_sec)
  146. end_time = time.time()
  147. total_time = end_time - start_time
  148. # RPS (requests per second)
  149. rps = concurrent_requests / total_time
  150. # Overall tokens per second
  151. output_tokens_per_second_overall = total_output_tokens / total_time
  152. input_tokens_per_second_overall = (num_token_input_prompt * concurrent_requests) / total_time
  153. output_tokens_per_second_per_gpu = output_tokens_per_second_overall / NUM_GPU
  154. input_tokens_per_second_per_gpu = input_tokens_per_second_overall / NUM_GPU
  155. p50_latency = np.percentile(latencies, 50)
  156. p99_latency = np.percentile(latencies, 99)
  157. # Count the number of requests below the token-per-second threshold
  158. below_threshold_count = sum(1 for tps in output_tokens_per_second_each_request if tps < THRESHOLD_TPS)
  159. output_tokens_per_second_per_request = sum(output_tokens_per_second_each_request)/len(output_tokens_per_second_each_request)
  160. return p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count
  161. # Print markdown
  162. print("| Number of Concurrent Requests | P50 Latency (ms) | P99 Latency (ms) | RPS | Output Tokens per Second | Output Tokens per Second per GPU | Input Tokens per Second | Input Tokens per Second per GPU |Average Output Tokens per Second per Request | Number of Requests Below Threshold |")
  163. print("|-------------------------------|------------------|------------------|------------------|-------------------|---------------------------|---------------------|------------------------|-------------------------------------- | ---------------------------------- |")
  164. # Save to file
  165. csv_file = "performance_metrics.csv"
  166. with open(csv_file, "w", newline='') as f:
  167. writer = csv.writer(f)
  168. writer.writerow(["Number of Concurrent Requests", "P50 Latency (ms)", "P99 Latency (ms)", "RPS", "Output Tokens per Second", "Output Tokens per Second per GPU", "Input Tokens per Second", "Input Tokens per Second per GPU", "Average Output Tokens per Second per Request"])
  169. for level in CONCURRENT_LEVELS:
  170. p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count = evaluate_performance(level)
  171. print(f"| {level} | {p50_latency:.2f} | {p99_latency:.2f} | {rps:.2f} | {output_tokens_per_second_overall:.2f} | {output_tokens_per_second_per_gpu:.2f} | {input_tokens_per_second_overall:.2f} | {input_tokens_per_second_per_gpu:.2f} | {output_tokens_per_second_per_request:.2f} | {below_threshold_count:.2f} |")
  172. writer.writerow([level, round(p50_latency, 2), round(p99_latency, 2), round(rps, 2), round(output_tokens_per_second_overall, 2), round(output_tokens_per_second_per_gpu, 2), round(input_tokens_per_second_overall, 2), round(input_tokens_per_second_per_gpu, 2), round(output_tokens_per_second_per_request, 2)])