alpaca_dataset.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
  4. import copy
  5. import json
  6. import torch
  7. from torch.utils.data import Dataset
  8. PROMPT_DICT = {
  9. "prompt_input": (
  10. "Below is an instruction that describes a task, paired with an input that provides further context. "
  11. "Write a response that appropriately completes the request.\n\n"
  12. "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
  13. ),
  14. "prompt_no_input": (
  15. "Below is an instruction that describes a task. "
  16. "Write a response that appropriately completes the request.\n\n"
  17. "### Instruction:\n{instruction}\n\n### Response:"
  18. ),
  19. }
  20. class InstructionDataset(Dataset):
  21. def __init__(self, dataset_config, tokenizer, partition="train"):
  22. self.ann = json.load(open(dataset_config.data_path))
  23. if partition == "train":
  24. self.ann = self.ann[200:]
  25. else:
  26. self.ann = self.ann[:200]
  27. self.tokenizer = tokenizer
  28. def __len__(self):
  29. return len(self.ann)
  30. def __getitem__(self, index):
  31. IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
  32. ann = self.ann[index]
  33. if ann.get("input", "") == "":
  34. prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
  35. else:
  36. prompt = PROMPT_DICT["prompt_input"].format_map(ann)
  37. example = prompt + ann["output"]
  38. prompt = torch.tensor(
  39. self.tokenizer.encode(prompt), dtype=torch.int64
  40. )
  41. example = self.tokenizer.encode(example)
  42. example.append(self.tokenizer.eos_token_id)
  43. example = torch.tensor(
  44. example, dtype=torch.int64
  45. )
  46. labels = copy.deepcopy(example)
  47. labels[: len(prompt)] = -1
  48. example_mask = example.ge(0)
  49. label_mask = labels.ge(0)
  50. example[~example_mask] = 0
  51. labels[~label_mask] = IGNORE_INDEX
  52. return {
  53. "input_ids": example.tolist(),
  54. "labels": labels.tolist(),
  55. "attention_mask":example_mask.tolist(),
  56. }