|
@@ -102,9 +102,9 @@ def main(**kwargs):
|
|
|
"""
|
|
|
try:
|
|
|
from optimum.bettertransformer import BetterTransformer
|
|
|
+ model = BetterTransformer.transform(model)
|
|
|
except ImportError:
|
|
|
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
|
|
|
- model = BetterTransformer.transform(model)
|
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
|
|
|
|
# Prepare the model for int8 training if quantization is enabled
|