浏览代码

use new tokenizer_name argument and resize embeddings if required

Rahul A R 11 月之前
父节点
当前提交
f8183b96fe
共有 1 个文件被更改,包括 7 次插入1 次删除
  1. 7 1
      src/llama_recipes/finetuning.py

+ 7 - 1
src/llama_recipes/finetuning.py

@@ -137,9 +137,15 @@ def main(**kwargs):
         )
 
     # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
 
+    # If there is a mismatch between tokenizer vocab size and embedding matrix, 
+    # throw a warning and then expand the embedding matrix
+    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
+        print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
+        model.resize_token_embeddings(len(tokenizer))
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
     # Prepare the model for int8 training if quantization is enabled