Infer wp.array DType from argument in constructor#1539
Conversation
📝 WalkthroughWalkthroughThis PR upgrades type annotations for array-creation functions to use generic ChangesGeneric DType array creation API
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
| def zeros( | ||
| shape: int | tuple[int, ...] | list[int] | None = None, | ||
| dtype: type = float, | ||
| dtype: type[DType] | None = None, |
There was a problem hiding this comment.
The default was changed to None.
This is to avoid a type error, because the concrete default float would pin the free DType TypeVar.
There is precedent to this (wp.full already uses this pattern).
This changes the behavior of the code. wp.zeros(10, dtype=None) is now a legal "use the default". Before it raised a TypeError in type_size_in_bytes.
There was a problem hiding this comment.
It's better to have float set as the default parameter value since the expectations are clear from the signature but, other than that, I don't think it's too much of an issue since it's a common Python idiom to consider None as a way to express “no value has been passed” and define the default behaviour inside the function's body.
|
@shi-eric could you take a look into this and let me know if there's interest in fixing this? If so, I'd work on a proper PR. Would be interested in your opinion on how to add a test for this. |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@warp/__init__.pyi`:
- Around line 1725-1734: The zeros() stub must be split into two overloads so
the return dtype binds when dtype is provided and defaults to warp.float32 when
omitted: add one overload for zeros(..., dtype: None = None) -> array[float32]
(or the module's float32 DType symbol) and a generic overload zeros(..., dtype:
type[DType]) -> array[DType]; keep other params (shape, device, requires_grad,
pinned, retain_grad, **kwargs) identical and ensure the TypeVar DType and array
type are the same symbols used elsewhere in the stub to allow proper type
inference for zeros(... ) and zeros(..., dtype=...).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: 63b4be49-162b-44d6-abc6-a4005f43e54b
📒 Files selected for processing (3)
warp/__init__.pyiwarp/_src/context.pywarp/_src/types.py
| def zeros( | ||
| shape: int | tuple[int, ...] | list[int] | None = None, | ||
| dtype: type = float, | ||
| dtype: type[DType] | None = None, | ||
| device: DeviceLike = None, | ||
| requires_grad: _builtins.bool = False, | ||
| pinned: _builtins.bool = False, | ||
| retain_grad: _builtins.bool = False, | ||
| **kwargs, | ||
| ) -> array: | ||
| ) -> array[DType]: | ||
| """Return a zero-initialized array.""" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether DType is inferred for omitted dtype vs explicit dtype from the current stub.
# Expected:
# - reveal_type(wp.zeros(10)) should be concrete default dtype array (not Any)
# - reveal_type(wp.zeros(10, dtype=wp.float32)) should be array[float32]
cat > /tmp/verify_zeros_typing.py <<'PY'
import warp as wp
a = wp.zeros(10)
b = wp.zeros(10, dtype=wp.float32)
reveal_type(a)
reveal_type(b)
PY
# Run with your project type checker (mypy/pyright) in CI/local env.
# Example (mypy):
# mypy /tmp/verify_zeros_typing.pyRepository: NVIDIA/warp
Length of output: 37
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 1) Inspect the current zeros stubs
echo "== zeros definitions =="
rg -n "def zeros\\(" warp/__init__.pyi || true
echo
echo "== zeros snippet (around line ~1725) =="
sed -n '1700,1760p' warp/__init__.pyi
echo
echo "== dtype aliases referenced by zeros snippet =="
rg -n "class DType|DType\\b|float32\\b" warp/__init__.pyi | head -n 80
# 2) Create a reveal_type snippet
cat > /tmp/verify_zeros_typing.py <<'PY'
import warp as wp
a = wp.zeros(10)
b = wp.zeros(10, dtype=wp.float32)
reveal_type(a)
reveal_type(b)
PY
# 3) Check if type checkers exist and run them
echo
echo "== type checker availability =="
command -v mypy || true
command -v pyright || true
echo
echo "== running pyright (if available) =="
if command -v pyright >/dev/null 2>&1; then
# Ensure pyright can see the local stub package
pyright /tmp/verify_zeros_typing.py --pythonpath . || true
else
echo "pyright not available"
fi
echo
echo "== running mypy (if available) =="
if command -v mypy >/dev/null 2>&1 || python -m mypy --version >/dev/null 2>&1; then
# Make mypy find the local warp package directory (with __init__.pyi)
MYPYPATH="$PWD" python -m mypy --hide-error-context --show-traceback /tmp/verify_zeros_typing.py || true
else
echo "mypy not available"
fiRepository: NVIDIA/warp
Length of output: 6199
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== locate zeros implementation (python) =="
rg -n "def zeros\\(" -S warp || true
echo
echo "== locate zeros callsites / dtype default =="
# Search for "zeros(" and "dtype" near it
rg -n "zeros\\(" warp | head -n 50 || true
echo
echo "== search for the core zeros implementation in _src/ =="
rg -n "zeros\\b" warp/_src | head -n 50 || true
echo
echo "== try to open likely implementation files =="
# List candidate files containing "zeros" to inspect around dtype defaults
fd -t f "*/**/*.py" warp | head -n 200 >/tmp/pyfiles.txt || true
python3 - <<'PY'
import subprocess, re, os, pathlib, sys
root="warp"
# Use ripgrep to find files with 'zeros' keyword
cmd=["rg","-l","\\bzeros\\b",root]
out=subprocess.check_output(cmd, text=True).splitlines()
for p in out[:25]:
print(p)
PY
echo
echo "== show first matches around zeros in candidate files =="
# Show contexts for zeros definitions if any
for f in $(rg -l "def zeros\\(" warp || true); do
echo "--- $f ---"
rg -n "def zeros\\(" "$f" -n -C 3 || true
done
echo
echo "== show contexts for 'dtype' default in zeros implementations =="
# Broad heuristic: within files mentioning zeros, show nearby 'dtype' and 'float' references
for f in $(rg -l "\\bzeros\\b" warp || true); do
if rg -q "zeros\\b" "$f"; then
# Only display if file is reasonably small for context scanning
lines=$(wc -l < "$f")
if [ "$lines" -le 600 ]; then
echo "--- $f ---"
rg -n "zeros\\b|dtype\\b|float\\b" "$f" -C 2 || true
fi
fi
doneRepository: NVIDIA/warp
Length of output: 50675
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# show zeros implementation signature + dtype default handling
sed -n '7635,7705p' warp/_src/context.py
echo "---- dtype handling (search within function) ----"
python3 - <<'PY'
import re, pathlib
p = pathlib.Path("warp/_src/context.py")
text = p.read_text()
# naive slice around function def
start = text.find("def zeros(")
print("function start idx", start)
sub = text[start:start+6000]
# print lines containing "dtype" and likely defaults
for i,line in enumerate(sub.splitlines(),1):
if "dtype" in line or "float" in line or "DType" in line or "float32" in line:
if any(k in line for k in ["dtype", "float32", "float", "default", "None", "type"]):
print(f"{i:04d}: {line}")
PY
echo "---- exact dtype default line(s) ----"
rg -n "def zeros\\(|dtype\\s*=|if\\s+dtype\\s+is\\s+None|dtype\\s+is\\s+None|float32|float\\b" -n warp/_src/context.py | head -n 60Repository: NVIDIA/warp
Length of output: 7298
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== header: find float/float32 dtype aliases =="
sed -n '1,120p' warp/__init__.pyi | sed -n '1,120p'
echo
echo "== locate DType definition area =="
sed -n '230,320p' warp/__init__.pyi
echo
echo "== locate array generic + dtype typing =="
rg -n "class Array\\b|class FabricArray\\b|DType\\b\\s*=\\s*TypeVar" warp/__init__.pyi | head -n 50
echo
echo "== show around warp.float / float32 declarations =="
rg -n "(^|\\b)(float32|float\\b|int32|int\\b)\\s*=" warp/__init__.pyi | head -n 80
echo
echo "== show type alias for float if present =="
# Try to find "float =" or "float: "
rg -n "^[[:space:]]*float[[:space:]]*=" warp/__init__.pyi || true
rg -n "^[[:space:]]*float\\b.*:" warp/__init__.pyi || true
echo
echo "== show Array[...] return typing helper =="
rg -n "def zeros\\(" warp/__init__.pyi | head -n 20Repository: NVIDIA/warp
Length of output: 10632
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== zeros related empty stub =="
rg -n "def empty\\(" warp/__init__.pyi | head -n 20
# show area around empty signature
line=$(rg -n "def empty\\(" warp/__init__.pyi | head -n 1 | cut -d: -f1)
start=$((line-10))
end=$((line+60))
sed -n "${start},${end}p" warp/__init__.pyi
echo
echo "== zeros_like/ones default dtype typing (if present) =="
rg -n "def (ones|full|empty_like|zeros_like|ones_like|full_like)\\(" warp/__init__.pyi | head -n 50Repository: NVIDIA/warp
Length of output: 86
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== stub block: warp/__init__.pyi zeros overloads =="
sed -n '1715,1760p' warp/__init__.pyi
echo
echo "== runtime: warp/_src/context.py empty() dtype mapping for python float =="
# locate empty() definition
rg -n "def empty\\(" warp/_src/context.py | head -n 20
line_empty=$(rg -n "def empty\\(" warp/_src/context.py | head -n 1 | cut -d: -f1)
start=$((line_empty-5))
end=$((line_empty+120))
sed -n "${start},${end}p" warp/_src/context.py
echo
echo "== runtime: dtype=float mapping if dtype is None => float (zeros) => empty() =="
rg -n "if dtype is None:|dtype = float|if dtype is float|dtype == float" warp/_src/context.py | head -n 80
echo
echo "== locate typing Scalar definition used by __init__.pyi =="
# Scalar is imported from warp._src.types; search for its definition in that file
rg -n "^(Scalar\\s*=|class Scalar\\b|from typing import .*Scalar)" warp/_src/types.py warp/_src/types.pyi 2>/dev/null || true
rg -n "\\bScalar\\b\\s*=" warp/_src/types.py warp/_src/types.pyi 2>/dev/null || true
echo
echo "== show relevant snippet around Scalar (small window) =="
# If a Scalar definition exists in types.py, print around the first match
scalar_line=$(rg -n "Scalar\\b\\s*=" warp/_src/types.py 2>/dev/null | head -n 1 | cut -d: -f1 || true)
if [ -n "${scalar_line:-}" ]; then
start=$((scalar_line-20))
end=$((scalar_line+60))
sed -n "${start},${end}p" warp/_src/types.py
else
echo "No direct 'Scalar=' found in warp/_src/types.py; searching for 'Scalar' typing constructs..."
rg -n "Scalar\\b" warp/_src/types.py | head -n 40
fiRepository: NVIDIA/warp
Length of output: 8692
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== locate type_to_warp mapping =="
rg -n "def type_to_warp\\b|type_to_warp\\b" warp/_src/types.py
echo
echo "== show mapping logic around type_to_warp =="
line=$(rg -n "def type_to_warp\\b" warp/_src/types.py | head -n 1 | cut -d: -f1)
start=$((line-30))
end=$((line+120))
sed -n "${start},${end}p" warp/_src/types.py
echo
echo "== look for python float -> float32 mapping explicitly =="
rg -n "float32\\b.*(float\\b|python float)|\\bfloat\\b.*float32\\b" warp/_src/types.py | head -n 50
echo
echo "== show native_scalar_types and any mapping usage =="
rg -n "native_scalar_types|float32\\b|python\\b.*float" warp/_src/types.py | head -n 80Repository: NVIDIA/warp
Length of output: 11202
Fix zeros() stub to bind return dtype when dtype is omitted
warp/_src/context.py::zeros() defaults dtype to Python float when omitted, and warp/_src/types.py::type_to_warp() maps that to warp.float32; however the current stub keeps dtype as type[DType] | None = None, so wp.zeros(10) can’t bind DType and the return often falls back to an imprecise type (e.g., array[Any]/unknown). Split overloads so omitted dtype returns array[float32] and explicit dtype binds DType.
Suggested stub shape
`@over`
def zeros(
shape: int | tuple[int, ...] | list[int] | None = None,
- dtype: type[DType] | None = None,
+ dtype: None = None,
device: DeviceLike = None,
requires_grad: _builtins.bool = False,
pinned: _builtins.bool = False,
retain_grad: _builtins.bool = False,
**kwargs,
-) -> array[DType]:
+) -> array[float32]: ...
+
+@over
+def zeros(
+ shape: int | tuple[int, ...] | list[int] | None = None,
+ dtype: type[DType] = ...,
+ device: DeviceLike = None,
+ requires_grad: _builtins.bool = False,
+ pinned: _builtins.bool = False,
+ retain_grad: _builtins.bool = False,
+ **kwargs,
+) -> array[DType]:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def zeros( | |
| shape: int | tuple[int, ...] | list[int] | None = None, | |
| dtype: type = float, | |
| dtype: type[DType] | None = None, | |
| device: DeviceLike = None, | |
| requires_grad: _builtins.bool = False, | |
| pinned: _builtins.bool = False, | |
| retain_grad: _builtins.bool = False, | |
| **kwargs, | |
| ) -> array: | |
| ) -> array[DType]: | |
| """Return a zero-initialized array.""" | |
| `@over` | |
| def zeros( | |
| shape: int | tuple[int, ...] | list[int] | None = None, | |
| dtype: None = None, | |
| device: DeviceLike = None, | |
| requires_grad: _builtins.bool = False, | |
| pinned: _builtins.bool = False, | |
| retain_grad: _builtins.bool = False, | |
| **kwargs, | |
| ) -> array[float32]: ... | |
| `@over` | |
| def zeros( | |
| shape: int | tuple[int, ...] | list[int] | None = None, | |
| dtype: type[DType] = ..., | |
| device: DeviceLike = None, | |
| requires_grad: _builtins.bool = False, | |
| pinned: _builtins.bool = False, | |
| retain_grad: _builtins.bool = False, | |
| **kwargs, | |
| ) -> array[DType]: | |
| """Return a zero-initialized array.""" |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@warp/__init__.pyi` around lines 1725 - 1734, The zeros() stub must be split
into two overloads so the return dtype binds when dtype is provided and defaults
to warp.float32 when omitted: add one overload for zeros(..., dtype: None =
None) -> array[float32] (or the module's float32 DType symbol) and a generic
overload zeros(..., dtype: type[DType]) -> array[DType]; keep other params
(shape, device, requires_grad, pinned, retain_grad, **kwargs) identical and
ensure the TypeVar DType and array type are the same symbols used elsewhere in
the stub to allow proper type inference for zeros(... ) and zeros(...,
dtype=...).
Greptile SummaryThis draft PR improves static type inference for
Confidence Score: 4/5Safe to merge as a draft foundation — all runtime behaviour is preserved by explicit None-guards and no logic paths are altered. The changes are purely annotation-level for all functions except the addition of The Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Caller: wp.zeros(10, dtype=wp.float32)"] --> B{"dtype provided?"}
B -- "Yes: type[DType]" --> C["DType bound to float32"]
B -- "No / None" --> D["DType unbound → Any\n(runtime: dtype = float)"]
C --> E["Return: array[float32] ✓"]
D --> F["Return: array[Any]\n(same as before)"]
G["Caller: wp.array([1,2,3], dtype=wp.int32)"] --> H["array.__init__\ndtype: type[DType] = Any"]
H --> I["DType bound to int32"]
I --> J["Return: array[int32] ✓"]
K["Caller: from_numpy(arr)"] --> L{"dtype provided?"}
L -- "Yes: type[DType]" --> M["DType bound → array[DType] ✓"]
L -- "No / None" --> N["Runtime: infers from arr.dtype\nStatic: array[Any]"]
Reviews (1): Last reviewed commit: "Infer wp.array DType from argument in co..." | Re-trigger Greptile |
| def zeros( | ||
| shape: int | tuple[int, ...] | list[int] | None = None, | ||
| dtype: type = float, | ||
| dtype: type[DType] | None = None, | ||
| device: DeviceLike = None, | ||
| requires_grad: bool = False, | ||
| pinned: bool = False, | ||
| retain_grad: bool = False, | ||
| **kwargs, | ||
| ) -> warp.array: | ||
| ) -> warp.array[DType]: |
There was a problem hiding this comment.
Default dtype not statically reflected when omitted
Changing the parameter default from dtype: type = float to dtype: type[DType] | None = None means type checkers bind DType only when a concrete dtype is passed. When the user calls wp.zeros(10) (omitting dtype), DType remains unresolved and the return type degrades to array[Any] — the same situation mypy strict mode complained about before. The intended default of float / wp.float32 is only applied at runtime via the if dtype is None guard, invisible to the type checker.
Adding @typing.overload variants (one for dtype=None returning array[float32] and one for dtype: type[DType] returning array[DType]) would give fully typed results in both cases. This pattern is also needed for ones and empty, which receive the same treatment here.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| def from_numpy( | ||
| arr: np.ndarray, | ||
| dtype: type | None = None, | ||
| dtype: type[DType] | None = None, | ||
| shape: Sequence[int] | None = None, | ||
| device: DeviceLike | None = None, | ||
| requires_grad: bool = False, | ||
| retain_grad: bool = False, | ||
| ) -> warp.array: | ||
| ) -> warp.array[DType]: |
There was a problem hiding this comment.
from_numpy with omitted dtype still returns array[Any]
When dtype is not supplied, the runtime successfully infers it from the numpy array's .dtype (resolving to e.g. float32, vec3, etc.), but the type checker sees an unbound DType and resolves the return as array[Any]. There is no static way to bridge the numpy dtype → Warp dtype mapping without explicit @overloads or a different approach — so this is a known limitation — but it's worth noting that the annotation currently overpromises for the no-dtype path: callers relying on inferred types from from_numpy(arr) won't benefit from this PR.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Thanks @thomasbbrunner! At first glance, I'm not seeing any obvious drawback to have these changes added, so let's give it a try. Could you please:
See https://nvidia.github.io/warp/stable/user_guide/contribution_guide.html for details. As for the tests you suggested, we currently don't cover typing behaviour with our tests, but @shi-eric suggested extending the |
Description
NOTE: this is a draft. Opening this PR for comments/feedback.
Currently,
wp.arrayconstructors & factories (zeros/ones/empty/full/from_numpy/...) infer the type asarray[Any, int]because their signatures don't expose the dtype as a generic parameter.For instance, notice how the
DTypeof theclass array(Array[DType, NDim])is alwaysAny:This causes
mypyto complain when in strict mode:error: Need type annotation for "a" [var-annotated]To fix this, users must either explicitly set the type or ignore the error. Either fix is a workaround and imo suboptimal.
Here, I'm proposing to set the array's
DTypeaccording to thedtypeargument in the constructor. With this, the type is correctly set:There are no open issues specifically describing this. This is somewhat of an extension to #549.
Is there interest in fixing this? If so, I'd update this MR to:
Checklist
Unreleasedsection.Summary by CodeRabbit
zeros,ones,empty,full,from_numpy) with support for generic type parameters, enabling better type inference and IDE support.