11use std:: {
2- env, fs,
2+ env,
3+ ffi:: CString ,
4+ fs,
35 io:: { Read , Write } ,
6+ mem:: MaybeUninit ,
7+ os:: raw:: c_int,
48 path:: { Path , PathBuf } ,
9+ ptr:: addr_of_mut,
510 sync:: atomic:: { AtomicBool , Ordering } ,
611} ;
712
@@ -11,6 +16,7 @@ use ptx_builder::{
1116 builder:: { BuildStatus , Builder , MessageFormat , Profile } ,
1217 error:: { BuildErrorKind , Error , Result } ,
1318} ;
19+ use ptx_compiler:: sys:: size_t;
1420
1521use super :: utils:: skip_kernel_compilation;
1622
@@ -56,6 +62,7 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
5662
5763 let LinkKernelConfig {
5864 kernel,
65+ kernel_hash,
5966 args,
6067 crate_name,
6168 crate_path,
@@ -192,6 +199,119 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
192199 kernel_ptx. replace_range ( type_layout_start..type_layout_end, "" ) ;
193200 }
194201
202+ let mut compiler = MaybeUninit :: uninit ( ) ;
203+ let r = unsafe {
204+ ptx_compiler:: sys:: nvPTXCompilerCreate (
205+ compiler. as_mut_ptr ( ) ,
206+ kernel_ptx. len ( ) as size_t ,
207+ kernel_ptx. as_ptr ( ) . cast ( ) ,
208+ )
209+ } ;
210+ emit_call_site_warning ! ( "PTX compiler create result {}" , r) ;
211+ let compiler = unsafe { compiler. assume_init ( ) } ;
212+
213+ let mut major = 0 ;
214+ let mut minor = 0 ;
215+ let r = unsafe {
216+ ptx_compiler:: sys:: nvPTXCompilerGetVersion ( addr_of_mut ! ( major) , addr_of_mut ! ( minor) )
217+ } ;
218+ emit_call_site_warning ! ( "PTX version result {}" , r) ;
219+ emit_call_site_warning ! ( "PTX compiler version {}.{}" , major, minor) ;
220+
221+ let kernel_name = if specialisation. is_empty ( ) {
222+ format ! ( "{kernel_hash}_kernel" )
223+ } else {
224+ format ! (
225+ "{kernel_hash}_kernel_{:016x}" ,
226+ seahash:: hash( specialisation. as_bytes( ) )
227+ )
228+ } ;
229+
230+ let options = vec ! [
231+ CString :: new( "--entry" ) . unwrap( ) ,
232+ CString :: new( kernel_name) . unwrap( ) ,
233+ CString :: new( "--verbose" ) . unwrap( ) ,
234+ CString :: new( "--warn-on-double-precision-use" ) . unwrap( ) ,
235+ CString :: new( "--warn-on-local-memory-usage" ) . unwrap( ) ,
236+ CString :: new( "--warn-on-spills" ) . unwrap( ) ,
237+ ] ;
238+ let options_ptrs = options. iter ( ) . map ( |o| o. as_ptr ( ) ) . collect :: < Vec < _ > > ( ) ;
239+
240+ let r = unsafe {
241+ ptx_compiler:: sys:: nvPTXCompilerCompile (
242+ compiler,
243+ options_ptrs. len ( ) as c_int ,
244+ options_ptrs. as_ptr ( ) . cast ( ) ,
245+ )
246+ } ;
247+ emit_call_site_warning ! ( "PTX compile result {}" , r) ;
248+
249+ let mut info_log_size = 0 ;
250+ let r = unsafe {
251+ ptx_compiler:: sys:: nvPTXCompilerGetInfoLogSize ( compiler, addr_of_mut ! ( info_log_size) )
252+ } ;
253+ emit_call_site_warning ! ( "PTX info log size result {}" , r) ;
254+ #[ allow( clippy:: cast_possible_truncation) ]
255+ let mut info_log: Vec < u8 > = Vec :: with_capacity ( info_log_size as usize ) ;
256+ if info_log_size > 0 {
257+ let r = unsafe {
258+ ptx_compiler:: sys:: nvPTXCompilerGetInfoLog ( compiler, info_log. as_mut_ptr ( ) . cast ( ) )
259+ } ;
260+ emit_call_site_warning ! ( "PTX info log content result {}" , r) ;
261+ #[ allow( clippy:: cast_possible_truncation) ]
262+ unsafe {
263+ info_log. set_len ( info_log_size as usize ) ;
264+ }
265+ }
266+ let info_log = String :: from_utf8_lossy ( & info_log) ;
267+
268+ let mut error_log_size = 0 ;
269+ let r = unsafe {
270+ ptx_compiler:: sys:: nvPTXCompilerGetErrorLogSize ( compiler, addr_of_mut ! ( error_log_size) )
271+ } ;
272+ emit_call_site_warning ! ( "PTX error log size result {}" , r) ;
273+ #[ allow( clippy:: cast_possible_truncation) ]
274+ let mut error_log: Vec < u8 > = Vec :: with_capacity ( error_log_size as usize ) ;
275+ if error_log_size > 0 {
276+ let r = unsafe {
277+ ptx_compiler:: sys:: nvPTXCompilerGetErrorLog ( compiler, error_log. as_mut_ptr ( ) . cast ( ) )
278+ } ;
279+ emit_call_site_warning ! ( "PTX error log content result {}" , r) ;
280+ #[ allow( clippy:: cast_possible_truncation) ]
281+ unsafe {
282+ error_log. set_len ( error_log_size as usize ) ;
283+ }
284+ }
285+ let error_log = String :: from_utf8_lossy ( & error_log) ;
286+
287+ // Ensure the compiler is not dropped
288+ let mut compiler = MaybeUninit :: new ( compiler) ;
289+ let r = unsafe { ptx_compiler:: sys:: nvPTXCompilerDestroy ( compiler. as_mut_ptr ( ) ) } ;
290+ emit_call_site_warning ! ( "PTX compiler destroy result {}" , r) ;
291+
292+ if !info_log. is_empty ( ) {
293+ emit_call_site_warning ! ( "PTX compiler info log:\n {}" , info_log) ;
294+ }
295+ if !error_log. is_empty ( ) {
296+ let mut max_lines = kernel_ptx. chars ( ) . filter ( |c| * c == '\n' ) . count ( ) + 1 ;
297+ let mut indent = 0 ;
298+ while max_lines > 0 {
299+ max_lines /= 10 ;
300+ indent += 1 ;
301+ }
302+
303+ abort_call_site ! (
304+ "PTX compiler error log:\n {}\n PTX source:\n {}" ,
305+ error_log,
306+ kernel_ptx
307+ . lines( )
308+ . enumerate( )
309+ . map( |( i, l) | format!( "{:indent$}| {l}" , i + 1 ) )
310+ . collect:: <Vec <_>>( )
311+ . join( "\n " )
312+ ) ;
313+ }
314+
195315 ( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
196316}
197317
0 commit comments