123456789101112131415161718192021222324 |
- import datasets
- import re
- def preprocess(text):
- text = text.strip()
- # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
- text = text.replace(" [title]", ". ")
- text = re.sub("\\[.*?\\]", "", text)
- text = text.replace(" ", " ")
- return text
- def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
- def _process_doc(doc):
- ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
- out_doc = {
- "query": preprocess(doc["activity_label"] + ": " + ctx),
- "choices": [preprocess(ending) for ending in doc["endings"]],
- "gold": int(doc["label"]),
- }
- return out_doc
- return dataset.map(_process_doc)
|