|
@@ -53,10 +53,8 @@ import configs
|
|
|
from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
|
MixedPrecision,
|
|
|
- StateDictType,
|
|
|
)
|
|
|
from torch.utils.data import DistributedSampler
|
|
|
-from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
|
|
|
import policies
|
|
|
from policies import AnyPrecisionAdamW
|
|
|
from configs import fsdp_config, train_config
|
|
@@ -66,7 +64,6 @@ from pkg_resources import packaging
|
|
|
import torch
|
|
|
import torch.cuda.nccl as nccl
|
|
|
import torch.distributed as dist
|
|
|
-from transformers.models.t5.modeling_t5 import T5Block
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
|
|
|
|
|
@@ -239,4 +236,4 @@ def main(**kwargs):
|
|
|
[print(f'Key: {k}, Value: {v}') for k, v in results.items()]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- fire.Fire(main)
|
|
|
+ fire.Fire(main)
|