|
@@ -42,13 +42,13 @@ def main(
|
|
|
prompt_file
|
|
|
), f"Provided Prompt file does not exist {prompt_file}"
|
|
|
with open(prompt_file, "r") as f:
|
|
|
- user_prompt = "\n".join(f.readlines())
|
|
|
+ user_prompt = f.read()
|
|
|
elif not sys.stdin.isatty():
|
|
|
user_prompt = "\n".join(sys.stdin.readlines())
|
|
|
else:
|
|
|
print("No user prompt provided. Exiting.")
|
|
|
sys.exit(1)
|
|
|
-
|
|
|
+
|
|
|
# Set the seeds for reproducibility
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
@@ -121,7 +121,7 @@ def main(
|
|
|
)
|
|
|
e2e_inference_time = (time.perf_counter()-start)*1000
|
|
|
print(f"the inference time is {e2e_inference_time} ms")
|
|
|
- filling = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
+ filling = tokenizer.batch_decode(outputs[:, batch["input_ids"].shape[1]:], skip_special_tokens=True)[0]
|
|
|
# Safety check of the model output
|
|
|
safety_results = [check(filling) for check in safety_checker]
|
|
|
are_safe = all([r[1] for r in safety_results])
|