-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds paligemma modeling code Blog post: https://huggingface.co/blog/paligemma Transformers PR: huggingface/transformers#30814 install the latest changes and run with ```bash # get the weights # text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf # run TGI text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf ``` basic example sending various requests ```python from huggingface_hub import InferenceClient client = InferenceClient("http://127.0.0.1:3000") images = [ "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png", ] prompts = [ "What animal is in this image?", "Name three colors in this image.", "What are 10 colors in this image?", "Where is the cow standing?", "answer en Where is the cow standing?", "Is there a bird in the image?", "Is ther a cow in the image?", "Is there a rabbit in the image?", "how many birds are in the image?", "how many rabbits are in the image?", ] for img in images: print(f"\nImage: {img.split('/')[-1]}") for prompt in prompts: inputs = f"![]({img}){prompt}\n" json_data = { "inputs": inputs, "parameters": { "max_new_tokens": 30, "do_sample": False, }, } generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False) print([f"{prompt}\n{generated_output}"]) ``` --------- Co-authored-by: Nicolas Patry <[email protected]>
- Loading branch information
Showing
23 changed files
with
1,148 additions
and
157 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions
25
integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
{ | ||
"details": { | ||
"best_of_sequences": null, | ||
"finish_reason": "eos_token", | ||
"generated_tokens": 2, | ||
"prefill": [], | ||
"seed": null, | ||
"tokens": [ | ||
{ | ||
"id": 54901, | ||
"logprob": -0.72753906, | ||
"special": false, | ||
"text": "beach" | ||
}, | ||
{ | ||
"id": 1, | ||
"logprob": -0.011009216, | ||
"special": true, | ||
"text": "<eos>" | ||
} | ||
], | ||
"top_tokens": null | ||
}, | ||
"generated_text": "beach" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
import requests | ||
import io | ||
import base64 | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def flash_pali_gemma_handle(launcher): | ||
with launcher( | ||
"google/paligemma-3b-pt-224", | ||
num_shard=1, | ||
revision="float16", | ||
max_input_length=4000, | ||
max_total_tokens=4096, | ||
) as handle: | ||
yield handle | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
async def flash_pali_gemma(flash_pali_gemma_handle): | ||
await flash_pali_gemma_handle.health(300) | ||
return flash_pali_gemma_handle.client | ||
|
||
|
||
def get_cow_beach(): | ||
with open("integration-tests/images/cow_beach.png", "rb") as image_file: | ||
encoded_string = base64.b64encode(image_file.read()) | ||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}" | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.private | ||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): | ||
cow = get_cow_beach() | ||
inputs = f"![]({cow})Where is the cow standing?\n" | ||
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) | ||
|
||
assert response.generated_text == "beach" | ||
assert response == response_snapshot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.