|
@@ -19,7 +19,7 @@ from transformers import LlamaTokenizer
|
|
|
|
|
|
|
|
|
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
|
|
|
-from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
|
|
|
from llama_recipes.utils.memory_utils import MemoryTrace
|
|
|
|
|
|
|
|
@@ -356,7 +356,7 @@ def get_policies(cfg, rank):
|
|
|
bf16_ready = verify_bfloat_support
|
|
|
|
|
|
if bf16_ready and not cfg.use_fp16:
|
|
|
- mixed_precision_policy = bfSixteen_mixed
|
|
|
+ mixed_precision_policy = bfSixteen
|
|
|
if rank == 0:
|
|
|
print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
|
|
|
elif cfg.use_fp16:
|