From 75bbdd61f2b22cf1206918dfe6b0c6ce7582d7ff Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 27 Aug 2025 10:03:05 -0400 Subject: [PATCH 1/2] try restore softmax tpu --- ext/ReactantNNlibExt/Implementations.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ext/ReactantNNlibExt/Implementations.jl b/ext/ReactantNNlibExt/Implementations.jl index 9c416fd133..9eec348e7e 100644 --- a/ext/ReactantNNlibExt/Implementations.jl +++ b/ext/ReactantNNlibExt/Implementations.jl @@ -12,11 +12,11 @@ function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) whe diff = exp.(x .- max_) # TOOD: re-enable conditional once https://github.com/EnzymeAD/Reactant.jl/issues/1581 # fixed - # @trace if all(isfinite, max_) - @. out = diff - # else - # @. out = ifelse(isinf(max_), ifelse(isinf(x), T(1), T(0)), diff) - # end + @trace if all(isfinite, max_) + @. out = diff + else + @. out = ifelse(isinf(max_), ifelse(isinf(x), T(1), T(0)), diff) + end out ./= sum(out; dims) return out end From 319f935c745801256a92946ffd5fd9ef64f516aa Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 27 Aug 2025 11:33:39 -0400 Subject: [PATCH 2/2] Set DUMP_MLIR_ALWAYS to true Enable MLIR dumping for debugging purposes. --- ext/ReactantNNlibExt/Implementations.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/ReactantNNlibExt/Implementations.jl b/ext/ReactantNNlibExt/Implementations.jl index 9eec348e7e..5640dcc92c 100644 --- a/ext/ReactantNNlibExt/Implementations.jl +++ b/ext/ReactantNNlibExt/Implementations.jl @@ -6,6 +6,8 @@ for (jlop, hloop) in ( @eval $(jlop)(x::TracedRNumber) = @opcall $(hloop)(x) end +Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true + function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N} x = T.(Reactant.materialize_traced_array(x)) max_ = maximum(x; dims)