|
@@ -31,7 +31,7 @@ def main(
|
|
|
temperature: float=1.0, # [optional] The value used to modulate the next token probabilities.
|
|
|
top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
|
repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
|
|
|
- length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
|
|
|
+ length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
|
|
|
enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
|
|
|
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
|
|
|
enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
|
|
@@ -98,12 +98,12 @@ def main(
|
|
|
top_k=top_k,
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
length_penalty=length_penalty,
|
|
|
- **kwargs
|
|
|
+ **kwargs
|
|
|
)
|
|
|
e2e_inference_time = (time.perf_counter()-start)*1000
|
|
|
print(f"the inference time is {e2e_inference_time} ms")
|
|
|
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
-
|
|
|
+
|
|
|
# Safety check of the model output
|
|
|
safety_results = [check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker]
|
|
|
are_safe = all([r[1] for r in safety_results])
|
|
@@ -156,7 +156,7 @@ def main(
|
|
|
label="Output",
|
|
|
)
|
|
|
],
|
|
|
- title="Llama2 Playground",
|
|
|
+ title="Meta Llama3 Playground",
|
|
|
description="https://github.com/facebookresearch/llama-recipes",
|
|
|
).queue().launch(server_name="0.0.0.0", share=True)
|
|
|
|