Skip to content

Commit

Permalink
remove SharedDDP as it is deprecated (intel#1103)
Browse files Browse the repository at this point in the history
* remove SharedDDP as it is deprecated
  • Loading branch information
lkk12014402 authored Jan 12, 2024
1 parent c5ab7db commit 4e6834a
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function pytest() {
mkdir -p ${coverage_log_dir}
pip install --no-cache-dir protobuf==3.20.0
## install transformers==4.34.1, to work with SharedDPO API
pip install transformers==4.34.1
pip install transformers
cd /intel-extension-for-transformers/tests/CI || exit 1
JOB_NAME=unit_test
ut_log_name=${LOG_DIR}/${JOB_NAME}.log
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,4 @@ def _load_sbert_model(
module = module_class.load(module_path)
modules[module_config['name']] = module

return modules
return modules
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@
"save_strategy=no,\n",
"save_total_limit=2,\n",
"seed=42,\n",
"sharded_ddp=[],\n",
"skip_memory_metrics=True,\n",
"tf32=None,\n",
"torch_compile=False,\n",
Expand Down Expand Up @@ -1526,7 +1525,6 @@
"save_strategy=no,\n",
"save_total_limit=2,\n",
"seed=42,\n",
"sharded_ddp=[],\n",
"skip_memory_metrics=True,\n",
"tf32=None,\n",
"torch_compile=False,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@
"save_strategy=no,\n",
"save_total_limit=2,\n",
"seed=42,\n",
"sharded_ddp=[],\n",
"skip_memory_metrics=True,\n",
"tf32=None,\n",
"torch_compile=False,\n",
Expand Down Expand Up @@ -740,7 +739,6 @@
"save_strategy=no,\n",
"save_total_limit=2,\n",
"seed=42,\n",
"sharded_ddp=[],\n",
"skip_memory_metrics=True,\n",
"tf32=None,\n",
"torch_compile=False,\n",
Expand Down Expand Up @@ -1322,7 +1320,6 @@
"save_strategy=no,\n",
"save_total_limit=2,\n",
"seed=42,\n",
"sharded_ddp=[],\n",
"skip_memory_metrics=True,\n",
"tf32=None,\n",
"torch_compile=False,\n",
Expand Down Expand Up @@ -1807,7 +1804,6 @@
"save_strategy=no,\n",
"save_total_limit=2,\n",
"seed=42,\n",
"sharded_ddp=[],\n",
"skip_memory_metrics=True,\n",
"tf32=None,\n",
"torch_compile=False,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
ShardedDDPOption,
logger,
)
from typing import List, Optional
Expand Down Expand Up @@ -176,7 +175,7 @@ def create_optimizer(self):
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
if self.is_fsdp_enabled:
return super().create_optimizer()

opt_model = self.model
Expand Down Expand Up @@ -237,27 +236,20 @@ def create_optimizer(self):

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")

return self.optimizer

Expand Down Expand Up @@ -297,7 +289,6 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
ShardedDDPOption,
logger,
)
from typing import List, Optional
Expand Down Expand Up @@ -328,7 +319,7 @@ def create_optimizer(self):
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
if self.is_fsdp_enabled:
return super().create_optimizer()

opt_model = self.model
Expand Down Expand Up @@ -401,27 +392,20 @@ def create_optimizer(self):

# optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")

return self.optimizer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


class LlavaConfig(MistralConfig):
model_type = "llava"
model_type = "llava_custom"


class LlavaMistralModel(LlavaMetaModel, MistralModel):
Expand Down Expand Up @@ -110,5 +110,5 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
_inputs['images'] = images
return _inputs

AutoConfig.register("llava", LlavaConfig)
AutoConfig.register("llava_custom", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaMistralForCausalLM)
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
QUANT_CONFIG,
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
)
from intel_extension_for_transformers.llm.quantization.utils import replace_linear
from transformers.configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -727,6 +728,13 @@ def load_low_bit(cls,
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
):
# Load from a safetensors checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
is_local = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from neural_compressor.utils.pytorch import load
from transformers import AutoModel, PretrainedConfig
from transformers.file_utils import add_start_docstrings
from transformers.modeling_utils import no_init_weights
from transformers.models.auto.auto_factory import _get_model_class
from transformers.utils.generic import ContextManagers
from optimum.exporters import TasksManager

from optimum.intel.neural_compressor import INCConfig
Expand Down Expand Up @@ -268,9 +266,7 @@ def _from_pretrained(
decoder = model
else:
model_class = _get_model_class(config, cls.auto_model_class._model_mapping)
init_contexts = [no_init_weights(_enable=True)]
with ContextManagers(init_contexts):
model = model_class(config)
model = model_class(config)

# Load the model from local directory
if os.path.isdir(model_id):
Expand Down
18 changes: 6 additions & 12 deletions intel_extension_for_transformers/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from transformers import __version__, Seq2SeqTrainer, Trainer, PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.file_utils import (
from transformers.utils import (
CONFIG_NAME,
WEIGHTS_NAME,
is_torch_tpu_available,
Expand All @@ -67,7 +67,6 @@
)
from transformers.trainer_utils import (
HPSearchBackend,
ShardedDDPOption,
TrainOutput,
EvalLoopOutput,
EvalPrediction,
Expand Down Expand Up @@ -762,7 +761,8 @@ def train(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
# delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
delay_optimizer_creation = is_sagemaker_mp_enabled()

if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
Expand Down Expand Up @@ -1176,9 +1176,7 @@ def training_step(
else:
loss.backward()
else:
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
elif self.use_apex:
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif NEW_DEEPSPEED_FLAG:
Expand Down Expand Up @@ -1265,9 +1263,7 @@ def training_step_length_adaptive(
else:
loss.backward()
else:
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
elif self.use_apex:
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif NEW_DEEPSPEED_FLAG:
Expand Down Expand Up @@ -1360,9 +1356,7 @@ def training_step_length_adaptive(
else:
loss.backward()
else:
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
elif self.use_apex:
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif NEW_DEEPSPEED_FLAG:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
QUANT_CONFIG = "quantization_config.json"
SPARSITY_CONFIG = "sparsity_config.json"
SAFE_WEIGHTS_NAME = "model.safetensors"

torch = LazyImport("torch")

Expand Down
2 changes: 1 addition & 1 deletion tests/CI/test_weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_auto_model_saving_loading(self):
if isinstance(module, QuantizedLinearQBits):
module_list.append(name)
self.assertTrue(len(module_list) > 0)
model.save_pretrained(self.workspace)
model.save_pretrained(self.workspace, safe_serialization=False)
loaded_model = AutoModelForCausalLM.from_pretrained(self.workspace)
for name, module in loaded_model.named_modules():
if isinstance(module, QuantizedLinearQBits):
Expand Down

0 comments on commit 4e6834a

Please sign in to comment.