test_sampler.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import random
  4. import pytest
  5. import torch
  6. from llama_recipes.data.sampler import LengthBasedBatchSampler
  7. from llama_recipes.data.sampler import DistributedLengthBasedBatchSampler
  8. SAMPLES = 33
  9. @pytest.fixture
  10. def dataset():
  11. random.seed(42)
  12. dataset = []
  13. def add_samples(ds, n, a, b):
  14. for _ in range(n):
  15. ds.append(random.randint(a,b) * [1,])
  16. add_samples(dataset, SAMPLES // 2, 1,9)
  17. add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
  18. return random.sample(dataset, len(dataset))
  19. @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
  20. def test_batch_sampler_array(dataset, batch_size, drop_last):
  21. sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
  22. EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
  23. all_ids = [i for b in sampler for i in b]
  24. assert len(set(all_ids)) == EXPECTED_LENGTH * batch_size if drop_last else len(dataset)
  25. assert len(sampler) == EXPECTED_LENGTH
  26. is_long = [len(d)>=10 for d in dataset]
  27. def check_batch(batch):
  28. return all(batch) or not any(batch)
  29. assert all(check_batch(is_long[i] for i in b) for b in sampler)
  30. @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
  31. def test_batch_sampler_dict(dataset, batch_size, drop_last):
  32. dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
  33. sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
  34. EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
  35. assert len(sampler) == EXPECTED_LENGTH
  36. is_long = [len(d)>=10 for d in dataset]
  37. def check_batch(batch):
  38. return all(batch) or not any(batch)
  39. assert all(check_batch(is_long[i] for i in b) for b in sampler)
  40. @pytest.mark.parametrize("batch_size", [2, 8])
  41. def test_dist_batch_sampling(dataset, batch_size):
  42. sampler_1 = DistributedLengthBasedBatchSampler(
  43. dataset,
  44. batch_size=batch_size,
  45. rank=0,
  46. num_replicas=2,
  47. shuffle=False,
  48. )
  49. sampler_2 = DistributedLengthBasedBatchSampler(
  50. dataset,
  51. batch_size=batch_size,
  52. rank=1,
  53. num_replicas=2,
  54. shuffle=False,
  55. )
  56. ids_1 = set(i for b in sampler_1 for i in b)
  57. ids_2 = set(i for b in sampler_2 for i in b)
  58. assert ids_1.isdisjoint(ids_2)
  59. assert len(ids_1)+len(ids_2) > 0
  60. assert len(ids_1)+len(ids_2) == len(dataset) // batch_size * batch_size