test_length_based_batch_sampler.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import random
  4. import pytest
  5. import torch
  6. from llama_recipes.data.sampler import LengthBasedBatchSampler
  7. SAMPLES = 33
  8. @pytest.fixture
  9. def dataset():
  10. random.seed(42)
  11. dataset = []
  12. def add_samples(ds, n, a, b):
  13. for _ in range(n):
  14. ds.append(random.randint(a,b) * [1,])
  15. add_samples(dataset, SAMPLES // 2, 1,9)
  16. add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
  17. return random.sample(dataset, len(dataset))
  18. @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
  19. def test_batch_sampler_array(dataset, batch_size, drop_last):
  20. sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
  21. EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
  22. assert len(sampler) == EXPECTED_LENGTH
  23. is_long = [len(d)>=10 for d in dataset]
  24. def check_batch(batch):
  25. return all(batch) or not any(batch)
  26. assert all(check_batch(is_long[i] for i in b) for b in sampler)
  27. @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
  28. def test_batch_sampler_dict(dataset, batch_size, drop_last):
  29. dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
  30. sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
  31. EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
  32. assert len(sampler) == EXPECTED_LENGTH
  33. is_long = [len(d)>=10 for d in dataset]
  34. def check_batch(batch):
  35. return all(batch) or not any(batch)
  36. assert all(check_batch(is_long[i] for i in b) for b in sampler)