samsum_dataset.py 1.0 KB

1234567891011121314151617181920212223242526272829303132
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/samsum
  4. import datasets
  5. from .utils import Concatenator
  6. def get_preprocessed_samsum(dataset_config, tokenizer, split):
  7. dataset = datasets.load_dataset("samsum", split=split)
  8. prompt = (
  9. f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
  10. )
  11. def apply_prompt_template(sample):
  12. return {
  13. "text": prompt.format(
  14. dialog=sample["dialogue"],
  15. summary=sample["summary"],
  16. eos_token=tokenizer.eos_token,
  17. )
  18. }
  19. dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
  20. dataset = dataset.map(
  21. lambda sample: tokenizer(sample["text"]),
  22. batched=True,
  23. remove_columns=list(dataset.features),
  24. ).map(Concatenator(), batched=True)
  25. return dataset