Browse Source

Invalidate labels in dialog dataset to disable loss

Matthias Reso 1 year ago
parent
commit
eafea7b366

+ 24 - 24
examples/custom_dataset.py

@@ -14,18 +14,18 @@ B_INST, E_INST = "[INST]", "[/INST]"
 B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
 
 def tokenize_dialog(dialog, tokenizer):
-    dialog_tokens = [
-            tokenizer(
-                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-            )
-            for prompt, answer in zip(dialog[::2], dialog[1::2])
-        ]
-    if len(dialog) % 2:    
-        dialog_tokens += [tokenizer(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )]
-    
-    combined_tokens = {}  
+    prompt_tokens = [tokenizer(f"{B_INST} {(prompt['content']).strip()} {E_INST}") for prompt in dialog[::2]]
+    answer_tokens = [tokenizer(f"{answer['content'].strip()} ") for answer in dialog[1::2]]
+    answer_tokens = [{k:v[1:] for k,v in items.items()} for items in answer_tokens]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    dialog_tokens = [dict(c, labels=len(c["input_ids"])*[-100,]if i % 2 == 0 else c["input_ids"]) for i,c in enumerate(dialog_tokens)]
+
+    if len(dialog) % 2:
+        dialog_tokens += [prompt_tokens[-1]]
+        dialog_tokens[-1] = dict(dialog_tokens[-1], labels=[-100]*len(dialog_tokens[-1]["input_ids"]))
+
+    combined_tokens = {}
     for k in dialog_tokens[0].keys():
         combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
     return combined_tokens
@@ -33,7 +33,7 @@ def tokenize_dialog(dialog, tokenizer):
 
 def get_custom_dataset(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
-    
+
     dataset = dataset.map(lambda sample: {
         "message_id": sample["message_id"],
         "parent_id": sample["parent_id"],
@@ -41,19 +41,19 @@ def get_custom_dataset(dataset_config, tokenizer, split):
         },
         batched=True,
         remove_columns=list(dataset.features),)
-    
+
     nodes = {}
-    
+
     messages = {}
     root_ids = []
-    
+
     for data in dataset:
         if data["parent_id"]:
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
         else:
             root_ids.append(data["message_id"])
         messages[data["message_id"]]=data["text"]
-           
+
     def follow(thread, current_id):
         thread = copy.copy(thread) + [messages[current_id]]
         if current_id in nodes:
@@ -63,18 +63,18 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             return new_threads
         else:
             return [thread]
-        
+
     def get_threads_from_root(root_id):
         all_threads = []
         thread = [messages[root_id]]
         for cid in nodes[root_id]:
             all_threads += follow(thread, cid)
         return all_threads
-            
+
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
-    
+
     def to_dialog(thread):
         dialog = []
         for i, content in enumerate(thread):
@@ -83,9 +83,9 @@ def get_custom_dataset(dataset_config, tokenizer, split):
                 "content": content,
             })
         return {"dialog": dialog}
-            
+
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
-    dataset = dataset.map(Concatenator(), batched=True)
-    
-    return dataset
+    dataset = dataset.map(Concatenator(add_labels=False), batched=True)
+
+    return dataset

+ 13 - 9
src/llama_recipes/datasets/utils.py

@@ -7,10 +7,13 @@ from itertools import chain
 from torch.utils.data import Dataset
 
 class Concatenator(object):
-    def __init__(self, chunk_size=2048):
+    def __init__(self, chunk_size=2048, add_labels=True):
         self.chunk_size=chunk_size
         self.residual = {"input_ids": [], "attention_mask": []}
-        
+        self.add_labels = add_labels
+        if not add_labels:
+            self.residual["labels"] = []
+
     def __call__(self, batch):
         concatenated_samples = {
             k: v + list(chain(*batch[k])) for k, v in self.residual.items()
@@ -35,7 +38,8 @@ class Concatenator(object):
             result = concatenated_samples
             self.residual = {k: [] for k in concatenated_samples.keys()}
 
-        result["labels"] = result["input_ids"].copy()
+        if self.add_labels:
+            result["labels"] = result["input_ids"].copy()
 
         return result
 
@@ -43,24 +47,24 @@ class ConcatDataset(Dataset):
     def __init__(self, dataset, chunk_size=4096):
         self.dataset = dataset
         self.chunk_size = chunk_size
-        
+
         self.samples = []
-        
+
         buffer = {
             "input_ids": [],
             "attention_mask": [],
             "labels": [],
             }
-        
+
         for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
             buffer = {k: v + sample[k] for k,v in buffer.items()}
-            
+
             while len(next(iter(buffer.values()))) > self.chunk_size:
                 self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
                 buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-                
+
     def __getitem__(self, idx):
         return self.samples[idx]
-    
+
     def __len__(self):
         return len(self.samples)

+ 3 - 1
tests/datasets/test_custom_dataset.py

@@ -34,10 +34,12 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     assert len(eval_dataloader) == 2*226
 
     it = iter(train_dataloader)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
+    entry = next(it)
+    STRING = tokenizer.decode(entry["input_ids"][0], skip_special_tokens=True)
     EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
 
     assert STRING.startswith(EXPECTED_STRING)
+    assert entry["labels"][0][:10].tolist() == 10*[-100]
 
     next(it)
     next(it)