diff --git a/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py b/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py index c39dd36..88023d0 100644 --- a/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py +++ b/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py @@ -33,7 +33,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection class GitGPTNeoXConfig(GPTNeoXConfig): @@ -95,7 +95,7 @@ def __init__(self, config: GPTNeoXConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py b/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py index b7ec65a..fa3ada2 100644 --- a/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py +++ b/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py @@ -27,7 +27,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig from .modeling_japanese_stablelm_alpha import ( @@ -95,7 +95,7 @@ def __init__(self, config: JapaneseStableLMAlphaConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/git_llm/git_llama/modeling_git_llama.py b/heron/models/git_llm/git_llama/modeling_git_llama.py index 47dcaa0..60b5947 100644 --- a/heron/models/git_llm/git_llama/modeling_git_llama.py +++ b/heron/models/git_llm/git_llama/modeling_git_llama.py @@ -33,7 +33,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection class GitLlamaConfig(LlamaConfig): @@ -95,7 +95,7 @@ def __init__(self, config: LlamaConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/git_llm/git_mpt/modeling_git_mpt.py b/heron/models/git_llm/git_mpt/modeling_git_mpt.py index 13dd2e8..4472e09 100644 --- a/heron/models/git_llm/git_mpt/modeling_git_mpt.py +++ b/heron/models/git_llm/git_mpt/modeling_git_mpt.py @@ -42,7 +42,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection class GitMptConfig(MptConfig): @@ -105,7 +105,7 @@ def __init__(self, config: MptConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/git_llm/git_opt/modeling_git_opt.py b/heron/models/git_llm/git_opt/modeling_git_opt.py index 453cef3..18c0726 100644 --- a/heron/models/git_llm/git_opt/modeling_git_opt.py +++ b/heron/models/git_llm/git_opt/modeling_git_opt.py @@ -33,7 +33,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection from transformers.models.opt.modeling_opt import OPTLearnedPositionalEmbedding @@ -96,7 +96,7 @@ def __init__(self, config: OPTConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/mlp_adapter.py b/heron/models/mlp_adapter.py new file mode 100644 index 0000000..4fd20d6 --- /dev/null +++ b/heron/models/mlp_adapter.py @@ -0,0 +1,28 @@ +import re +import torch +import torch.nn as nn +from transformers.models.git.modeling_git import GitProjection + +class MLPProjection(GitProjection): + def __init__(self, config): + super(MLPProjection, self).__init__(config) + self.config = config + + if re.match(r'^mlp(\d+)x_gelu$', config.mlp_adapter): + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.mlp_adapter) + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.vision_config.hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + modules.append(nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps)) + elif re.match(r'^linear$', config.mlp_adapter): + modules = [nn.Linear(config.vision_config.hidden_size, config.hidden_size)] + modules.append(nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps)) + else: + raise ValueError(f'Unknown mlp_adapter name: {config.mlp_adapter}') + + self.visual_projection = nn.Sequential(*modules) + + def forward(self, embeddings: torch.Tensor) -> torch.Tensor: + return self.visual_projection(embeddings) diff --git a/heron/models/utils.py b/heron/models/utils.py index 3aa0a5c..55baad4 100644 --- a/heron/models/utils.py +++ b/heron/models/utils.py @@ -31,6 +31,10 @@ def load_model( model_type = model_config["model_type"] language_model = model_config["language_model_name"] num_image_with_embedding = model_config["num_image_with_embedding"] + if "mlp_adapter" in model_config: + adapter_name = model_config["mlp_adapter"] + else: + adapter_name = "linear" # set dtype if model_config.get("fp16", False): @@ -47,6 +51,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitOPTForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype ) @@ -59,6 +64,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitLlamaForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype ) @@ -71,6 +77,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitMptForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype ) @@ -86,6 +93,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitJapaneseStableLMAlphaForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype ) @@ -102,6 +110,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitGPTNeoXForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype )