Skip to content

Commit 17901eb

Browse files
fix: make scalarize more consistent with old behavior
1 parent 2aeba4c commit 17901eb

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

src/substitute.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,26 @@ head. The returned function is passed `f`, the expression with `f` as the head,
293293
"""
294294
scalarization_function(@nospecialize(_)) = _default_scalarize
295295

296+
scalarization_function(::Union{typeof(+), typeof(-), typeof(*), typeof(/), typeof(\), typeof(^), typeof(LinearAlgebra.norm), typeof(map), typeof(mapreduce), typeof(broadcast)}) = _default_scalarize_array
297+
298+
function _default_scalarize_array(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
299+
@nospecialize f
300+
args = arguments(x)
301+
if toplevel && f !== broadcast
302+
f(map(unwrap_const, args)...)
303+
else
304+
f(map(unwrap_const scalarize, args)...)
305+
end
306+
end
307+
296308
function _default_scalarize(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
297309
@nospecialize f
298310

299-
f isa BasicSymbolic{T} && return collect(x)
311+
sh = shape(x)
312+
_is_array_shape(sh) && return [x[idx] for idx in eachindex(x)]
300313

301314
args = arguments(x)
302-
if toplevel && f !== broadcast
315+
if toplevel
303316
f(map(unwrap_const, args)...)
304317
else
305318
f(map(unwrap_const scalarize, args)...)
@@ -310,17 +323,7 @@ function scalarize(x::BasicSymbolic{T}, ::Val{toplevel} = Val{false}()) where {T
310323
sh = shape(x)
311324
sh isa Unknown && return x
312325
@match x begin
313-
BSImpl.Const(; val) => begin
314-
if _is_array_shape(sh)
315-
if val isa SparseMatrixCSC
316-
return val
317-
else
318-
Const{T}.(val)
319-
end
320-
else
321-
x
322-
end
323-
end
326+
BSImpl.Const(;) => return x
324327
BSImpl.Sym(;) => _is_array_shape(sh) ? [x[idx] for idx in eachindex(x)] : x
325328
BSImpl.ArrayOp(; output_idx, expr, term, ranges, reduce) => begin
326329
term === nothing || return scalarize(term, Val{toplevel}())

0 commit comments

Comments
 (0)