1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import random
- import pytest
- import torch
- from llama_recipes.data.sampler import LengthBasedBatchSampler
- SAMPLES = 33
- @pytest.fixture
- def dataset():
- dataset = []
- def add_samples(ds, n, a, b):
- for _ in range(n):
- ds.append(random.randint(a,b) * [1,])
- add_samples(dataset, SAMPLES // 2, 1,9)
- add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
-
- return random.sample(dataset, len(dataset))
-
-
- @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
- def test_batch_sampler_array(dataset, batch_size, drop_last):
-
- sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
-
- EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
-
- assert len(sampler) == EXPECTED_LENGTH
- is_long = [len(d)>=10 for d in dataset]
-
- def check_batch(batch):
- return all(batch) or not any(batch)
-
- assert all(check_batch(is_long[i] for i in b) for b in sampler)
-
-
- @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
- def test_batch_sampler_dict(dataset, batch_size, drop_last):
-
- dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
-
- sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
-
- EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
-
- assert len(sampler) == EXPECTED_LENGTH
- is_long = [len(d)>=10 for d in dataset]
-
- def check_batch(batch):
- return all(batch) or not any(batch)
-
- assert all(check_batch(is_long[i] for i in b) for b in sampler)
|