dataset_utils.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import torch
  4. from functools import partial
  5. from ft_datasets import (
  6. get_grammar_dataset,
  7. get_alpaca_dataset,
  8. get_samsum_dataset,
  9. )
  10. from typing import Optional
  11. DATASET_PREPROC = {
  12. "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
  13. "grammar_dataset": get_grammar_dataset,
  14. "samsum_dataset": get_samsum_dataset,
  15. }
  16. def get_preprocessed_dataset(
  17. tokenizer, dataset_config, split: str = "train"
  18. ) -> torch.utils.data.Dataset:
  19. if not dataset_config.dataset in DATASET_PREPROC:
  20. raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
  21. def get_split():
  22. return (
  23. dataset_config.train_split
  24. if split == "train"
  25. else dataset_config.test_split
  26. )
  27. return DATASET_PREPROC[dataset_config.dataset](
  28. dataset_config,
  29. tokenizer,
  30. get_split(),
  31. )