Quellcode durchsuchen

Remove check for nighlies for low_cpu_fsdp and bump torch version to 2.2 instead

Matthias Reso vor 10 Monaten
Ursprung
Commit
43cb6a2db4
2 geänderte Dateien mit 6 neuen und 12 gelöschten Zeilen
  1. 1 1
      requirements.txt
  2. 5 11
      src/llama_recipes/finetuning.py

+ 1 - 1
requirements.txt

@@ -1,4 +1,4 @@
-torch>=2.0.1
+torch>=2.2
 accelerate
 appdirs
 loralib

+ 5 - 11
src/llama_recipes/finetuning.py

@@ -2,7 +2,6 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
-from pkg_resources import packaging
 
 import dataclasses
 import fire
@@ -51,7 +50,7 @@ from llama_recipes.utils.train_utils import (
 from accelerate.utils import is_xpu_available
 
 def setup_wandb(train_config, fsdp_config, **kwargs):
-    try: 
+    try:
         import wandb
     except ImportError:
         raise ImportError(
@@ -97,7 +96,7 @@ def main(**kwargs):
 
     if train_config.use_wandb:
         if not train_config.enable_fsdp or rank==0:
-            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)    
+            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
@@ -108,11 +107,6 @@ def main(**kwargs):
         model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
         overhead and currently requires latest nightly.
         """
-        v = packaging.version.parse(torch.__version__)
-        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
-        if not verify_latest_nightly:
-            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
-                            "please install latest nightly.")
         if rank == 0:
             model = LlamaForCausalLM.from_pretrained(
                 train_config.model_name,
@@ -157,12 +151,12 @@ def main(**kwargs):
         if wandb_run:
             wandb_run.config.update(peft_config)
 
-        
+
     hsdp_device_mesh = None
     if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
         hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
         print("HSDP device mesh is ready")
-        
+
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
@@ -171,7 +165,7 @@ def main(**kwargs):
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-        
+
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()