Browse Source

Add gradio library for user interface in inference.py (#367)

Hamid Shojanazeri 1 year ago
parent
commit
fa5edb8c47
2 changed files with 47 additions and 12 deletions
  1. 2 0
      benchmarks/inference_throughput/cloud-api/README.md
  2. 45 12
      examples/inference.py

+ 2 - 0
benchmarks/inference_throughput/cloud-api/README.md

@@ -7,7 +7,9 @@ Disclaimer - The purpose of the code is to provide a configurable setup to measu
 # Azure - Getting Started
 To get started, there are certain steps we need to take to deploy the models:
 
+<!-- markdown-link-check-disable -->
 * Register for a valid Azure account with subscription [here](https://azure.microsoft.com/en-us/free/search/?ef_id=_k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&OCID=AIDcmm5edswduu_SEM__k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&gad_source=1&gclid=CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE)
+<!-- markdown-link-check-enable -->
 * Take a quick look on what is the [Azure AI Studio](https://learn.microsoft.com/en-us/azure/ai-studio/what-is-ai-studio?tabs=home) and navigate to the website from the link in the article
 * Follow the demos in the article to create a project and [resource](https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/manage-resource-groups-portal) group, or you can also follow the guide [here](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=azure-studio)
 * Select Llama models from Model catalog

+ 45 - 12
examples/inference.py

@@ -7,6 +7,7 @@ import fire
 import os
 import sys
 import time
+import gradio as gr
 
 import torch
 from transformers import LlamaTokenizer
@@ -39,18 +40,8 @@ def main(
     use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
-    if prompt_file is not None:
-        assert os.path.exists(
-            prompt_file
-        ), f"Provided Prompt file does not exist {prompt_file}"
-        with open(prompt_file, "r") as f:
-            user_prompt = "\n".join(f.readlines())
-    elif not sys.stdin.isatty():
-        user_prompt = "\n".join(sys.stdin.readlines())
-    else:
-        print("No user prompt provided. Exiting.")
-        sys.exit(1)
 
+  def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,):
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
@@ -126,7 +117,49 @@ def main(
             if not is_safe:
                 print(method)
                 print(report)
-                
+    return output_text
+
+  if prompt_file is not None:
+      assert os.path.exists(
+          prompt_file
+      ), f"Provided Prompt file does not exist {prompt_file}"
+      with open(prompt_file, "r") as f:
+          user_prompt = "\n".join(f.readlines())
+      inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
+  elif not sys.stdin.isatty():
+      user_prompt = "\n".join(sys.stdin.readlines())
+      inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
+  else:
+      gr.Interface(
+        fn=inference,
+        inputs=[
+            gr.components.Textbox(
+                lines=9,
+                label="User Prompt",
+                placeholder="none",
+            ),
+            gr.components.Slider(
+                minimum=0, maximum=1, value=1.0, label="Temperature"
+            ),
+            gr.components.Slider(
+                minimum=0, maximum=1, value=1.0, label="Top p"
+            ),
+            gr.components.Slider(
+                minimum=0, maximum=100, step=1, value=50, label="Top k"
+            ),
+            gr.components.Slider(
+                minimum=1, maximum=2000, step=1, value=200, label="Max tokens"
+            ),
+        ],
+        outputs=[
+            gr.components.Textbox(
+                lines=5,
+                label="Output",
+            )
+        ],
+        title="Llama2 Playground",
+        description="https://github.com/facebookresearch/llama-recipes",
+      ).queue().launch(server_name="0.0.0.0", share=True)
 
 if __name__ == "__main__":
     fire.Fire(main)