Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModernBERT inference fails on CPU: ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) #35388

Open
4 tasks
umarbutler opened this issue Dec 21, 2024 · 2 comments
Labels

Comments

@umarbutler
Copy link
Contributor

System Info

  • transformers version: 4.48.0.dev0
  • Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
  • Python version: 3.12.5
  • Huggingface_hub version: 0.25.1
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 4090

Who can help?

@Rocketknight1 @arthu

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When one runs the below code, taken exactly from the Hugging Face ModernBERT's README except for the addition of device = 'cpu', they get the error ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?):

import torch
from transformers import pipeline
from pprint import pprint

pipe = pipeline(
    "fill-mask",
    model="answerdotai/ModernBERT-base",
    torch_dtype=torch.bfloat16,
    device='cpu',
)

input_text = "He walked to the [MASK]."
results = pipe(input_text)
pprint(results)

Here is the full traceback of the error:

ValueError                                Traceback (most recent call last)
Cell In[1], line 13
      5 pipe = pipeline(
      6     "fill-mask",
      7     model="answerdotai/ModernBERT-base",
      8     torch_dtype=torch.bfloat16,
      9     device='cpu',
     10 )
     12 input_text = "He walked to the [MASK]."
---> 13 results = pipe(input_text)
     14 pprint(results)

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/fill_mask.py:270, in FillMaskPipeline.__call__(self, inputs, **kwargs)
    248 def __call__(self, inputs, **kwargs):
    249     """
    250     Fill the masked token in the text(s) given as inputs.
    251 
   (...)
    268         - **token_str** (str) -- The predicted token (to replace the masked one).
    269     """
--> 270     outputs = super().__call__(inputs, **kwargs)
    271     if isinstance(inputs, list) and len(inputs) == 1:
    272         return outputs[0]

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/base.py:1301, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1293     return next(
   1294         iter(
   1295             self.get_iterator(
   (...)
   1298         )
   1299     )
   1300 else:
-> 1301     return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/base.py:1308, in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
   1306 def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
   1307     model_inputs = self.preprocess(inputs, **preprocess_params)
-> 1308     model_outputs = self.forward(model_inputs, **forward_params)
   1309     outputs = self.postprocess(model_outputs, **postprocess_params)
   1310     return outputs

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/base.py:1208, in Pipeline.forward(self, model_inputs, **forward_params)
   1206     with inference_context():
   1207         model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
-> 1208         model_outputs = self._forward(model_inputs, **forward_params)
   1209         model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
   1210 else:

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/fill_mask.py:127, in FillMaskPipeline._forward(self, model_inputs)
    126 def _forward(self, model_inputs):
--> 127     model_outputs = self.model(**model_inputs)
    128     model_outputs["input_ids"] = model_inputs["input_ids"]
    129     return model_outputs

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:1059, in ModernBertForMaskedLM.forward(self, input_ids, attention_mask, sliding_window_mask, position_ids, labels, indices, cu_seqlens, max_seqlen, batch_size, seq_len, output_attentions, output_hidden_states, return_dict, **kwargs)
   1054         with torch.no_grad():
   1055             input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
   1056                 inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
   1057             )
-> 1059 outputs = self.model(
   1060     input_ids,
   1061     attention_mask=attention_mask,
   1062     sliding_window_mask=sliding_window_mask,
   1063     position_ids=position_ids,
   1064     indices=indices,
   1065     cu_seqlens=cu_seqlens,
   1066     max_seqlen=max_seqlen,
   1067     batch_size=batch_size,
   1068     seq_len=seq_len,
   1069     output_attentions=output_attentions,
   1070     output_hidden_states=output_hidden_states,
   1071     return_dict=return_dict,
   1072 )
   1073 last_hidden_state = outputs[0]
   1075 if self.sparse_prediction and labels is not None:
   1076     # flatten labels and output first

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:913, in ModernBertModel.forward(self, input_ids, attention_mask, sliding_window_mask, position_ids, indices, cu_seqlens, max_seqlen, batch_size, seq_len, output_attentions, output_hidden_states, return_dict)
    902     layer_outputs = self._gradient_checkpointing_func(
    903         encoder_layer.__call__,
    904         hidden_states,
   (...)
    910         output_attentions,
    911     )
    912 else:
--> 913     layer_outputs = encoder_layer(
    914         hidden_states,
    915         attention_mask=attention_mask,
    916         sliding_window_mask=sliding_window_mask,
    917         position_ids=position_ids,
    918         cu_seqlens=cu_seqlens,
    919         max_seqlen=max_seqlen,
    920         output_attentions=output_attentions,
    921     )
    922 hidden_states = layer_outputs[0]
    923 if output_attentions and len(layer_outputs) > 1:

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:529, in ModernBertEncoderLayer.forward(self, hidden_states, attention_mask, sliding_window_mask, position_ids, cu_seqlens, max_seqlen, output_attentions)
    519 def forward(
    520     self,
    521     hidden_states: torch.Tensor,
   (...)
    527     output_attentions: Optional[bool] = False,
    528 ) -> torch.Tensor:
--> 529     attn_outputs = self.attn(
    530         self.attn_norm(hidden_states),
    531         attention_mask=attention_mask,
    532         sliding_window_mask=sliding_window_mask,
    533         position_ids=position_ids,
    534         cu_seqlens=cu_seqlens,
    535         max_seqlen=max_seqlen,
    536         output_attentions=output_attentions,
    537     )
    538     hidden_states = hidden_states + attn_outputs[0]
    539     mlp_output = (
    540         self.compiled_mlp(hidden_states)
    541         if self.config.reference_compile
    542         else self.mlp(self.mlp_norm(hidden_states))
    543     )

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:487, in ModernBertAttention.forward(self, hidden_states, output_attentions, **kwargs)
    484 else:
    485     qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
--> 487 attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
    488     self,
    489     qkv=qkv,
    490     rotary_emb=self.rotary_emb,
    491     local_attention=self.local_attention,
    492     bs=bs,
    493     dim=self.all_head_size,
    494     output_attentions=output_attentions,
    495     **kwargs,
    496 )
    497 hidden_states = attn_outputs[0]
    498 hidden_states = self.out_drop(self.Wo(hidden_states))

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:349, in flash_attention_forward(module, qkv, rotary_emb, cu_seqlens, max_seqlen, local_attention, bs, dim, target_dtype, **_kwargs)
    336 def flash_attention_forward(
    337     module: "ModernBertAttention",
    338     qkv: torch.Tensor,
   (...)
    347 ) -> Tuple[torch.Tensor]:
    348     # (total_seqlen, 3, nheads, headdim)
--> 349     qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
    351     convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
    352     if convert_dtype:
    353         # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
    354         # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:178, in ModernBertUnpaddedRotaryEmbedding.forward(self, qkv, cu_seqlens, max_seqlen)
    175 if max_seqlen is not None:
    176     self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
--> 178 qkv = apply_rotary_unpadded(
    179     qkv,
    180     self._cos_cached,
    181     self._sin_cached,
    182     cu_seqlens=cu_seqlens,
    183     max_seqlen=max_seqlen,
    184 )
    186 return qkv

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:136, in apply_rotary_unpadded(qkv, cos, sin, cu_seqlens, max_seqlen)
    113 def apply_rotary_unpadded(
    114     qkv,
    115     cos,
   (...)
    118     max_seqlen: Optional[int] = None,
    119 ):
    120     """
    121     Arguments:
    122         qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
   (...)
    134     Apply rotary embedding to the first rotary_dim of x.
    135     """
--> 136     return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)

File ~/dev/.venv/lib/python3.12/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
    572 if not torch._C._are_functorch_transforms_active():
    573     # See NOTE: [functorch vjp and autograd interaction]
    574     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575     return super().apply(*args, **kwargs)  # type: ignore[misc]
    577 if not is_setup_ctx_defined:
    578     raise RuntimeError(
    579         "In order to use an autograd.Function with functorch transforms "
    580         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    581         "staticmethod. For more details, please see "
    582         "https://pytorch.org/docs/main/notes/extending.func.html"
    583     )

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:75, in ApplyRotaryEmbUnpad.forward(ctx, qkv, cos, sin, cu_seqlens, max_seqlen)
     71 # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
     72 # we get the same tensor
     73 # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
     74 qk = qkv[:, :2].view(total_nnz, -1, headdim)
---> 75 apply_rotary(
     76     qk,
     77     cos,
     78     sin,
     79     seqlen_offsets=0,
     80     cu_seqlens=cu_seqlens,
     81     max_seqlen=max_seqlen,
     82     interleaved=False,
     83     inplace=True,
     84 )
     86 ctx.save_for_backward(cos, sin, cu_seqlens)
     87 ctx.max_seqlen = max_seqlen

File ~/dev/.venv/lib/python3.12/site-packages/flash_attn/ops/triton/rotary.py:202, in apply_rotary(x, cos, sin, seqlen_offsets, cu_seqlens, max_seqlen, interleaved, inplace, conjugate)
    199 # Need this, otherwise Triton tries to launch from cuda:0 and we get
    200 # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
    201 with torch.cuda.device(x.device.index):
--> 202     rotary_kernel[grid](
    203         output,  # data ptrs
    204         x,
    205         cos,
    206         sin,
    207         cu_seqlens,
    208         seqlen_offsets,
    209         seqlen,  # shapes
    210         rotary_dim,
    211         seqlen_ro,
    212         output.stride(0) if not is_varlen else 0,  # batch_strides if not varlen else 0
    213         output.stride(-3),  # seqlen_stride or total_seqlen_stride
    214         output.stride(-2),  # nheads_stride
    215         output.stride(-1),  # headdim_stride
    216         x.stride(0) if not is_varlen else 0,  # batch_strides if not varlen else 0
    217         x.stride(-3),  # seqlen stride or total_seqlen_stride
    218         x.stride(-2),  # nheads stride
    219         x.stride(-1),  # headdim stride
    220         BLOCK_K,
    221         isinstance(seqlen_offsets, torch.Tensor),
    222         is_varlen,
    223         interleaved,
    224         conjugate,
    225         BLOCK_M,
    226     )
    227 return output

File ~/dev/.venv/lib/python3.12/site-packages/triton/runtime/jit.py:345, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
    339 def __getitem__(self, grid) -> T:
    340     """
    341     A JIT function is launched with: fn[grid](*args, **kwargs).
    342     Hence JITFunction.__getitem__ returns a callable proxy that
    343     memorizes the grid.
    344     """
--> 345     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/triton/runtime/jit.py:691, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    689     # launch kernel
    690     launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals)
--> 691     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
    692                self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)
    693 return kernel

File ~/dev/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.py:365, in CudaLauncher.__call__(self, *args, **kwargs)
    364 def __call__(self, *args, **kwargs):
--> 365     self.launch(*args, **kwargs)

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

Expected behavior

It works.

@umarbutler
Copy link
Contributor Author

Related AnswerDotAI/ModernBERT#152

@umarbutler
Copy link
Contributor Author

Hint: in the traceback, you see this line:

    199 # Need this, otherwise Triton tries to launch from cuda:0 and we get
    200 # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
    201 with torch.cuda.device(x.device.index):
--> 202     rotary_kernel[grid](
    203         output,  # data ptrs
    204         x,
    205         cos,
    206         sin,
    207         cu_seqlens,
    208         seqlen_offsets,
    209         seqlen,  # shapes
    210         rotary_dim,
    211         seqlen_ro,
    212         output.stride(0) if not is_varlen else 0,  # batch_strides if not varlen else 0
    213         output.stride(-3),  # seqlen_stride or total_seqlen_stride
    214         output.stride(-2),  # nheads_stride
    215         output.stride(-1),  # headdim_stride
    216         x.stride(0) if not is_varlen else 0,  # batch_strides if not varlen else 0
    217         x.stride(-3),  # seqlen stride or total_seqlen_stride
    218         x.stride(-2),  # nheads stride
    219         x.stride(-1),  # headdim stride
    220         BLOCK_K,
    221         isinstance(seqlen_offsets, torch.Tensor),
    222         is_varlen,
    223         interleaved,
    224         conjugate,
    225         BLOCK_M,
    226     )
    227 return output

So it seems this error is known about but the fix didn't work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant