custom_dataset.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/samsum
  4. import copy
  5. import datasets
  6. from llama_recipes.datasets.utils import Concatenator
  7. def get_custom_dataset(dataset_config, tokenizer, split):
  8. dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
  9. dataset = dataset.map(lambda sample: {
  10. "message_id": sample["message_id"],
  11. "parent_id": sample["parent_id"],
  12. "text": sample["text"],
  13. },
  14. batched=True,
  15. remove_columns=list(dataset.features),)
  16. # print(ids[0])
  17. p2c = {}
  18. ids2text = {}
  19. root_ids = []
  20. for data in dataset:
  21. if data["parent_id"]:
  22. p2c[data["parent_id"]] = p2c.get(data["parent_id"], []) + [data["message_id"]]
  23. else:
  24. root_ids.append(data["message_id"])
  25. ids2text[data["message_id"]]=data["text"]
  26. def follow(thread, current_id):
  27. thread = copy.copy(thread) + [ids2text[current_id]]
  28. if current_id in p2c:
  29. new_threads = []
  30. for next_id in p2c[current_id]:
  31. new_threads += follow(thread, next_id)
  32. return new_threads
  33. else:
  34. return [thread]
  35. def get_threads_from_root(root_id):
  36. all_threads = []
  37. thread = [ids2text[root_id]]
  38. for cid in p2c[root_id]:
  39. all_threads += follow(thread, cid)
  40. return all_threads
  41. # all_threads = []
  42. # for rid in root_ids:
  43. dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
  44. dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
  45. dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
  46. print(len(dataset))
  47. from pprint import pprint
  48. pprint(dataset[:10])
  49. return dataset
  50. # threads={}
  51. # prompt = (
  52. # f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
  53. # )
  54. # def apply_prompt_template(sample):
  55. # return {
  56. # "text": prompt.format(
  57. # dialog=sample["dialogue"],
  58. # summary=sample["summary"],
  59. # eos_token=tokenizer.eos_token,
  60. # )
  61. # }
  62. # dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
  63. # dataset = dataset.map(
  64. # lambda sample: tokenizer(sample["text"]),
  65. # batched=True,
  66. # remove_columns=list(dataset.features),
  67. # ).map(Concatenator(), batched=True)
  68. # return dataset