Sfoglia il codice sorgente

Fix batching error

Matthias Reso 1 anno fa
parent
commit
ec00a2f722
2 ha cambiato i file con 5 aggiunte e 5 eliminazioni
  1. 1 1
      examples/custom_dataset.py
  2. 4 4
      tests/datasets/test_custom_dataset.py

+ 1 - 1
examples/custom_dataset.py

@@ -27,7 +27,7 @@ def tokenize_dialog(dialog, tokenizer):
     
     combined_tokens = {}  
     for k in dialog_tokens[0].keys():
-        combined_tokens[k] = [list(itertools.chain(*(t[k] for t in dialog_tokens)))]
+        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
     return combined_tokens
 
 

+ 4 - 4
tests/datasets/test_custom_dataset.py

@@ -17,7 +17,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
-        "batch_size_training": 1,
+        "batch_size_training": 2,
         "use_peft": False,
         }
     
@@ -30,9 +30,9 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     eval_dataloader = args[2]
     tokenizer = args[3]
     
-    assert len(train_dataloader) == 2241
-    assert len(eval_dataloader) == 2241
-    
+    assert len(train_dataloader) == 226
+    assert len(eval_dataloader) == 2*226
+
     STRING = tokenizer.decode(next(iter(train_dataloader))["input_ids"][0], skip_special_tokens=True)
     EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "