ソースを参照

Remove max_word from alpaca; lets deal tokenizer deal with truncation

Matthias Reso 1 年間 前
コミット
d3015b4c80
1 ファイル変更1 行追加6 行削除
  1. 1 6
      src/llama_recipes/datasets/alpaca_dataset.py

+ 1 - 6
src/llama_recipes/datasets/alpaca_dataset.py

@@ -24,17 +24,14 @@ PROMPT_DICT = {
 }
 
 class InstructionDataset(Dataset):
-    def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
+    def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         if partition == "train":
             self.ann = self.ann
         else:
             self.ann = self.ann[:200]
 
-        self.max_words = max_words
-        # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
         self.tokenizer = tokenizer
-        # self.tokenizer1 = tokenizer
 
     def __len__(self):
         return len(self.ann)
@@ -57,8 +54,6 @@ class InstructionDataset(Dataset):
         example = torch.tensor(
             example, dtype=torch.int64
         )
-        if example.shape[0] > self.max_words:
-            example = example[: self.max_words]
         labels = copy.deepcopy(example)
         labels[: len(prompt)] = -1
         example_mask = example.ge(0)