Ver código fonte

Use bf16 parameters in bf16 mixed prec (#283)

Hamid Shojanazeri 1 ano atrás
pai
commit
acce2d8770
1 arquivos alterados com 2 adições e 2 exclusões
  1. 2 2
      src/llama_recipes/utils/train_utils.py

+ 2 - 2
src/llama_recipes/utils/train_utils.py

@@ -19,7 +19,7 @@ from transformers import LlamaTokenizer
 
 
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
-from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 
 
@@ -356,7 +356,7 @@ def get_policies(cfg, rank):
         bf16_ready = verify_bfloat_support
 
         if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen_mixed
+            mixed_precision_policy = bfSixteen
             if rank == 0:
                 print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
         elif cfg.use_fp16: