|
@@ -31,7 +31,7 @@ def main(
|
|
|
temperature: float=1.0,
|
|
|
top_k: int=50,
|
|
|
repetition_penalty: float=1.0,
|
|
|
- length_penalty: int=1,
|
|
|
+ length_penalty: int=1,
|
|
|
enable_azure_content_safety: bool=False,
|
|
|
enable_sensitive_topics: bool=False,
|
|
|
enable_salesforce_content_safety: bool=True,
|
|
@@ -98,12 +98,12 @@ def main(
|
|
|
top_k=top_k,
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
length_penalty=length_penalty,
|
|
|
- **kwargs
|
|
|
+ **kwargs
|
|
|
)
|
|
|
e2e_inference_time = (time.perf_counter()-start)*1000
|
|
|
print(f"the inference time is {e2e_inference_time} ms")
|
|
|
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
-
|
|
|
+
|
|
|
|
|
|
safety_results = [check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker]
|
|
|
are_safe = all([r[1] for r in safety_results])
|
|
@@ -156,7 +156,7 @@ def main(
|
|
|
label="Output",
|
|
|
)
|
|
|
],
|
|
|
- title="Llama2 Playground",
|
|
|
+ title="Meta Llama3 Playground",
|
|
|
description="https://github.com/facebookresearch/llama-recipes",
|
|
|
).queue().launch(server_name="0.0.0.0", share=True)
|
|
|
|