alpaca_dataset.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
  4. import copy
  5. import json
  6. import os
  7. import torch
  8. from sentencepiece import SentencePieceProcessor
  9. from torch.utils.data import Dataset
  10. from typing import List
  11. PROMPT_DICT = {
  12. "prompt_input": (
  13. "Below is an instruction that describes a task, paired with an input that provides further context. "
  14. "Write a response that appropriately completes the request.\n\n"
  15. "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
  16. ),
  17. "prompt_no_input": (
  18. "Below is an instruction that describes a task. "
  19. "Write a response that appropriately completes the request.\n\n"
  20. "### Instruction:\n{instruction}\n\n### Response:"
  21. ),
  22. }
  23. class InstructionDataset(Dataset):
  24. def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
  25. self.ann = json.load(open(dataset_config.data_path))
  26. if partition == "train":
  27. self.ann = self.ann
  28. else:
  29. self.ann = self.ann[:200]
  30. self.max_words = max_words
  31. # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
  32. self.tokenizer = tokenizer
  33. # self.tokenizer1 = tokenizer
  34. def __len__(self):
  35. return len(self.ann)
  36. def __getitem__(self, index):
  37. ann = self.ann[index]
  38. if ann.get("input", "") == "":
  39. prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
  40. else:
  41. prompt = PROMPT_DICT["prompt_input"].format_map(ann)
  42. example = prompt + ann["output"]
  43. prompt = torch.tensor(
  44. self.tokenizer.encode(prompt), dtype=torch.int64
  45. )
  46. example = self.tokenizer.encode(example)
  47. example.append(self.tokenizer.eos_token_id)
  48. example = torch.tensor(
  49. example, dtype=torch.int64
  50. )
  51. padding = self.max_words - example.shape[0]
  52. if padding > 0:
  53. example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
  54. elif padding < 0:
  55. example = example[: self.max_words]
  56. labels = copy.deepcopy(example)
  57. labels[: len(prompt)] = -1
  58. example_mask = example.ge(0)
  59. label_mask = labels.ge(0)
  60. example[~example_mask] = 0
  61. labels[~label_mask] = 0
  62. example_mask = example_mask.float()
  63. label_mask = label_mask.float()
  64. return {
  65. "input_ids": example,
  66. "labels": labels,
  67. "attention_mask":example_mask,
  68. }