|
@@ -42,6 +42,9 @@ class InstructionDataset(Dataset):
|
|
|
return len(self.ann)
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
+ IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
|
|
|
+
|
|
|
+
|
|
|
ann = self.ann[index]
|
|
|
if ann.get("input", "") == "":
|
|
|
prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
|
|
@@ -66,7 +69,7 @@ class InstructionDataset(Dataset):
|
|
|
example_mask = example.ge(0)
|
|
|
label_mask = labels.ge(0)
|
|
|
example[~example_mask] = 0
|
|
|
- labels[~label_mask] = 0
|
|
|
+ labels[~label_mask] = IGNORE_INDEX
|
|
|
example_mask = example_mask.float()
|
|
|
label_mask = label_mask.float()
|
|
|
|