Selaa lähdekoodia

Merge branch 'fix/invalidate_label_for_chat' into feature/length_based_batch_sampling

Matthias Reso 1 vuosi sitten
vanhempi
commit
52c417b7d5

+ 19 - 25
examples/custom_dataset.py

@@ -11,21 +11,16 @@ from llama_recipes.datasets.utils import Concatenator
 
 
 B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
 
 def tokenize_dialog(dialog, tokenizer):
-    dialog_tokens = [
-            tokenizer(
-                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-            )
-            for prompt, answer in zip(dialog[::2], dialog[1::2])
-        ]
-    if len(dialog) % 2:    
-        dialog_tokens += [tokenizer(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )]
-    
-    combined_tokens = {}  
+    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+    answer_tokens = [{k:v[1:] for k,v in items.items()} for items in answer_tokens]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    dialog_tokens = [dict(c, labels=len(c["input_ids"])*[-100,]if i % 2 == 0 else c["input_ids"]) for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {}
     for k in dialog_tokens[0].keys():
         combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
     return combined_tokens
@@ -33,7 +28,7 @@ def tokenize_dialog(dialog, tokenizer):
 
 def get_custom_dataset(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
-    
+
     dataset = dataset.map(lambda sample: {
         "message_id": sample["message_id"],
         "parent_id": sample["parent_id"],
@@ -41,19 +36,19 @@ def get_custom_dataset(dataset_config, tokenizer, split):
         },
         batched=True,
         remove_columns=list(dataset.features),)
-    
+
     nodes = {}
-    
+
     messages = {}
     root_ids = []
-    
+
     for data in dataset:
         if data["parent_id"]:
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
         else:
             root_ids.append(data["message_id"])
         messages[data["message_id"]]=data["text"]
-           
+
     def follow(thread, current_id):
         thread = copy.copy(thread) + [messages[current_id]]
         if current_id in nodes:
@@ -63,18 +58,18 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             return new_threads
         else:
             return [thread]
-        
+
     def get_threads_from_root(root_id):
         all_threads = []
         thread = [messages[root_id]]
         for cid in nodes[root_id]:
             all_threads += follow(thread, cid)
         return all_threads
-            
+
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
-    
+
     def to_dialog(thread):
         dialog = []
         for i, content in enumerate(thread):
@@ -83,9 +78,8 @@ def get_custom_dataset(dataset_config, tokenizer, split):
                 "content": content,
             })
         return {"dialog": dialog}
-            
+
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
-    dataset = dataset.map(lambda x: dict(x, labels=x["input_ids"].copy()), remove_columns=list(dataset.features))
-    
-    return dataset
+
+    return dataset

+ 0 - 33
src/llama_recipes/data/concatenator.py

@@ -7,39 +7,6 @@ from itertools import chain
 from torch.utils.data import Dataset
 
 
-class Concatenator(object):
-    def __init__(self, chunk_size=2048):
-        self.chunk_size=chunk_size
-        self.residual = {"input_ids": [], "attention_mask": []}
-
-    def __call__(self, batch):
-        concatenated_samples = {
-            k: v + list(chain(*batch[k])) for k, v in self.residual.items()
-        }
-
-        total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
-
-        if total_length >= self.chunk_size:
-            chunk_num = total_length // self.chunk_size
-            result = {
-                k: [
-                    v[i : i + self.chunk_size]
-                    for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
-                ]
-                for k, v in concatenated_samples.items()
-            }
-            self.residual = {
-                k: v[(chunk_num * self.chunk_size) :]
-                for k, v in concatenated_samples.items()
-            }
-        else:
-            result = concatenated_samples
-            self.residual = {k: [] for k in concatenated_samples.keys()}
-
-        result["labels"] = result["input_ids"].copy()
-
-        return result
-
 class ConcatDataset(Dataset):
     def __init__(self, dataset, chunk_size=4096):
         self.dataset = dataset

+ 10 - 3
tests/datasets/test_custom_dataset.py

@@ -38,15 +38,22 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
     EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
     assert STRING.startswith(EXPECTED_STRING)
-    
-    assert next(it)["input_ids"].size(0) == 4
+
+    # assert next(it)["input_ids"].size(0) == 4
+    # it = iter(train_dataloader)
+    # entry = next(it)
+    # STRING = tokenizer.decode(entry["input_ids"][0], skip_special_tokens=True)
+    # EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+
+    # assert STRING.startswith(EXPECTED_STRING)
+    # assert entry["labels"][0][:10].tolist() == 10*[-100]
 
     next(it)
     next(it)
     STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
     EXPECTED_STRING = "[INST] Implementa el algoritmo `bubble sort` en C. [/INST] xdxdxd"
     assert STRING.startswith(EXPECTED_STRING)
-    
+
     assert "labels" in next(iter(train_dataloader)).keys()
     assert "input_ids" in next(iter(train_dataloader)).keys()
     assert "attention_mask" in next(iter(train_dataloader)).keys()