Просмотр исходного кода

Use bf16 parameters in bf16 mixed prec (#283)

Hamid Shojanazeri 1 год назад
Родитель
Сommit
acce2d8770
1 измененных файлов с 2 добавлено и 2 удалено
  1. 2 2
      src/llama_recipes/utils/train_utils.py

+ 2 - 2
src/llama_recipes/utils/train_utils.py

@@ -19,7 +19,7 @@ from transformers import LlamaTokenizer
 
 
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
-from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 
 
@@ -356,7 +356,7 @@ def get_policies(cfg, rank):
         bf16_ready = verify_bfloat_support
 
         if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen_mixed
+            mixed_precision_policy = bfSixteen
             if rank == 0:
                 print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
         elif cfg.use_fp16: