Skip to content

Commit

Permalink
Pali gemma modeling (#1895)
Browse files Browse the repository at this point in the history
This PR adds paligemma modeling code

Blog post: https://huggingface.co/blog/paligemma
Transformers PR: huggingface/transformers#30814

install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf

# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```


basic example sending various requests
```python
from huggingface_hub import InferenceClient

client = InferenceClient("http://127.0.0.1:3000")


images = [
    "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]

prompts = [
    "What animal is in this image?",
    "Name three colors in this image.",
    "What are 10 colors in this image?",
    "Where is the cow standing?",
    "answer en Where is the cow standing?",
    "Is there a bird in the image?",
    "Is ther a cow in the image?",
    "Is there a rabbit in the image?",
    "how many birds are in the image?",
    "how many rabbits are in the image?",
]

for img in images:
    print(f"\nImage: {img.split('/')[-1]}")
    for prompt in prompts:
        inputs = f"![]({img}){prompt}\n"
        json_data = {
            "inputs": inputs,
            "parameters": {
                "max_new_tokens": 30,
                "do_sample": False,
            },
        }
        generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
        print([f"{prompt}\n{generated_output}"])

```

---------

Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
drbh and Narsil authored May 16, 2024
1 parent 6c715f8 commit 40213c9
Show file tree
Hide file tree
Showing 23 changed files with 1,148 additions and 157 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002
EC2_AMI_ID: ami-0789b6925c11b1fb2
EC2_INSTANCE_TYPE: g5.12xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6
Expand Down
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ARG PYTORCH_VERSION=2.3.0
ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1
ARG MAMBA_VERSION=23.3.1-1
ARG MAMBA_VERSION=24.3.0-0
ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx
Expand Down Expand Up @@ -181,6 +181,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
ca-certificates \
make \
curl \
git \
&& rm -rf /var/lib/apt/lists/*

# Copy conda with PyTorch installed
Expand Down
Binary file added integration-tests/images/cow_beach.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 2,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 54901,
"logprob": -0.72753906,
"special": false,
"text": "beach"
},
{
"id": 1,
"logprob": -0.011009216,
"special": true,
"text": "<eos>"
}
],
"top_tokens": null
},
"generated_text": "beach"
}
39 changes: 39 additions & 0 deletions integration-tests/models/test_flash_pali_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import requests
import io
import base64


@pytest.fixture(scope="module")
def flash_pali_gemma_handle(launcher):
with launcher(
"google/paligemma-3b-pt-224",
num_shard=1,
revision="float16",
max_input_length=4000,
max_total_tokens=4096,
) as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_pali_gemma(flash_pali_gemma_handle):
await flash_pali_gemma_handle.health(300)
return flash_pali_gemma_handle.client


def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
cow = get_cow_beach()
inputs = f"![]({cow})Where is the cow standing?\n"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)

assert response.generated_text == "beach"
assert response == response_snapshot
21 changes: 19 additions & 2 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,13 @@ impl LlavaNext {
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct ClipVisionModel {
image_size: usize,
patch_size: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}

Expand All @@ -118,6 +116,24 @@ impl Idefics2 {
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PaliTextConfig {
num_image_tokens: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
text_config: PaliTextConfig,
}

impl Paligemma {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
self.text_config.num_image_tokens
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
Expand All @@ -140,6 +156,7 @@ pub enum Config {
Phi3,
Llama,
Baichuan,
Paligemma(Paligemma),
Gemma,
Cohere,
Drbx,
Expand Down
24 changes: 24 additions & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,30 @@ fn prepare_input(
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Paligemma(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
Expand Down
Loading

0 comments on commit 40213c9

Please sign in to comment.