|
@@ -9,7 +9,7 @@ import fire
|
|
import random
|
|
import random
|
|
import torch
|
|
import torch
|
|
import torch.optim as optim
|
|
import torch.optim as optim
|
|
-from peft import get_peft_model, prepare_model_for_int8_training
|
|
|
|
|
|
+from peft import get_peft_model, prepare_model_for_kbit_training
|
|
from torch.distributed.fsdp import (
|
|
from torch.distributed.fsdp import (
|
|
FullyShardedDataParallel as FSDP,
|
|
FullyShardedDataParallel as FSDP,
|
|
ShardingStrategy
|
|
ShardingStrategy
|