diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index be0a4b5d7c3..d4289015bea 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -710,34 +710,41 @@ def forward( # ) if SYSTEM == "ipex": attn_output = torch.empty_like(query_states) - ipex.llm.functional.varlen_attention( - ( - query_states.contiguous() - if query_states.device.type == "xpu" - else query_states - ), - ( - key_states.contiguous() - if key_states.device.type == "xpu" - else key_states - ), - ( - value_states.contiguous() - if value_states.device.type == "xpu" - else value_states - ), - attn_output, - cu_seqlen_q, - cu_seqlen_k, - max_q, - max_k, - 0.0, - self.softmax_scale, - False, - causal, - False, - None, - ) + if query_states.device.type == "xpu": + ipex.llm.functional.varlen_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_output, + cu_seqlen_q, + cu_seqlen_k, + None, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query_states, + key_states, + value_states, + attn_output, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) else: attn_output = flash_attn_2_cuda.varlen_fwd( query_states, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index e2fc60b198f..231d02b539c 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -460,22 +460,41 @@ def forward( # execute flash attention if SYSTEM == "ipex": attn_output = torch.empty_like(query) - ipex.llm.functional.varlen_attention( - (query.contiguous() if query.device.type == "xpu" else query), - (key.contiguous() if key.device.type == "xpu" else key), - (value.contiguous() if value.device.type == "xpu" else value), - attn_output, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - 0.0, - self.softmax_scale, - False, - causal, - False, - None, - ) + if query.device.type == "xpu": + ipex.llm.functional.varlen_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_output, + cu_seqlens, + cu_seqlens, + None, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query, + key, + value, + attn_output, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) else: attn_output = flash_attn_2_cuda.varlen_fwd( query, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 75f718bdf97..855eaa6a5f1 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -130,22 +130,41 @@ def forward( # execute flash attention if SYSTEM == "ipex": attn_output = torch.empty_like(query) - ipex.llm.functional.varlen_attention( - (query.contiguous() if query.device.type == "xpu" else query), - (key.contiguous() if key.device.type == "xpu" else key), - (value.contiguous() if value.device.type == "xpu" else value), - attn_output, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - 0.0, - self.softmax_scale, - False, - causal, - False, - None, - ) + if query.device.type == "xpu": + ipex.llm.functional.varlen_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_output, + cu_seqlens, + cu_seqlens, + None, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query, + key, + value, + attn_output, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) else: attn_output = flash_attn_2_cuda.varlen_fwd( query, diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index af9a811cf45..a9ecef76488 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -59,7 +59,7 @@ def concatenate(cls, batches): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): assert self.image_indices is not None - batch = super().filter(request_ids) + batch = super(VlmCausalLMBatch, self).filter(request_ids) assert self.image_indices is not None indices = [] for i, request_id in enumerate(request_ids): @@ -85,6 +85,7 @@ def filter(self, request_ids: List[int]): ] else: batch.cross_attention_states = None + batch.pixel_values = None return batch @classmethod