@@ -293,13 +293,26 @@ head. The returned function is passed `f`, the expression with `f` as the head,
293
293
"""
294
294
scalarization_function (@nospecialize (_)) = _default_scalarize
295
295
296
+ scalarization_function (:: Union{typeof(+), typeof(-), typeof(*), typeof(/), typeof(\), typeof(^), typeof(LinearAlgebra.norm), typeof(map), typeof(mapreduce), typeof(broadcast)} ) = _default_scalarize_array
297
+
298
+ function _default_scalarize_array (f, x:: BasicSymbolic{T} , :: Val{toplevel} ) where {T, toplevel}
299
+ @nospecialize f
300
+ args = arguments (x)
301
+ if toplevel && f != = broadcast
302
+ f (map (unwrap_const, args)... )
303
+ else
304
+ f (map (unwrap_const ∘ scalarize, args)... )
305
+ end
306
+ end
307
+
296
308
function _default_scalarize (f, x:: BasicSymbolic{T} , :: Val{toplevel} ) where {T, toplevel}
297
309
@nospecialize f
298
310
299
- f isa BasicSymbolic{T} && return collect (x)
311
+ sh = shape (x)
312
+ _is_array_shape (sh) && return [x[idx] for idx in eachindex (x)]
300
313
301
314
args = arguments (x)
302
- if toplevel && f != = broadcast
315
+ if toplevel
303
316
f (map (unwrap_const, args)... )
304
317
else
305
318
f (map (unwrap_const ∘ scalarize, args)... )
@@ -310,17 +323,7 @@ function scalarize(x::BasicSymbolic{T}, ::Val{toplevel} = Val{false}()) where {T
310
323
sh = shape (x)
311
324
sh isa Unknown && return x
312
325
@match x begin
313
- BSImpl. Const (; val) => begin
314
- if _is_array_shape (sh)
315
- if val isa SparseMatrixCSC
316
- return val
317
- else
318
- Const {T} .(val)
319
- end
320
- else
321
- x
322
- end
323
- end
326
+ BSImpl. Const (;) => return x
324
327
BSImpl. Sym (;) => _is_array_shape (sh) ? [x[idx] for idx in eachindex (x)] : x
325
328
BSImpl. ArrayOp (; output_idx, expr, term, ranges, reduce) => begin
326
329
term === nothing || return scalarize (term, Val {toplevel} ())
0 commit comments