Skip to content

Basic Nx.block implementation#1709

Draft
Chapaman wants to merge 20 commits intoelixir-nx:mainfrom
Chapaman:block_implementation
Draft

Basic Nx.block implementation#1709
Chapaman wants to merge 20 commits intoelixir-nx:mainfrom
Chapaman:block_implementation

Conversation

@Chapaman
Copy link
Copy Markdown
Contributor

@Chapaman Chapaman commented Mar 20, 2026

Extensible computation blocks (Nx.block/4): structs, required backend.block/4, no backend_args

Nx.block/4 is the Nx API for named blocks: a %Nx.Block.*{} holds static options, args is a list of tensor inputs (and optional trailing keyword fragments where the API needs them), output is the result template (Nx.Tensor or tuple of tensors), and fun is the default implementation fn struct, ...tensors -> ... end.

Dispatch: Nx.block/4 resolves a common backend from tensor leaves in args, validates fun arity (1 + length(args)), and calls only:

backend.block(struct, output, args, fun)

There is no second hop through Nx.Shared.optional/4, no function_exported?(backend, op, …) before block, and no Nx.Block.backend_args/2 — backends receive the same (struct, output, args, fun) contract and interpret structs / tensors themselves.

Earlier call sites that used Nx.Shared.optional/4 (or similar) for things like logical_not and phase now use block with Nx.Block.LogicalNot, Nx.Block.Phase, etc.

Changes

Nx.block/4 (nx/lib/nx.ex)

  • Validates default fun arity.
  • backend = Nx.Shared.list_impl!(args) then backend.block(struct, output, args, fun) (single path).

Nx.Backend (nx/lib/nx/backend.ex)

  • @callback block(struct, output :: tensor | tuple, args :: [term], fun) :: tensor | tuple is required.
  • Default portable behaviour: apply(fun, [struct | args]).
  • Former optional per-op surface is folded into block/4 or dropped from the behaviour where ops are only reached via blocks; @optional_callbacks / duplicate **@impl**s are cleaned up accordingly (see Torchx).

Nx.Block (nx/lib/nx/block.ex)

  • Nx.Block.name/1 maps struct modules to stable atoms (evaluator / cache / tooling).
  • backend_args/2 removed — no separate argument reshaping layer; options live on structs, tensors in args.
  • Structs for logical_not, phase, all_close, cumulatives, take, take_along_axis, top_k, fft2 / ifft2, LinAlg (QR, Eigh, SVD, LU, Cholesky, Solve, Determinant, …) with @derive {Nx.Container, …} where fields cross defn boundaries.

Call sites and LinAlg

  • Nx.logical_not, phase, all_close, cumulatives, take, take_along_axis, top_k, fft2 / ifft2, and LinAlg entry points use Nx.block/4 with the matching %Nx.Block.*{} and tensor args.
  • LinAlg helpers use struct(Nx.Block.*, opts) and list args as before; defn code consumes %Nx.Block.*{} and struct fields instead of ad-hoc maps where applicable.

Defn / backends

  • Nx.Defn.Expr: block/3 (public) / block/4 (@impl Nx.Backend) build op: :block expressions.
  • Nx.Defn.Evaluator: evaluates :block by building a default fun whose arity matches the tensor prefix of args, then backend.block(struct, out, param_prefix, fun) (no backend_args).
  • EXLA (exla/lib/exla/defn.ex): :block compilation / cache keys use struct + call args in a single contract (no separate backend_args + extras split).
  • Torchx (torchx/lib/torchx/backend.ex): block/4 pattern-matches on Nx.Block.* (QR, LU, eigh, solve, SVD, TakeAlongAxis, FFT2, IFFT2, LogicalNot, …) for native paths, otherwise apply(fun, [struct | args]); standalone fft2/ifft2 behaviour impls removed where redundant; stray @impl removed for ops no longer on Nx.Backend.

Tests & tooling

  • nx/test/nx/optional_test.exs: backends exercise block/4; assertions updated where defaults run on BinaryBackend.
  • Torchx: fft2 / ifft2 doctests excluded in torchx/nx_doctest_test.exs ( inspect / signed-zero vs BinaryBackend ); torchx/nx_block_test.exs adds numerical checks vs BinaryBackend.
  • Nx.all_close/3: tolerance uses a safe abs(b) for integer b (e.g. via f64) so signed-min overflow does not break assert_all_close (relevant once all_close goes through %Nx.Block.AllClose{}).

@Chapaman Chapaman closed this Mar 20, 2026
@Chapaman Chapaman reopened this Mar 20, 2026
@josevalim
Copy link
Copy Markdown
Contributor

Should we soft-deprecate Nx.Backend.optional/3 (and related patterns) in docs first, keep behavior, and migrate call sites gradually?

It is already an optional callback, so I would simply stop calling it and remove it from the behaviour description.

nx/lib/nx.ex Outdated
Comment on lines +6916 to +6919
defp cumulative_op_block(:cumulative_sum), do: %Nx.Block.CumulativeSum{}
defp cumulative_op_block(:cumulative_product), do: %Nx.Block.CumulativeProduct{}
defp cumulative_op_block(:cumulative_min), do: %Nx.Block.CumulativeMin{}
defp cumulative_op_block(:cumulative_max), do: %Nx.Block.CumulativeMax{}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this closer to the usage

nx/lib/nx.ex Outdated
Nx.Shared.optional(op, [tensor, [axis: axis, reverse: reverse]], tensor, fn tensor, opts ->
block = cumulative_op_block(op)

Nx.block(block, [tensor, [axis: axis, reverse: reverse]], tensor, fn ^block, tensor, opts ->
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Nx.block(block, [tensor, [axis: axis, reverse: reverse]], tensor, fn ^block, tensor, opts ->
block(block, [tensor, [axis: axis, reverse: reverse]], tensor, fn ^block, tensor, opts ->

given that we're in the Nx module already, we can omit the Nx. prefix

@impl true
def optional(function_name, args, default_impl) do
def block(struct, args, fun) do
function_name = Nx.Block.name(struct)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this and change the case below to match on the structs directly

Comment on lines 61 to 68
mps_unsupported = [
:lu,
:eigh,
:solve,
:determinant,
:cholesky,
:matrix_power
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also check on the struct name instead


@impl true
def optional(function_name, args, default_impl) do
def block(struct, args, fun) do
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def block(struct, args, fun) do
def block(%block_name{} = struct, args, fun) do

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then use block_name to match instead of the function_name conversion

Comment on lines +379 to 387
case apply(fun, [struct | params ++ opts]) do
%{data: %{context: context}} = res ->
expr(res, context, :optional, [expr(res, context, name, in_args), res, fun])
expr(res, context, :block, [expr(res, context, name, in_args), res, fun])

t when is_tuple(t) ->
context = elem(t, 0).data.context
out = expr(tuple_out(tuple_size(t)), context, name, in_args)
tuple(expr(out, context, :optional, [out, t, fun]), Tuple.to_list(t))
tuple(expr(out, context, :block, [out, t, fun]), Tuple.to_list(t))
end
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, I think the general shape of the expression needs less indirection. We want the expr to generally be (op: :block, args: [struct, args, fun]), with maybe this nice-to-have of also passing down to the compiler the expected output shape

@Chapaman Chapaman force-pushed the block_implementation branch from d36c848 to c350ab9 Compare March 26, 2026 00:11
nx/lib/nx.ex Outdated
Comment on lines +6905 to +6915
cond do
function_exported?(backend, :block, 3) ->
backend.block(struct, args, fun)

function_exported?(backend, name, length(backend_args) + 1) ->
apply(backend, name, [output | backend_args])

true ->
apply(fun, [struct | args])
end
end
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josevalim we have three main approaches here

  1. current: keep optional callbacks with the same contract. this requires the Nx.Block.backend_args trick or some protocol implementation
  2. change the optional callbacks contracts to receive their structs as the first args and no opts (removes the backend_args trick)
  3. @polvalente 's preferred approach: remove optional callbacks altogether and backends have to deal with def block directly, which won't be optional

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option 3. I think we always call block/3, we just make it easier to dispatch to the default implementation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's move forward with that then. This will let us get rid of the backend_args separation in Nx.Block since now backends will adapt to the new (struct, ...args) contract

)
when evec_type_kind != :c and eval_type_kind != :c do
when op == :block and evec_type_kind != :c and eval_type_kind != :c do
tensor = hd(in_args)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's pattern match on the args as it was done before!

) do
axis = opts[:axis]
)
when op == :block do
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move it to pattern matching!

cache
)
when evec_type_kind != :c and eval_type_kind != :c do
when op == :block and evec_type_kind != :c and eval_type_kind != :c do
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move the op == :block back into a pattern matching!

Comment on lines +726 to +727
when op == :block do
tensor = hd(in_args)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, pattern matching!

Comment on lines +741 to +742
when op == :block do
tensor = hd(in_args)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, pattern matching!

Comment on lines +760 to +761
when op == :block do
tensor = hd(in_args)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, pattern matching!

Comment on lines +779 to +780
{call_args, cache} = Enum.map_reduce(in_args, cache, &recur_operator(&1, state, &2))
key = computation_key(op, [struct | call_args ++ rest])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you are mapping over in_args, it feels like opts will be included twice? Do we even have opts now? Aren't those all struct fields now?

defp block_default_fun(expr, state, expr_cache, 8),
do: fn struct, a, b, c, d, e, f, g, h ->
block_apply_default(expr, state, expr_cache, struct, [a, b, c, d, e, f, g, h])
end
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have logic for this in Nx.Defn.Compiler.fun

Copy link
Copy Markdown
Contributor

@josevalim josevalim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did one pass. The current code is a bit mixed in style. My understanding is that in_args would not have opts anymore, because the opts would be field in the structure.

Futhermore, the structs are all values, never containers, so we shouldn't need to implement the Nx.Container protocol for them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants