Browse Source

fixing the output format

Hamid Shojanazeri 1 năm trước cách đây
mục cha
commit
2d9f4796e8

+ 3 - 3
inference/code-llama/code_infilling_example.py

@@ -42,13 +42,13 @@ def main(
             prompt_file
         ), f"Provided Prompt file does not exist {prompt_file}"
         with open(prompt_file, "r") as f:
-            user_prompt = "\n".join(f.readlines())
+            user_prompt = f.read()
     elif not sys.stdin.isatty():
         user_prompt = "\n".join(sys.stdin.readlines())
     else:
         print("No user prompt provided. Exiting.")
         sys.exit(1)
-    
+
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
@@ -121,7 +121,7 @@ def main(
         )
     e2e_inference_time = (time.perf_counter()-start)*1000
     print(f"the inference time is {e2e_inference_time} ms")
-    filling = tokenizer.decode(outputs[0], skip_special_tokens=True)    
+    filling = tokenizer.batch_decode(outputs[:, batch["input_ids"].shape[1]:], skip_special_tokens=True)[0]
     # Safety check of the model output
     safety_results = [check(filling) for check in safety_checker]
     are_safe = all([r[1] for r in safety_results])

+ 1 - 2
inference/code-llama/code_infilling_prompt.txt

@@ -1,4 +1,3 @@
-'''def remove_non_ascii(s: str) -> str:
+def remove_non_ascii(s: str) -> str:
     """ <FILL_ME>
     return result
-'''