custom_dataset.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/samsum
  4. import copy
  5. import datasets
  6. import itertools
  7. from llama_recipes.datasets.utils import Concatenator
  8. B_INST, E_INST = "[INST]", "[/INST]"
  9. B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
  10. def tokenize_dialog(dialog, tokenizer):
  11. dialog_tokens = [
  12. tokenizer(
  13. f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
  14. )
  15. for prompt, answer in zip(dialog[::2], dialog[1::2])
  16. ]
  17. if len(dialog) % 2:
  18. dialog_tokens += [tokenizer(
  19. f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
  20. )]
  21. combined_tokens = {}
  22. for k in dialog_tokens[0].keys():
  23. combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
  24. return combined_tokens
  25. def get_custom_dataset(dataset_config, tokenizer, split):
  26. dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
  27. dataset = dataset.map(lambda sample: {
  28. "message_id": sample["message_id"],
  29. "parent_id": sample["parent_id"],
  30. "text": sample["text"],
  31. },
  32. batched=True,
  33. remove_columns=list(dataset.features),)
  34. nodes = {}
  35. messages = {}
  36. root_ids = []
  37. for data in dataset:
  38. if data["parent_id"]:
  39. nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
  40. else:
  41. root_ids.append(data["message_id"])
  42. messages[data["message_id"]]=data["text"]
  43. def follow(thread, current_id):
  44. thread = copy.copy(thread) + [messages[current_id]]
  45. if current_id in nodes:
  46. new_threads = []
  47. for next_id in nodes[current_id]:
  48. new_threads += follow(thread, next_id)
  49. return new_threads
  50. else:
  51. return [thread]
  52. def get_threads_from_root(root_id):
  53. all_threads = []
  54. thread = [messages[root_id]]
  55. for cid in nodes[root_id]:
  56. all_threads += follow(thread, cid)
  57. return all_threads
  58. dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
  59. dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
  60. dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
  61. def to_dialog(thread):
  62. dialog = []
  63. for i, content in enumerate(thread):
  64. dialog.append({
  65. "role": "user" if i % 2 == 0 else "assistant",
  66. "content": content,
  67. })
  68. return {"dialog": dialog}
  69. dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
  70. dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
  71. dataset = dataset.map(Concatenator(), batched=True)
  72. return dataset