alpaca_dataset.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
  4. import copy
  5. import json
  6. import os
  7. import torch
  8. from sentencepiece import SentencePieceProcessor
  9. from torch.utils.data import Dataset
  10. from typing import List
  11. PROMPT_DICT = {
  12. "prompt_input": (
  13. "Below is an instruction that describes a task, paired with an input that provides further context. "
  14. "Write a response that appropriately completes the request.\n\n"
  15. "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
  16. ),
  17. "prompt_no_input": (
  18. "Below is an instruction that describes a task. "
  19. "Write a response that appropriately completes the request.\n\n"
  20. "### Instruction:\n{instruction}\n\n### Response:"
  21. ),
  22. }
  23. class InstructionDataset(Dataset):
  24. def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
  25. self.ann = json.load(open(dataset_config.data_path))
  26. if partition == "train":
  27. self.ann = self.ann
  28. else:
  29. self.ann = self.ann[:200]
  30. self.max_words = max_words
  31. # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
  32. self.tokenizer = tokenizer
  33. # self.tokenizer1 = tokenizer
  34. def __len__(self):
  35. return len(self.ann)
  36. def __getitem__(self, index):
  37. IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
  38. ann = self.ann[index]
  39. if ann.get("input", "") == "":
  40. prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
  41. else:
  42. prompt = PROMPT_DICT["prompt_input"].format_map(ann)
  43. example = prompt + ann["output"]
  44. prompt = torch.tensor(
  45. self.tokenizer.encode(prompt), dtype=torch.int64
  46. )
  47. example = self.tokenizer.encode(example)
  48. example.append(self.tokenizer.eos_token_id)
  49. example = torch.tensor(
  50. example, dtype=torch.int64
  51. )
  52. padding = self.max_words - example.shape[0]
  53. if padding > 0:
  54. example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
  55. elif padding < 0:
  56. example = example[: self.max_words]
  57. labels = copy.deepcopy(example)
  58. labels[: len(prompt)] = -1
  59. example_mask = example.ge(0)
  60. label_mask = labels.ge(0)
  61. example[~example_mask] = 0
  62. labels[~label_mask] = IGNORE_INDEX
  63. example_mask = example_mask.float()
  64. label_mask = label_mask.float()
  65. return {
  66. "input_ids": example,
  67. "labels": labels,
  68. "attention_mask":example_mask,
  69. }