11use  std:: { 
2-     env,  fs, 
2+     env, 
3+     ffi:: CString , 
4+     fs, 
35    io:: { Read ,  Write } , 
6+     mem:: MaybeUninit , 
7+     os:: raw:: c_int, 
48    path:: { Path ,  PathBuf } , 
9+     ptr:: addr_of_mut, 
510    sync:: atomic:: { AtomicBool ,  Ordering } , 
611} ; 
712
@@ -11,6 +16,7 @@ use ptx_builder::{
1116    builder:: { BuildStatus ,  Builder ,  MessageFormat ,  Profile } , 
1217    error:: { BuildErrorKind ,  Error ,  Result } , 
1318} ; 
19+ use  ptx_compiler:: sys:: size_t; 
1420
1521use  super :: utils:: skip_kernel_compilation; 
1622
@@ -56,6 +62,7 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
5662
5763    let  LinkKernelConfig  { 
5864        kernel, 
65+         kernel_hash, 
5966        args, 
6067        crate_name, 
6168        crate_path, 
@@ -192,6 +199,119 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
192199        kernel_ptx. replace_range ( type_layout_start..type_layout_end,  "" ) ; 
193200    } 
194201
202+     let  mut  compiler = MaybeUninit :: uninit ( ) ; 
203+     let  r = unsafe  { 
204+         ptx_compiler:: sys:: nvPTXCompilerCreate ( 
205+             compiler. as_mut_ptr ( ) , 
206+             kernel_ptx. len ( )  as  size_t , 
207+             kernel_ptx. as_ptr ( ) . cast ( ) , 
208+         ) 
209+     } ; 
210+     emit_call_site_warning ! ( "PTX compiler create result {}" ,  r) ; 
211+     let  compiler = unsafe  {  compiler. assume_init ( )  } ; 
212+ 
213+     let  mut  major = 0 ; 
214+     let  mut  minor = 0 ; 
215+     let  r = unsafe  { 
216+         ptx_compiler:: sys:: nvPTXCompilerGetVersion ( addr_of_mut ! ( major) ,  addr_of_mut ! ( minor) ) 
217+     } ; 
218+     emit_call_site_warning ! ( "PTX version result {}" ,  r) ; 
219+     emit_call_site_warning ! ( "PTX compiler version {}.{}" ,  major,  minor) ; 
220+ 
221+     let  kernel_name = if  specialisation. is_empty ( )  { 
222+         format ! ( "{kernel_hash}_kernel" ) 
223+     }  else  { 
224+         format ! ( 
225+             "{kernel_hash}_kernel_{:016x}" , 
226+             seahash:: hash( specialisation. as_bytes( ) ) 
227+         ) 
228+     } ; 
229+ 
230+     let  options = vec ! [ 
231+         CString :: new( "--entry" ) . unwrap( ) , 
232+         CString :: new( kernel_name) . unwrap( ) , 
233+         CString :: new( "--verbose" ) . unwrap( ) , 
234+         CString :: new( "--warn-on-double-precision-use" ) . unwrap( ) , 
235+         CString :: new( "--warn-on-local-memory-usage" ) . unwrap( ) , 
236+         CString :: new( "--warn-on-spills" ) . unwrap( ) , 
237+     ] ; 
238+     let  options_ptrs = options. iter ( ) . map ( |o| o. as_ptr ( ) ) . collect :: < Vec < _ > > ( ) ; 
239+ 
240+     let  r = unsafe  { 
241+         ptx_compiler:: sys:: nvPTXCompilerCompile ( 
242+             compiler, 
243+             options_ptrs. len ( )  as  c_int , 
244+             options_ptrs. as_ptr ( ) . cast ( ) , 
245+         ) 
246+     } ; 
247+     emit_call_site_warning ! ( "PTX compile result {}" ,  r) ; 
248+ 
249+     let  mut  info_log_size = 0 ; 
250+     let  r = unsafe  { 
251+         ptx_compiler:: sys:: nvPTXCompilerGetInfoLogSize ( compiler,  addr_of_mut ! ( info_log_size) ) 
252+     } ; 
253+     emit_call_site_warning ! ( "PTX info log size result {}" ,  r) ; 
254+     #[ allow( clippy:: cast_possible_truncation) ]  
255+     let  mut  info_log:  Vec < u8 >  = Vec :: with_capacity ( info_log_size as  usize ) ; 
256+     if  info_log_size > 0  { 
257+         let  r = unsafe  { 
258+             ptx_compiler:: sys:: nvPTXCompilerGetInfoLog ( compiler,  info_log. as_mut_ptr ( ) . cast ( ) ) 
259+         } ; 
260+         emit_call_site_warning ! ( "PTX info log content result {}" ,  r) ; 
261+         #[ allow( clippy:: cast_possible_truncation) ]  
262+         unsafe  { 
263+             info_log. set_len ( info_log_size as  usize ) ; 
264+         } 
265+     } 
266+     let  info_log = String :: from_utf8_lossy ( & info_log) ; 
267+ 
268+     let  mut  error_log_size = 0 ; 
269+     let  r = unsafe  { 
270+         ptx_compiler:: sys:: nvPTXCompilerGetErrorLogSize ( compiler,  addr_of_mut ! ( error_log_size) ) 
271+     } ; 
272+     emit_call_site_warning ! ( "PTX error log size result {}" ,  r) ; 
273+     #[ allow( clippy:: cast_possible_truncation) ]  
274+     let  mut  error_log:  Vec < u8 >  = Vec :: with_capacity ( error_log_size as  usize ) ; 
275+     if  error_log_size > 0  { 
276+         let  r = unsafe  { 
277+             ptx_compiler:: sys:: nvPTXCompilerGetErrorLog ( compiler,  error_log. as_mut_ptr ( ) . cast ( ) ) 
278+         } ; 
279+         emit_call_site_warning ! ( "PTX error log content result {}" ,  r) ; 
280+         #[ allow( clippy:: cast_possible_truncation) ]  
281+         unsafe  { 
282+             error_log. set_len ( error_log_size as  usize ) ; 
283+         } 
284+     } 
285+     let  error_log = String :: from_utf8_lossy ( & error_log) ; 
286+ 
287+     // Ensure the compiler is not dropped 
288+     let  mut  compiler = MaybeUninit :: new ( compiler) ; 
289+     let  r = unsafe  {  ptx_compiler:: sys:: nvPTXCompilerDestroy ( compiler. as_mut_ptr ( ) )  } ; 
290+     emit_call_site_warning ! ( "PTX compiler destroy result {}" ,  r) ; 
291+ 
292+     if  !info_log. is_empty ( )  { 
293+         emit_call_site_warning ! ( "PTX compiler info log:\n {}" ,  info_log) ; 
294+     } 
295+     if  !error_log. is_empty ( )  { 
296+         let  mut  max_lines = kernel_ptx. chars ( ) . filter ( |c| * c == '\n' ) . count ( )  + 1 ; 
297+         let  mut  indent = 0 ; 
298+         while  max_lines > 0  { 
299+             max_lines /= 10 ; 
300+             indent += 1 ; 
301+         } 
302+ 
303+         abort_call_site ! ( 
304+             "PTX compiler error log:\n {}\n PTX source:\n {}" , 
305+             error_log, 
306+             kernel_ptx
307+                 . lines( ) 
308+                 . enumerate( ) 
309+                 . map( |( i,  l) | format!( "{:indent$}| {l}" ,  i + 1 ) ) 
310+                 . collect:: <Vec <_>>( ) 
311+                 . join( "\n " ) 
312+         ) ; 
313+     } 
314+ 
195315    ( quote !  {  const  PTX_STR :  & ' static  str  = #kernel_ptx;  #( #type_layouts) *  } ) . into ( ) 
196316} 
197317
0 commit comments