@@ -497,33 +497,53 @@ end
497497import  Base. Broadcast:  BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498498using  Base. Broadcast:  combine_styles
499499
500- struct  StructArrayStyle{S, N} <:  AbstractArrayStyle{N}  end 
500+ @static  if  fieldcount (Base. Broadcast. Broadcasted) ==  4 
501+     struct  StructArrayStyle{N, S} <:  AbstractArrayStyle{N} 
502+         style:: S 
503+         StructArrayStyle {N} (style) where  {N} =  new {N, typeof(style)} (style)
504+     end 
505+     StructArrayStyle {N} (style:: StructArrayStyle ) where  {N} =  StructArrayStyle {N} (style. style)
506+     parent_style (s:: BroadcastStyle ) =  s
507+     parent_style (s:: StructArrayStyle ) =  s. style
508+     style (bc:: Broadcasted ) =  bc. style
509+     const  broadcasted =  Broadcasted
510+ else 
511+     struct  StructArrayStyle{N, S} <:  AbstractArrayStyle{N} 
512+         StructArrayStyle {N} (style) where  {N} =  new {N, typeof(style)} ()
513+     end 
514+     StructArrayStyle {N} (style:: StructArrayStyle{M, S} ) where  {N, M, S} =  StructArrayStyle {N} (S ())
515+     parent_style (s:: BroadcastStyle ) =  s
516+     parent_style (:: StructArrayStyle{N, S} ) where  {N} =  S ()
517+     style (:: Broadcasted{Style} ) where  {Style} =  Style ()
518+     broadcasted (s, f, args, axes) =  Broadcasted {typeof(s)} (f, args, axes)
519+ end 
520+ StructArrayStyle {N, S} () where  {N, S} =  StructArrayStyle {N} (S ())
521+ parent_style (bc:: Broadcasted ) =  parent_style (style (bc))
522+ ofstyle (s, bc:: Broadcasted ) =  broadcasted (s, bc. f, bc. args, bc. axes)
501523
502524#  Here we define the dimension tracking behavior of StructArrayStyle
503- function  StructArrayStyle {S, M } (:: Val{N} ) where  {S, M, N}
525+ function  StructArrayStyle {M, S } (:: Val{N} ) where  {S, M, N}
504526    T =  S <:  AbstractArrayStyle{M}  ?  typeof (S (Val {N} ())) :  S
505-     return  StructArrayStyle {T, N } ()
527+     return  StructArrayStyle {N, T } ()
506528end 
507529
508530#  StructArrayStyle is a wrapped style.
509531#  Here we try our best to resolve style conflict.
510- function  BroadcastStyle (b:: AbstractArrayStyle{M} , a:: StructArrayStyle{S, N } ) where  {S, N, M}
532+ function  BroadcastStyle (b:: AbstractArrayStyle{M} , a:: StructArrayStyle{N, S } ) where  {S, N, M}
511533    N′ =  M ===  Any ||  N ===  Any ?  Any :  max (M, N)
512-     S′ =  Broadcast. result_style (S (), b)
513-     return  S′ isa  StructArrayStyle ?  typeof (S′)(Val {N′} ()) :  StructArrayStyle {typeof(S′), N′} ()
534+     return  StructArrayStyle {N′} (Broadcast. result_style (parent_style (a), b))
514535end 
515536BroadcastStyle (:: StructArrayStyle , :: DefaultArrayStyle ) =  Unknown ()
516537
517538@inline  combine_style_types (:: Type{A} , args... ) where  {A<: AbstractArray } = 
518539    combine_style_types (BroadcastStyle (A), args... )
519540@inline  combine_style_types (s:: BroadcastStyle , :: Type{A} , args... ) where  {A<: AbstractArray } = 
520541    combine_style_types (Broadcast. result_style (s, BroadcastStyle (A)), args... )
521- combine_style_types (:: StructArrayStyle{S} ) where  {S} =  S () #  avoid nested StructArrayStyle
522542combine_style_types (s:: BroadcastStyle ) =  s
523543
524544Base. @pure  cst (:: Type{SA} ) where  {SA} =  combine_style_types (array_types (SA). parameters... )
525545
526- BroadcastStyle (:: Type{SA} ) where  {SA<: StructArray } =  StructArrayStyle {typeof(cst(SA)),  ndims(SA)} ()
546+ BroadcastStyle (:: Type{SA} ) where  {SA<: StructArray } =  StructArrayStyle {ndims(SA)} (cst (SA) )
527547
528548""" 
529549    always_struct_broadcast(style::BroadcastStyle) 
@@ -551,8 +571,8 @@ See also [`always_struct_broadcast`](@ref).
551571""" 
552572try_struct_copy (bc:: Broadcasted ) =  copy (bc)
553573
554- function  Base. copy (bc:: Broadcasted{StructArrayStyle{S, N}}  )  where  {S, N} 
555-     if  always_struct_broadcast (S ( ))
574+ function  Base. copy (bc:: Broadcasted{<: StructArrayStyle}  ) 
575+     if  always_struct_broadcast (parent_style (bc ))
556576        return  invoke (copy, Tuple{Broadcasted}, bc)
557577    else 
558578        return  try_struct_copy (replace_structarray (bc))
@@ -567,55 +587,49 @@ an equivalent one without it. This is not a must if the root `BroadcastStyle`
567587supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,  
568588e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`. 
569589""" 
570- function  replace_structarray (bc:: Broadcasted{Style}  )  where  {Style} 
590+ function  replace_structarray (bc:: Broadcasted ) 
571591    args =  replace_structarray_args (bc. args)
572-     Style′  =  parent_style (Style () )
573-     return  Broadcasted {Style′} ( bc. f, args, bc. axes)
592+     style  =  parent_style (bc )
593+     return  broadcasted (style,  bc. f, args, bc. axes)
574594end 
575595function  replace_structarray (A:: StructArray )
576596    f =  Instantiator (eltype (A))
577597    args =  Tuple (components (A))
578-     Style  =  typeof ( combine_styles (args... ) )
579-     return  Broadcasted {Style} ( f, args, axes (A))
598+     style  =  combine_styles (args... )
599+     return  broadcasted (style,  f, args, axes (A))
580600end 
581601replace_structarray (@nospecialize (A)) =  A
582602
583603replace_structarray_args (args:: Tuple ) =  (replace_structarray (args[1 ]), replace_structarray_args (tail (args))... )
584604replace_structarray_args (:: Tuple{} ) =  ()
585605
586- parent_style (@nospecialize (x)) =  typeof (x)
587- parent_style (:: StructArrayStyle{S, N} ) where  {S, N} =  S
588- parent_style (:: StructArrayStyle{S, N} ) where  {N, S<: AbstractArrayStyle{N} } =  S
589- parent_style (:: StructArrayStyle{S, N} ) where  {S<: AbstractArrayStyle{Any} , N} =  S
590- parent_style (:: StructArrayStyle{S, N} ) where  {S<: AbstractArrayStyle , N} =  typeof (S (Val (N)))
591- 
592606#  `instantiate` and `_axes` might be overloaded for static axes.
593- function  Broadcast. instantiate (bc:: Broadcasted{Style} ) where  {Style <:  StructArrayStyle }
594-     Style′ =  parent_style (Style ())
595-     bc′ =  Broadcast. instantiate (convert (Broadcasted{Style′}, bc))
596-     return  convert (Broadcasted{Style}, bc′)
607+ function  Broadcast. instantiate (bc:: Broadcasted{<:StructArrayStyle} )
608+     bc′ =  Broadcast. instantiate (ofstyle (parent_style (bc), bc))
609+     return  ofstyle (style (bc), bc′)
597610end 
598611
599- function  Broadcast. _axes (bc:: Broadcasted{Style} , :: Nothing ) where  {Style <:  StructArrayStyle }
600-     Style′ =  parent_style (Style ())
601-     return  Broadcast. _axes (convert (Broadcasted{Style′}, bc), nothing )
612+ function  Broadcast. _axes (bc:: Broadcasted{<:StructArrayStyle} , :: Nothing )
613+     return  Broadcast. _axes (ofstyle (parent_style (bc), bc), nothing )
602614end 
603615
604616#  Here we use `similar` defined for `S` to build the dest Array.
605- function  Base. similar (bc:: Broadcasted{StructArrayStyle{S, N}}  , :: Type{ElType} ) where  {S, N,  ElType}
606-     bc′ =  convert (Broadcasted{S} , bc)
617+ function  Base. similar (bc:: Broadcasted{<: StructArrayStyle}  , :: Type{ElType} ) where  {ElType}
618+     bc′ =  ofstyle ( parent_style (bc) , bc)
607619    return  isnonemptystructtype (ElType) ?  buildfromschema (T ->  similar (bc′, T), ElType) :  similar (bc′, ElType)
608620end 
609621
610622#  Unwrapper to recover the behaviour defined by parent style.
611- @inline  function  Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{StructArrayStyle{S, N}} ) where  {S, N}
612-     bc′ =  always_struct_broadcast (S ()) ?  convert (Broadcasted{S}, bc) :  replace_structarray (bc)
623+ @inline  function  Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:StructArrayStyle} )
624+     ps =  parent_style (bc)
625+     bc′ =  always_struct_broadcast (ps) ?  ofstyle (ps, bc) :  replace_structarray (bc)
613626    return  copyto! (dest, bc′)
614627end 
615628
616- @inline  function  Broadcast. materialize! (:: StructArrayStyle{S} , dest, bc:: Broadcasted ) where  {S}
617-     bc′ =  always_struct_broadcast (S ()) ?  bc :  replace_structarray (bc)
618-     return  Broadcast. materialize! (S (), dest, bc′)
629+ @inline  function  Broadcast. materialize! (s:: StructArrayStyle , dest, bc:: Broadcasted )
630+     ps =  parent_style (s)
631+     bc′ =  always_struct_broadcast (ps) ?  bc :  replace_structarray (bc)
632+     return  Broadcast. materialize! (ps, dest, bc′)
619633end 
620634
621635#  for aliasing analysis during broadcast
0 commit comments