extend_tokenizer.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """
  2. Code borrowed from https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/scripts/merge_tokenizer/merge_tokenizers.py
  3. """
  4. import os
  5. import fire
  6. import re
  7. from transformers import LlamaTokenizer
  8. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
  9. from huggingface_hub import hf_hub_download
  10. from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
  11. def main(new_tokenizer_path, extended_tokenizer_save_path):
  12. original_tokenizer_path = hf_hub_download(repo_id="meta-llama/Llama-2-7b-chat-hf", filename="tokenizer.model", local_dir="original_tokenizer")
  13. original_tokenizer_spm = sp_pb2_model.ModelProto()
  14. original_tokenizer_spm.ParseFromString(open(original_tokenizer_path, "rb").read())
  15. new_tokenizer_spm = sp_pb2_model.ModelProto()
  16. new_tokenizer_spm.ParseFromString(open(os.path.join(new_tokenizer_path, "tokenizer.model"), "rb").read())
  17. def contains_eng(text):
  18. eng_pattern = re.compile(r"[\u0020-\u007E]+")
  19. return True if eng_pattern.search(text) else False
  20. original_tokenizer_tokenset = set(p.piece for p in original_tokenizer_spm.pieces)
  21. print(f"Number of tokens before merge: {len(original_tokenizer_tokenset)}")
  22. for p in new_tokenizer_spm.pieces:
  23. piece = p.piece
  24. if piece not in original_tokenizer_tokenset and not contains_eng(piece):
  25. new_p = sp_pb2_model.ModelProto().SentencePiece()
  26. new_p.piece = piece
  27. new_p.score = 0
  28. original_tokenizer_spm.pieces.append(new_p)
  29. print(f"Number of tokens after merge: {len(original_tokenizer_spm.pieces)}")
  30. os.makedirs(extended_tokenizer_save_path, exist_ok=True)
  31. with open(os.path.join(extended_tokenizer_save_path, "tokenizer.model"), "wb") as f:
  32. f.write(original_tokenizer_spm.SerializeToString())
  33. tokenizer = LlamaTokenizer(vocab_file=os.path.join(extended_tokenizer_save_path, "tokenizer.model"), legacy=False)
  34. tokenizer.save_pretrained(extended_tokenizer_save_path)
  35. print(f"Tokenizer saved to {extended_tokenizer_save_path}")
  36. # Verify that the extended tokenizer's English vocab matches with that of the original Llama tokenizer
  37. tok1 = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')
  38. tok2 = LlamaTokenizer.from_pretrained(extended_tokenizer_save_path)
  39. for i in range(len(tok1)):
  40. assert tok1.convert_ids_to_tokens(i) == tok2.convert_ids_to_tokens(i), f"Token mismatch at index {i}."
  41. if __name__ == "__main__":
  42. fire.Fire(main)