|
@@ -65,6 +65,7 @@ def main(**kwargs):
|
|
|
setup_environ_flags(rank)
|
|
|
|
|
|
|
|
|
+ use_cache = False if train_config.enable_fsdp else None
|
|
|
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
|
|
|
"""
|
|
|
for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
|
|
@@ -82,9 +83,11 @@ def main(**kwargs):
|
|
|
train_config.model_name,
|
|
|
load_in_8bit=True if train_config.quantization else None,
|
|
|
device_map="auto" if train_config.quantization else None,
|
|
|
+ use_cache=use_cache,
|
|
|
)
|
|
|
else:
|
|
|
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
|
+ llama_config.use_cache = use_cache
|
|
|
with torch.device("meta"):
|
|
|
model = LlamaForCausalLM(llama_config)
|
|
|
|
|
@@ -93,6 +96,7 @@ def main(**kwargs):
|
|
|
train_config.model_name,
|
|
|
load_in_8bit=True if train_config.quantization else None,
|
|
|
device_map="auto" if train_config.quantization else None,
|
|
|
+ use_cache=use_cache,
|
|
|
)
|
|
|
if train_config.enable_fsdp and train_config.use_fast_kernels:
|
|
|
"""
|