Skip to content

Commit

Permalink
Full-stack PaliGemma
Browse files Browse the repository at this point in the history
- Decoder should take embeddings combined with text and image instead of token ids.
- If only text token ids are passed, call forward() as before.
- If image embeddings are merged, position id is 1-based, not 0-based. RoPE calculation should take it into account.
- If image embeddings are merged, diagonal causal mask doesn't work. Pass a mask just to remove unused KV entries.
- Updated verifier API to get pixel_features
- verifier.ReauthoredModelWrapper.generate() passes only new token in next round instead of all tokens of input + generated so far.
- Conversion and unittests will be added in a following change.

PiperOrigin-RevId: 696901961
  • Loading branch information
ai-edge-bot authored and copybara-github committed Nov 15, 2024
1 parent ff85c8d commit 14de8c0
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 24 deletions.
48 changes: 43 additions & 5 deletions ai_edge_torch/generative/examples/paligemma/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

"""Example of building a decoder of PaliGemma 3B model which is Gemma1."""

from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
Expand All @@ -35,6 +37,41 @@
)


class Decoder(model_builder.DecoderOnlyModel):
"""A decoder of PaliGemma 3B model which is Gemma1.
Besides a tensor of text token IDs, forward() can also take a tensor of
embeddings which may include text or image or both.
"""

@torch.inference_mode
def forward(
self,
tokens: torch.Tensor,
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
input_embeds: torch.Tensor = None,
) -> dict[torch.Tensor, kv_utils.KVCache]:
if input_embeds is None:
return super().forward(tokens, input_pos, kv_cache)

assert input_embeds is not None

repo_pos = input_pos + 1 # PaliGemma position is 1-based.
cos, sin = self.rope_cache
rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))

# The first part of input_embeds are image embeddings. Diagonal causal mask
# doesn't work here.
embeds_len = input_embeds.shape[1]
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
mask[:, embeds_len:] = float("-inf")

return self.forward_with_embeds(
input_embeds, rope, mask, input_pos, kv_cache
)


def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for the decoder of a PaliGemma 3B model.
Expand Down Expand Up @@ -96,8 +133,9 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
def build_decoder(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_decoder_config(**kwargs),
tensor_names=TENSOR_NAMES,
)
decoder = Decoder(get_decoder_config(**kwargs))
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
# Loose the strictness because only decoder is being loaded.
loader.load(decoder, strict=False)
decoder.eval()
return decoder
135 changes: 135 additions & 0 deletions ai_edge_torch/generative/examples/paligemma/paligemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Example of building a full-stack of PaliGemma model."""

from dataclasses import dataclass

from ai_edge_torch.generative.examples.paligemma import decoder
from ai_edge_torch.generative.examples.paligemma import image_encoder
import ai_edge_torch.generative.layers.kv_cache as kv_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
from torch import nn

PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"


@dataclass
class PaliGemmaConfig:
"""PaliGemma model configurations."""

image_encoder_config: cfg.ModelConfig
decoder_config: cfg.ModelConfig

image_token_id: int
image_projection_use_bias: bool = False


class PaliGemma(nn.Module):
"""PaliGemma model from the Edge Generative API."""

def __init__(self, config: PaliGemmaConfig):
super().__init__()

self.image_encoder = image_encoder.SiglipVisionEncoder(
config.image_encoder_config
)
self.image_projection = nn.Linear(
config.image_encoder_config.embedding_dim,
config.decoder_config.embedding_dim,
bias=config.image_projection_use_bias,
)
self.decoder = decoder.Decoder(config.decoder_config)
self.config = config

@torch.inference_mode
def forward(
self,
tokens: torch.Tensor,
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
pixel_values: torch.Tensor = None,
) -> dict[torch.Tensor, kv_utils.KVCache]:
if pixel_values is None:
return self.decoder(tokens, input_pos, kv_cache)

input_embeds = self.decoder.tok_embedding(tokens)

image_encoded = self.image_encoder(pixel_values=pixel_values)
image_embeds = self.image_projection(image_encoded)
if self.config.decoder_config.embedding_scale is not None:
image_embeds = image_embeds / self.config.decoder_config.embedding_scale

# Merge image_embeds into text_embeds as PaliGemmaForConditionalGeneration.
image_mask = tokens == self.config.image_token_id
image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)

return self.decoder(
tokens=None,
input_pos=input_pos,
kv_cache=kv_cache,
input_embeds=input_embeds,
)


def get_model_config() -> PaliGemmaConfig:
"""Returns the model config for a PaliGemma 3B-224 model.
Returns:
The model config for a PaliGemma 3B model.
"""
return PaliGemmaConfig(
image_encoder_config=image_encoder.get_image_encoder_config(),
decoder_config=decoder.get_decoder_config(),
image_projection_use_bias=True,
image_token_id=257152,
)


def get_fake_image_encoder_config() -> PaliGemmaConfig:
return PaliGemmaConfig(
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
decoder_config=decoder.get_fake_decoder_config(),
image_projection_use_bias=True,
image_token_id=257152,
)


def build_model(checkpoint_path: str) -> PaliGemma:
config = get_model_config()
model = PaliGemma(config)
# Load the parameters of image encoder.
loader = loading_utils.ModelLoader(
checkpoint_path, image_encoder.TENSOR_NAMES
)
loader.load(model.image_encoder, strict=False)
# Load the parameters of decoder.
loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
loader.load(model.decoder, strict=False)

# Load the parameters of image projection.
loader = loading_utils.ModelLoader(checkpoint_path, None)
state = loader.get_state()
converted_state = dict()
converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
if config.image_projection_use_bias:
converted_state["bias"] = state.pop(f"{PROJECTION_TENSOR_NAME}.bias")
model.image_projection.load_state_dict(converted_state)

model.eval()
return model
134 changes: 134 additions & 0 deletions ai_edge_torch/generative/examples/paligemma/verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Verifies the reauthored PaliGemma 3B model."""

import logging
import pathlib
from absl import app
from absl import flags
from ai_edge_torch.generative.examples.paligemma import paligemma
from ai_edge_torch.generative.layers import kv_cache
from ai_edge_torch.generative.utilities import verifier
from PIL import Image
import requests
import torch
import transformers

_IMAGE_URL = flags.DEFINE_string(
"image_url",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
"The image URI to encode.",
)
_PROMPTS = flags.DEFINE_string(
"prompts",
"Caption en",
"The input prompts to generate answers.",
)
_MAX_NEW_TOKENS = flags.DEFINE_integer(
"max_new_tokens",
30,
"The maximum size of the generated tokens.",
)


class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
"""Reauthored PaliGemma model wrapper."""

def _init_kv_cache(self):
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)


def main(_):
checkpoint = "google/paligemma-3b-mix-224"
logging.info("Loading the original model from: %s", checkpoint)
original_model = (
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
)

# Locate the cached dir.
cached_config_file = transformers.utils.cached_file(
checkpoint, transformers.utils.CONFIG_NAME
)
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
reauthored_model = paligemma.build_model(reauthored_checkpoint)

logging.info("Loading the processor from: %s", checkpoint)
# It works only when GemmaTokenizerFast is available. In some environments,
# use_fast=False doeesn't work either if the tokenizer cannot load the
# sentencepiece model file properly.
processor = transformers.AutoProcessor.from_pretrained(checkpoint)

logging.info("Loading the image from: %s", _IMAGE_URL.value)
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")

logging.info("Verifying the reauthored model with model.forward()...")
logging.info("Forwarding the original model...")
outputs_original = original_model.forward(
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"]
)
outputs_original = outputs_original.logits
logging.info("outputs_original: %s", outputs_original)

logging.info("Forwarding the reauthored model...")
wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
outputs_reauthored = wrapped_reauthored_model.forward(
tokens=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
)
logging.info("outputs_reauthored: %s", outputs_reauthored)

try:
assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-03)
except AssertionError as e:
logging.error("*** FAILED *** verify with forward()")
raise e
else:
logging.info("*** PASSED *** verify with forward()")

logging.info("Verifying the reauthored model with model.generate()...")
logging.info("Generating answer with the original model...")
outputs_original = original_model.generate(
**inputs, max_new_tokens=_MAX_NEW_TOKENS.value, do_sample=False
)
response_original = processor.decode(
outputs_original[0], skip_special_tokens=True
)
logging.info("outputs_from_original_model: [[%s]]", response_original)

logging.info("Generating answer with the reauthored model...")
outputs_reauthored = wrapped_reauthored_model.generate(
prompts=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=_MAX_NEW_TOKENS.value,
)
response_reauthored = processor.decode(
outputs_reauthored[0], skip_special_tokens=True
)
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)

try:
assert response_original == response_reauthored
except AssertionError as e:
logging.error("*** FAILED *** verify with generate()")
raise e
else:
logging.info("*** PASSED *** verify with generate()")


if __name__ == "__main__":
app.run(main)
5 changes: 4 additions & 1 deletion ai_edge_torch/generative/utilities/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def __init__(self, file_name: str, names: TensorNames) -> None:
self._names = names
self._loader = self._get_loader()

def get_state(self) -> Dict[str, torch.Tensor]:
return self._loader(self._file_name)

def load(
self, model: torch.nn.Module, strict: bool = True
) -> Tuple[List[str], List[str]]:
Expand All @@ -150,7 +153,7 @@ def load(
ValueError: If conversion results in unmapped tensors and strict mode is
enabled.
"""
state = self._loader(self._file_name)
state = self.get_state()
state = state["model_state_dict"] if "model_state_dict" in state else state
converted_state = dict()
if self._names.embedding is not None:
Expand Down
Loading

0 comments on commit 14de8c0

Please sign in to comment.