diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index 3a356392a6aa..3124c2c6e556 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -385,6 +385,13 @@ def squeeze( ) -> Tensor: ... @overload def squeeze(input: RaggedTensor, axis: int | tuple[int, ...] | list[int], name: str | None = None) -> RaggedTensor: ... +def split( + value: TensorCompatible, + num_or_size_splits: int | TensorCompatible, + axis: int | Tensor = 0, + num: int | None = None, + name: str | None = "split", +) -> list[Tensor]: ... def tensor_scatter_nd_update( tensor: TensorCompatible, indices: TensorCompatible, updates: TensorCompatible, name: str | None = None ) -> Tensor: ...