Conversation
It is already an optional callback, so I would simply stop calling it and remove it from the behaviour description. |
nx/lib/nx.ex
Outdated
| defp cumulative_op_block(:cumulative_sum), do: %Nx.Block.CumulativeSum{} | ||
| defp cumulative_op_block(:cumulative_product), do: %Nx.Block.CumulativeProduct{} | ||
| defp cumulative_op_block(:cumulative_min), do: %Nx.Block.CumulativeMin{} | ||
| defp cumulative_op_block(:cumulative_max), do: %Nx.Block.CumulativeMax{} |
There was a problem hiding this comment.
Let's move this closer to the usage
nx/lib/nx.ex
Outdated
| Nx.Shared.optional(op, [tensor, [axis: axis, reverse: reverse]], tensor, fn tensor, opts -> | ||
| block = cumulative_op_block(op) | ||
|
|
||
| Nx.block(block, [tensor, [axis: axis, reverse: reverse]], tensor, fn ^block, tensor, opts -> |
There was a problem hiding this comment.
| Nx.block(block, [tensor, [axis: axis, reverse: reverse]], tensor, fn ^block, tensor, opts -> | |
| block(block, [tensor, [axis: axis, reverse: reverse]], tensor, fn ^block, tensor, opts -> |
given that we're in the Nx module already, we can omit the Nx. prefix
torchx/lib/torchx/backend.ex
Outdated
| @impl true | ||
| def optional(function_name, args, default_impl) do | ||
| def block(struct, args, fun) do | ||
| function_name = Nx.Block.name(struct) |
There was a problem hiding this comment.
We can remove this and change the case below to match on the structs directly
| mps_unsupported = [ | ||
| :lu, | ||
| :eigh, | ||
| :solve, | ||
| :determinant, | ||
| :cholesky, | ||
| :matrix_power | ||
| ] |
There was a problem hiding this comment.
This should also check on the struct name instead
torchx/lib/torchx/backend.ex
Outdated
|
|
||
| @impl true | ||
| def optional(function_name, args, default_impl) do | ||
| def block(struct, args, fun) do |
There was a problem hiding this comment.
| def block(struct, args, fun) do | |
| def block(%block_name{} = struct, args, fun) do |
There was a problem hiding this comment.
Then use block_name to match instead of the function_name conversion
| case apply(fun, [struct | params ++ opts]) do | ||
| %{data: %{context: context}} = res -> | ||
| expr(res, context, :optional, [expr(res, context, name, in_args), res, fun]) | ||
| expr(res, context, :block, [expr(res, context, name, in_args), res, fun]) | ||
|
|
||
| t when is_tuple(t) -> | ||
| context = elem(t, 0).data.context | ||
| out = expr(tuple_out(tuple_size(t)), context, name, in_args) | ||
| tuple(expr(out, context, :optional, [out, t, fun]), Tuple.to_list(t)) | ||
| tuple(expr(out, context, :block, [out, t, fun]), Tuple.to_list(t)) | ||
| end |
There was a problem hiding this comment.
As we discussed, I think the general shape of the expression needs less indirection. We want the expr to generally be (op: :block, args: [struct, args, fun]), with maybe this nice-to-have of also passing down to the compiler the expected output shape
d36c848 to
c350ab9
Compare
…t argument and updated deprecated :optional-based inspect/graph tests
nx/lib/nx.ex
Outdated
| cond do | ||
| function_exported?(backend, :block, 3) -> | ||
| backend.block(struct, args, fun) | ||
|
|
||
| function_exported?(backend, name, length(backend_args) + 1) -> | ||
| apply(backend, name, [output | backend_args]) | ||
|
|
||
| true -> | ||
| apply(fun, [struct | args]) | ||
| end | ||
| end |
There was a problem hiding this comment.
@josevalim we have three main approaches here
- current: keep optional callbacks with the same contract. this requires the Nx.Block.backend_args trick or some protocol implementation
- change the optional callbacks contracts to receive their structs as the first args and no opts (removes the backend_args trick)
- @polvalente 's preferred approach: remove optional callbacks altogether and backends have to deal with def block directly, which won't be optional
There was a problem hiding this comment.
Option 3. I think we always call block/3, we just make it easier to dispatch to the default implementation.
There was a problem hiding this comment.
Yeah, let's move forward with that then. This will let us get rid of the backend_args separation in Nx.Block since now backends will adapt to the new (struct, ...args) contract
| ) | ||
| when evec_type_kind != :c and eval_type_kind != :c do | ||
| when op == :block and evec_type_kind != :c and eval_type_kind != :c do | ||
| tensor = hd(in_args) |
There was a problem hiding this comment.
Let's pattern match on the args as it was done before!
| ) do | ||
| axis = opts[:axis] | ||
| ) | ||
| when op == :block do |
There was a problem hiding this comment.
Let's move it to pattern matching!
| cache | ||
| ) | ||
| when evec_type_kind != :c and eval_type_kind != :c do | ||
| when op == :block and evec_type_kind != :c and eval_type_kind != :c do |
There was a problem hiding this comment.
Let's move the op == :block back into a pattern matching!
| when op == :block do | ||
| tensor = hd(in_args) |
There was a problem hiding this comment.
Same as above, pattern matching!
| when op == :block do | ||
| tensor = hd(in_args) |
There was a problem hiding this comment.
Same as above, pattern matching!
| when op == :block do | ||
| tensor = hd(in_args) |
There was a problem hiding this comment.
Same as above, pattern matching!
| {call_args, cache} = Enum.map_reduce(in_args, cache, &recur_operator(&1, state, &2)) | ||
| key = computation_key(op, [struct | call_args ++ rest]) |
There was a problem hiding this comment.
Because you are mapping over in_args, it feels like opts will be included twice? Do we even have opts now? Aren't those all struct fields now?
| defp block_default_fun(expr, state, expr_cache, 8), | ||
| do: fn struct, a, b, c, d, e, f, g, h -> | ||
| block_apply_default(expr, state, expr_cache, struct, [a, b, c, d, e, f, g, h]) | ||
| end |
There was a problem hiding this comment.
We already have logic for this in Nx.Defn.Compiler.fun
josevalim
left a comment
There was a problem hiding this comment.
I did one pass. The current code is a bit mixed in style. My understanding is that in_args would not have opts anymore, because the opts would be field in the structure.
Futhermore, the structs are all values, never containers, so we shouldn't need to implement the Nx.Container protocol for them.
Extensible computation blocks (
Nx.block/4): structs, requiredbackend.block/4, nobackend_argsNx.block/4is the Nx API for named blocks: a%Nx.Block.*{}holds static options,argsis a list of tensor inputs (and optional trailing keyword fragments where the API needs them),outputis the result template (Nx.Tensoror tuple of tensors), andfunis the default implementationfn struct, ...tensors -> ... end.Dispatch:
Nx.block/4resolves a common backend from tensor leaves inargs, validatesfunarity (1 + length(args)), and calls only:backend.block(struct, output, args, fun)There is no second hop through
Nx.Shared.optional/4, nofunction_exported?(backend, op, …)beforeblock, and noNx.Block.backend_args/2— backends receive the same(struct, output, args, fun)contract and interpret structs / tensors themselves.Earlier call sites that used
Nx.Shared.optional/4(or similar) for things likelogical_notandphasenow useblockwithNx.Block.LogicalNot,Nx.Block.Phase, etc.Changes
Nx.block/4(nx/lib/nx.ex)funarity.backend = Nx.Shared.list_impl!(args)thenbackend.block(struct, output, args, fun)(single path).Nx.Backend(nx/lib/nx/backend.ex)@callback block(struct, output :: tensor | tuple, args :: [term], fun) :: tensor | tupleis required.apply(fun, [struct | args]).block/4or dropped from the behaviour where ops are only reached via blocks;@optional_callbacks/ duplicate **@impl**s are cleaned up accordingly (see Torchx).Nx.Block(nx/lib/nx/block.ex)Nx.Block.name/1maps struct modules to stable atoms (evaluator / cache / tooling).backend_args/2removed — no separate argument reshaping layer; options live on structs, tensors inargs.logical_not,phase,all_close, cumulatives,take,take_along_axis,top_k,fft2/ifft2, LinAlg (QR,Eigh,SVD,LU,Cholesky,Solve,Determinant, …) with@derive {Nx.Container, …}where fields crossdefnboundaries.Call sites and LinAlg
Nx.logical_not,phase,all_close, cumulatives,take,take_along_axis,top_k,fft2/ifft2, and LinAlg entry points useNx.block/4with the matching%Nx.Block.*{}and tensorargs.struct(Nx.Block.*, opts)and listargsas before;defncode consumes%Nx.Block.*{}and struct fields instead of ad-hoc maps where applicable.Defn / backends
Nx.Defn.Expr:block/3(public) /block/4(@impl Nx.Backend) buildop: :blockexpressions.Nx.Defn.Evaluator: evaluates:blockby building a defaultfunwhose arity matches the tensor prefix ofargs, thenbackend.block(struct, out, param_prefix, fun)(nobackend_args).exla/lib/exla/defn.ex)::blockcompilation / cache keys usestruct+ call args in a single contract (no separatebackend_args+extrassplit).torchx/lib/torchx/backend.ex):block/4pattern-matches onNx.Block.*(QR, LU, eigh, solve, SVD,TakeAlongAxis,FFT2,IFFT2,LogicalNot, …) for native paths, otherwiseapply(fun, [struct | args]); standalonefft2/ifft2behaviour impls removed where redundant; stray@implremoved for ops no longer onNx.Backend.Tests & tooling
nx/test/nx/optional_test.exs: backends exerciseblock/4; assertions updated where defaults run onBinaryBackend.fft2/ifft2doctests excluded intorchx/nx_doctest_test.exs(inspect/ signed-zero vsBinaryBackend);torchx/nx_block_test.exsadds numerical checks vsBinaryBackend.Nx.all_close/3: tolerance uses a safeabs(b)for integerb(e.g. viaf64) so signed-min overflow does not breakassert_all_close(relevant onceall_closegoes through%Nx.Block.AllClose{}).