test_length_based_batch_sampler.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import random
  2. import pytest
  3. import torch
  4. from llama_recipes.data.sampler import LengthBasedBatchSampler
  5. SAMPLES = 33
  6. @pytest.fixture
  7. def dataset():
  8. dataset = []
  9. def add_samples(ds, n, a, b):
  10. for _ in range(n):
  11. ds.append(random.randint(a,b) * [1,])
  12. add_samples(dataset, SAMPLES // 2, 1,9)
  13. add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
  14. return random.sample(dataset, len(dataset))
  15. @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
  16. def test_batch_sampler_array(dataset, batch_size, drop_last):
  17. sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
  18. EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
  19. assert len(sampler) == EXPECTED_LENGTH
  20. is_long = [len(d)>=10 for d in dataset]
  21. def check_batch(batch):
  22. return all(batch) or not any(batch)
  23. assert all(check_batch(is_long[i] for i in b) for b in sampler)
  24. @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
  25. def test_batch_sampler_dict(dataset, batch_size, drop_last):
  26. dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
  27. sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
  28. EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
  29. assert len(sampler) == EXPECTED_LENGTH
  30. is_long = [len(d)>=10 for d in dataset]
  31. def check_batch(batch):
  32. return all(batch) or not any(batch)
  33. assert all(check_batch(is_long[i] for i in b) for b in sampler)