hellaswag_utils.py 881 B

123456789101112131415161718192021222324252627
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import datasets
  4. import re
  5. def preprocess(text):
  6. text = text.strip()
  7. # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
  8. text = text.replace(" [title]", ". ")
  9. text = re.sub("\\[.*?\\]", "", text)
  10. text = text.replace(" ", " ")
  11. return text
  12. def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
  13. def _process_doc(doc):
  14. ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
  15. out_doc = {
  16. "query": preprocess(doc["activity_label"] + ": " + ctx),
  17. "choices": [preprocess(ending) for ending in doc["endings"]],
  18. "gold": int(doc["label"]),
  19. }
  20. return out_doc
  21. return dataset.map(_process_doc)