Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Multi Layer Perceptron (MLP) selection for projector #25

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
BaseModelOutputWithPooling,
CausalLMOutputWithPast,
)
from transformers.models.git.modeling_git import GitProjection
from heron.models.mlp_adapter import MLPProjection


class GitGPTNeoXConfig(GPTNeoXConfig):
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self, config: GPTNeoXConfig):

# Git modules
self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
self.visual_projection = GitProjection(config)
self.visual_projection = MLPProjection(config)

if config.num_image_with_embedding is not None:
self.img_temporal_embedding = nn.ParameterList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
BaseModelOutputWithPooling,
CausalLMOutputWithPast,
)
from transformers.models.git.modeling_git import GitProjection
from heron.models.mlp_adapter import MLPProjection

from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
from .modeling_japanese_stablelm_alpha import (
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self, config: JapaneseStableLMAlphaConfig):

# Git modules
self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
self.visual_projection = GitProjection(config)
self.visual_projection = MLPProjection(config)

if config.num_image_with_embedding is not None:
self.img_temporal_embedding = nn.ParameterList(
Expand Down
4 changes: 2 additions & 2 deletions heron/models/git_llm/git_llama/modeling_git_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
BaseModelOutputWithPooling,
CausalLMOutputWithPast,
)
from transformers.models.git.modeling_git import GitProjection
from heron.models.mlp_adapter import MLPProjection


class GitLlamaConfig(LlamaConfig):
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self, config: LlamaConfig):

# Git modules
self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
self.visual_projection = GitProjection(config)
self.visual_projection = MLPProjection(config)

if config.num_image_with_embedding is not None:
self.img_temporal_embedding = nn.ParameterList(
Expand Down
4 changes: 2 additions & 2 deletions heron/models/git_llm/git_mpt/modeling_git_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
BaseModelOutputWithPooling,
CausalLMOutputWithPast,
)
from transformers.models.git.modeling_git import GitProjection
from heron.models.mlp_adapter import MLPProjection


class GitMptConfig(MptConfig):
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, config: MptConfig):

# Git modules
self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
self.visual_projection = GitProjection(config)
self.visual_projection = MLPProjection(config)

if config.num_image_with_embedding is not None:
self.img_temporal_embedding = nn.ParameterList(
Expand Down
4 changes: 2 additions & 2 deletions heron/models/git_llm/git_opt/modeling_git_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
BaseModelOutputWithPooling,
CausalLMOutputWithPast,
)
from transformers.models.git.modeling_git import GitProjection
from heron.models.mlp_adapter import MLPProjection
from transformers.models.opt.modeling_opt import OPTLearnedPositionalEmbedding


Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(self, config: OPTConfig):

# Git modules
self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
self.visual_projection = GitProjection(config)
self.visual_projection = MLPProjection(config)

if config.num_image_with_embedding is not None:
self.img_temporal_embedding = nn.ParameterList(
Expand Down
28 changes: 28 additions & 0 deletions heron/models/mlp_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import re
import torch
import torch.nn as nn
from transformers.models.git.modeling_git import GitProjection

class MLPProjection(GitProjection):
def __init__(self, config):
super(MLPProjection, self).__init__(config)
self.config = config

if re.match(r'^mlp(\d+)x_gelu$', config.mlp_adapter):
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.mlp_adapter)
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.vision_config.hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
modules.append(nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps))
elif re.match(r'^linear$', config.mlp_adapter):
modules = [nn.Linear(config.vision_config.hidden_size, config.hidden_size)]
modules.append(nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps))
else:
raise ValueError(f'Unknown mlp_adapter name: {config.mlp_adapter}')

self.visual_projection = nn.Sequential(*modules)

def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
return self.visual_projection(embeddings)
9 changes: 9 additions & 0 deletions heron/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def load_model(
model_type = model_config["model_type"]
language_model = model_config["language_model_name"]
num_image_with_embedding = model_config["num_image_with_embedding"]
if "mlp_adapter" in model_config:
adapter_name = model_config["mlp_adapter"]
else:
adapter_name = "linear"

# set dtype
if model_config.get("fp16", False):
Expand All @@ -47,6 +51,7 @@ def load_model(
num_image_with_embedding=num_image_with_embedding,
vision_model_name=model_config["vision_model_name"],
)
git_config.mlp_adapter = adapter_name
model = GitOPTForCausalLM.from_pretrained(
language_model, config=git_config, torch_dtype=torch_dtype
)
Expand All @@ -59,6 +64,7 @@ def load_model(
num_image_with_embedding=num_image_with_embedding,
vision_model_name=model_config["vision_model_name"],
)
git_config.mlp_adapter = adapter_name
model = GitLlamaForCausalLM.from_pretrained(
language_model, config=git_config, torch_dtype=torch_dtype
)
Expand All @@ -71,6 +77,7 @@ def load_model(
num_image_with_embedding=num_image_with_embedding,
vision_model_name=model_config["vision_model_name"],
)
git_config.mlp_adapter = adapter_name
model = GitMptForCausalLM.from_pretrained(
language_model, config=git_config, torch_dtype=torch_dtype
)
Expand All @@ -86,6 +93,7 @@ def load_model(
num_image_with_embedding=num_image_with_embedding,
vision_model_name=model_config["vision_model_name"],
)
git_config.mlp_adapter = adapter_name
model = GitJapaneseStableLMAlphaForCausalLM.from_pretrained(
language_model, config=git_config, torch_dtype=torch_dtype
)
Expand All @@ -102,6 +110,7 @@ def load_model(
num_image_with_embedding=num_image_with_embedding,
vision_model_name=model_config["vision_model_name"],
)
git_config.mlp_adapter = adapter_name
model = GitGPTNeoXForCausalLM.from_pretrained(
language_model, config=git_config, torch_dtype=torch_dtype
)
Expand Down