|
@@ -144,7 +144,7 @@ def main(**kwargs):
|
|
|
|
|
|
# Prepare the model for int8 training if quantization is enabled
|
|
# Prepare the model for int8 training if quantization is enabled
|
|
if train_config.quantization:
|
|
if train_config.quantization:
|
|
- model = prepare_model_for_int8_training(model)
|
|
|
|
|
|
+ model = prepare_model_for_kbit_training(model)
|
|
|
|
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
if train_config.enable_fsdp and fsdp_config.pure_bf16:
|
|
if train_config.enable_fsdp and fsdp_config.pure_bf16:
|