hellaswag_utils.py 716 B

123456789101112131415161718192021222324
  1. import datasets
  2. import re
  3. def preprocess(text):
  4. text = text.strip()
  5. # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
  6. text = text.replace(" [title]", ". ")
  7. text = re.sub("\\[.*?\\]", "", text)
  8. text = text.replace(" ", " ")
  9. return text
  10. def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
  11. def _process_doc(doc):
  12. ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
  13. out_doc = {
  14. "query": preprocess(doc["activity_label"] + ": " + ctx),
  15. "choices": [preprocess(ending) for ending in doc["endings"]],
  16. "gold": int(doc["label"]),
  17. }
  18. return out_doc
  19. return dataset.map(_process_doc)