From d72680a8e8fe0d087d8d41d1f0774eee87db303c Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 8 Aug 2025 12:12:16 -0300 Subject: [PATCH 1/4] feat: add initial draft --- exla/lib/exla/backend.ex | 11 +++++++++ exla/lib/exla/defn.ex | 17 +++++++++++++ exla/lib/exla/mlir/value.ex | 13 ++++++++++ nx/lib/nx/backend.ex | 10 ++++++++ nx/lib/nx/binary_backend.ex | 5 ++++ nx/lib/nx/defn/evaluator.ex | 47 ++++++++++++++++++++++++++++++++++++ nx/lib/nx/defn/expr.ex | 18 ++++++++++++++ nx/lib/nx/defn/tree.ex | 13 ++++++++++ nx/lib/nx/shared.ex | 21 ++++++++++++++++ torchx/lib/torchx/backend.ex | 5 ++++ 10 files changed, 160 insertions(+) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 50747de4bb0..539cc911ffb 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -329,6 +329,17 @@ defmodule EXLA.Backend do jit([], wrapper_fun, tensors, [List.to_tuple(tensors)]) end + @impl true + def elixir_call(name, args, fun) do + {tensors, rest} = Enum.split_while(args, &is_struct(&1, Nx.Tensor)) + + wrapper_fun = fn tensors -> + Nx.Defn.Expr.elixir_call(name, Tuple.to_list(tensors) ++ rest, fun) + end + + jit([], wrapper_fun, tensors, [List.to_tuple(tensors)]) + end + binary_ops = [:add, :subtract, :multiply, :pow, :remainder, :divide, :atan2, :min, :max, :quotient] ++ [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] ++ diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 413c38ce45a..b899be577ac 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -546,6 +546,23 @@ defmodule EXLA.Defn do end end + defp cached_recur_operator(:elixir_call, %T{data: %Expr{args: args}}, state, cache) do + [call, expr, _callback] = args + %{data: %{args: in_args}} = call + + {args, opts} = Enum.split_while(in_args, &(not is_list(&1))) + {_opts, _ignored} = {opts, nil} + + {operands, cache} = Enum.map_reduce(args, cache, &recur_operator(&1, state, &2)) + + out_typespecs = container_to_typespecs(expr) + + # Emit a generic custom call that the EXLA runtime can bind to Erlang/Elixir. + results = Value.custom_call(state.builder, "nx_elixir_custom_call", operands, out_typespecs) + + {wrap_tuple_result(results, expr), cache} + end + defp cached_recur_operator( :lu, %T{ diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index f955e672006..69c63e0bb52 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -840,6 +840,19 @@ defmodule EXLA.MLIR.Value do |> one!() end + def custom_call( + %Function{} = func, + call_target_name, + operands, + out_typespecs, + extra_attrs \\ [] + ) do + result_types = typespecs_to_mlir_types(out_typespecs) + attributes = [call_target_name: attr_string(call_target_name), api_version: attr_i32(4), has_side_effect: attr_boolean(true)] + attributes = attributes ++ extra_attrs + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + end + def get_typespec(value) do EXLA.NIF.mlir_get_typespec(value.ref) end diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 3c463ba2373..df0abdcc3ef 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -142,6 +142,15 @@ defmodule Nx.Backend do """ @callback optional(atom, [term], fun) :: tensor + @doc """ + Invoked to execute a generic Elixir callback from within defn. + + The backend may choose how to execute it. For example, EXLA can lower + to a custom_call that interacts with Erlang/Elixir via C; pure CPU + backends may call the function directly. + """ + @callback elixir_call(atom, [term], fun) :: tensor + @callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor @callback cholesky(out :: tensor, tensor) :: tensor @callback eigh({eigenvals :: tensor, eigenvecs :: tensor}, tensor, keyword) :: tensor @@ -162,6 +171,7 @@ defmodule Nx.Backend do @optional_callbacks [ optional: 3, + elixir_call: 3, solve: 3, determinant: 2, logical_not: 2, diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 974d558b0de..75eb88e8dde 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2658,4 +2658,9 @@ defmodule Nx.BinaryBackend do defp bitstring_copy(bitstring, n) do for _ <- 1..n, into: <<>>, do: bitstring end + + @impl true + def elixir_call(_name, args, fun) when is_function(fun) do + apply(fun, args) + end end diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index d028ec6a63c..86b9b0b0190 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -175,6 +175,28 @@ defmodule Nx.Defn.Evaluator do Map.put(cache, [:optional | id], optional_expr_cache) end + defp compute_cache(:elixir_call, %{data: %Expr{args: args, id: id}}, state, cache) do + [call, expr, _callback] = args + %{data: %{args: call_args_in, op: call_name}} = call + + {call_args, opts} = Enum.split_while(call_args_in, &(not is_list(&1))) + + cache = Enum.reduce(call_args, cache, &compute_cache(&1, state, &2)) + key = computation_key(call_name, call_args ++ opts) + + {optional_expr_cache, cache} = + case cache do + %{^key => optional_expr_cache} -> + {optional_expr_cache, cache} + + %{} -> + optional_expr_cache = {expr, init_compute_cache(expr, state)} + {optional_expr_cache, Map.put(cache, key, optional_expr_cache)} + end + + Map.put(cache, [:optional | id], optional_expr_cache) + end + defp compute_cache(:cond, %{data: %Expr{args: [clauses, last], id: id}}, state, cache) do %{parent_ids: parent_ids, current_ids: current_ids} = state @@ -431,6 +453,31 @@ defmodule Nx.Defn.Evaluator do end end + defp eval_apply( + :elixir_call, + %{data: %Expr{args: [call, out, _callback], id: id}}, + state, + caches + ) do + {args, caches} = Tree.apply_args(call, caches, &eval(&1, state, &2)) + backend = Nx.Shared.list_impl!(args) + + if function_exported?(backend, call.data.op, length(args) + 1) do + out = + case call do + %{type: {:tuple, _}} -> out + _ -> call + end + + {apply(backend, call.data.op, [out | args]), caches} + else + params = Enum.map(args, &fn -> &1 end) + {{expr, optional_cache}, caches} = pop_cache!(caches, [:optional | id]) + {res, _} = composite_eval(expr, %{state | params: params}, [optional_cache]) + {res, caches} + end + end + defp eval_apply(op, %{vectorized_axes: [_ | _]} = ans, _state, _caches) do raise "unexpected vectorized axes in evaluator for operation #{inspect(op)}: #{inspect(ans)}" end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 782e4a07fd7..eb457fe01ee 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -41,6 +41,8 @@ defmodule Nx.Defn.Expr do * `attach_token(token(%Nx.Defn.Token{}), expr)` + * `elixir_call(name, args, fun)` + `defn` compilers must handle said nodes accordingly. """ @@ -384,6 +386,22 @@ defmodule Nx.Defn.Expr do end end + @impl true + def elixir_call(name, in_args, fun) do + {args, opts} = Enum.split_while(in_args, &(not is_list(&1))) + params = Enum.with_index(args, ¶meter/2) + + case apply(fun, params ++ opts) do + %{data: %{context: context}} = res -> + expr(res, context, :elixir_call, [expr(res, context, name, in_args), res, fun]) + + t when is_tuple(t) -> + context = elem(t, 0).data.context + out = expr(tuple_out(tuple_size(t)), context, name, in_args) + tuple(expr(out, context, :elixir_call, [out, t, fun]), Tuple.to_list(t)) + end + end + ## Nx.Defn AST callbacks @doc false diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index 582b9d46892..9ad840542bb 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -192,6 +192,19 @@ defmodule Nx.Defn.Tree do {[call, expr, callback], acc} end + def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, type, acc, fun) do + [call, expr, callback] = args + {call, acc} = fun.(call, acc) + + {expr, acc} = + case type do + :all -> Composite.traverse(expr, acc, fun) + :scope -> {expr, acc} + end + + {[call, expr, callback], acc} + end + def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do {hooks, acc} = Enum.map_reduce(token.hooks, acc, fn %{expr: expr} = token, acc -> diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index 156748859c4..5285dc4f0ee 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -583,6 +583,27 @@ defmodule Nx.Shared do "expected default implementation to match template #{inspect(right)}, got: #{inspect(left)}" end + @doc false + def elixir_call(output, function_name, args, default_impl) + when is_atom(function_name) and is_list(args) and is_function(default_impl) do + arity = length(args) + 1 + backend = list_impl!(args) + + cond do + function_exported?(backend, function_name, arity) -> + apply(backend, function_name, [output | args]) + + function_exported?(backend, :elixir_call, 3) -> + backend.elixir_call(function_name, args, default_impl) + |> ensure_optional_compatible!(output) + + true -> + default_impl + |> apply(args) + |> ensure_optional_compatible!(output) + end + end + @doc false def raise_complex_not_supported(function, arity) do raise ArgumentError, "Nx.#{function}/#{arity} does not support complex inputs" diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 77308ce94de..1a02fc78dd7 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1825,4 +1825,9 @@ defmodule Torchx.Backend do raise "operation #{unquote(fun)} is not supported on Torchx.Backend" end end + + @impl true + def elixir_call(_name, args, fun) when is_function(fun) do + apply(fun, args) + end end From da7d7e44d7c4ef4e91d0897711cab09d1399b5a9 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 8 Aug 2025 13:54:12 -0300 Subject: [PATCH 2/4] evaluator mode working --- elixir_call.exs | 44 +++++++++++++++++++++++++++++++++++ exla/lib/exla/backend.ex | 10 ++------ nx/lib/nx/backend.ex | 2 +- nx/lib/nx/binary_backend.ex | 2 +- nx/lib/nx/defn/evaluator.ex | 45 ++++++++++-------------------------- nx/lib/nx/defn/expr.ex | 20 ++++++++-------- nx/lib/nx/defn/tree.ex | 20 ++++++++-------- nx/lib/nx/shared.ex | 10 +++----- torchx/lib/torchx/backend.ex | 2 +- 9 files changed, 85 insertions(+), 70 deletions(-) create mode 100644 elixir_call.exs diff --git a/elixir_call.exs b/elixir_call.exs new file mode 100644 index 00000000000..a70c76dd9db --- /dev/null +++ b/elixir_call.exs @@ -0,0 +1,44 @@ +Mix.install([{:exla, path: "exla"}, {:pythonx, "~> 0.4"}]) + +Pythonx.uv_init(""" +[project] +name = "project" +version = "0.0.0" +requires-python = "==3.13.*" +dependencies = [ + "numpy==2.2.2" +] +""") + +Nx.global_default_backend(EXLA.Backend) +t = Nx.iota({10}) + +elixir_fun = fn t, opts -> + input = Nx.to_flat_list(t) + + {res, _ctx} = + Pythonx.eval( + """ + import numpy as np + arr = np.array(input) + + c = np.cos(arr) + offset + + list(c) + """, + %{"input" => input, "offset" => opts[:value]} + ) + + Nx.tensor(Pythonx.decode(res)) +end + +jit_fun = fn t -> + s = Nx.size(t) + + out = + Nx.Shared.elixir_call(%{t | type: Nx.Type.to_floating(t.type)}, [t, [value: 10]], elixir_fun) + + Nx.negate(out) +end + +dbg(Nx.Defn.jit_apply(jit_fun, [t])) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 539cc911ffb..cc8aef1bfdd 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -330,14 +330,8 @@ defmodule EXLA.Backend do end @impl true - def elixir_call(name, args, fun) do - {tensors, rest} = Enum.split_while(args, &is_struct(&1, Nx.Tensor)) - - wrapper_fun = fn tensors -> - Nx.Defn.Expr.elixir_call(name, Tuple.to_list(tensors) ++ rest, fun) - end - - jit([], wrapper_fun, tensors, [List.to_tuple(tensors)]) + def elixir_call(_out, args, fun) do + apply(fun, args) end binary_ops = diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index df0abdcc3ef..f8556ce308d 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -149,7 +149,7 @@ defmodule Nx.Backend do to a custom_call that interacts with Erlang/Elixir via C; pure CPU backends may call the function directly. """ - @callback elixir_call(atom, [term], fun) :: tensor + @callback elixir_call(out :: tensor | tuple, [term], fun) :: tensor @callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor @callback cholesky(out :: tensor, tensor) :: tensor diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 75eb88e8dde..478f2760171 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2660,7 +2660,7 @@ defmodule Nx.BinaryBackend do end @impl true - def elixir_call(_name, args, fun) when is_function(fun) do + def elixir_call(_out, args, fun) when is_function(fun) do apply(fun, args) end end diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 86b9b0b0190..601653df8bc 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -176,25 +176,12 @@ defmodule Nx.Defn.Evaluator do end defp compute_cache(:elixir_call, %{data: %Expr{args: args, id: id}}, state, cache) do - [call, expr, _callback] = args - %{data: %{args: call_args_in, op: call_name}} = call - - {call_args, opts} = Enum.split_while(call_args_in, &(not is_list(&1))) - - cache = Enum.reduce(call_args, cache, &compute_cache(&1, state, &2)) - key = computation_key(call_name, call_args ++ opts) - - {optional_expr_cache, cache} = - case cache do - %{^key => optional_expr_cache} -> - {optional_expr_cache, cache} - - %{} -> - optional_expr_cache = {expr, init_compute_cache(expr, state)} - {optional_expr_cache, Map.put(cache, key, optional_expr_cache)} - end + [in_args, _fun] = args - Map.put(cache, [:optional | id], optional_expr_cache) + Enum.reduce(in_args, cache, fn + t, cache when is_list(t) -> cache + t, cache -> compute_cache(t, state, cache) + end) end defp compute_cache(:cond, %{data: %Expr{args: [clauses, last], id: id}}, state, cache) do @@ -455,26 +442,18 @@ defmodule Nx.Defn.Evaluator do defp eval_apply( :elixir_call, - %{data: %Expr{args: [call, out, _callback], id: id}}, + %{data: %Expr{args: [in_args, fun], id: id}} = expr, state, caches ) do - {args, caches} = Tree.apply_args(call, caches, &eval(&1, state, &2)) - backend = Nx.Shared.list_impl!(args) - - if function_exported?(backend, call.data.op, length(args) + 1) do - out = - case call do - %{type: {:tuple, _}} -> out - _ -> call - end + {tensor_args, opts} = Enum.split_while(in_args, &(not is_list(&1))) + {evaluated_tensors, caches} = Enum.map_reduce(tensor_args, caches, &eval(&1, state, &2)) + backend = Nx.Shared.list_impl!(evaluated_tensors) - {apply(backend, call.data.op, [out | args]), caches} + if backend == Nx.Defn.Expr do + {expr, caches} else - params = Enum.map(args, &fn -> &1 end) - {{expr, optional_cache}, caches} = pop_cache!(caches, [:optional | id]) - {res, _} = composite_eval(expr, %{state | params: params}, [optional_cache]) - {res, caches} + {apply(fun, evaluated_tensors ++ opts), caches} end end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index eb457fe01ee..1d488df8883 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -387,18 +387,18 @@ defmodule Nx.Defn.Expr do end @impl true - def elixir_call(name, in_args, fun) do - {args, opts} = Enum.split_while(in_args, &(not is_list(&1))) - params = Enum.with_index(args, ¶meter/2) + def elixir_call(out, in_args, fun) do + {tensor_args, _opts} = Enum.split_while(in_args, &(not is_list(&1))) + [%T{data: %Expr{context: context}} | _] = Enum.map(tensor_args, &to_expr/1) - case apply(fun, params ++ opts) do - %{data: %{context: context}} = res -> - expr(res, context, :elixir_call, [expr(res, context, name, in_args), res, fun]) + case out do + t when is_struct(t, Nx.Tensor) -> + expr(t, context, :elixir_call, [in_args, fun]) - t when is_tuple(t) -> - context = elem(t, 0).data.context - out = expr(tuple_out(tuple_size(t)), context, name, in_args) - tuple(expr(out, context, :elixir_call, [out, t, fun]), Tuple.to_list(t)) + tuple when is_tuple(tuple) -> + out_template = tuple_out(tuple_size(tuple)) + expr_node = expr(out_template, context, :elixir_call, [in_args, fun]) + tuple(expr_node, Tuple.to_list(tuple)) end end diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index 9ad840542bb..bd741346e7f 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -193,16 +193,18 @@ defmodule Nx.Defn.Tree do end def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, type, acc, fun) do - [call, expr, callback] = args - {call, acc} = fun.(call, acc) - - {expr, acc} = - case type do - :all -> Composite.traverse(expr, acc, fun) - :scope -> {expr, acc} - end + [in_args, callback] = args + + {in_args, acc} = + Enum.map_reduce(in_args, acc, fn t, acc -> + if is_list(t) do + {t, acc} + else + Composite.traverse(t, acc, fun) + end + end) - {[call, expr, callback], acc} + {[in_args, callback], acc} end def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index 5285dc4f0ee..e30b699fc5f 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -584,21 +584,17 @@ defmodule Nx.Shared do end @doc false - def elixir_call(output, function_name, args, default_impl) - when is_atom(function_name) and is_list(args) and is_function(default_impl) do + def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do arity = length(args) + 1 backend = list_impl!(args) cond do - function_exported?(backend, function_name, arity) -> - apply(backend, function_name, [output | args]) - function_exported?(backend, :elixir_call, 3) -> - backend.elixir_call(function_name, args, default_impl) + backend.elixir_call(output, args, fun) |> ensure_optional_compatible!(output) true -> - default_impl + fun |> apply(args) |> ensure_optional_compatible!(output) end diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 1a02fc78dd7..eb813e1e264 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1827,7 +1827,7 @@ defmodule Torchx.Backend do end @impl true - def elixir_call(_name, args, fun) when is_function(fun) do + def elixir_call(_out, args, fun) when is_function(fun) do apply(fun, args) end end From fc9c28ca0ade383c308bdccfba37b81702770320 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 9 Aug 2025 14:37:20 -0300 Subject: [PATCH 3/4] test: add tests --- elixir_call.exs | 44 ------------- exla/lib/exla/backend.ex | 5 -- exla/lib/exla/defn.ex | 21 ++----- exla/test/exla/defn/elixir_call_test.exs | 61 +++++++++++++++++++ nx/lib/nx.ex | 55 +++++++++++++++++ nx/lib/nx/binary_backend.ex | 5 -- nx/lib/nx/defn/evaluator.ex | 4 +- nx/lib/nx/defn/tree.ex | 2 +- nx/lib/nx/shared.ex | 17 ------ .../nx/defn/elixir_call_evaluator_test.exs | 49 +++++++++++++++ torchx/lib/torchx/backend.ex | 5 -- torchx/mix.exs | 4 +- torchx/test/torchx/defn/elixir_call_test.exs | 51 ++++++++++++++++ 13 files changed, 225 insertions(+), 98 deletions(-) delete mode 100644 elixir_call.exs create mode 100644 exla/test/exla/defn/elixir_call_test.exs create mode 100644 nx/test/nx/defn/elixir_call_evaluator_test.exs create mode 100644 torchx/test/torchx/defn/elixir_call_test.exs diff --git a/elixir_call.exs b/elixir_call.exs deleted file mode 100644 index a70c76dd9db..00000000000 --- a/elixir_call.exs +++ /dev/null @@ -1,44 +0,0 @@ -Mix.install([{:exla, path: "exla"}, {:pythonx, "~> 0.4"}]) - -Pythonx.uv_init(""" -[project] -name = "project" -version = "0.0.0" -requires-python = "==3.13.*" -dependencies = [ - "numpy==2.2.2" -] -""") - -Nx.global_default_backend(EXLA.Backend) -t = Nx.iota({10}) - -elixir_fun = fn t, opts -> - input = Nx.to_flat_list(t) - - {res, _ctx} = - Pythonx.eval( - """ - import numpy as np - arr = np.array(input) - - c = np.cos(arr) + offset - - list(c) - """, - %{"input" => input, "offset" => opts[:value]} - ) - - Nx.tensor(Pythonx.decode(res)) -end - -jit_fun = fn t -> - s = Nx.size(t) - - out = - Nx.Shared.elixir_call(%{t | type: Nx.Type.to_floating(t.type)}, [t, [value: 10]], elixir_fun) - - Nx.negate(out) -end - -dbg(Nx.Defn.jit_apply(jit_fun, [t])) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index cc8aef1bfdd..50747de4bb0 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -329,11 +329,6 @@ defmodule EXLA.Backend do jit([], wrapper_fun, tensors, [List.to_tuple(tensors)]) end - @impl true - def elixir_call(_out, args, fun) do - apply(fun, args) - end - binary_ops = [:add, :subtract, :multiply, :pow, :remainder, :divide, :atan2, :min, :max, :quotient] ++ [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] ++ diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index b899be577ac..7a7b3865a81 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -546,23 +546,6 @@ defmodule EXLA.Defn do end end - defp cached_recur_operator(:elixir_call, %T{data: %Expr{args: args}}, state, cache) do - [call, expr, _callback] = args - %{data: %{args: in_args}} = call - - {args, opts} = Enum.split_while(in_args, &(not is_list(&1))) - {_opts, _ignored} = {opts, nil} - - {operands, cache} = Enum.map_reduce(args, cache, &recur_operator(&1, state, &2)) - - out_typespecs = container_to_typespecs(expr) - - # Emit a generic custom call that the EXLA runtime can bind to Erlang/Elixir. - results = Value.custom_call(state.builder, "nx_elixir_custom_call", operands, out_typespecs) - - {wrap_tuple_result(results, expr), cache} - end - defp cached_recur_operator( :lu, %T{ @@ -1226,6 +1209,10 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end + defp to_operator(:elixir_call, _, _, _) do + raise "Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler." + end + defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do n = opts[:length] axis = opts[:axis] diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs new file mode 100644 index 00000000000..add051a3f6e --- /dev/null +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -0,0 +1,61 @@ +defmodule EXLA.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + import Nx.Testing + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + Nx.default_backend(EXLA.Backend) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert_equal(y, expected) + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) + end + + test "fails when using EXLA compiler" do + x = Nx.tensor([1, 2, 3]) + + assert_raise RuntimeError, + "Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler.", + fn -> + EXLA.jit_apply(&split_and_sum/1, [x]) + end + end +end diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 091372d0057..715a149286f 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2196,6 +2196,61 @@ defmodule Nx do list end + @doc """ + Invokes an Elixir function from within defn. + + This function allows integrating arbitrary Elixir code into `defn` graphs. + It receives an output template (a tensor or a tuple of tensors) that + specifies the expected shapes, types, and names of the result, a list of + arguments to pass to the Elixir function, and the function itself. + + Inside `defn`, this builds an expression node understood by compilers. + Outside `defn` or on backends without special support, it executes `fun` + directly and validates the result matches the template. + """ + @doc type: :backend + def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do + {:arity, arity} = Function.info(fun, :arity) + num_args = length(args) + + if arity != num_args do + raise ArgumentError, + "expected #{arity} arguments, got #{num_args}" + end + + backend = Nx.Shared.list_impl!(args) + + cond do + function_exported?(backend, :elixir_call, 3) -> + output + |> backend.elixir_call(args, fun) + |> ensure_call_compatible!(output) + + true -> + fun + |> apply(args) + |> ensure_call_compatible!(output) + end + end + + defp ensure_call_compatible!(left, right) when tuple_size(left) == tuple_size(right) do + [Tuple.to_list(left), Tuple.to_list(right)] + |> Enum.zip_with(fn [l, r] -> ensure_call_compatible!(l, r) end) + + left + end + + defp ensure_call_compatible!( + %{shape: shape, type: type, names: names} = left, + %{shape: shape, type: type, names: names} + ), + do: left + + defp ensure_call_compatible!(left, right) do + raise ArgumentError, + "expected the elixir_call function to match the given output template #{inspect(right)}, got: #{inspect(left)}" + end + defp chunk([], data, type) do match_types [type] do <> = data diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 478f2760171..974d558b0de 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2658,9 +2658,4 @@ defmodule Nx.BinaryBackend do defp bitstring_copy(bitstring, n) do for _ <- 1..n, into: <<>>, do: bitstring end - - @impl true - def elixir_call(_out, args, fun) when is_function(fun) do - apply(fun, args) - end end diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 601653df8bc..c913f4ec3c4 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -175,7 +175,7 @@ defmodule Nx.Defn.Evaluator do Map.put(cache, [:optional | id], optional_expr_cache) end - defp compute_cache(:elixir_call, %{data: %Expr{args: args, id: id}}, state, cache) do + defp compute_cache(:elixir_call, %{data: %Expr{args: args}}, state, cache) do [in_args, _fun] = args Enum.reduce(in_args, cache, fn @@ -442,7 +442,7 @@ defmodule Nx.Defn.Evaluator do defp eval_apply( :elixir_call, - %{data: %Expr{args: [in_args, fun], id: id}} = expr, + %{data: %Expr{args: [in_args, fun]}} = expr, state, caches ) do diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index bd741346e7f..733a131e4f6 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -192,7 +192,7 @@ defmodule Nx.Defn.Tree do {[call, expr, callback], acc} end - def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, type, acc, fun) do + def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, _type, acc, fun) do [in_args, callback] = args {in_args, acc} = diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index e30b699fc5f..156748859c4 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -583,23 +583,6 @@ defmodule Nx.Shared do "expected default implementation to match template #{inspect(right)}, got: #{inspect(left)}" end - @doc false - def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do - arity = length(args) + 1 - backend = list_impl!(args) - - cond do - function_exported?(backend, :elixir_call, 3) -> - backend.elixir_call(output, args, fun) - |> ensure_optional_compatible!(output) - - true -> - fun - |> apply(args) - |> ensure_optional_compatible!(output) - end - end - @doc false def raise_complex_not_supported(function, arity) do raise ArgumentError, "Nx.#{function}/#{arity} does not support complex inputs" diff --git a/nx/test/nx/defn/elixir_call_evaluator_test.exs b/nx/test/nx/defn/elixir_call_evaluator_test.exs new file mode 100644 index 00000000000..92fad6b431e --- /dev/null +++ b/nx/test/nx/defn/elixir_call_evaluator_test.exs @@ -0,0 +1,49 @@ +defmodule Nx.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert Nx.all_close(y, expected) |> Nx.to_number() == 1 + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert expected == y + end +end diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index eb813e1e264..77308ce94de 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1825,9 +1825,4 @@ defmodule Torchx.Backend do raise "operation #{unquote(fun)} is not supported on Torchx.Backend" end end - - @impl true - def elixir_call(_out, args, fun) when is_function(fun) do - apply(fun, args) - end end diff --git a/torchx/mix.exs b/torchx/mix.exs index fa5531e5411..5174e59cdb8 100644 --- a/torchx/mix.exs +++ b/torchx/mix.exs @@ -41,8 +41,8 @@ defmodule Torchx.MixProject do defp deps do [ - {:nx, "~> 0.10.0"}, - # {:nx, path: "../nx"}, + # {:nx, "~> 0.10.0"}, + {:nx, path: "../nx"}, {:ex_doc, "~> 0.29", only: :docs} ] end diff --git a/torchx/test/torchx/defn/elixir_call_test.exs b/torchx/test/torchx/defn/elixir_call_test.exs new file mode 100644 index 00000000000..9c504fa6c8b --- /dev/null +++ b/torchx/test/torchx/defn/elixir_call_test.exs @@ -0,0 +1,51 @@ +defmodule Torchx.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + import Nx.Testing + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + Nx.default_backend(Torchx.Backend) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert_equal(y, expected) + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) + end +end From 25300b7d1960a8ac2136eb94f3fef37a9e7bc52a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 9 Aug 2025 14:41:54 -0300 Subject: [PATCH 4/4] fix grad --- exla/lib/exla/mlir/value.ex | 13 ------------- nx/lib/nx/defn/grad.ex | 4 ++++ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 69c63e0bb52..f955e672006 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -840,19 +840,6 @@ defmodule EXLA.MLIR.Value do |> one!() end - def custom_call( - %Function{} = func, - call_target_name, - operands, - out_typespecs, - extra_attrs \\ [] - ) do - result_types = typespecs_to_mlir_types(out_typespecs) - attributes = [call_target_name: attr_string(call_target_name), api_version: attr_i32(4), has_side_effect: attr_boolean(true)] - attributes = attributes ++ extra_attrs - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) - end - def get_typespec(value) do EXLA.NIF.mlir_get_typespec(value.ref) end diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 2941889f989..8c72d0fed03 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -122,6 +122,10 @@ defmodule Nx.Defn.Grad do acc end + defp parents_args(:elixir_call, _expr, _id, acc, _parent_vectorized_names) do + acc + end + defp parents_args( :optional, %{data: %{args: [call, _expr, callback]}} = t,