Skip to content

Multi modality fix #3283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions server/text_generation_server/models/custom_modeling/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,34 +710,41 @@ def forward(
# )
if SYSTEM == "ipex":
attn_output = torch.empty_like(query_states)
ipex.llm.functional.varlen_attention(
(
query_states.contiguous()
if query_states.device.type == "xpu"
else query_states
),
(
key_states.contiguous()
if key_states.device.type == "xpu"
else key_states
),
(
value_states.contiguous()
if value_states.device.type == "xpu"
else value_states
),
attn_output,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query_states.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
attn_output,
cu_seqlen_q,
cu_seqlen_k,
None,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query_states,
key_states,
value_states,
attn_output,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query_states,
Expand Down
51 changes: 35 additions & 16 deletions server/text_generation_server/models/custom_modeling/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,22 +460,41 @@ def forward(
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,
Expand Down
51 changes: 35 additions & 16 deletions server/text_generation_server/models/custom_modeling/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,41 @@ def forward(
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/mllama_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def concatenate(cls, batches):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super().filter(request_ids)
batch = super(VlmCausalLMBatch, self).filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
Expand All @@ -85,6 +85,7 @@ def filter(self, request_ids: List[int]):
]
else:
batch.cross_attention_states = None
batch.pixel_values = None
return batch

@classmethod
Expand Down