|
@@ -9,7 +9,7 @@ import fire
|
|
|
import random
|
|
|
import torch
|
|
|
import torch.optim as optim
|
|
|
-from peft import get_peft_model, prepare_model_for_int8_training
|
|
|
+from peft import get_peft_model, prepare_model_for_kbit_training
|
|
|
from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
|
ShardingStrategy
|
|
@@ -144,7 +144,7 @@ def main(**kwargs):
|
|
|
|
|
|
# Prepare the model for int8 training if quantization is enabled
|
|
|
if train_config.quantization:
|
|
|
- model = prepare_model_for_int8_training(model)
|
|
|
+ model = prepare_model_for_kbit_training(model)
|
|
|
|
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
|
if train_config.enable_fsdp and fsdp_config.pure_bf16:
|