1- use  proc_macro2:: TokenStream ; 
1+ use  proc_macro2:: { Span ,   TokenStream } ; 
22use  proc_macro2_diagnostics:: { Diagnostic ,  SpanDiagnosticExt } ; 
33use  quote:: { format_ident,  quote} ; 
4+ use  std:: cell:: OnceCell ; 
45use  syn:: spanned:: Spanned ; 
56use  syn:: { 
67    parenthesized,  parse_macro_input,  token,  Attribute ,  Data ,  DeriveInput ,  Expr ,  Field ,  Fields , 
7-     FieldsNamed ,  Ident ,  Lit ,  LitInt ,  Path ,  Type ,  Variant , 
8+     FieldsNamed ,  Ident ,  Lit ,  LitInt ,  Meta ,   Path ,  Type ,  Variant , 
89} ; 
910
1011type  Result < T >  = std:: result:: Result < T ,  Diagnostic > ; 
1112
13+ enum  FieldKind  { 
14+     Ptr ( Type ,  usize ) , 
15+     Ty ( Type ) , 
16+ } 
17+ 
18+ impl  FieldKind  { 
19+     fn  ty ( & self )  -> & Type  { 
20+         match  self  { 
21+             FieldKind :: Ptr ( ty,  _)  | FieldKind :: Ty ( ty)  => & ty, 
22+         } 
23+     } 
24+ } 
25+ 
1226struct  AbstractField  { 
13-     ty :  Type , 
14-     width :  Option < Type > , 
27+     kind :  FieldKind , 
1528    ident :  Ident , 
16-     named :   bool , 
29+     name :   Option < String > , 
1730} 
1831
1932impl  AbstractField  { 
20-     fn  from_field ( field :  Field )  -> Result < Self >  { 
33+     fn  from_field ( field :  Field ,   parent_name :   & Ident ,   pointer_width :   Option < usize > )  -> Result < Self >  { 
2134        let  Some ( ident)  = field. ident  else  { 
2235            return  Err ( field. span ( ) . error ( "field must be named" ) ) ; 
2336        } ; 
24-         let  named = field. attrs . iter ( ) . any ( |attr| attr. path ( ) . is_ident ( "named" ) ) ; 
25-         let  width = field
26-             . attrs 
27-             . iter ( ) 
28-             . find ( |attr| attr. path ( ) . is_ident ( "width" ) ) ; 
29-         if  let  Type :: Ptr ( ty)  = field. ty  { 
30-             if  let  Some ( attr)  = width { 
31-                 if  let  Expr :: Lit ( expr)  = & attr. meta . require_name_value ( ) ?. value  { 
32-                     if  let  Lit :: Str ( lit_str)  = & expr. lit  { 
33-                         return  Ok ( Self  { 
34-                             ty :  * ty. elem , 
35-                             width :  Some ( lit_str. parse ( ) ?) , 
36-                             ident, 
37-                             named, 
38-                         } ) ; 
39-                     } 
40-                 } 
37+         let  kind = match  field. ty  { 
38+             Type :: Ptr ( ty)  => { 
39+                 let  Some ( width)  = pointer_width else  { 
40+                     return  Err ( parent_name. span ( ) . error ( 
41+                         // broken up to make rustfmt happy 
42+                         "types containing pointer fields must be \  
43+                           decorated with `#[binja(pointer_width = <int>)]`", 
44+                     ) ) ; 
45+                 } ; 
46+                 FieldKind :: Ptr ( * ty. elem ,  width) 
4147            } 
42-             Err ( ident. span ( ) 
43-                 . error ( "pointer field must have explicit `#[width = \" <type>\" ]` attribute, for example: `u64`" ) ) 
44-         }  else  { 
45-             match  width { 
46-                 Some ( attr)  => Err ( attr
47-                     . span ( ) 
48-                     . error ( "`#[width]` attribute can only be applied to pointer fields" ) ) , 
49-                 None  => Ok ( Self  { 
50-                     ty :  field. ty , 
51-                     width :  None , 
52-                     ident, 
53-                     named, 
54-                 } ) , 
55-             } 
56-         } 
48+             _ => FieldKind :: Ty ( field. ty ) , 
49+         } ; 
50+         let  name = find_binja_attr ( & field. attrs ) ?
51+             . map ( |attr| match  attr. kind  { 
52+                 BinjaAttrKind :: PointerWidth ( _)  => Err ( attr. span . error ( 
53+                     // broken up to make rustfmt happy 
54+                     "invalid attribute, expected either \  
55+                      `#[binja(named)]` or `#[binja(name = \" ...\" )]`", 
56+                 ) ) , 
57+                 BinjaAttrKind :: Named ( Some ( name) )  => Ok ( name) , 
58+                 BinjaAttrKind :: Named ( None )  => { 
59+                     let  ty = kind. ty ( ) ; 
60+                     Ok ( quote ! ( #ty) . to_string ( ) ) 
61+                 } 
62+             } ) 
63+             . transpose ( ) ?; 
64+         Ok ( Self  {  kind,  ident,  name } ) 
5765    } 
5866
5967    fn  resolved_ty ( & self )  -> TokenStream  { 
60-         let  ty = & self . ty ; 
68+         let  ty = self . kind . ty ( ) ; 
6169        let  mut  resolved = quote !  {  <#ty as  :: binaryninja:: types:: AbstractType >:: resolve_type( )  } ; 
62-         if  self . named  { 
70+         if  let   Some ( name )  =  & self . name  { 
6371            resolved = quote !  { 
64-                 :: binaryninja:: types:: Type :: named_type_from_type( 
65-                     stringify!( #ty) , 
66-                     & #resolved
67-                 ) 
68-             } ; 
72+                 :: binaryninja:: types:: Type :: named_type_from_type( #name,  & #resolved) 
73+             } 
6974        } 
70-         if  let  Some ( width)  = & self . width  { 
75+         if  let  FieldKind :: Ptr ( _ ,   width)  = self . kind  { 
7176            resolved = quote !  { 
72-                 :: binaryninja:: types:: Type :: pointer_of_width( 
73-                     & #resolved, 
74-                     :: std:: mem:: size_of:: <#width>( ) , 
75-                     false , 
76-                     false , 
77-                     None 
78-                 ) 
77+                 :: binaryninja:: types:: Type :: pointer_of_width( & #resolved,  #width,  false ,  false ,  None ) 
7978            } 
8079        } 
8180        resolved
8281    } 
8382} 
8483
84+ #[ derive( Debug ) ]  
85+ struct  BinjaAttr  { 
86+     kind :  BinjaAttrKind , 
87+     span :  Span , 
88+ } 
89+ 
90+ #[ derive( Debug ) ]  
91+ enum  BinjaAttrKind  { 
92+     PointerWidth ( usize ) , 
93+     Named ( Option < String > ) , 
94+ } 
95+ 
96+ fn  find_binja_attr ( attrs :  & [ Attribute ] )  -> Result < Option < BinjaAttr > >  { 
97+     let  binja_attr = OnceCell :: new ( ) ; 
98+ 
99+     let  set_attr = |attr :  BinjaAttr | { 
100+         let  span = attr. span ; 
101+         binja_attr
102+             . set ( attr) 
103+             . map_err ( |_| span. error ( "conflicting `#[binja(...)]` attributes" ) ) 
104+     } ; 
105+ 
106+     for  attr in  attrs { 
107+         let  Some ( ident)  = attr. path ( ) . get_ident ( )  else  { 
108+             continue ; 
109+         } ; 
110+         if  ident == "binja"  { 
111+             let  meta = attr. parse_args :: < Meta > ( ) ?; 
112+             let  meta_ident = meta. path ( ) . require_ident ( ) ?; 
113+             if  meta_ident == "pointer_width"  { 
114+                 let  value = & meta. require_name_value ( ) ?. value ; 
115+                 if  let  Expr :: Lit ( expr)  = & value { 
116+                     if  let  Lit :: Int ( val)  = & expr. lit  { 
117+                         set_attr ( BinjaAttr  { 
118+                             kind :  BinjaAttrKind :: PointerWidth ( val. base10_parse ( ) ?) , 
119+                             span :  attr. span ( ) , 
120+                         } ) ?; 
121+                         continue ; 
122+                     } 
123+                 } 
124+                 return  Err ( value. span ( ) . error ( "expected integer literal" ) ) ; 
125+             }  else  if  meta_ident == "name"  { 
126+                 let  value = & meta. require_name_value ( ) ?. value ; 
127+                 if  let  Expr :: Lit ( expr)  = & value { 
128+                     if  let  Lit :: Str ( lit)  = & expr. lit  { 
129+                         set_attr ( BinjaAttr  { 
130+                             kind :  BinjaAttrKind :: Named ( Some ( lit. value ( ) ) ) , 
131+                             span :  attr. span ( ) , 
132+                         } ) ?; 
133+                         continue ; 
134+                     } 
135+                 } 
136+                 return  Err ( value. span ( ) . error ( r#"expected string literal"# ) ) ; 
137+             }  else  if  meta_ident == "named"  { 
138+                 meta. require_path_only ( ) ?; 
139+                 set_attr ( BinjaAttr  { 
140+                     kind :  BinjaAttrKind :: Named ( None ) , 
141+                     span :  attr. span ( ) , 
142+                 } ) ?; 
143+             }  else  { 
144+                 return  Err ( meta
145+                     . span ( ) 
146+                     . error ( format ! ( "unrecognized property `{meta_ident}`" ) ) ) ; 
147+             } 
148+         } 
149+     } 
150+     Ok ( binja_attr. into_inner ( ) ) 
151+ } 
152+ 
85153struct  Repr  { 
86154    c :  bool , 
87155    packed :  Option < Option < LitInt > > , 
@@ -90,7 +158,7 @@ struct Repr {
90158} 
91159
92160impl  Repr  { 
93-     fn  from_attrs ( attrs :  Vec < Attribute > )  -> Result < Self >  { 
161+     fn  from_attrs ( attrs :  & [ Attribute ] )  -> Result < Self >  { 
94162        let  mut  c = false ; 
95163        let  mut  packed = None ; 
96164        let  mut  align = None ; 
@@ -99,15 +167,7 @@ impl Repr {
99167            let  Some ( ident)  = attr. path ( ) . get_ident ( )  else  { 
100168                continue ; 
101169            } ; 
102-             if  ident == "named"  { 
103-                 return  Err ( attr
104-                     . span ( ) 
105-                     . error ( "`#[named]` attribute can only be applied to fields" ) ) ; 
106-             }  else  if  ident == "width"  { 
107-                 return  Err ( attr
108-                     . span ( ) 
109-                     . error ( "`#[width]` attribute can only be applied to pointer fields" ) ) ; 
110-             }  else  if  ident == "repr"  { 
170+             if  ident == "repr"  { 
111171                attr. parse_nested_meta ( |meta| { 
112172                    if  let  Some ( ident)  = meta. path . get_ident ( )  { 
113173                        if  ident == "C"  { 
@@ -153,7 +213,7 @@ fn ident_in_list<const N: usize>(ident: &Ident, list: [&'static str; N]) -> bool
153213    list. iter ( ) . any ( |id| ident == id) 
154214} 
155215
156- #[ proc_macro_derive( AbstractType ,  attributes( named ,  width ) ) ]  
216+ #[ proc_macro_derive( AbstractType ,  attributes( binja ) ) ]  
157217pub  fn  abstract_type_derive ( input :  proc_macro:: TokenStream )  -> proc_macro:: TokenStream  { 
158218    let  input = parse_macro_input ! ( input as  DeriveInput ) ; 
159219    match  impl_abstract_type ( input)  { 
@@ -163,7 +223,18 @@ pub fn abstract_type_derive(input: proc_macro::TokenStream) -> proc_macro::Token
163223} 
164224
165225fn  impl_abstract_type ( ast :  DeriveInput )  -> Result < TokenStream >  { 
166-     let  repr = Repr :: from_attrs ( ast. attrs ) ?; 
226+     let  repr = Repr :: from_attrs ( & ast. attrs ) ?; 
227+     let  width = find_binja_attr ( & ast. attrs ) ?
228+         . map ( |attr| match  attr. kind  { 
229+             BinjaAttrKind :: PointerWidth ( width)  => Ok ( width) , 
230+             BinjaAttrKind :: Named ( Some ( _) )  => Err ( attr
231+                 . span 
232+                 . error ( r#"`#[binja(name = "...")] is only supported on fields"# ) ) , 
233+             BinjaAttrKind :: Named ( None )  => Err ( attr
234+                 . span 
235+                 . error ( "`#[binja(named)]` is only supported on fields" ) ) , 
236+         } ) 
237+         . transpose ( ) ?; 
167238
168239    if  !ast. generics . params . is_empty ( )  { 
169240        return  Err ( ast. generics . span ( ) . error ( "type must not be generic" ) ) ; 
@@ -173,7 +244,7 @@ fn impl_abstract_type(ast: DeriveInput) -> Result<TokenStream> {
173244    match  ast. data  { 
174245        Data :: Struct ( s)  => match  s. fields  { 
175246            Fields :: Named ( fields)  => { 
176-                 impl_abstract_structure_type ( ident,  fields,  repr,  StructureKind :: Struct ) 
247+                 impl_abstract_structure_type ( ident,  fields,  repr,  width ,   StructureKind :: Struct ) 
177248            } 
178249            Fields :: Unnamed ( _)  => Err ( s
179250                . fields 
@@ -184,7 +255,9 @@ fn impl_abstract_type(ast: DeriveInput) -> Result<TokenStream> {
184255                . error ( "unit structs are unsupported; provide at least one named field" ) ) , 
185256        } , 
186257        Data :: Enum ( e)  => impl_abstract_enum_type ( ident,  e. variants ,  repr) , 
187-         Data :: Union ( u)  => impl_abstract_structure_type ( ident,  u. fields ,  repr,  StructureKind :: Union ) , 
258+         Data :: Union ( u)  => { 
259+             impl_abstract_structure_type ( ident,  u. fields ,  repr,  width,  StructureKind :: Union ) 
260+         } 
188261    } 
189262} 
190263
@@ -197,6 +270,7 @@ fn impl_abstract_structure_type(
197270    name :  Ident , 
198271    fields :  FieldsNamed , 
199272    repr :  Repr , 
273+     pointer_width :  Option < usize > , 
200274    kind :  StructureKind , 
201275)  -> Result < TokenStream >  { 
202276    if  !repr. c  { 
@@ -210,21 +284,25 @@ fn impl_abstract_structure_type(
210284    let  abstract_fields = fields
211285        . named 
212286        . into_iter ( ) 
213-         . map ( AbstractField :: from_field) 
287+         . map ( |field|  AbstractField :: from_field ( field ,   & name ,  pointer_width ) ) 
214288        . collect :: < Result < Vec < _ > > > ( ) ?; 
215289    let  layout_name = format_ident ! ( "__{name}_layout" ) ; 
216290    let  field_wrapper = format_ident ! ( "__{name}_field_wrapper" ) ; 
217291    let  layout_fields = abstract_fields
218292        . iter ( ) 
219293        . map ( |field| { 
220294            let  ident = & field. ident ; 
221-             let  layout_ty = field. width . as_ref ( ) . unwrap_or ( & field. ty ) ; 
222-             quote !  { 
223-                 #ident:  #field_wrapper<
224-                     [ u8 ;  <#layout_ty as  :: binaryninja:: types:: AbstractType >:: SIZE ] , 
225-                     {  <#layout_ty as  :: binaryninja:: types:: AbstractType >:: ALIGN  } , 
226-                 >
227-             } 
295+             let  ( size,  align)  = match  & field. kind  { 
296+                 FieldKind :: Ptr ( _,  width)  => { 
297+                     let  align = width. next_power_of_two ( ) ; 
298+                     ( quote !  {  #width } ,  quote !  {  #align } ) 
299+                 } 
300+                 FieldKind :: Ty ( ty)  => ( 
301+                     quote !  {  <#ty as  :: binaryninja:: types:: AbstractType >:: SIZE  } , 
302+                     quote !  {  {  <#ty as  :: binaryninja:: types:: AbstractType >:: ALIGN  }  } , 
303+                 ) , 
304+             } ; 
305+             quote !  {  #ident:  #field_wrapper<[ u8 ;  #size] ,  #align> } 
228306        } ) 
229307        . collect :: < Vec < _ > > ( ) ; 
230308    let  args = abstract_fields
0 commit comments