Skip to content

Commit

Permalink
Support for freezing pretrained vision model layers with regex (#3981)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanreidel committed Jun 1, 2024
1 parent 423a82a commit 830c3f0
Show file tree
Hide file tree
Showing 9 changed files with 416 additions and 5 deletions.
71 changes: 71 additions & 0 deletions examples/regex_freezing/ecd_freezing_with_regex_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging
import os
import shutil

import pandas as pd
import yaml
from datasets import load_dataset

from ludwig.api import LudwigModel

"""
To inspect model layers in the terminal, type: "ludwig collect_summary -pm resnet18"
For some models, a HuggingFace Token will be necessary.
Once you obtain one, use "export HUGGING_FACE_HUB_TOKEN="<api_token>"" in the terminal.
"""

dataset = load_dataset("beans")
train_df = pd.DataFrame(
{"image_path": [f"train_{i}.jpg" for i in range(len(dataset["train"]))], "label": dataset["train"]["labels"]}
)
test_df = pd.DataFrame(
{"image_path": [f"test_{i}.jpg" for i in range(len(dataset["test"]))], "label": dataset["test"]["labels"]}
)

os.makedirs("train_images", exist_ok=True)
os.makedirs("test_images", exist_ok=True)

for i, img in enumerate(dataset["train"]["image"]):
img.save(f"train_images/train_{i}.jpg")
for i, img in enumerate(dataset["test"]["image"]):
img.save(f"test_images/test_{i}.jpg")

train_df["image_path"] = train_df["image_path"].apply(lambda x: os.path.join("train_images", x))
test_df["image_path"] = test_df["image_path"].apply(lambda x: os.path.join("test_images", x))

train_df.to_csv("beans_train.csv", index=False)
test_df.to_csv("beans_test.csv", index=False)


config = yaml.safe_load(
r"""
input_features:
- name: image_path
type: image
encoder:
type: resnet
use_pretrained: true
trainable: true
output_features:
- name: label
type: category
trainer:
epochs: 1
batch_size: 5
layers_to_freeze_regex: '(layer1\.0\.*|layer2\.0\.*)'
"""
)

model = LudwigModel(config, logging_level=logging.INFO)
train_stats = model.train(dataset="beans_train.csv", skip_save_model=True)
eval_stats, predictions, output_directory = model.evaluate(dataset="beans_test.csv")

print("Training Statistics: ", train_stats)
print("Evaluation Statistics: ", eval_stats)

shutil.rmtree("train_images")
shutil.rmtree("test_images")
os.remove("beans_train.csv")
os.remove("beans_test.csv")
64 changes: 64 additions & 0 deletions examples/regex_freezing/llm_freezing_with_regex_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging

import yaml

from ludwig.api import LudwigModel

"""
To inspect model layers in the terminal, type: "ludwig collect_summary -pm resnet18"
For some models, a HuggingFace Token will be necessary.
Once you obtain one, use "export HUGGING_FACE_HUB_TOKEN="<api_token>"" in the terminal.
"""

config_str = yaml.safe_load(
r"""
model_type: llm
base_model: facebook/opt-350m
adapter:
type: lora
prompt:
template: |
### Instruction:
Generate a concise summary of the following text, capturing the main points and conclusions.
### Input:
{input}
### Response:
input_features:
- name: prompt
type: text
preprocessing:
max_sequence_length: 256
output_features:
- name: output
type: text
preprocessing:
max_sequence_length: 256
trainer:
type: finetune
layers_to_freeze_regex: (decoder\.layers\.22\.final_layer_norm\.*)
learning_rate: 0.0001
batch_size: 5
gradient_accumulation_steps: 16
epochs: 1
learning_rate_scheduler:
warmup_fraction: 0.01
preprocessing:
sample_ratio: 0.1
generation:
pad_token_id : 0
"""
)

model = LudwigModel(config=config_str, logging_level=logging.INFO)
results = model.train(dataset="ludwig://alpaca")
60 changes: 56 additions & 4 deletions ludwig/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
import argparse
import importlib
import logging
import os
import sys
Expand Down Expand Up @@ -186,6 +187,51 @@ def print_model_summary(model_path: str, **kwargs) -> None:
logger.info(name)


def pretrained_summary(pretrained_model: str, **kwargs) -> None:
"""Loads a pretrained model from Huggingface or Torchvision models and prints names of layers.
# Inputs
:param pretrained_model: (str) name of model to load (case sensitive).
# Return
:return: (`None`)
"""
from transformers import AutoConfig, AutoModel

model = None
# get access token if available
token = os.getenv("HUGGING_FACE_HUB_TOKEN")
if token is None:
logger.info("No token provided. Continuing loading without token access.")
elif not token:
raise ValueError("Invalid token provided. Exiting.")
else:
logger.info("Valid token provided. Proceeding with token access.")

# Try to load from transformers/HF
# TODO -> Fix OOM on large models e.g. llama 3 8B
try:
config = AutoConfig.from_pretrained(pretrained_model, token=token, low_cpu_mem_usage=True)
model = AutoModel.from_config(config=config)
logger.info(f"Loaded {pretrained_model} from Hugging Face Transformers.")
except Exception as e:
logger.error(f"Failed to load {pretrained_model} from Hugging Face Transformers: {e}")

# Try and load from torchvision-models
if model is None:
try:
module = importlib.import_module("torchvision.models")
model = getattr(module, pretrained_model)(weights=None)
except AttributeError:
logger.error(f"{pretrained_model} is not a valid torchvision model.")

if model:
for name, _ in model.named_parameters():
logger.info(name)
else:
logger.error(f"Unable to load the model {pretrained_model} from any known source.")


def cli_collect_activations(sys_argv):
"""Command Line Interface to communicate with the collection of tensors and there are several options that can
specified when calling this function:
Expand Down Expand Up @@ -374,8 +420,8 @@ def cli_collect_weights(sys_argv):
def cli_collect_summary(sys_argv):
"""Command Line Interface to collecting a summary of the model layers and weights.
--m: Input model that is necessary to collect to the tensors, this is a
required *option*
--m: Input model that is necessary to collect to the tensors
--pm: Model name in order to fetch from Huggingface or Torchvision
--v: Verbose: Defines the logging level that the user will be exposed to
"""
parser = argparse.ArgumentParser(
Expand All @@ -389,7 +435,10 @@ def cli_collect_summary(sys_argv):
# ----------------
# Model parameters
# ----------------
parser.add_argument("-m", "--model_path", help="model to load", required=True)
parser.add_argument("-m", "--model_path", help="model to load", required=False)
parser.add_argument(
"-pm", "--pretrained_model", help="pretrained model to summarize (torchvision and huggingface)", required=False
)

# ------------------
# Runtime parameters
Expand All @@ -416,7 +465,10 @@ def cli_collect_summary(sys_argv):

print_ludwig("Collect Summary", LUDWIG_VERSION)

print_model_summary(**vars(args))
if args.model_path:
print_model_summary(**vars(args))
elif args.pretrained_model and not args.model_path:
pretrained_summary(**vars(args))


if __name__ == "__main__":
Expand Down
9 changes: 9 additions & 0 deletions ludwig/schema/metadata/configs/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ ecd:
In many large-scale training runs, evaluation is often configured to run on
a sub-epoch time scale, or every few thousand steps.
ui_display_name: Checkpoints per epoch
layers_to_freeze_regex:
default_value_reasoning:
By default no layers will be frozen when fine-tuning a pretrained model.
description_implications:
Freezing specific layers can improve a pretrained model's performance in a number
of ways. At a basic level, freezing early layers can prevent overfitting by retaining
more general features (beneficial for small datasets). Also can reduce computational
resource use and lower overall training time due to less gradient calculations.
expected_impact: 1
early_stop:
default_value_reasoning:
Deep learning models are prone to overfitting. It's generally
Expand Down
20 changes: 20 additions & 0 deletions ludwig/schema/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from abc import ABC
from typing import Optional, Type, Union

Expand Down Expand Up @@ -162,6 +163,14 @@ def __post_init__(self):
f"`gradient_accumulation_steps` ({self.gradient_accumulation_steps})."
)

if self.layers_to_freeze_regex:
try:
re.compile(self.layers_to_freeze_regex)
except re.error:
raise ConfigValidationError(
f"`layers_to_freeze_regex` ({self.layers_to_freeze_regex}) must be a valid regular expression."
)

learning_rate: Union[float, str] = schema_utils.OneOfOptionsField(
default=0.001,
allow_none=False,
Expand Down Expand Up @@ -444,6 +453,17 @@ def __post_init__(self):
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["enable_gradient_checkpointing"],
)

layers_to_freeze_regex: str = schema_utils.String(
default=None,
allow_none=True,
description=(
"Freeze specific layers based on provided regex. Freezing specific layers can improve a "
"pretrained model's performance in a number of ways. At a basic level, freezing early layers can "
"prevent overfitting by retaining more general features (beneficial for small datasets). Also can "
"reduce computational resource use and lower overall training time due to less gradient calculations. "
),
)

def update_batch_size_grad_accum(self, num_workers: int):
from ludwig.utils.trainer_utils import get_rendered_batch_size_grad_accum

Expand Down
6 changes: 6 additions & 0 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from ludwig.utils.torch_utils import get_torch_device
from ludwig.utils.trainer_utils import (
append_metrics,
freeze_layers_regex,
get_final_steps_per_checkpoint,
get_latest_metrics_dict,
get_new_progress_tracker,
Expand Down Expand Up @@ -171,6 +172,7 @@ def __init__(
self._validation_field = config.validation_field
self._validation_metric = config.validation_metric
self.early_stop = config.early_stop
self.layers_to_freeze_regex = config.layers_to_freeze_regex
self.steps_per_checkpoint = config.steps_per_checkpoint
self.checkpoints_per_epoch = config.checkpoints_per_epoch
self.evaluate_training_set = config.evaluate_training_set
Expand Down Expand Up @@ -242,6 +244,10 @@ def prepare(self):
base_learning_rate *= lr_scale_fn(self.distributed.size())
self.base_learning_rate = base_learning_rate

# Given that regex is supplied, freeze layers
if self.config.layers_to_freeze_regex:
freeze_layers_regex(self.config, self.model)

# We may need to replace the embedding layer when using 8-bit optimizers from bitsandbytes.
update_embedding_layer(self.compiled_model, self.config)

Expand Down
61 changes: 60 additions & 1 deletion ludwig/utils/trainer_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
from collections import defaultdict
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, TYPE_CHECKING, Union

try:
from typing import Literal
Expand All @@ -10,14 +11,18 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import AUTO, COMBINED, LOSS
from ludwig.models.base import BaseModel
from ludwig.models.ecd import ECD
from ludwig.models.llm import LLM
from ludwig.modules.metric_modules import get_best_function
from ludwig.schema.trainer import ECDTrainerConfig, FineTuneTrainerConfig
from ludwig.utils.data_utils import save_json
from ludwig.utils.metric_utils import TrainerMetric

if TYPE_CHECKING:
from ludwig.features.base_feature import OutputFeature
from ludwig.schema.trainer import BaseTrainerConfig


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -506,3 +511,57 @@ def get_rendered_batch_size_grad_accum(config: "BaseTrainerConfig", num_workers:
gradient_accumulation_steps = 1

return batch_size, gradient_accumulation_steps


def freeze_layers_regex(config: Union[ECDTrainerConfig, FineTuneTrainerConfig], model: Union[ECD, LLM]) -> None:
"""Freezes layers in a model whose names match a specified regular expression pattern.
This function iterates over all parameters of the model, checking each parameter's name against
the regular expression defined in the configuration object.
If a match is found, the parameter's `requires_grad` attribute is set to False,
effectively freezing the layer for training purposes.
If no matches are found, an error is logged indicating the issue with the regex or the model's layer names.
Parameters:
- config (Union[ECDTrainerConfig, FineTuneTrainerConfig]):
- model (Union[ECD, LLM]): The model object containing layers and parameters. This could be an instance of either
ECD or LLM classes, which should have a method `named_parameters()` that yields the name and parameter
object of each layer.
Raises:
- re.error: If the regular expression pattern in `config.layers_to_freeze_regex` is invalid, an error is logged
and the function exits.
Returns:
- None: This function does not return any value but modifies the model in-place by freezing certain layers.
"""
pattern = re.compile(config.layers_to_freeze_regex)
matched_layers = set()

for name, p in model.named_parameters():
if re.search(pattern, str(name)):
p.requires_grad = False
matched_layers.add(name)
if matched_layers:
logger.info(f"Layers where requires_grad was set to False: {matched_layers}")
else:
logger.error(f"No regex match for {config.layers_to_freeze_regex}! Check layer names and regex syntax.")

count_parameters(model)


def count_parameters(model) -> None:
"""Counts number of trainable parameters post freezing.
Returns:
- None: This function does not return any value.
"""
total_params = 0
for _, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
params = parameter.numel()

total_params += params

logger.info(f"Total Trainable Parameters after freezing: {total_params}")
Loading

0 comments on commit 830c3f0

Please sign in to comment.