Skip to content

Infer wp.array DType from argument in constructor#1539

Draft
thomasbbrunner wants to merge 1 commit into
NVIDIA:mainfrom
thomasbbrunner:tbrunner/mypy-array-constructors
Draft

Infer wp.array DType from argument in constructor#1539
thomasbbrunner wants to merge 1 commit into
NVIDIA:mainfrom
thomasbbrunner:tbrunner/mypy-array-constructors

Conversation

@thomasbbrunner

@thomasbbrunner thomasbbrunner commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Description

NOTE: this is a draft. Opening this PR for comments/feedback.

Currently, wp.array constructors & factories (zeros/ones/empty/full/from_numpy/...) infer the type as array[Any, int] because their signatures don't expose the dtype as a generic parameter.

For instance, notice how the DType of the class array(Array[DType, NDim]) is always Any:

a = wp.array([1, 2, 3], dtype=wp.int32)
reveal_type(a)
# output: array[Any, int]

f = wp.full(10, 7, dtype=wp.float32)
reveal_type(f)
# output: array[Any, int]

This causes mypy to complain when in strict mode: error: Need type annotation for "a" [var-annotated]

To fix this, users must either explicitly set the type or ignore the error. Either fix is a workaround and imo suboptimal.

Here, I'm proposing to set the array's DType according to the dtype argument in the constructor. With this, the type is correctly set:

a = wp.array([1, 2, 3], dtype=wp.int32)
reveal_type(a)
# output: array[int32, int]

f = wp.full(10, 7, dtype=wp.float32)
reveal_type(f)
# output: array[float32, int]

There are no open issues specifically describing this. This is somewhat of an extension to #549.

Is there interest in fixing this? If so, I'd update this MR to:

  • add tests for this (potentially calling mypy to assert that the revealed types match the expectations), and,
  • extend this pattern to other constructors (zeros_like/*_like, from_torch/paddle/jax/dlpack).

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • CHANGELOG.md is updated for any user-facing changes under the Unreleased section.

Summary by CodeRabbit

  • Refactor
    • Improved type annotations for array creation functions (zeros, ones, empty, full, from_numpy) with support for generic type parameters, enabling better type inference and IDE support.

@copy-pr-bot

copy-pr-bot Bot commented Jun 8, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 8, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

This PR upgrades type annotations for array-creation functions to use generic DType parameters. The DType type variable is imported, warp.array.__init__ annotation is updated, and zeros, ones, full, empty, and from_numpy signatures now accept dtype: type[DType] | None and return array[DType]. Runtime defaults are added where needed, and the .pyi stub is synchronized.

Changes

Generic DType array creation API

Layer / File(s) Summary
Type foundation: DType import and array class annotation
warp/_src/context.py, warp/_src/types.py
DType is imported and used to annotate warp.array.__init__'s dtype parameter as type[DType], enabling generic typing for dependent functions.
zeros function: generic signature, runtime default, and stub
warp/_src/context.py, warp/__init__.pyi
zeros() signature now accepts dtype: type[DType] | None = None and returns array[DType]; when dtype is None at runtime it defaults to float; the .pyi stub is updated to match.
ones function: generic signature and runtime default
warp/_src/context.py
ones() signature updated to dtype: type[DType] | None = None returning array[DType]; explicit dtype = float assignment added when dtype is None.
full, empty, and from_numpy functions: generic signatures
warp/_src/context.py
full(), empty(), and from_numpy() signatures updated to use dtype: type[DType] | None and return array[DType]; empty() also adds explicit dtype = float when dtype is None.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change—updating warp.array constructors to infer DType from dtype arguments—which aligns with modifications across multiple files (init.pyi, context.py, types.py).
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread warp/_src/context.py
def zeros(
shape: int | tuple[int, ...] | list[int] | None = None,
dtype: type = float,
dtype: type[DType] | None = None,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default was changed to None.

This is to avoid a type error, because the concrete default float would pin the free DType TypeVar.

There is precedent to this (wp.full already uses this pattern).

This changes the behavior of the code. wp.zeros(10, dtype=None) is now a legal "use the default". Before it raised a TypeError in type_size_in_bytes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to have float set as the default parameter value since the expectations are clear from the signature but, other than that, I don't think it's too much of an issue since it's a common Python idiom to consider None as a way to express “no value has been passed” and define the default behaviour inside the function's body.

@thomasbbrunner

Copy link
Copy Markdown
Contributor Author

@shi-eric could you take a look into this and let me know if there's interest in fixing this? If so, I'd work on a proper PR.

Would be interested in your opinion on how to add a test for this.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/__init__.pyi`:
- Around line 1725-1734: The zeros() stub must be split into two overloads so
the return dtype binds when dtype is provided and defaults to warp.float32 when
omitted: add one overload for zeros(..., dtype: None = None) -> array[float32]
(or the module's float32 DType symbol) and a generic overload zeros(..., dtype:
type[DType]) -> array[DType]; keep other params (shape, device, requires_grad,
pinned, retain_grad, **kwargs) identical and ensure the TypeVar DType and array
type are the same symbols used elsewhere in the stub to allow proper type
inference for zeros(... ) and zeros(..., dtype=...).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: 63b4be49-162b-44d6-abc6-a4005f43e54b

📥 Commits

Reviewing files that changed from the base of the PR and between c1007f1 and 4861b4a.

📒 Files selected for processing (3)
  • warp/__init__.pyi
  • warp/_src/context.py
  • warp/_src/types.py

Comment thread warp/__init__.pyi
Comment on lines 1725 to 1734
def zeros(
shape: int | tuple[int, ...] | list[int] | None = None,
dtype: type = float,
dtype: type[DType] | None = None,
device: DeviceLike = None,
requires_grad: _builtins.bool = False,
pinned: _builtins.bool = False,
retain_grad: _builtins.bool = False,
**kwargs,
) -> array:
) -> array[DType]:
"""Return a zero-initialized array."""

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether DType is inferred for omitted dtype vs explicit dtype from the current stub.
# Expected:
# - reveal_type(wp.zeros(10)) should be concrete default dtype array (not Any)
# - reveal_type(wp.zeros(10, dtype=wp.float32)) should be array[float32]
cat > /tmp/verify_zeros_typing.py <<'PY'
import warp as wp

a = wp.zeros(10)
b = wp.zeros(10, dtype=wp.float32)

reveal_type(a)
reveal_type(b)
PY

# Run with your project type checker (mypy/pyright) in CI/local env.
# Example (mypy):
# mypy /tmp/verify_zeros_typing.py

Repository: NVIDIA/warp

Length of output: 37


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 1) Inspect the current zeros stubs
echo "== zeros definitions =="
rg -n "def zeros\\(" warp/__init__.pyi || true

echo
echo "== zeros snippet (around line ~1725) =="
sed -n '1700,1760p' warp/__init__.pyi

echo
echo "== dtype aliases referenced by zeros snippet =="
rg -n "class DType|DType\\b|float32\\b" warp/__init__.pyi | head -n 80

# 2) Create a reveal_type snippet
cat > /tmp/verify_zeros_typing.py <<'PY'
import warp as wp

a = wp.zeros(10)
b = wp.zeros(10, dtype=wp.float32)

reveal_type(a)
reveal_type(b)
PY

# 3) Check if type checkers exist and run them
echo
echo "== type checker availability =="
command -v mypy || true
command -v pyright || true

echo
echo "== running pyright (if available) =="
if command -v pyright >/dev/null 2>&1; then
  # Ensure pyright can see the local stub package
  pyright /tmp/verify_zeros_typing.py --pythonpath . || true
else
  echo "pyright not available"
fi

echo
echo "== running mypy (if available) =="
if command -v mypy >/dev/null 2>&1 || python -m mypy --version >/dev/null 2>&1; then
  # Make mypy find the local warp package directory (with __init__.pyi)
  MYPYPATH="$PWD" python -m mypy --hide-error-context --show-traceback /tmp/verify_zeros_typing.py || true
else
  echo "mypy not available"
fi

Repository: NVIDIA/warp

Length of output: 6199


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== locate zeros implementation (python) =="
rg -n "def zeros\\(" -S warp || true

echo
echo "== locate zeros callsites / dtype default =="
# Search for "zeros(" and "dtype" near it
rg -n "zeros\\(" warp | head -n 50 || true

echo
echo "== search for the core zeros implementation in _src/ =="
rg -n "zeros\\b" warp/_src | head -n 50 || true

echo
echo "== try to open likely implementation files =="
# List candidate files containing "zeros" to inspect around dtype defaults
fd -t f "*/**/*.py" warp | head -n 200 >/tmp/pyfiles.txt || true
python3 - <<'PY'
import subprocess, re, os, pathlib, sys
root="warp"
# Use ripgrep to find files with 'zeros' keyword
cmd=["rg","-l","\\bzeros\\b",root]
out=subprocess.check_output(cmd, text=True).splitlines()
for p in out[:25]:
    print(p)
PY

echo
echo "== show first matches around zeros in candidate files =="
# Show contexts for zeros definitions if any
for f in $(rg -l "def zeros\\(" warp || true); do
  echo "--- $f ---"
  rg -n "def zeros\\(" "$f" -n -C 3 || true
done

echo
echo "== show contexts for 'dtype' default in zeros implementations =="
# Broad heuristic: within files mentioning zeros, show nearby 'dtype' and 'float' references
for f in $(rg -l "\\bzeros\\b" warp || true); do
  if rg -q "zeros\\b" "$f"; then
    # Only display if file is reasonably small for context scanning
    lines=$(wc -l < "$f")
    if [ "$lines" -le 600 ]; then
      echo "--- $f ---"
      rg -n "zeros\\b|dtype\\b|float\\b" "$f" -C 2 || true
    fi
  fi
done

Repository: NVIDIA/warp

Length of output: 50675


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# show zeros implementation signature + dtype default handling
sed -n '7635,7705p' warp/_src/context.py

echo "---- dtype handling (search within function) ----"
python3 - <<'PY'
import re, pathlib
p = pathlib.Path("warp/_src/context.py")
text = p.read_text()
# naive slice around function def
start = text.find("def zeros(")
print("function start idx", start)
sub = text[start:start+6000]
# print lines containing "dtype" and likely defaults
for i,line in enumerate(sub.splitlines(),1):
    if "dtype" in line or "float" in line or "DType" in line or "float32" in line:
        if any(k in line for k in ["dtype", "float32", "float", "default", "None", "type"]):
            print(f"{i:04d}: {line}")
PY

echo "---- exact dtype default line(s) ----"
rg -n "def zeros\\(|dtype\\s*=|if\\s+dtype\\s+is\\s+None|dtype\\s+is\\s+None|float32|float\\b" -n warp/_src/context.py | head -n 60

Repository: NVIDIA/warp

Length of output: 7298


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== header: find float/float32 dtype aliases =="
sed -n '1,120p' warp/__init__.pyi | sed -n '1,120p'

echo
echo "== locate DType definition area =="
sed -n '230,320p' warp/__init__.pyi

echo
echo "== locate array generic + dtype typing =="
rg -n "class Array\\b|class FabricArray\\b|DType\\b\\s*=\\s*TypeVar" warp/__init__.pyi | head -n 50

echo
echo "== show around warp.float / float32 declarations =="
rg -n "(^|\\b)(float32|float\\b|int32|int\\b)\\s*=" warp/__init__.pyi | head -n 80

echo
echo "== show type alias for float if present =="
# Try to find "float =" or "float: "
rg -n "^[[:space:]]*float[[:space:]]*=" warp/__init__.pyi || true
rg -n "^[[:space:]]*float\\b.*:" warp/__init__.pyi || true

echo
echo "== show Array[...] return typing helper =="
rg -n "def zeros\\(" warp/__init__.pyi | head -n 20

Repository: NVIDIA/warp

Length of output: 10632


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== zeros related empty stub =="
rg -n "def empty\\(" warp/__init__.pyi | head -n 20

# show area around empty signature
line=$(rg -n "def empty\\(" warp/__init__.pyi | head -n 1 | cut -d: -f1)
start=$((line-10))
end=$((line+60))
sed -n "${start},${end}p" warp/__init__.pyi

echo
echo "== zeros_like/ones default dtype typing (if present) =="
rg -n "def (ones|full|empty_like|zeros_like|ones_like|full_like)\\(" warp/__init__.pyi | head -n 50

Repository: NVIDIA/warp

Length of output: 86


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== stub block: warp/__init__.pyi zeros overloads =="
sed -n '1715,1760p' warp/__init__.pyi

echo
echo "== runtime: warp/_src/context.py empty() dtype mapping for python float =="
# locate empty() definition
rg -n "def empty\\(" warp/_src/context.py | head -n 20
line_empty=$(rg -n "def empty\\(" warp/_src/context.py | head -n 1 | cut -d: -f1)
start=$((line_empty-5))
end=$((line_empty+120))
sed -n "${start},${end}p" warp/_src/context.py

echo
echo "== runtime: dtype=float mapping if dtype is None => float (zeros) => empty() =="
rg -n "if dtype is None:|dtype = float|if dtype is float|dtype == float" warp/_src/context.py | head -n 80

echo
echo "== locate typing Scalar definition used by __init__.pyi =="
# Scalar is imported from warp._src.types; search for its definition in that file
rg -n "^(Scalar\\s*=|class Scalar\\b|from typing import .*Scalar)" warp/_src/types.py warp/_src/types.pyi 2>/dev/null || true
rg -n "\\bScalar\\b\\s*=" warp/_src/types.py warp/_src/types.pyi 2>/dev/null || true

echo
echo "== show relevant snippet around Scalar (small window) =="
# If a Scalar definition exists in types.py, print around the first match
scalar_line=$(rg -n "Scalar\\b\\s*=" warp/_src/types.py 2>/dev/null | head -n 1 | cut -d: -f1 || true)
if [ -n "${scalar_line:-}" ]; then
  start=$((scalar_line-20))
  end=$((scalar_line+60))
  sed -n "${start},${end}p" warp/_src/types.py
else
  echo "No direct 'Scalar=' found in warp/_src/types.py; searching for 'Scalar' typing constructs..."
  rg -n "Scalar\\b" warp/_src/types.py | head -n 40
fi

Repository: NVIDIA/warp

Length of output: 8692


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== locate type_to_warp mapping =="
rg -n "def type_to_warp\\b|type_to_warp\\b" warp/_src/types.py

echo
echo "== show mapping logic around type_to_warp =="
line=$(rg -n "def type_to_warp\\b" warp/_src/types.py | head -n 1 | cut -d: -f1)
start=$((line-30))
end=$((line+120))
sed -n "${start},${end}p" warp/_src/types.py

echo
echo "== look for python float -> float32 mapping explicitly =="
rg -n "float32\\b.*(float\\b|python float)|\\bfloat\\b.*float32\\b" warp/_src/types.py | head -n 50

echo
echo "== show native_scalar_types and any mapping usage =="
rg -n "native_scalar_types|float32\\b|python\\b.*float" warp/_src/types.py | head -n 80

Repository: NVIDIA/warp

Length of output: 11202


Fix zeros() stub to bind return dtype when dtype is omitted

warp/_src/context.py::zeros() defaults dtype to Python float when omitted, and warp/_src/types.py::type_to_warp() maps that to warp.float32; however the current stub keeps dtype as type[DType] | None = None, so wp.zeros(10) can’t bind DType and the return often falls back to an imprecise type (e.g., array[Any]/unknown). Split overloads so omitted dtype returns array[float32] and explicit dtype binds DType.

Suggested stub shape
 `@over`
 def zeros(
     shape: int | tuple[int, ...] | list[int] | None = None,
-    dtype: type[DType] | None = None,
+    dtype: None = None,
     device: DeviceLike = None,
     requires_grad: _builtins.bool = False,
     pinned: _builtins.bool = False,
     retain_grad: _builtins.bool = False,
     **kwargs,
-) -> array[DType]:
+) -> array[float32]: ...
+
+@over
+def zeros(
+    shape: int | tuple[int, ...] | list[int] | None = None,
+    dtype: type[DType] = ...,
+    device: DeviceLike = None,
+    requires_grad: _builtins.bool = False,
+    pinned: _builtins.bool = False,
+    retain_grad: _builtins.bool = False,
+    **kwargs,
+) -> array[DType]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def zeros(
shape: int | tuple[int, ...] | list[int] | None = None,
dtype: type = float,
dtype: type[DType] | None = None,
device: DeviceLike = None,
requires_grad: _builtins.bool = False,
pinned: _builtins.bool = False,
retain_grad: _builtins.bool = False,
**kwargs,
) -> array:
) -> array[DType]:
"""Return a zero-initialized array."""
`@over`
def zeros(
shape: int | tuple[int, ...] | list[int] | None = None,
dtype: None = None,
device: DeviceLike = None,
requires_grad: _builtins.bool = False,
pinned: _builtins.bool = False,
retain_grad: _builtins.bool = False,
**kwargs,
) -> array[float32]: ...
`@over`
def zeros(
shape: int | tuple[int, ...] | list[int] | None = None,
dtype: type[DType] = ...,
device: DeviceLike = None,
requires_grad: _builtins.bool = False,
pinned: _builtins.bool = False,
retain_grad: _builtins.bool = False,
**kwargs,
) -> array[DType]:
"""Return a zero-initialized array."""
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@warp/__init__.pyi` around lines 1725 - 1734, The zeros() stub must be split
into two overloads so the return dtype binds when dtype is provided and defaults
to warp.float32 when omitted: add one overload for zeros(..., dtype: None =
None) -> array[float32] (or the module's float32 DType symbol) and a generic
overload zeros(..., dtype: type[DType]) -> array[DType]; keep other params
(shape, device, requires_grad, pinned, retain_grad, **kwargs) identical and
ensure the TypeVar DType and array type are the same symbols used elsewhere in
the stub to allow proper type inference for zeros(... ) and zeros(...,
dtype=...).

@greptile-apps

greptile-apps Bot commented Jun 8, 2026

Copy link
Copy Markdown

Greptile Summary

This draft PR improves static type inference for wp.array and its factory functions (zeros, ones, empty, full, from_numpy) by threading a DType TypeVar through their signatures, so type checkers resolve wp.zeros(10, dtype=wp.float32) as array[float32] rather than array[Any].

  • array.__init__ gains dtype: type[DType] (replacing dtype: Any), allowing constructor calls with an explicit dtype to produce a properly parameterised array[DType].
  • Factory functions in context.py switch from a concrete default (dtype=float) to dtype: type[DType] | None = None with a runtime if dtype is None: dtype = float guard; their return annotations are updated from bare warp.array to warp.array[DType].
  • The zeros merged stub in __init__.pyi is updated to match; ones, full, empty, and from_numpy flow through their re-export declarations and don't need a manual stub change.

Confidence Score: 4/5

Safe to merge as a draft foundation — all runtime behaviour is preserved by explicit None-guards and no logic paths are altered.

The changes are purely annotation-level for all functions except the addition of if dtype is None: dtype = float guards, which exactly replicate the previous default. The main open question is whether callers who omit dtype get the expected static type (array[float32]) — they don't; unbound DType resolves to array[Any], so strict-mode users still need explicit annotations in that case. This is a known limitation the PR author acknowledges, not a regression.

The from_numpy function in context.py warrants a second look: its no-dtype path silently returns array[Any] to the type checker even though the runtime correctly resolves the type, which could mislead callers expecting full inference from this PR.

Important Files Changed

Filename Overview
warp/_src/types.py Changes array.__init__ parameter annotation from dtype: Any = Any to dtype: type[DType] = Any, connecting the constructor's dtype argument to the class-level TypeVar so type checkers infer array[int32] instead of array[Any] when an explicit dtype is passed. No runtime changes.
warp/_src/context.py Updates zeros, ones, full, empty, and from_numpy to use `dtype: type[DType]
warp/init.pyi Updates only the zeros merged stub to `dtype: type[DType]

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["Caller: wp.zeros(10, dtype=wp.float32)"] --> B{"dtype provided?"}
    B -- "Yes: type[DType]" --> C["DType bound to float32"]
    B -- "No / None" --> D["DType unbound → Any\n(runtime: dtype = float)"]
    C --> E["Return: array[float32] ✓"]
    D --> F["Return: array[Any]\n(same as before)"]

    G["Caller: wp.array([1,2,3], dtype=wp.int32)"] --> H["array.__init__\ndtype: type[DType] = Any"]
    H --> I["DType bound to int32"]
    I --> J["Return: array[int32] ✓"]

    K["Caller: from_numpy(arr)"] --> L{"dtype provided?"}
    L -- "Yes: type[DType]" --> M["DType bound → array[DType] ✓"]
    L -- "No / None" --> N["Runtime: infers from arr.dtype\nStatic: array[Any]"]
Loading

Reviews (1): Last reviewed commit: "Infer wp.array DType from argument in co..." | Re-trigger Greptile

Comment thread warp/_src/context.py
Comment on lines 7648 to +7656
def zeros(
shape: int | tuple[int, ...] | list[int] | None = None,
dtype: type = float,
dtype: type[DType] | None = None,
device: DeviceLike = None,
requires_grad: bool = False,
pinned: bool = False,
retain_grad: bool = False,
**kwargs,
) -> warp.array:
) -> warp.array[DType]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Default dtype not statically reflected when omitted

Changing the parameter default from dtype: type = float to dtype: type[DType] | None = None means type checkers bind DType only when a concrete dtype is passed. When the user calls wp.zeros(10) (omitting dtype), DType remains unresolved and the return type degrades to array[Any] — the same situation mypy strict mode complained about before. The intended default of float / wp.float32 is only applied at runtime via the if dtype is None guard, invisible to the type checker.

Adding @typing.overload variants (one for dtype=None returning array[float32] and one for dtype: type[DType] returning array[DType]) would give fully typed results in both cases. This pattern is also needed for ones and empty, which receive the same treatment here.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment thread warp/_src/context.py
Comment on lines 8006 to +8013
def from_numpy(
arr: np.ndarray,
dtype: type | None = None,
dtype: type[DType] | None = None,
shape: Sequence[int] | None = None,
device: DeviceLike | None = None,
requires_grad: bool = False,
retain_grad: bool = False,
) -> warp.array:
) -> warp.array[DType]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 from_numpy with omitted dtype still returns array[Any]

When dtype is not supplied, the runtime successfully infers it from the numpy array's .dtype (resolving to e.g. float32, vec3, etc.), but the type checker sees an unbound DType and resolves the return as array[Any]. There is no static way to bridge the numpy dtype → Warp dtype mapping without explicit @overloads or a different approach — so this is a known limitation — but it's worth noting that the annotation currently overpromises for the no-dtype path: callers relying on inferred types from from_numpy(arr) won't benefit from this PR.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@christophercrouzet

Copy link
Copy Markdown
Member

Thanks @thomasbbrunner!

At first glance, I'm not seeing any obvious drawback to have these changes added, so let's give it a try.

Could you please:

  • Open a new issue.
  • Wrap up the other changes you planned on adding.
  • Add a CHANGELOG entry.
  • Reference the issue in this PR's title, the commit message, and the CHANGELOG.
  • Edit your commit message to include the DCO.

See https://nvidia.github.io/warp/stable/user_guide/contribution_guide.html for details.

As for the tests you suggested, we currently don't cover typing behaviour with our tests, but @shi-eric suggested extending the type-check-stubs CI job with a new step that would check the reveal_type() output (maybe using a python -c inline command if it's small enough?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants