|
@@ -7,6 +7,7 @@ import fire
|
|
|
import os
|
|
|
import sys
|
|
|
import time
|
|
|
+import gradio as gr
|
|
|
|
|
|
import torch
|
|
|
from transformers import LlamaTokenizer
|
|
@@ -39,18 +40,8 @@ def main(
|
|
|
use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
|
|
|
**kwargs
|
|
|
):
|
|
|
- if prompt_file is not None:
|
|
|
- assert os.path.exists(
|
|
|
- prompt_file
|
|
|
- ), f"Provided Prompt file does not exist {prompt_file}"
|
|
|
- with open(prompt_file, "r") as f:
|
|
|
- user_prompt = "\n".join(f.readlines())
|
|
|
- elif not sys.stdin.isatty():
|
|
|
- user_prompt = "\n".join(sys.stdin.readlines())
|
|
|
- else:
|
|
|
- print("No user prompt provided. Exiting.")
|
|
|
- sys.exit(1)
|
|
|
|
|
|
+ def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,):
|
|
|
safety_checker = get_safety_checker(enable_azure_content_safety,
|
|
|
enable_sensitive_topics,
|
|
|
enable_salesforce_content_safety,
|
|
@@ -126,7 +117,49 @@ def main(
|
|
|
if not is_safe:
|
|
|
print(method)
|
|
|
print(report)
|
|
|
-
|
|
|
+ return output_text
|
|
|
+
|
|
|
+ if prompt_file is not None:
|
|
|
+ assert os.path.exists(
|
|
|
+ prompt_file
|
|
|
+ ), f"Provided Prompt file does not exist {prompt_file}"
|
|
|
+ with open(prompt_file, "r") as f:
|
|
|
+ user_prompt = "\n".join(f.readlines())
|
|
|
+ inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
|
|
|
+ elif not sys.stdin.isatty():
|
|
|
+ user_prompt = "\n".join(sys.stdin.readlines())
|
|
|
+ inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
|
|
|
+ else:
|
|
|
+ gr.Interface(
|
|
|
+ fn=inference,
|
|
|
+ inputs=[
|
|
|
+ gr.components.Textbox(
|
|
|
+ lines=9,
|
|
|
+ label="User Prompt",
|
|
|
+ placeholder="none",
|
|
|
+ ),
|
|
|
+ gr.components.Slider(
|
|
|
+ minimum=0, maximum=1, value=1.0, label="Temperature"
|
|
|
+ ),
|
|
|
+ gr.components.Slider(
|
|
|
+ minimum=0, maximum=1, value=1.0, label="Top p"
|
|
|
+ ),
|
|
|
+ gr.components.Slider(
|
|
|
+ minimum=0, maximum=100, step=1, value=50, label="Top k"
|
|
|
+ ),
|
|
|
+ gr.components.Slider(
|
|
|
+ minimum=1, maximum=2000, step=1, value=200, label="Max tokens"
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ outputs=[
|
|
|
+ gr.components.Textbox(
|
|
|
+ lines=5,
|
|
|
+ label="Output",
|
|
|
+ )
|
|
|
+ ],
|
|
|
+ title="Llama2 Playground",
|
|
|
+ description="https://github.com/facebookresearch/llama-recipes",
|
|
|
+ ).queue().launch(server_name="0.0.0.0", share=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(main)
|