Explorar el Código

Fix bug in alapaca, set ignore_idx=-100.

luoyifan hace 1 año
padre
commit
405255c284
Se han modificado 1 ficheros con 4 adiciones y 1 borrados
  1. 4 1
      ft_datasets/alpaca_dataset.py

+ 4 - 1
ft_datasets/alpaca_dataset.py

@@ -42,6 +42,9 @@ class InstructionDataset(Dataset):
         return len(self.ann)
 
     def __getitem__(self, index):
+        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss
+
+
         ann = self.ann[index]
         if ann.get("input", "") == "":
             prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
@@ -66,7 +69,7 @@ class InstructionDataset(Dataset):
         example_mask = example.ge(0)
         label_mask = labels.ge(0)
         example[~example_mask] = 0
-        labels[~label_mask] = 0
+        labels[~label_mask] = IGNORE_INDEX
         example_mask = example_mask.float()
         label_mask = label_mask.float()