Browse Source

Remove padding in alpaca ds; remove concat in grammar

Matthias Reso 1 year ago
parent
commit
be63d9ec39

+ 0 - 2
src/llama_recipes/configs/datasets.py

@@ -9,7 +9,6 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     test_split: str = "validation"
-    input_length: int = 2048
     
     
 @dataclass
@@ -17,7 +16,6 @@ class grammar_dataset:
     dataset: str = "grammar_dataset"
     train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
-    input_length: int = 2048
 
     
 @dataclass

+ 1 - 4
src/llama_recipes/datasets/alpaca_dataset.py

@@ -57,10 +57,7 @@ class InstructionDataset(Dataset):
         example = torch.tensor(
             example, dtype=torch.int64
         )
-        padding = self.max_words - example.shape[0]
-        if padding > 0:
-            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
-        elif padding < 0:
+        if example.shape[0] > self.max_words:
             example = example[: self.max_words]
         labels = copy.deepcopy(example)
         labels[: len(prompt)] = -1

+ 1 - 1
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -81,5 +81,5 @@ def get_dataset(
         csv_name=csv_name,
     )
     
-    return ConcatDataset(dataset, chunk_size=dataset_config.input_length)
+    return dataset