alpaca_dataset.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
  4. import copy
  5. import json
  6. import torch
  7. from torch.utils.data import Dataset
  8. PROMPT_DICT = {
  9. "prompt_input": (
  10. "Below is an instruction that describes a task, paired with an input that provides further context. "
  11. "Write a response that appropriately completes the request.\n\n"
  12. "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
  13. ),
  14. "prompt_no_input": (
  15. "Below is an instruction that describes a task. "
  16. "Write a response that appropriately completes the request.\n\n"
  17. "### Instruction:\n{instruction}\n\n### Response:"
  18. ),
  19. }
  20. class InstructionDataset(Dataset):
  21. def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
  22. self.ann = json.load(open(dataset_config.data_path))
  23. if partition == "train":
  24. self.ann = self.ann
  25. else:
  26. self.ann = self.ann[:200]
  27. self.max_words = max_words
  28. # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
  29. self.tokenizer = tokenizer
  30. # self.tokenizer1 = tokenizer
  31. def __len__(self):
  32. return len(self.ann)
  33. def __getitem__(self, index):
  34. IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
  35. ann = self.ann[index]
  36. if ann.get("input", "") == "":
  37. prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
  38. else:
  39. prompt = PROMPT_DICT["prompt_input"].format_map(ann)
  40. example = prompt + ann["output"]
  41. prompt = torch.tensor(
  42. self.tokenizer.encode(prompt), dtype=torch.int64
  43. )
  44. example = self.tokenizer.encode(example)
  45. example.append(self.tokenizer.eos_token_id)
  46. example = torch.tensor(
  47. example, dtype=torch.int64
  48. )
  49. padding = self.max_words - example.shape[0]
  50. if padding > 0:
  51. example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
  52. elif padding < 0:
  53. example = example[: self.max_words]
  54. labels = copy.deepcopy(example)
  55. labels[: len(prompt)] = -1
  56. example_mask = example.ge(0)
  57. label_mask = labels.ge(0)
  58. example[~example_mask] = 0
  59. labels[~label_mask] = IGNORE_INDEX
  60. example_mask = example_mask.float()
  61. label_mask = label_mask.float()
  62. return {
  63. "input_ids": example,
  64. "labels": labels,
  65. "attention_mask":example_mask,
  66. }