@@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp
7878/// allowing them to be conveniently passed to user-defined or wrapper
7979/// functions. The struct is declared in [`Writer::write_type_defs`].
8080pub ( crate ) const EXTERNAL_TEXTURE_WRAPPER_STRUCT : & str = "NagaExternalTextureWrapper" ;
81+ pub ( crate ) const COOPERATIVE_MULTIPLY_ADD_FUNCTION : & str = "NagaCooperativeMultiplyAdd" ;
8182
8283/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
8384///
@@ -483,6 +484,12 @@ enum WrappedFunction {
483484 ImageQuerySize {
484485 class : crate :: ImageClass ,
485486 } ,
487+ CooperativeMultiplyAdd {
488+ columns : crate :: CooperativeSize ,
489+ rows : crate :: CooperativeSize ,
490+ intermediate : crate :: CooperativeSize ,
491+ scalar : crate :: Scalar ,
492+ } ,
486493}
487494
488495pub struct Writer < W > {
@@ -543,14 +550,6 @@ impl crate::Scalar {
543550 }
544551}
545552
546- impl crate :: CooperativeScalar {
547- const fn to_msl_name ( self ) -> & ' static str {
548- match self {
549- Self :: F32 => "float" ,
550- }
551- }
552- }
553-
554553const fn separate ( need_separator : bool ) -> & ' static str {
555554 if need_separator {
556555 ","
@@ -2842,12 +2841,14 @@ impl<W: Write> Writer<W> {
28422841 }
28432842 write ! ( self . out, "}}" ) ?;
28442843 }
2845- crate :: Expression :: MulAdd { a, b, c } => {
2846- self . put_expression ( a, context, false ) ?;
2847- write ! ( self . out, " * " ) ?;
2848- self . put_expression ( b, context, false ) ?;
2849- write ! ( self . out, " + " ) ?;
2850- self . put_expression ( c, context, false ) ?;
2844+ crate :: Expression :: CooperativeMultiplyAdd { a, b, c } => {
2845+ write ! ( self . out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(" ) ?;
2846+ self . put_expression ( a, context, true ) ?;
2847+ write ! ( self . out, ", " ) ?;
2848+ self . put_expression ( b, context, true ) ?;
2849+ write ! ( self . out, ", " ) ?;
2850+ self . put_expression ( c, context, true ) ?;
2851+ write ! ( self . out, ")" ) ?;
28512852 }
28522853 }
28532854 Ok ( ( ) )
@@ -4230,6 +4231,49 @@ impl<W: Write> Writer<W> {
42304231 }
42314232 writeln ! ( self . out, ");" ) ?;
42324233 }
4234+ crate :: Statement :: CooperativeLoadStore {
4235+ store,
4236+ target,
4237+ pointer,
4238+ stride,
4239+ row_major,
4240+ } => {
4241+ let op_str = if store { "store" } else { "load" } ;
4242+ write ! ( self . out, "{level}{NAMESPACE}::simdgroup_{op_str}(" ) ?;
4243+ self . put_expression ( target, & context. expression , true ) ?;
4244+ write ! ( self . out, ", " ) ?;
4245+ self . put_expression ( pointer, & context. expression , true ) ?;
4246+ if stride. is_some ( ) || row_major {
4247+ write ! ( self . out, ", " ) ?;
4248+ match stride {
4249+ Some ( expression) => {
4250+ self . put_expression ( expression, & context. expression , true ) ?;
4251+ }
4252+ None => {
4253+ let default_stride = match * context. expression . resolve_type ( target)
4254+ {
4255+ crate :: TypeInner :: CooperativeMatrix {
4256+ columns, rows, ..
4257+ } => {
4258+ if row_major {
4259+ columns as u32
4260+ } else {
4261+ rows as u32
4262+ }
4263+ }
4264+ _ => 0 ,
4265+ } ;
4266+ write ! ( self . out, "{default_stride}" ) ?;
4267+ }
4268+ }
4269+ }
4270+ if row_major {
4271+ let matrix_origin = "0" ;
4272+ let transpose = true ;
4273+ write ! ( self . out, ", {matrix_origin}, {transpose}" ) ?;
4274+ }
4275+ writeln ! ( self . out, ");" ) ?;
4276+ }
42334277 }
42344278 }
42354279
@@ -6286,6 +6330,62 @@ template <typename A>
62866330 Ok ( ( ) )
62876331 }
62886332
6333+ fn write_wrapped_cooperative_multiply_add (
6334+ & mut self ,
6335+ module : & crate :: Module ,
6336+ func_ctx : & back:: FunctionCtx ,
6337+ a : Handle < crate :: Expression > ,
6338+ b : Handle < crate :: Expression > ,
6339+ ) -> BackendResult {
6340+ let ( a_c, a_r, scalar) = match * func_ctx. resolve_type ( a, & module. types ) {
6341+ crate :: TypeInner :: CooperativeMatrix {
6342+ columns,
6343+ rows,
6344+ scalar,
6345+ ..
6346+ } => ( columns, rows, scalar) ,
6347+ _ => unreachable ! ( ) ,
6348+ } ;
6349+ let ( b_c, b_r) = match * func_ctx. resolve_type ( b, & module. types ) {
6350+ crate :: TypeInner :: CooperativeMatrix { columns, rows, .. } => ( columns, rows) ,
6351+ _ => unreachable ! ( ) ,
6352+ } ;
6353+ let wrapped = WrappedFunction :: CooperativeMultiplyAdd {
6354+ columns : b_c,
6355+ rows : a_r,
6356+ intermediate : a_c,
6357+ scalar,
6358+ } ;
6359+ if !self . wrapped_functions . insert ( wrapped) {
6360+ return Ok ( ( ) ) ;
6361+ }
6362+ let scalar_name = match scalar. width {
6363+ 2 => "half" ,
6364+ 4 => "float" ,
6365+ 8 => "double" ,
6366+ _ => unreachable ! ( ) ,
6367+ } ;
6368+ writeln ! (
6369+ self . out,
6370+ "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{" ,
6371+ b_c as u32 , a_r as u32 , a_c as u32 , a_r as u32 , b_c as u32 , b_r as u32 , b_c as u32 , a_r as u32 ,
6372+ ) ?;
6373+ let l1 = back:: Level ( 1 ) ;
6374+ writeln ! (
6375+ self . out,
6376+ "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;" ,
6377+ b_c as u32 , a_r as u32
6378+ ) ?;
6379+ writeln ! (
6380+ self . out,
6381+ "{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);"
6382+ ) ?;
6383+ writeln ! ( self . out, "{l1}return d;" ) ?;
6384+ writeln ! ( self . out, "}}" ) ?;
6385+ writeln ! ( self . out) ?;
6386+ Ok ( ( ) )
6387+ }
6388+
62896389 pub ( super ) fn write_wrapped_functions (
62906390 & mut self ,
62916391 module : & crate :: Module ,
@@ -6360,6 +6460,9 @@ template <typename A>
63606460 crate :: Expression :: ImageQuery { image, query } => {
63616461 self . write_wrapped_image_query ( module, func_ctx, image, query) ?;
63626462 }
6463+ crate :: Expression :: CooperativeMultiplyAdd { a, b, c : _ } => {
6464+ self . write_wrapped_cooperative_multiply_add ( module, func_ctx, a, b) ?;
6465+ }
63636466 _ => { }
63646467 }
63656468 }
0 commit comments