From be7ed2b7133680303636106c6ffa1ce7ecc5f763 Mon Sep 17 00:00:00 2001 From: Onely7-nlp Date: Thu, 2 Nov 2023 20:59:25 +0900 Subject: [PATCH 1/2] add mlp_projector --- .../git_gpt_neox/modeling_git_gpt_neox.py | 4 +-- .../modeling_git_japanese_stablelm_alpha.py | 4 +-- .../git_llm/git_llama/modeling_git_llama.py | 4 +-- .../git_llm/git_mpt/modeling_git_mpt.py | 4 +-- .../git_llm/git_opt/modeling_git_opt.py | 4 +-- heron/models/mlp_adapter.py | 29 +++++++++++++++++++ heron/models/utils.py | 9 ++++++ 7 files changed, 48 insertions(+), 10 deletions(-) create mode 100644 heron/models/mlp_adapter.py diff --git a/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py b/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py index c39dd36..88023d0 100644 --- a/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py +++ b/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py @@ -33,7 +33,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection class GitGPTNeoXConfig(GPTNeoXConfig): @@ -95,7 +95,7 @@ def __init__(self, config: GPTNeoXConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py b/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py index b7ec65a..fa3ada2 100644 --- a/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py +++ b/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py @@ -27,7 +27,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig from .modeling_japanese_stablelm_alpha import ( @@ -95,7 +95,7 @@ def __init__(self, config: JapaneseStableLMAlphaConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/git_llm/git_llama/modeling_git_llama.py b/heron/models/git_llm/git_llama/modeling_git_llama.py index 47dcaa0..60b5947 100644 --- a/heron/models/git_llm/git_llama/modeling_git_llama.py +++ b/heron/models/git_llm/git_llama/modeling_git_llama.py @@ -33,7 +33,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection class GitLlamaConfig(LlamaConfig): @@ -95,7 +95,7 @@ def __init__(self, config: LlamaConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/git_llm/git_mpt/modeling_git_mpt.py b/heron/models/git_llm/git_mpt/modeling_git_mpt.py index 13dd2e8..4472e09 100644 --- a/heron/models/git_llm/git_mpt/modeling_git_mpt.py +++ b/heron/models/git_llm/git_mpt/modeling_git_mpt.py @@ -42,7 +42,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection class GitMptConfig(MptConfig): @@ -105,7 +105,7 @@ def __init__(self, config: MptConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/git_llm/git_opt/modeling_git_opt.py b/heron/models/git_llm/git_opt/modeling_git_opt.py index 453cef3..18c0726 100644 --- a/heron/models/git_llm/git_opt/modeling_git_opt.py +++ b/heron/models/git_llm/git_opt/modeling_git_opt.py @@ -33,7 +33,7 @@ BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from transformers.models.git.modeling_git import GitProjection +from heron.models.mlp_adapter import MLPProjection from transformers.models.opt.modeling_opt import OPTLearnedPositionalEmbedding @@ -96,7 +96,7 @@ def __init__(self, config: OPTConfig): # Git modules self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) - self.visual_projection = GitProjection(config) + self.visual_projection = MLPProjection(config) if config.num_image_with_embedding is not None: self.img_temporal_embedding = nn.ParameterList( diff --git a/heron/models/mlp_adapter.py b/heron/models/mlp_adapter.py new file mode 100644 index 0000000..e18a026 --- /dev/null +++ b/heron/models/mlp_adapter.py @@ -0,0 +1,29 @@ +import re +import torch +import torch.nn as nn +from transformers.models.git.modeling_git import GitProjection + +class MLPProjection(GitProjection): + def __init__(self, config): + super(MLPProjection, self).__init__(config) + self.config = config + + if hasattr(config, "mlp_adapter"): + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.mlp_adapter) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.vision_config.hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + modules.append(nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps)) + else: + raise ValueError(f'Unknown mlp_adapter name: {config.mlp_adapter}') + else: + modules = [nn.Linear(config.vision_config.hidden_size, config.hidden_size)] + modules.append(nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps)) + + self.visual_projection = nn.Sequential(*modules) + + def forward(self, embeddings: torch.Tensor) -> torch.Tensor: + return self.visual_projection(embeddings) diff --git a/heron/models/utils.py b/heron/models/utils.py index 3aa0a5c..55baad4 100644 --- a/heron/models/utils.py +++ b/heron/models/utils.py @@ -31,6 +31,10 @@ def load_model( model_type = model_config["model_type"] language_model = model_config["language_model_name"] num_image_with_embedding = model_config["num_image_with_embedding"] + if "mlp_adapter" in model_config: + adapter_name = model_config["mlp_adapter"] + else: + adapter_name = "linear" # set dtype if model_config.get("fp16", False): @@ -47,6 +51,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitOPTForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype ) @@ -59,6 +64,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitLlamaForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype ) @@ -71,6 +77,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitMptForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype ) @@ -86,6 +93,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitJapaneseStableLMAlphaForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype ) @@ -102,6 +110,7 @@ def load_model( num_image_with_embedding=num_image_with_embedding, vision_model_name=model_config["vision_model_name"], ) + git_config.mlp_adapter = adapter_name model = GitGPTNeoXForCausalLM.from_pretrained( language_model, config=git_config, torch_dtype=torch_dtype ) From 89c8b71b8d6c1b2ee5e511e9ff5c700ad9a5b675 Mon Sep 17 00:00:00 2001 From: Onely7-nlp Date: Mon, 6 Nov 2023 02:46:33 +0900 Subject: [PATCH 2/2] modify mlp_adapter.py --- heron/models/mlp_adapter.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/heron/models/mlp_adapter.py b/heron/models/mlp_adapter.py index e18a026..4fd20d6 100644 --- a/heron/models/mlp_adapter.py +++ b/heron/models/mlp_adapter.py @@ -8,20 +8,19 @@ def __init__(self, config): super(MLPProjection, self).__init__(config) self.config = config - if hasattr(config, "mlp_adapter"): + if re.match(r'^mlp(\d+)x_gelu$', config.mlp_adapter): mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.mlp_adapter) - if mlp_gelu_match: - mlp_depth = int(mlp_gelu_match.group(1)) - modules = [nn.Linear(config.vision_config.hidden_size, config.hidden_size)] - for _ in range(1, mlp_depth): - modules.append(nn.GELU()) - modules.append(nn.Linear(config.hidden_size, config.hidden_size)) - modules.append(nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps)) - else: - raise ValueError(f'Unknown mlp_adapter name: {config.mlp_adapter}') - else: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.vision_config.hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + modules.append(nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps)) + elif re.match(r'^linear$', config.mlp_adapter): modules = [nn.Linear(config.vision_config.hidden_size, config.hidden_size)] modules.append(nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps)) + else: + raise ValueError(f'Unknown mlp_adapter name: {config.mlp_adapter}') self.visual_projection = nn.Sequential(*modules)