dataset_utils.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from functools import partial
  4. import torch
  5. from llama_recipes.datasets import (
  6. get_grammar_dataset,
  7. get_alpaca_dataset,
  8. get_samsum_dataset,
  9. )
  10. DATASET_PREPROC = {
  11. "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
  12. "grammar_dataset": get_grammar_dataset,
  13. "samsum_dataset": get_samsum_dataset,
  14. }
  15. def get_preprocessed_dataset(
  16. tokenizer, dataset_config, split: str = "train"
  17. ) -> torch.utils.data.Dataset:
  18. if not dataset_config.dataset in DATASET_PREPROC:
  19. raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
  20. def get_split():
  21. return (
  22. dataset_config.train_split
  23. if split == "train"
  24. else dataset_config.test_split
  25. )
  26. return DATASET_PREPROC[dataset_config.dataset](
  27. dataset_config,
  28. tokenizer,
  29. get_split(),
  30. )