1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- import copy
- import json
- import os
- import torch
- from sentencepiece import SentencePieceProcessor
- from torch.utils.data import Dataset
- from typing import List
- PROMPT_DICT = {
- "prompt_input": (
- "Below is an instruction that describes a task, paired with an input that provides further context. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
- ),
- "prompt_no_input": (
- "Below is an instruction that describes a task. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Response:"
- ),
- }
- class InstructionDataset(Dataset):
- def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
- self.ann = json.load(open(dataset_config.data_path))
- if partition == "train":
- self.ann = self.ann
- else:
- self.ann = self.ann[:200]
- self.max_words = max_words
-
- self.tokenizer = tokenizer
-
- def __len__(self):
- return len(self.ann)
- def __getitem__(self, index):
- ann = self.ann[index]
- if ann.get("input", "") == "":
- prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
- else:
- prompt = PROMPT_DICT["prompt_input"].format_map(ann)
- example = prompt + ann["output"]
- prompt = torch.tensor(
- self.tokenizer.encode(prompt), dtype=torch.int64
- )
- example = self.tokenizer.encode(example)
- example.append(self.tokenizer.eos_token_id)
- example = torch.tensor(
- example, dtype=torch.int64
- )
- padding = self.max_words - example.shape[0]
- if padding > 0:
- example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
- elif padding < 0:
- example = example[: self.max_words]
- labels = copy.deepcopy(example)
- labels[: len(prompt)] = -1
- example_mask = example.ge(0)
- label_mask = labels.ge(0)
- example[~example_mask] = 0
- labels[~label_mask] = 0
- example_mask = example_mask.float()
- label_mask = label_mask.float()
- return {
- "input_ids": example,
- "labels": labels,
- "attention_mask":example_mask,
- }
|