Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ using Base.Cartesian
using Base: linearindices, tail, OneTo, to_shape,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache,
nullable_returntype, null_safe_op, hasvalue, isoperator
import Base: broadcast, broadcast!
export broadcast_getindex, broadcast_setindex!, dotview, @__dot__
import Base: broadcast, broadcast!, @_pure_meta
export broadcast_getindex, broadcast_setindex!, dotview, @__dot__, isfusing

const ScalarType = Union{Type{Any}, Type{Nullable}}

# containers can override isfusing to disable broadcast fusion for specific container types
isfusing(args...) = (@_pure_meta; true)

## Broadcasting utilities ##
# fallbacks for some special cases
@inline broadcast(f, x::Number...) = f(x...)
Expand Down
29 changes: 19 additions & 10 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,9 @@
oldarg))
fargs args)))
(let ,fbody ,@(reverse (fuse-lets fargs args '()))))))
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
; convert e == (. f (tuple args)) to (fuse f args),
; recursively for nested calls, fusing if fuse? is true.
(define (dot-to-fuse e fuse?)
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
(define (sk args kwargs pargs)
Expand All @@ -1742,8 +1744,8 @@
(let* ((kws.args (split-kwargs args))
(kws (car kws.args))
(args (cdr kws.args)) ; fusing occurs on positional args only
(args_ (map dot-to-fuse args)))
(if (anyfuse? args_)
(args_ (map (lambda (e) (dot-to-fuse e fuse?)) args)))
(if (and fuse? (anyfuse? args_))
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
`(fuse ,(to-lambda f args kws) ,args_))))
(if (and (pair? e) (eq? (car e) '|.|))
Expand Down Expand Up @@ -1801,15 +1803,22 @@
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
(cf (cdadr f) args '() '() '() '() '()))
e)) ; (not (fuse? e))
(let ((e (compress-fuse (dot-to-fuse rhs))) ; an expression '(fuse func args) if expr is a dot call
(lhs-view (ref-to-view lhs))) ; x[...] expressions on lhs turn in to view(x, ...) to update x in-place
; convert fuse expressions to ordinary broadcast calls, or broadcast! if lhs != null:
(define (to-broadcast lhs e)
(if (fuse? e)
(let ((bargs (map (lambda (e) (to-broadcast '() e)) (caddr e))))
(if (null? lhs)
`(call (top broadcast) ,(from-lambda (cadr e)) ,@bargs)
`(call (top broadcast!) ,(from-lambda (cadr e)) ,(ref-to-view lhs) ,@bargs)))
(if (null? lhs)
(expand-forms `(call (top broadcast) ,(from-lambda (cadr e)) ,@(caddr e)))
(expand-forms `(call (top broadcast!) ,(from-lambda (cadr e)) ,lhs-view ,@(caddr e))))
(if (null? lhs)
(expand-forms e)
(expand-forms `(call (top broadcast!) (top identity) ,lhs-view ,e))))))
e
`(call (top broadcast!) (top identity) ,(ref-to-view lhs) ,e))))
(let ((e (compress-fuse (dot-to-fuse rhs #t))) ; an expression '(fuse func args) if expr is a dot call
(e0 (dot-to-fuse rhs #f))) ; e without fusion
(if (fuse? e)
(expand-forms `(if (call (top isfusing) ,@(caddr e))
,(to-broadcast lhs e) ,(to-broadcast lhs e0)))
(expand-forms (to-broadcast lhs e)))))

(define (expand-where body var)
(let* ((bounds (analyze-typevar var))
Expand Down
2 changes: 1 addition & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ end

# make sure scalars are inlined, which causes f.(x,scalar) to lower to a "thunk"
import Base.Meta: isexpr
@test isexpr(expand(:(f.(x,y))), :call)
@test isexpr(expand(:(f.(x,y))), :body)
@test isexpr(expand(:(f.(x,1))), :thunk)
@test isexpr(expand(:(f.(x,1.0))), :thunk)
@test isexpr(expand(:(f.(x,$π))), :thunk)
Expand Down