grammar_dataset.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/jfleg
  4. # For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
  5. import argparse
  6. import csv
  7. import glob
  8. import os
  9. import json
  10. import time
  11. import logging
  12. import random
  13. import re
  14. from itertools import chain
  15. from string import punctuation
  16. import pandas as pd
  17. import numpy as np
  18. import torch
  19. from torch.utils.data import Dataset
  20. from datasets import load_dataset
  21. from pathlib import Path
  22. from ft_datasets.utils import ConcatDataset
  23. class grammar(Dataset):
  24. def __init__(
  25. self,
  26. tokenizer,
  27. csv_name=None,
  28. ):
  29. try:
  30. self.dataset = load_dataset(
  31. "csv",
  32. data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"},
  33. delimiter=",",
  34. )
  35. except Exception as e:
  36. print("Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset.")
  37. raise e
  38. # self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path)
  39. # if num_samples:
  40. # self.dataset = self.dataset.select(list(range(0, num_samples)))
  41. self.tokenizer = tokenizer
  42. self.print_text = False # print_text
  43. def __len__(self):
  44. return self.dataset["train"].shape[0]
  45. def convert_to_features(self, example_batch):
  46. # Create prompt and tokenize contexts and questions
  47. if self.print_text:
  48. print("Input Text: ", self.clean_text(example_batch["text"]))
  49. input_ = example_batch["input"]
  50. target_ = example_batch["target"]
  51. prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
  52. sample = self.tokenizer(prompt)
  53. return sample
  54. def __getitem__(self, index):
  55. sample = self.convert_to_features(self.dataset["train"][index])
  56. source_ids = sample["input_ids"]
  57. src_mask = sample["attention_mask"]
  58. return {
  59. "input_ids": source_ids,
  60. "attention_mask": src_mask,
  61. "labels": source_ids.copy(),
  62. }
  63. def get_dataset(
  64. dataset_config, tokenizer, csv_name=None
  65. ):
  66. """cover function for handling loading the working dataset"""
  67. """dataset loading"""
  68. if csv_name is None:
  69. currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
  70. print(f"Loading dataset {currPath}")
  71. csv_name = str(currPath)
  72. dataset = grammar(
  73. tokenizer=tokenizer,
  74. csv_name=csv_name,
  75. )
  76. return ConcatDataset(dataset, chunk_size=dataset_config.input_length)