Browse Source

Add gradio library for user interface in inference.py

Joone Hur 1 year ago
parent
commit
83efccb22b
1 changed files with 45 additions and 12 deletions
  1. 45 12
      examples/inference.py

+ 45 - 12
examples/inference.py

@@ -7,6 +7,7 @@ import fire
 import os
 import sys
 import time
+import gradio as gr
 
 import torch
 from transformers import LlamaTokenizer
@@ -39,18 +40,8 @@ def main(
     use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
-    if prompt_file is not None:
-        assert os.path.exists(
-            prompt_file
-        ), f"Provided Prompt file does not exist {prompt_file}"
-        with open(prompt_file, "r") as f:
-            user_prompt = "\n".join(f.readlines())
-    elif not sys.stdin.isatty():
-        user_prompt = "\n".join(sys.stdin.readlines())
-    else:
-        print("No user prompt provided. Exiting.")
-        sys.exit(1)
 
+  def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,):
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
@@ -126,7 +117,49 @@ def main(
             if not is_safe:
                 print(method)
                 print(report)
-                
+    return output_text
+
+  if prompt_file is not None:
+      assert os.path.exists(
+          prompt_file
+      ), f"Provided Prompt file does not exist {prompt_file}"
+      with open(prompt_file, "r") as f:
+          user_prompt = "\n".join(f.readlines())
+      inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
+  elif not sys.stdin.isatty():
+      user_prompt = "\n".join(sys.stdin.readlines())
+      inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
+  else:
+      gr.Interface(
+        fn=inference,
+        inputs=[
+            gr.components.Textbox(
+                lines=9,
+                label="User Prompt",
+                placeholder="none",
+            ),
+            gr.components.Slider(
+                minimum=0, maximum=1, value=1.0, label="Temperature"
+            ),
+            gr.components.Slider(
+                minimum=0, maximum=1, value=1.0, label="Top p"
+            ),
+            gr.components.Slider(
+                minimum=0, maximum=100, step=1, value=50, label="Top k"
+            ),
+            gr.components.Slider(
+                minimum=1, maximum=2000, step=1, value=200, label="Max tokens"
+            ),
+        ],
+        outputs=[
+            gr.components.Textbox(
+                lines=5,
+                label="Output",
+            )
+        ],
+        title="Llama2 Playground",
+        description="https://github.com/facebookresearch/llama-recipes",
+      ).queue().launch(server_name="0.0.0.0", share=True)
 
 if __name__ == "__main__":
     fire.Fire(main)