|
@@ -12,6 +12,7 @@ from peft import get_peft_model, prepare_model_for_int8_training
|
|
|
from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
|
)
|
|
|
+from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
from torch.optim.lr_scheduler import StepLR
|
|
|
from torch.utils.data import DistributedSampler
|
|
|
from transformers import (
|
|
@@ -144,6 +145,7 @@ def main(**kwargs):
|
|
|
model = FSDP(
|
|
|
model,
|
|
|
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
|
|
|
+ cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
|
|
|
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
|
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
|
device_id=torch.cuda.current_device(),
|