grammar_dataset.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/jfleg
  4. # For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
  5. from datasets import load_dataset
  6. from pathlib import Path
  7. from torch.utils.data import Dataset
  8. class grammar(Dataset):
  9. def __init__(
  10. self,
  11. tokenizer,
  12. csv_name=None,
  13. ):
  14. try:
  15. self.dataset = load_dataset(
  16. "csv",
  17. data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"},
  18. delimiter=",",
  19. )
  20. except Exception as e:
  21. print("Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset.")
  22. raise e
  23. # self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path)
  24. # if num_samples:
  25. # self.dataset = self.dataset.select(list(range(0, num_samples)))
  26. self.tokenizer = tokenizer
  27. self.print_text = False # print_text
  28. def __len__(self):
  29. return self.dataset["train"].shape[0]
  30. def convert_to_features(self, example_batch):
  31. # Create prompt and tokenize contexts and questions
  32. if self.print_text:
  33. print("Input Text: ", self.clean_text(example_batch["text"]))
  34. input_ = example_batch["input"]
  35. target_ = example_batch["target"]
  36. prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
  37. prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
  38. label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
  39. sample = {
  40. "input_ids": prompt_ids + label_ids,
  41. "attention_mask": [1] * len(prompt_ids + label_ids),
  42. "labels": [-100] * len(prompt_ids) + label_ids
  43. }
  44. return sample
  45. def __getitem__(self, index):
  46. return self.convert_to_features(self.dataset["train"][int(index)])
  47. def get_dataset(
  48. dataset_config, tokenizer, csv_name=None
  49. ):
  50. """cover function for handling loading the working dataset"""
  51. """dataset loading"""
  52. if csv_name is None:
  53. currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
  54. print(f"Loading dataset {currPath}")
  55. csv_name = str(currPath)
  56. dataset = grammar(
  57. tokenizer=tokenizer,
  58. csv_name=csv_name,
  59. )
  60. return dataset