diff --git a/examples/regex_freezing/ecd_freezing_with_regex_training.py b/examples/regex_freezing/ecd_freezing_with_regex_training.py new file mode 100644 index 00000000000..dca1704b0cf --- /dev/null +++ b/examples/regex_freezing/ecd_freezing_with_regex_training.py @@ -0,0 +1,71 @@ +import logging +import os +import shutil + +import pandas as pd +import yaml +from datasets import load_dataset + +from ludwig.api import LudwigModel + +""" +To inspect model layers in the terminal, type: "ludwig collect_summary -pm resnet18" + +For some models, a HuggingFace Token will be necessary. +Once you obtain one, use "export HUGGING_FACE_HUB_TOKEN=""" in the terminal. +""" + +dataset = load_dataset("beans") +train_df = pd.DataFrame( + {"image_path": [f"train_{i}.jpg" for i in range(len(dataset["train"]))], "label": dataset["train"]["labels"]} +) +test_df = pd.DataFrame( + {"image_path": [f"test_{i}.jpg" for i in range(len(dataset["test"]))], "label": dataset["test"]["labels"]} +) + +os.makedirs("train_images", exist_ok=True) +os.makedirs("test_images", exist_ok=True) + +for i, img in enumerate(dataset["train"]["image"]): + img.save(f"train_images/train_{i}.jpg") +for i, img in enumerate(dataset["test"]["image"]): + img.save(f"test_images/test_{i}.jpg") + +train_df["image_path"] = train_df["image_path"].apply(lambda x: os.path.join("train_images", x)) +test_df["image_path"] = test_df["image_path"].apply(lambda x: os.path.join("test_images", x)) + +train_df.to_csv("beans_train.csv", index=False) +test_df.to_csv("beans_test.csv", index=False) + + +config = yaml.safe_load( + r""" +input_features: + - name: image_path + type: image + encoder: + type: resnet + use_pretrained: true + trainable: true +output_features: + - name: label + type: category +trainer: + epochs: 1 + batch_size: 5 + layers_to_freeze_regex: '(layer1\.0\.*|layer2\.0\.*)' + + """ +) + +model = LudwigModel(config, logging_level=logging.INFO) +train_stats = model.train(dataset="beans_train.csv", skip_save_model=True) +eval_stats, predictions, output_directory = model.evaluate(dataset="beans_test.csv") + +print("Training Statistics: ", train_stats) +print("Evaluation Statistics: ", eval_stats) + +shutil.rmtree("train_images") +shutil.rmtree("test_images") +os.remove("beans_train.csv") +os.remove("beans_test.csv") diff --git a/examples/regex_freezing/llm_freezing_with_regex_training.py b/examples/regex_freezing/llm_freezing_with_regex_training.py new file mode 100644 index 00000000000..c002038f434 --- /dev/null +++ b/examples/regex_freezing/llm_freezing_with_regex_training.py @@ -0,0 +1,64 @@ +import logging + +import yaml + +from ludwig.api import LudwigModel + +""" +To inspect model layers in the terminal, type: "ludwig collect_summary -pm resnet18" + +For some models, a HuggingFace Token will be necessary. +Once you obtain one, use "export HUGGING_FACE_HUB_TOKEN=""" in the terminal. +""" + +config_str = yaml.safe_load( + r""" +model_type: llm +base_model: facebook/opt-350m + +adapter: + type: lora + +prompt: + template: | + ### Instruction: + Generate a concise summary of the following text, capturing the main points and conclusions. + + ### Input: + {input} + + ### Response: + +input_features: + - name: prompt + type: text + preprocessing: + max_sequence_length: 256 + + +output_features: + - name: output + type: text + preprocessing: + max_sequence_length: 256 + +trainer: + type: finetune + layers_to_freeze_regex: (decoder\.layers\.22\.final_layer_norm\.*) + learning_rate: 0.0001 + batch_size: 5 + gradient_accumulation_steps: 16 + epochs: 1 + learning_rate_scheduler: + warmup_fraction: 0.01 + +preprocessing: + sample_ratio: 0.1 + +generation: + pad_token_id : 0 +""" +) + +model = LudwigModel(config=config_str, logging_level=logging.INFO) +results = model.train(dataset="ludwig://alpaca") diff --git a/ludwig/collect.py b/ludwig/collect.py index f3dc6d7e7e3..066edcd191c 100644 --- a/ludwig/collect.py +++ b/ludwig/collect.py @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================== import argparse +import importlib import logging import os import sys @@ -186,6 +187,51 @@ def print_model_summary(model_path: str, **kwargs) -> None: logger.info(name) +def pretrained_summary(pretrained_model: str, **kwargs) -> None: + """Loads a pretrained model from Huggingface or Torchvision models and prints names of layers. + + # Inputs + :param pretrained_model: (str) name of model to load (case sensitive). + + # Return + :return: (`None`) + """ + from transformers import AutoConfig, AutoModel + + model = None + # get access token if available + token = os.getenv("HUGGING_FACE_HUB_TOKEN") + if token is None: + logger.info("No token provided. Continuing loading without token access.") + elif not token: + raise ValueError("Invalid token provided. Exiting.") + else: + logger.info("Valid token provided. Proceeding with token access.") + + # Try to load from transformers/HF + # TODO -> Fix OOM on large models e.g. llama 3 8B + try: + config = AutoConfig.from_pretrained(pretrained_model, token=token, low_cpu_mem_usage=True) + model = AutoModel.from_config(config=config) + logger.info(f"Loaded {pretrained_model} from Hugging Face Transformers.") + except Exception as e: + logger.error(f"Failed to load {pretrained_model} from Hugging Face Transformers: {e}") + + # Try and load from torchvision-models + if model is None: + try: + module = importlib.import_module("torchvision.models") + model = getattr(module, pretrained_model)(weights=None) + except AttributeError: + logger.error(f"{pretrained_model} is not a valid torchvision model.") + + if model: + for name, _ in model.named_parameters(): + logger.info(name) + else: + logger.error(f"Unable to load the model {pretrained_model} from any known source.") + + def cli_collect_activations(sys_argv): """Command Line Interface to communicate with the collection of tensors and there are several options that can specified when calling this function: @@ -374,8 +420,8 @@ def cli_collect_weights(sys_argv): def cli_collect_summary(sys_argv): """Command Line Interface to collecting a summary of the model layers and weights. - --m: Input model that is necessary to collect to the tensors, this is a - required *option* + --m: Input model that is necessary to collect to the tensors + --pm: Model name in order to fetch from Huggingface or Torchvision --v: Verbose: Defines the logging level that the user will be exposed to """ parser = argparse.ArgumentParser( @@ -389,7 +435,10 @@ def cli_collect_summary(sys_argv): # ---------------- # Model parameters # ---------------- - parser.add_argument("-m", "--model_path", help="model to load", required=True) + parser.add_argument("-m", "--model_path", help="model to load", required=False) + parser.add_argument( + "-pm", "--pretrained_model", help="pretrained model to summarize (torchvision and huggingface)", required=False + ) # ------------------ # Runtime parameters @@ -416,7 +465,10 @@ def cli_collect_summary(sys_argv): print_ludwig("Collect Summary", LUDWIG_VERSION) - print_model_summary(**vars(args)) + if args.model_path: + print_model_summary(**vars(args)) + elif args.pretrained_model and not args.model_path: + pretrained_summary(**vars(args)) if __name__ == "__main__": diff --git a/ludwig/schema/metadata/configs/trainer.yaml b/ludwig/schema/metadata/configs/trainer.yaml index 6d2d9baf50e..b615d2e8b88 100644 --- a/ludwig/schema/metadata/configs/trainer.yaml +++ b/ludwig/schema/metadata/configs/trainer.yaml @@ -70,6 +70,15 @@ ecd: In many large-scale training runs, evaluation is often configured to run on a sub-epoch time scale, or every few thousand steps. ui_display_name: Checkpoints per epoch + layers_to_freeze_regex: + default_value_reasoning: + By default no layers will be frozen when fine-tuning a pretrained model. + description_implications: + Freezing specific layers can improve a pretrained model's performance in a number + of ways. At a basic level, freezing early layers can prevent overfitting by retaining + more general features (beneficial for small datasets). Also can reduce computational + resource use and lower overall training time due to less gradient calculations. + expected_impact: 1 early_stop: default_value_reasoning: Deep learning models are prone to overfitting. It's generally diff --git a/ludwig/schema/trainer.py b/ludwig/schema/trainer.py index d04d7b9f284..322e4df2b06 100644 --- a/ludwig/schema/trainer.py +++ b/ludwig/schema/trainer.py @@ -1,3 +1,4 @@ +import re from abc import ABC from typing import Optional, Type, Union @@ -162,6 +163,14 @@ def __post_init__(self): f"`gradient_accumulation_steps` ({self.gradient_accumulation_steps})." ) + if self.layers_to_freeze_regex: + try: + re.compile(self.layers_to_freeze_regex) + except re.error: + raise ConfigValidationError( + f"`layers_to_freeze_regex` ({self.layers_to_freeze_regex}) must be a valid regular expression." + ) + learning_rate: Union[float, str] = schema_utils.OneOfOptionsField( default=0.001, allow_none=False, @@ -444,6 +453,17 @@ def __post_init__(self): parameter_metadata=TRAINER_METADATA[MODEL_ECD]["enable_gradient_checkpointing"], ) + layers_to_freeze_regex: str = schema_utils.String( + default=None, + allow_none=True, + description=( + "Freeze specific layers based on provided regex. Freezing specific layers can improve a " + "pretrained model's performance in a number of ways. At a basic level, freezing early layers can " + "prevent overfitting by retaining more general features (beneficial for small datasets). Also can " + "reduce computational resource use and lower overall training time due to less gradient calculations. " + ), + ) + def update_batch_size_grad_accum(self, num_workers: int): from ludwig.utils.trainer_utils import get_rendered_batch_size_grad_accum diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 9edb6507fe5..74bbdd5885b 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -82,6 +82,7 @@ from ludwig.utils.torch_utils import get_torch_device from ludwig.utils.trainer_utils import ( append_metrics, + freeze_layers_regex, get_final_steps_per_checkpoint, get_latest_metrics_dict, get_new_progress_tracker, @@ -171,6 +172,7 @@ def __init__( self._validation_field = config.validation_field self._validation_metric = config.validation_metric self.early_stop = config.early_stop + self.layers_to_freeze_regex = config.layers_to_freeze_regex self.steps_per_checkpoint = config.steps_per_checkpoint self.checkpoints_per_epoch = config.checkpoints_per_epoch self.evaluate_training_set = config.evaluate_training_set @@ -242,6 +244,10 @@ def prepare(self): base_learning_rate *= lr_scale_fn(self.distributed.size()) self.base_learning_rate = base_learning_rate + # Given that regex is supplied, freeze layers + if self.config.layers_to_freeze_regex: + freeze_layers_regex(self.config, self.model) + # We may need to replace the embedding layer when using 8-bit optimizers from bitsandbytes. update_embedding_layer(self.compiled_model, self.config) diff --git a/ludwig/utils/trainer_utils.py b/ludwig/utils/trainer_utils.py index 207d64b8b78..8a9fd779d4f 100644 --- a/ludwig/utils/trainer_utils.py +++ b/ludwig/utils/trainer_utils.py @@ -1,6 +1,7 @@ import logging +import re from collections import defaultdict -from typing import Dict, List, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple, TYPE_CHECKING, Union try: from typing import Literal @@ -10,7 +11,10 @@ from ludwig.api_annotations import DeveloperAPI from ludwig.constants import AUTO, COMBINED, LOSS from ludwig.models.base import BaseModel +from ludwig.models.ecd import ECD +from ludwig.models.llm import LLM from ludwig.modules.metric_modules import get_best_function +from ludwig.schema.trainer import ECDTrainerConfig, FineTuneTrainerConfig from ludwig.utils.data_utils import save_json from ludwig.utils.metric_utils import TrainerMetric @@ -18,6 +22,7 @@ from ludwig.features.base_feature import OutputFeature from ludwig.schema.trainer import BaseTrainerConfig + logger = logging.getLogger(__name__) @@ -506,3 +511,57 @@ def get_rendered_batch_size_grad_accum(config: "BaseTrainerConfig", num_workers: gradient_accumulation_steps = 1 return batch_size, gradient_accumulation_steps + + +def freeze_layers_regex(config: Union[ECDTrainerConfig, FineTuneTrainerConfig], model: Union[ECD, LLM]) -> None: + """Freezes layers in a model whose names match a specified regular expression pattern. + + This function iterates over all parameters of the model, checking each parameter's name against + the regular expression defined in the configuration object. + If a match is found, the parameter's `requires_grad` attribute is set to False, + effectively freezing the layer for training purposes. + If no matches are found, an error is logged indicating the issue with the regex or the model's layer names. + + Parameters: + - config (Union[ECDTrainerConfig, FineTuneTrainerConfig]): + - model (Union[ECD, LLM]): The model object containing layers and parameters. This could be an instance of either + ECD or LLM classes, which should have a method `named_parameters()` that yields the name and parameter + object of each layer. + + Raises: + - re.error: If the regular expression pattern in `config.layers_to_freeze_regex` is invalid, an error is logged + and the function exits. + + Returns: + - None: This function does not return any value but modifies the model in-place by freezing certain layers. + """ + pattern = re.compile(config.layers_to_freeze_regex) + matched_layers = set() + + for name, p in model.named_parameters(): + if re.search(pattern, str(name)): + p.requires_grad = False + matched_layers.add(name) + if matched_layers: + logger.info(f"Layers where requires_grad was set to False: {matched_layers}") + else: + logger.error(f"No regex match for {config.layers_to_freeze_regex}! Check layer names and regex syntax.") + + count_parameters(model) + + +def count_parameters(model) -> None: + """Counts number of trainable parameters post freezing. + + Returns: + - None: This function does not return any value. + """ + total_params = 0 + for _, parameter in model.named_parameters(): + if not parameter.requires_grad: + continue + params = parameter.numel() + + total_params += params + + logger.info(f"Total Trainable Parameters after freezing: {total_params}") diff --git a/tests/integration_tests/test_cli.py b/tests/integration_tests/test_cli.py index 16791ee6e3b..0c0302a2e5c 100644 --- a/tests/integration_tests/test_cli.py +++ b/tests/integration_tests/test_cli.py @@ -269,6 +269,26 @@ def test_collect_summary_activations_weights_cli(tmpdir, csv_filename): assert _run_ludwig("collect_summary", model=os.path.join(tmpdir, "experiment_run", MODEL_FILE_NAME)) +@pytest.mark.parametrize( + "model_name", + [ + "alexnet", + "convnext_base", + "convnext_large", + "convnext_small", + "convnext_tiny", + "densenet121", + "densenet161", + "densenet169", + "openai-community/gpt2", + "facebook/opt-125m", + ], +) +def test_collect_summary_pretrained_model_cli(model_name): + """Test collect_summary pretrained model cli.""" + assert _run_ludwig("collect_summary", pretrained_model=model_name) + + def test_synthesize_dataset_cli(tmpdir, csv_filename): """Test synthesize_data cli.""" # test depends on default setting of --dataset_size diff --git a/tests/ludwig/modules/test_regex_freezing.py b/tests/ludwig/modules/test_regex_freezing.py new file mode 100644 index 00000000000..7ec39544710 --- /dev/null +++ b/tests/ludwig/modules/test_regex_freezing.py @@ -0,0 +1,110 @@ +import logging +import re +from contextlib import nullcontext as no_error_raised + +import pytest + +from ludwig.api import LudwigModel +from ludwig.constants import ( + BASE_MODEL, + BATCH_SIZE, + EPOCHS, + GENERATION, + INPUT_FEATURES, + MODEL_LLM, + MODEL_TYPE, + OUTPUT_FEATURES, + TRAINER, + TYPE, +) +from ludwig.encoders.image.torchvision import TVEfficientNetEncoder +from ludwig.schema.trainer import ECDTrainerConfig +from ludwig.utils.misc_utils import set_random_seed +from ludwig.utils.trainer_utils import freeze_layers_regex +from tests.integration_tests.utils import category_feature, generate_data, image_feature, text_feature + +RANDOM_SEED = 130 + + +@pytest.mark.parametrize( + "regex", + [ + r"(features\.1.*|features\.2.*|features\.3.*|model\.features\.4\.1\.block\.3\.0\.weight)", + r"(features\.1.*|features\.2\.*|features\.3.*)", + r"(features\.4\.0\.block|features\.4\.\d+\.block)", + r"(features\.5\.*|features\.6\.*|features\.7\.*)", + r"(features\.8\.\d+\.weight|features\.8\.\d+\.bias)", + ], +) +def test_tv_efficientnet_freezing(regex): + set_random_seed(RANDOM_SEED) + + pretrained_model = TVEfficientNetEncoder( + model_variant="b0", use_pretrained=False, saved_weights_in_checkpoint=True, trainable=True + ) + + config = ECDTrainerConfig(layers_to_freeze_regex=regex) + freeze_layers_regex(config, pretrained_model) + for name, param in pretrained_model.named_parameters(): + if re.search(re.compile(regex), name): + assert not param.requires_grad + else: + assert param.requires_grad + + +def test_llm_freezing(tmpdir, csv_filename): + input_features = [text_feature(name="input", encoder={"type": "passthrough"})] + output_features = [text_feature(name="output")] + + train_df = generate_data(input_features, output_features, filename=csv_filename, num_examples=25) + + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "hf-internal-testing/tiny-random-GPTJForCausalLM", + INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], + OUTPUT_FEATURES: [text_feature(name="output")], + TRAINER: {TYPE: "finetune", BATCH_SIZE: 8, EPOCHS: 1, "layers_to_freeze_regex": r"(h\.0\.attn\.*)"}, + GENERATION: {"pad_token_id": 0}, + } + + model = LudwigModel(config, logging_level=logging.INFO) + + output_directory: str = str(tmpdir) + model.train(dataset=train_df, output_directory=output_directory, skip_save_processed_input=False) + + for name, p in model.model.named_parameters(): + if "h.0.attn" in name: + assert not p.requires_grad + else: + assert p.requires_grad + + +def test_frozen_tv_training(tmpdir, csv_filename): + input_features = [ + image_feature(tmpdir, encoder={"type": "efficientnet", "use_pretrained": False, "model_variant": "b0"}) + ] + output_features = [category_feature()] + + config = { + "input_features": input_features, + "output_features": output_features, + TRAINER: { + "layers_to_freeze_regex": r"(features\.1.*|features\.2\.*|features\.3.*)", + "epochs": 1, + "train_steps": 1, + }, + } + + training_data_csv_path = generate_data(config["input_features"], config["output_features"], csv_filename) + model = LudwigModel(config) + + with no_error_raised(): + model.experiment( + dataset=training_data_csv_path, + skip_save_training_description=True, + skip_save_training_statistics=True, + skip_save_model=True, + skip_save_progress=True, + skip_save_log=True, + skip_save_processed_input=True, + )