|
@@ -5,9 +5,32 @@
|
|
|
|
|
|
import copy
|
|
|
import datasets
|
|
|
+import itertools
|
|
|
|
|
|
from llama_recipes.datasets.utils import Concatenator
|
|
|
|
|
|
+
|
|
|
+B_INST, E_INST = "[INST]", "[/INST]"
|
|
|
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
|
|
+
|
|
|
+def tokenize_dialog(dialog, tokenizer):
|
|
|
+ dialog_tokens = [
|
|
|
+ tokenizer(
|
|
|
+ f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
|
|
|
+ )
|
|
|
+ for prompt, answer in zip(dialog[::2], dialog[1::2])
|
|
|
+ ]
|
|
|
+ if len(dialog) % 2:
|
|
|
+ dialog_tokens += [tokenizer(
|
|
|
+ f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
|
|
|
+ )]
|
|
|
+
|
|
|
+ combined_tokens = {}
|
|
|
+ for k in dialog_tokens[0].keys():
|
|
|
+ combined_tokens[k] = [list(itertools.chain(*(t[k] for t in dialog_tokens)))]
|
|
|
+ return combined_tokens
|
|
|
+
|
|
|
+
|
|
|
def get_custom_dataset(dataset_config, tokenizer, split):
|
|
|
dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
|
|
|
|
|
@@ -18,73 +41,51 @@ def get_custom_dataset(dataset_config, tokenizer, split):
|
|
|
},
|
|
|
batched=True,
|
|
|
remove_columns=list(dataset.features),)
|
|
|
-
|
|
|
- # print(ids[0])
|
|
|
|
|
|
- p2c = {}
|
|
|
+ nodes = {}
|
|
|
|
|
|
- ids2text = {}
|
|
|
+ messages = {}
|
|
|
root_ids = []
|
|
|
|
|
|
for data in dataset:
|
|
|
if data["parent_id"]:
|
|
|
- p2c[data["parent_id"]] = p2c.get(data["parent_id"], []) + [data["message_id"]]
|
|
|
+ nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
|
|
|
else:
|
|
|
root_ids.append(data["message_id"])
|
|
|
- ids2text[data["message_id"]]=data["text"]
|
|
|
+ messages[data["message_id"]]=data["text"]
|
|
|
|
|
|
def follow(thread, current_id):
|
|
|
- thread = copy.copy(thread) + [ids2text[current_id]]
|
|
|
- if current_id in p2c:
|
|
|
+ thread = copy.copy(thread) + [messages[current_id]]
|
|
|
+ if current_id in nodes:
|
|
|
new_threads = []
|
|
|
- for next_id in p2c[current_id]:
|
|
|
+ for next_id in nodes[current_id]:
|
|
|
new_threads += follow(thread, next_id)
|
|
|
return new_threads
|
|
|
else:
|
|
|
return [thread]
|
|
|
|
|
|
-
|
|
|
def get_threads_from_root(root_id):
|
|
|
all_threads = []
|
|
|
- thread = [ids2text[root_id]]
|
|
|
- for cid in p2c[root_id]:
|
|
|
+ thread = [messages[root_id]]
|
|
|
+ for cid in nodes[root_id]:
|
|
|
all_threads += follow(thread, cid)
|
|
|
return all_threads
|
|
|
-
|
|
|
-
|
|
|
- # all_threads = []
|
|
|
- # for rid in root_ids:
|
|
|
-
|
|
|
|
|
|
dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
|
|
|
dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
|
|
|
dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
|
|
|
+
|
|
|
+ def to_dialog(thread):
|
|
|
+ dialog = []
|
|
|
+ for i, content in enumerate(thread):
|
|
|
+ dialog.append({
|
|
|
+ "role": "user" if i % 2 == 0 else "assistant",
|
|
|
+ "content": content,
|
|
|
+ })
|
|
|
+ return {"dialog": dialog}
|
|
|
|
|
|
- print(len(dataset))
|
|
|
- from pprint import pprint
|
|
|
- pprint(dataset[:10])
|
|
|
+ dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
|
|
|
+ dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
|
|
|
+ dataset = dataset.map(Concatenator(), batched=True)
|
|
|
|
|
|
- return dataset
|
|
|
- # threads={}
|
|
|
-
|
|
|
- # prompt = (
|
|
|
- # f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
|
|
|
- # )
|
|
|
-
|
|
|
- # def apply_prompt_template(sample):
|
|
|
- # return {
|
|
|
- # "text": prompt.format(
|
|
|
- # dialog=sample["dialogue"],
|
|
|
- # summary=sample["summary"],
|
|
|
- # eos_token=tokenizer.eos_token,
|
|
|
- # )
|
|
|
- # }
|
|
|
-
|
|
|
- # dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
|
|
|
-
|
|
|
- # dataset = dataset.map(
|
|
|
- # lambda sample: tokenizer(sample["text"]),
|
|
|
- # batched=True,
|
|
|
- # remove_columns=list(dataset.features),
|
|
|
- # ).map(Concatenator(), batched=True)
|
|
|
- # return dataset
|
|
|
+ return dataset
|