|
@@ -137,9 +137,15 @@ def main(**kwargs):
|
|
|
)
|
|
|
|
|
|
# Load the tokenizer and add special tokens
|
|
|
- tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
|
|
|
+ tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
|
|
|
+ # If there is a mismatch between tokenizer vocab size and embedding matrix,
|
|
|
+ # throw a warning and then expand the embedding matrix
|
|
|
+ if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
|
|
|
+ print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
|
|
|
+ model.resize_token_embeddings(len(tokenizer))
|
|
|
+
|
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
|
|
|
|
# Prepare the model for int8 training if quantization is enabled
|