|
@@ -76,11 +76,6 @@ def main(**kwargs):
|
|
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
|
|
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
|
|
overhead and currently requires latest nightly.
|
|
overhead and currently requires latest nightly.
|
|
"""
|
|
"""
|
|
- v = packaging.version.parse(torch.__version__)
|
|
|
|
- verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
|
|
|
|
- if not verify_latest_nightly:
|
|
|
|
- raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
|
|
|
|
- "please install latest nightly.")
|
|
|
|
if rank == 0:
|
|
if rank == 0:
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
train_config.model_name,
|
|
train_config.model_name,
|