|
@@ -24,17 +24,14 @@ PROMPT_DICT = {
|
|
|
}
|
|
|
|
|
|
class InstructionDataset(Dataset):
|
|
|
- def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
|
|
|
+ def __init__(self, dataset_config, tokenizer, partition="train"):
|
|
|
self.ann = json.load(open(dataset_config.data_path))
|
|
|
if partition == "train":
|
|
|
self.ann = self.ann
|
|
|
else:
|
|
|
self.ann = self.ann[:200]
|
|
|
|
|
|
- self.max_words = max_words
|
|
|
- # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
|
|
|
self.tokenizer = tokenizer
|
|
|
- # self.tokenizer1 = tokenizer
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.ann)
|
|
@@ -57,8 +54,6 @@ class InstructionDataset(Dataset):
|
|
|
example = torch.tensor(
|
|
|
example, dtype=torch.int64
|
|
|
)
|
|
|
- if example.shape[0] > self.max_words:
|
|
|
- example = example[: self.max_words]
|
|
|
labels = copy.deepcopy(example)
|
|
|
labels[: len(prompt)] = -1
|
|
|
example_mask = example.ge(0)
|