1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- # For dataset details visit: https://huggingface.co/datasets/jfleg
- # For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
- from datasets import load_dataset
- from pathlib import Path
- from torch.utils.data import Dataset
- class grammar(Dataset):
- def __init__(
- self,
- tokenizer,
- csv_name=None,
- ):
- try:
- self.dataset = load_dataset(
- "csv",
- data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"},
- delimiter=",",
- )
- except Exception as e:
- print("Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset.")
- raise e
- # self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path)
- # if num_samples:
- # self.dataset = self.dataset.select(list(range(0, num_samples)))
- self.tokenizer = tokenizer
- self.print_text = False # print_text
- def __len__(self):
- return self.dataset["train"].shape[0]
- def convert_to_features(self, example_batch):
- # Create prompt and tokenize contexts and questions
- if self.print_text:
- print("Input Text: ", self.clean_text(example_batch["text"]))
- input_ = example_batch["input"]
- target_ = example_batch["target"]
- prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
- prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
- label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
- sample = {
- "input_ids": prompt_ids + label_ids,
- "attention_mask": [1] * len(prompt_ids + label_ids),
- "labels": [-100] * len(prompt_ids) + label_ids
- }
- return sample
- def __getitem__(self, index):
- return self.convert_to_features(self.dataset["train"][int(index)])
- def get_dataset(
- dataset_config, tokenizer, csv_name=None
- ):
- """cover function for handling loading the working dataset"""
- """dataset loading"""
- if csv_name is None:
- currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
- print(f"Loading dataset {currPath}")
- csv_name = str(currPath)
- dataset = grammar(
- tokenizer=tokenizer,
- csv_name=csv_name,
- )
- return dataset
|